cos-gpu-installer: Switch to using cos-nvidia-gpu-drivers

Use cos-nvidia-gpu-drivers(-asia/eu) for sourcing generic datacenter
Linux NVIDIA drivers used for installing matching userspace components
for OSS kernel modules.

BUG=b/335514880

Change-Id: I752cdcd5bdd8abc5421f91add46f56b3415da80f
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/71059
Tested-by: Arnav Kansal <rnv@google.com>
Reviewed-by: Robert Kolchmeyer <rkolchmeyer@google.com>
Cloud-Build: GCB Service account <228075978874@cloudbuild.gserviceaccount.com>
Reviewed-by: Shuo Yang <gshuoy@google.com>
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
index 267a3ae..f9ad8f0 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/install.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -166,10 +166,10 @@
 	f.StringVar(&c.nvidiaInstallerURLOpen, "nvidia-installer-url-open", "", "This can be used to specify the location of the GSP firmware and user-space NVIDIA GPU driver components from a corresponding driver release of the OSS kernel modules. This flag is only for debugging and testing.")
 	f.StringVar(&c.gcsDownloadBucketNvidia, "gcs-download-bucket-nvidia", "", "The GCS bucket containing NVIDIA generic driver packages. "+
 		"Must be used in conjunction with `-gcs-download-prefix-nvidia`."+
-		"If not set, then it would be one of 'nvidia-drivers-us-public', 'nvidia-drivers-asia-public', and 'nvidia-drivers-eu-public' based on where the VM is running.")
+		"If not set, then it would be one of 'cos-nvidia-gpu-drivers', 'cos-nvidia-gpu-drivers-asia', and 'cos-nvidia-gpu-drivers-eu' based on where the VM is running.")
 	f.StringVar(&c.gcsDownloadPrefixNvidia, "gcs-download-prefix-nvidia", "", "The GCS prefix where NVIDIA generic driver packages are located within the bucket. "+
 		"Must be used in conjunction with `-gcs-download-bucket-nvidia`."+
-		"If not set, defaults to 'tesla/{driver-version}' (e.g., 'tesla/535.129.03').")
+		"If not set, defaults to ''.")
 	f.BoolVar(&c.debug, "debug", false,
 		"Enable debug mode.")
 	f.BoolVar(&c.test, "test", false,
@@ -263,7 +263,7 @@
 	}
 	hostInstallDir := filepath.Join(hostRootPath, c.hostInstallDir)
 
-	downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix)
+	downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix, c.gcsDownloadBucketNvidia, c.gcsDownloadPrefixNvidia)
 	if c.nvidiaInstallerURL == "" {
 		versionInput := c.driverVersion
 		useFallback := useFallbackMechanism(c)
@@ -511,7 +511,7 @@
 
 	var installerFile string
 	if c.nvidiaInstallerURLOpen == "" {
-		installerFile, err = installer.DownloadGenericDriverInstaller(ctx, downloader, c.driverVersion, c.gcsDownloadBucketNvidia, c.gcsDownloadPrefixNvidia, "")
+		installerFile, err = installer.DownloadGenericDriverInstaller(ctx, downloader, c.driverVersion)
 	} else {
 		installerFile, err = installer.DownloadToInstallDir(c.nvidiaInstallerURLOpen, "Unofficial GPU driver installer")
 	}
diff --git a/src/cmd/cos_gpu_installer/internal/commands/list.go b/src/cmd/cos_gpu_installer/internal/commands/list.go
index 665053b..c7bdb78 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/list.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/list.go
@@ -68,7 +68,7 @@
 		log.Errorf("Unable to init feature flags: %v", err)
 		return subcommands.ExitFailure
 	}
-	downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix)
+	downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix, "", "")
 	if featureConfig.PerGPULabelFeatureFlag {
 		var gpuProtoCacheDir = filepath.Join(hostRootPath, c.gpuProtoCacheDir)
 		if c.gpuProtoCacheDir == "" {
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer.go b/src/cmd/cos_gpu_installer/internal/installer/installer.go
index 02c1416..8d7f42d 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer.go
@@ -17,7 +17,6 @@
 
 	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/internal/signing"
 	"cos.googlesource.com/cos/tools.git/src/pkg/cos"
-	"cos.googlesource.com/cos/tools.git/src/pkg/gcs"
 	"cos.googlesource.com/cos/tools.git/src/pkg/modules"
 	"cos.googlesource.com/cos/tools.git/src/pkg/utils"
 
@@ -35,9 +34,6 @@
 	defaultFilePermission        = 0755
 	signedURLKey                 = "Expires"
 	prebuiltModuleTemplate       = "nvidia-drivers-%s.tgz"
-	gcsBucketNvidiaTemplate      = "nvidia-drivers-%s-public"
-	gcsPrefixNvidiaTemplate      = "tesla/%s"
-	nvidiaRunFileTemplate        = "NVIDIA-Linux-x86_64-%s.run"
 	DefaultVersion               = "default"
 	LatestVersion                = "latest"
 	MajorGPUDriverArtifactPrefix = "gpu_"
@@ -506,21 +502,6 @@
 	return nil
 }
 
-func getInstallerDownloadLocation(metadataZone string) string {
-	fields := strings.Split(metadataZone, "/")
-	zone := fields[len(fields)-1]
-	locationMapping := map[string]string{
-		"us":     "us",
-		"asia":   "asia",
-		"europe": "eu",
-	}
-	location, ok := locationMapping[strings.Split(zone, "-")[0]]
-	if !ok {
-		location = "us"
-	}
-	return location
-}
-
 func createHostDirBindMount(hostDir, bindMountPath string) error {
 	if err := os.MkdirAll(hostDir, defaultFilePermission); err != nil {
 		return errors.Wrapf(err, "failed to create dir %s", hostDir)
@@ -700,43 +681,8 @@
 	return downloader.ArtifactExists(ctx, prebuiltModulesArtifactPath)
 }
 
-// getGenericDriverInstallerGCSVariables returns the GCSBucket, GCSPath, errorMsg for the NVIDIA driver installer.
-func getGenericDriverInstallerGCSVariables(driverVersion, bucketNvidia, prefixNvidia string) (string, string, error) {
-	metadataZone, err := utils.GetGCEMetadata("zone")
-	if err != nil {
-		return "", "", errors.Wrap(err, "failed to get GCE metadata zone")
-	}
-	downloadLocation := getInstallerDownloadLocation(metadataZone)
-
-	gcsBucketNvidia := fmt.Sprintf(gcsBucketNvidiaTemplate, downloadLocation)
-	gcsPrefixNvidia := fmt.Sprintf(gcsPrefixNvidiaTemplate, driverVersion)
-	if bucketNvidia != "" {
-		gcsBucketNvidia = bucketNvidia
-		gcsPrefixNvidia = prefixNvidia
-	}
-	runFileName := fmt.Sprintf(nvidiaRunFileTemplate, driverVersion)
-	gcsPathNvidia := filepath.Join(gcsPrefixNvidia, runFileName)
-	return gcsBucketNvidia, gcsPathNvidia, nil
-}
-
 // DownloadGenericDriverInstaller downloads the generic GPU driver installer given driver version.
-func DownloadGenericDriverInstaller(ctx context.Context, downloader *cos.GCSDownloader, driverVersion, bucketNvidia, prefixNvidia, installerDir string) (string, error) {
+func DownloadGenericDriverInstaller(ctx context.Context, downloader *cos.GCSDownloader, driverVersion string) (string, error) {
 	log.Infof("Downloading GPU driver installer version %s", driverVersion)
-	gcsClient, err := downloader.GetGCSClient(ctx)
-	if err != nil {
-		return "", err
-	}
-	gcsBucketNvidia, gcsPathNvidia, err := getGenericDriverInstallerGCSVariables(driverVersion, bucketNvidia, prefixNvidia)
-	if err != nil {
-		return "", fmt.Errorf("failed to get generic NVIDIA driver installer GCS variables with error: %v", err)
-	}
-	if installerDir == "" {
-		installerDir = gpuInstallDirContainer
-	}
-	outputPath := filepath.Join(installerDir, path.Base(gcsPathNvidia))
-	err = gcs.DownloadGCSObject(ctx, gcsClient, gcsBucketNvidia, gcsPathNvidia, outputPath)
-	if err != nil {
-		return "", fmt.Errorf("failed to download gpu driver installer file from bucket: %s object path: %s with error: %v", gcsBucketNvidia, gcsPathNvidia, err)
-	}
-	return path.Base(gcsPathNvidia), nil
+	return downloader.DownloadGenericNvidiaDriver(ctx, gpuInstallDirContainer, driverVersion)
 }
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
index 64d511d..597faa3 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
@@ -3,7 +3,6 @@
 import (
 	"bytes"
 	"context"
-	"fmt"
 	"os"
 	"path"
 	"path/filepath"
@@ -15,204 +14,6 @@
 	"github.com/golang/protobuf/proto"
 )
 
-func TestGetInstallerDownloadLocation(t *testing.T) {
-	for _, tc := range []struct {
-		testName         string
-		metadataZone     string
-		expectedLocation string
-	}{
-		{
-			"us-west1-b",
-			"projects/123456789/zones/us-west1-b",
-			"us",
-		},
-		{
-			"asia-east1-a",
-			"projects/123456789/zones/asia-east1-a",
-			"asia",
-		},
-		{
-			"europe-west1-b",
-			"projects/123456789/zones/europe-west1-b",
-			"eu",
-		},
-		{
-			"australia-southeast1-a",
-			"projects/123456789/zones/australia-southeast1-a",
-			"us",
-		},
-	} {
-		location := getInstallerDownloadLocation(tc.metadataZone)
-		if location != tc.expectedLocation {
-			t.Errorf("%s: expect location: %s, got: %s", tc.testName, tc.expectedLocation, location)
-		}
-	}
-}
-
-func TestGetGenericDriverInstallerGCSVariables(t *testing.T) {
-
-	var testCases = []struct {
-		driverVersion  string
-		bucketNvidia   string
-		prefixNvidia   string
-		expectedBucket string
-		expectedPath   string
-	}{
-		{
-			"535.125.09",
-			"",
-			"",
-			"nvidia-drivers-us-public",
-			"tesla/535.125.09/NVIDIA-Linux-x86_64-535.125.09.run",
-		},
-		{
-			"535.125.09",
-			"cos-nvidia-bucket",
-			"tesla/temp",
-			"cos-nvidia-bucket",
-			"tesla/temp/NVIDIA-Linux-x86_64-535.125.09.run",
-		},
-		{
-			"470.129.23",
-			"",
-			"tesla/temp",
-			"nvidia-drivers-us-public",
-			"tesla/470.129.23/NVIDIA-Linux-x86_64-470.129.23.run",
-		},
-		{
-			"470.129.23",
-			"cos-nvidia-bucket",
-			"",
-			"cos-nvidia-bucket",
-			"NVIDIA-Linux-x86_64-470.129.23.run",
-		},
-	}
-	for index, tc := range testCases {
-		t.Run(fmt.Sprintf("Test %d: TestGetGenericDriverInstallerGCSVariables", index), func(t *testing.T) {
-			actualBucket, actualPath, err := getGenericDriverInstallerGCSVariables(tc.driverVersion, tc.bucketNvidia, tc.prefixNvidia)
-			if err != nil {
-				t.Errorf("Unexpected err, want: nil, got: %v", err)
-			}
-			if actualBucket != tc.expectedBucket {
-				t.Errorf("Unexpected bucket result, want: %s, got: %s", tc.expectedBucket, actualBucket)
-			}
-			if actualPath != tc.expectedPath {
-				t.Errorf("Unexpected aritifact path result, want: %s, got: %s", tc.expectedPath, actualPath)
-			}
-		})
-	}
-}
-
-func TestDownloadGenericDriverInstaller(t *testing.T) {
-	fakeGCS := fakes.GCSForTest(t)
-	fakeBucket := "cos-tools"
-	fakePrefix := "10000.00.00/lakitu"
-	fakeGCSClient := fakeGCS.Client
-	ctx := context.Background()
-	tempDir, err := os.MkdirTemp("", "installerDir")
-	if err != nil {
-		t.Fatalf("Failed to create tempdir: %v", err)
-	}
-	defer os.RemoveAll(tempDir)
-	var mockData = []struct {
-		bucket     string
-		prefix     string
-		objectName string
-		content    string
-	}{
-		{
-			"nvidia-drivers-us-public",
-			"tesla/550.121.43",
-			"NVIDIA-Linux-x86_64-550.121.43.run",
-			"This is 550.121.43 run file content",
-		},
-		{
-			"cos-tools",
-			"10000.00.00/lakitu",
-			"NVIDIA-Linux-x86_64-535.125.43.run",
-			"This is 535.125.43 run file content",
-		},
-		{
-			"testBucket",
-			"testPrefix",
-			"NVIDIA-Linux-x86_64-470.121.43.run",
-			"This is 470.121.43 run file content",
-		},
-		{
-			"testBucket1",
-			"",
-			"NVIDIA-Linux-x86_64-525.121.43.run",
-			"This is 525.121.43 run file content",
-		},
-	}
-	for _, testData := range mockData {
-		fakeGCS.Objects[path.Join("/", testData.bucket, testData.prefix, testData.objectName)] = []byte(testData.content)
-	}
-	var testCases = []struct {
-		driverVersion   string
-		bucketNvidia    string
-		prefixNvidia    string
-		expectedContent string
-		expectedPath    string
-	}{
-		{
-			"535.125.43",
-			"cos-tools",
-			"10000.00.00/lakitu",
-			"This is 535.125.43 run file content",
-			"NVIDIA-Linux-x86_64-535.125.43.run",
-		},
-		{
-			"470.121.43",
-			"testBucket",
-			"testPrefix",
-			"This is 470.121.43 run file content",
-			"NVIDIA-Linux-x86_64-470.121.43.run",
-		},
-		{
-			"525.121.43",
-			"testBucket1",
-			"",
-			"This is 525.121.43 run file content",
-			"NVIDIA-Linux-x86_64-525.121.43.run",
-		},
-		{
-			"550.121.43",
-			"",
-			"",
-			"This is 550.121.43 run file content",
-			"NVIDIA-Linux-x86_64-550.121.43.run",
-		},
-		{
-			"550.121.43",
-			"",
-			"invalidPrefix",
-			"This is 550.121.43 run file content",
-			"NVIDIA-Linux-x86_64-550.121.43.run",
-		},
-	}
-	var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
-	for index, tc := range testCases {
-		t.Run(fmt.Sprintf("Test %d: TestDownloadGenericDriverInstaller", index), func(t *testing.T) {
-			pathName, err := DownloadGenericDriverInstaller(ctx, FakeGCSDownloader, tc.driverVersion, tc.bucketNvidia, tc.prefixNvidia, tempDir)
-			if err != nil {
-				t.Errorf("Unexpected err, want: nil, got: %v", err)
-			}
-			if pathName != tc.expectedPath {
-				t.Errorf("Unexpected path result, want: %s, got: %s", tc.expectedPath, pathName)
-			}
-			runFileContent, err := os.ReadFile(filepath.Join(tempDir, pathName))
-			if err != nil {
-				t.Errorf("Unexpected err, want: nil, got: %v", err)
-			}
-			if string(runFileContent) != tc.expectedContent {
-				t.Errorf("Unexpected content, want: %s, got: %s", tc.expectedContent, string(runFileContent))
-			}
-		})
-	}
-
-}
-
 func TestDownloadGPUDriverVersionsProto(t *testing.T) {
 	fakeGCS := fakes.GCSForTest(t)
 	fakeBucket := "cos-tools"
@@ -268,7 +69,7 @@
 	}
 	defer os.RemoveAll(tempDir)
 	expectedFilePath := filepath.Join(tempDir, gpuDriverProtoBin)
-	var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+	var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
 	actualData, err := DownloadGPUDriverVersionsProto(ctx, FakeGCSDownloader, tempDir)
 	if err != nil {
 		t.Fatalf("DownloadGPUDriverVersionsProto returned an error: %v", err)
diff --git a/src/pkg/cos/artifacts.go b/src/pkg/cos/artifacts.go
index 9a055d1..f1d46cb 100644
--- a/src/pkg/cos/artifacts.go
+++ b/src/pkg/cos/artifacts.go
@@ -18,26 +18,50 @@
 )
 
 const (
-	cosToolsGCS      = "cos-tools"
-	cosToolsGCSAsia  = "cos-tools-asia"
-	cosToolsGCSEU    = "cos-tools-eu"
-	kernelInfo       = "kernel_info"
-	kernelSrcArchive = "kernel-src.tar.gz"
-	kernelHeaders    = "kernel-headers.tgz"
-	toolchainURL     = "toolchain_url"
-	toolchainArchive = "toolchain.tar.xz"
-	toolchainEnv     = "toolchain_env"
-	crosKernelRepo   = "https://chromium.googlesource.com/chromiumos/third_party/kernel"
+	cosArtifacts                = "cos"
+	nvidiaArtifacts             = "nvidia"
+	cosToolsGCS                 = "cos-tools"
+	cosToolsGCSAsia             = "cos-tools-asia"
+	cosToolsGCSEU               = "cos-tools-eu"
+	nvidiaDriversGCS            = "cos-nvidia-gpu-drivers"
+	nvidiaDriversGCSAsia        = "cos-nvidia-gpu-drivers-asia"
+	nvidiaDriversGCSEU          = "cos-nvidia-gpu-drivers-eu"
+	kernelInfo                  = "kernel_info"
+	kernelSrcArchive            = "kernel-src.tar.gz"
+	kernelHeaders               = "kernel-headers.tgz"
+	toolchainURL                = "toolchain_url"
+	toolchainArchive            = "toolchain.tar.xz"
+	toolchainEnv                = "toolchain_env"
+	nvidiaGenericDriverTemplate = "NVIDIA-Linux-x86_64-%s.run"
+	crosKernelRepo              = "https://chromium.googlesource.com/chromiumos/third_party/kernel"
 )
 
-// Map VM zone prefix to specific cos-tools bucket for geo-redundancy.
-var cosToolsPrefixMap = map[string]string{
-	"us":           cosToolsGCS,
-	"northamerica": cosToolsGCS,
-	"southamerica": cosToolsGCS,
-	"europe":       cosToolsGCSEU,
-	"asia":         cosToolsGCSAsia,
-	"australia":    cosToolsGCSAsia,
+// Map VM zone prefix to GCS buckets for cos-tools and NVIDIA drivers for geo-redundancy.
+var geoBucketMap = map[string]map[string]string{
+	"us": {
+		cosArtifacts:    cosToolsGCS,
+		nvidiaArtifacts: nvidiaDriversGCS,
+	},
+	"northamerica": {
+		cosArtifacts:    cosToolsGCS,
+		nvidiaArtifacts: nvidiaDriversGCS,
+	},
+	"southamerica": {
+		cosArtifacts:    cosToolsGCS,
+		nvidiaArtifacts: nvidiaDriversGCS,
+	},
+	"europe": {
+		cosArtifacts:    cosToolsGCSEU,
+		nvidiaArtifacts: nvidiaDriversGCSEU,
+	},
+	"asia": {
+		cosArtifacts:    cosToolsGCSAsia,
+		nvidiaArtifacts: nvidiaDriversGCSAsia,
+	},
+	"australia": {
+		cosArtifacts:    cosToolsGCSAsia,
+		nvidiaArtifacts: nvidiaDriversGCSAsia,
+	},
 }
 
 // ArtifactsDownloader defines the interface to download COS artifacts.
@@ -45,8 +69,9 @@
 	DownloadKernelSrc(ctx context.Context, destDir string) error
 	DownloadToolchainEnv(ctx context.Context, destDir string) error
 	DownloadToolchain(ctx context.Context, destDir string) error
-	DownloadKernelHeaders(dctx context.Context, estDir string) error
+	DownloadKernelHeaders(ctx context.Context, destDir string) error
 	DownloadArtifact(ctx context.Context, destDir, artifact string) error
+	DownloadGenericNvidiaDriver(ctx context.Context, destDir, driverVersion string) (string, error)
 	GetArtifact(ctx context.Context, artifact string) ([]byte, error)
 	ArtifactExists(ctx context.Context, artifact string) (bool, error)
 	ListArtifacts(ctx context.Context, prefix string) ([]string, error)
@@ -54,38 +79,49 @@
 
 // GCSDownloader is the struct downloading COS artifacts from GCS bucket.
 type GCSDownloader struct {
-	gcsClient         *storage.Client
-	envReader         *EnvReader
-	gcsDownloadBucket string
-	gcsDownloadPrefix string
-	mutex             sync.Mutex
+	gcsClient               *storage.Client
+	envReader               *EnvReader
+	gcsDownloadBucket       string
+	gcsDownloadPrefix       string
+	gcsDownloadBucketNvidia string
+	gcsDownloadPrefixNvidia string
+	mutex                   sync.Mutex
 }
 
 // NewGCSDownloader creates a GCSDownloader instance.
-func NewGCSDownloader(gcsClient *storage.Client, e *EnvReader, bucket, prefix string) *GCSDownloader {
+func NewGCSDownloader(gcsClient *storage.Client, e *EnvReader, bucket, prefix, nvidiaBucket, nvidiaPrefix string) *GCSDownloader {
+	zone, err := metadata.Zone()
+	if err != nil {
+		glog.Warningf("failed to get zone from metadata, will use defaults, err: %v", err)
+	}
+
+	zonePrefix := "us" // Default to 'us' if zone is not available
+	if err == nil {
+		zonePrefix = strings.Split(zone, "-")[0]
+	}
+
+	// Get buckets based on zone, using defaults if necessary
+	buckets := geoBucketMap["us"] // Default to 'us' buckets
+	if zoneBuckets, found := geoBucketMap[zonePrefix]; found {
+		buckets = zoneBuckets
+	}
+
 	// If bucket is not set, use cos-tools, cos-tools-asia or cos-tools-eu
 	// according to the zone the VM is running in for geo-redundancy.
 	// If cannot fetch zone from metadata or get an unknown zone prefix,
 	// use cos-tools as the default GCS bucket.
 	if bucket == "" {
-		zone, err := metadata.Zone()
-		if err != nil {
-			glog.Warningf("failed to get zone from metadata, will use 'gs://cos-tools' as artifact bucket, err: %v", err)
-			bucket = cosToolsGCS
-		} else {
-			zonePrefix := strings.Split(zone, "-")[0]
-			if geoBucket, found := cosToolsPrefixMap[zonePrefix]; found {
-				bucket = geoBucket
-			} else {
-				bucket = cosToolsGCS
-			}
-		}
+		bucket = buckets[cosArtifacts]
 	}
 	// Use {build number}/{board} as the default GCS download prefix.
 	if prefix == "" {
 		prefix = path.Join(e.BuildNumber(), e.Board())
 	}
-	return &GCSDownloader{gcsClient: gcsClient, envReader: e, gcsDownloadBucket: bucket, gcsDownloadPrefix: prefix}
+
+	if nvidiaBucket == "" {
+		nvidiaBucket = buckets[nvidiaArtifacts]
+	}
+	return &GCSDownloader{gcsClient: gcsClient, envReader: e, gcsDownloadBucket: bucket, gcsDownloadPrefix: prefix, gcsDownloadBucketNvidia: nvidiaBucket, gcsDownloadPrefixNvidia: nvidiaPrefix}
 }
 
 // DownloadKernelSrc downloads COS kernel sources to destination directory.
@@ -130,23 +166,33 @@
 
 // DownloadArtifact downloads an artifact from the GCS prefix configured in GCSDownloader.
 func (d *GCSDownloader) DownloadArtifact(ctx context.Context, destDir, artifactPath string) error {
-	gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
+	return d.downloadArtifact(ctx, destDir, d.gcsDownloadBucket, d.gcsDownloadPrefix, artifactPath)
+}
+
+func (d *GCSDownloader) downloadArtifact(ctx context.Context, destDir, bucket, prefix, artifactPath string) error {
+	gcsPath := path.Join(prefix, artifactPath)
 	filename := filepath.Base(gcsPath)
 	outputPath := filepath.Join(destDir, filename)
 	gcsClient, err := d.GetGCSClient(ctx)
 	if err != nil {
 		return err
 	}
-	glog.Infof("Start to download %s artifact from bucket: %s, object: %s to %s.", filename, d.gcsDownloadBucket, gcsPath, outputPath)
-	if err := gcs.DownloadGCSObject(ctx, gcsClient, d.gcsDownloadBucket, gcsPath, outputPath); err != nil {
-		glog.Errorf("Failed to download %s artifact from bucket: %s, object: %s to %s with error: %v.", filename, d.gcsDownloadBucket, gcsPath, outputPath, err)
-		return fmt.Errorf("failed to download %s artifact from bucket: %s, object: %s to %s with error: %w", filename, d.gcsDownloadBucket, gcsPath, outputPath, err)
+	glog.Infof("Start to download %s artifact from bucket: %s, object: %s to %s.", filename, bucket, gcsPath, outputPath)
+	if err := gcs.DownloadGCSObject(ctx, gcsClient, bucket, gcsPath, outputPath); err != nil {
+		glog.Errorf("Failed to download %s artifact from bucket: %s, object: %s to %s with error: %v.", filename, bucket, gcsPath, outputPath, err)
+		return fmt.Errorf("failed to download %s artifact from bucket: %s, object: %s to %s with error: %w", filename, bucket, gcsPath, outputPath, err)
 	}
-	glog.Infof("Sucessfully downloaded %s artifact from bucket: %s, object: %s to %s.", filename, d.gcsDownloadBucket, gcsPath, outputPath)
+	glog.Infof("Sucessfully downloaded %s artifact from bucket: %s, object: %s to %s.", filename, bucket, gcsPath, outputPath)
 	return nil
 
 }
 
+// DownloadNvidiaArtifact downloads a NVIDIA drivers from the NVIDIA GCS bucket and prefix configured in GCSDownloader.
+func (d *GCSDownloader) DownloadGenericNvidiaDriver(ctx context.Context, destDir, driverVersion string) (string, error) {
+	nvidiaArtifactPath := fmt.Sprintf(nvidiaGenericDriverTemplate, driverVersion)
+	return nvidiaArtifactPath, d.downloadArtifact(ctx, destDir, d.gcsDownloadBucketNvidia, d.gcsDownloadPrefixNvidia, nvidiaArtifactPath)
+}
+
 // ArtifactExists check whether the artifactpath exists.
 func (d *GCSDownloader) ArtifactExists(ctx context.Context, artifactPath string) (bool, error) {
 	gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
diff --git a/src/pkg/cos/artifacts_test.go b/src/pkg/cos/artifacts_test.go
index 646237a..e624834 100644
--- a/src/pkg/cos/artifacts_test.go
+++ b/src/pkg/cos/artifacts_test.go
@@ -14,6 +14,92 @@
 	"github.com/google/go-cmp/cmp"
 )
 
+// TestNewGCSDownloader is testing the functionality to create a GCS downloader with appropriate buckets.
+func TestNewGCSDownloader(t *testing.T) {
+	fakeGCS := fakes.GCSForTest(t)
+	fakeGCSClient := fakeGCS.Client
+	fakeEnvReader := &EnvReader{}
+	fakeEnvReader.osRelease = map[string]string{"BUILD_ID": "10000.0.0"}
+	fakeEnvReader.lsbRelease = map[string]string{"CHROMEOS_RELEASE_BOARD": "lakitu"}
+	var testCases = []struct {
+		name                 string
+		bucket               string
+		prefix               string
+		bucketNvidia         string
+		prefixNvidia         string
+		expectedBucket       string
+		expectedPrefix       string
+		expectedBucketNvidia string
+		expectedPrefixNvidia string
+	}{
+		{
+			"all parameters specified",
+			"bucket",
+			"prefix",
+			"nvBucket",
+			"nvPrefix",
+			"bucket",
+			"prefix",
+			"nvBucket",
+			"nvPrefix",
+		},
+		{
+			"NVIDIA bucket default",
+			"bucket",
+			"prefix",
+			"",
+			"nvPrefix",
+			"bucket",
+			"prefix",
+			"cos-nvidia-gpu-drivers",
+			"nvPrefix",
+		},
+		{
+			"COS Tools prefix default",
+			"bucket",
+			"",
+			"nvBucket",
+			"nvPrefix",
+			"bucket",
+			"10000.0.0/lakitu",
+			"nvBucket",
+			"nvPrefix",
+		},
+		{
+			"COS Tools bucket default",
+			"",
+			"prefix",
+			"nvBucket",
+			"nvPrefix",
+			"cos-tools",
+			"prefix",
+			"nvBucket",
+			"nvPrefix",
+		},
+	}
+	for index, tc := range testCases {
+		t.Run(fmt.Sprintf("Test %d: NewGCSDownloader: %s", index, tc.name), func(t *testing.T) {
+			gcsDownloader := NewGCSDownloader(fakeGCSClient, fakeEnvReader, tc.bucket, tc.prefix, tc.bucketNvidia, tc.prefixNvidia)
+			if gcsDownloader == nil {
+				t.Fatalf("Failed to create GCSDownloader with %v", tc)
+			}
+			if !cmp.Equal(gcsDownloader.gcsDownloadBucket, tc.expectedBucket) {
+				t.Errorf("The expected bucket is: %s, but the actual bucket is: %s.", tc.expectedBucket, gcsDownloader.gcsDownloadBucket)
+			}
+			if !cmp.Equal(gcsDownloader.gcsDownloadPrefix, tc.expectedPrefix) {
+				t.Errorf("The expected prefix is: %s, but the actual prefix is: %s.", tc.expectedPrefix, gcsDownloader.gcsDownloadPrefix)
+			}
+			if !cmp.Equal(gcsDownloader.gcsDownloadBucketNvidia, tc.expectedBucketNvidia) {
+				t.Errorf("The expected NVIDIA bucket is: %s, but the actual NVIDIA bucket is: %s.", tc.expectedBucketNvidia, gcsDownloader.gcsDownloadBucketNvidia)
+			}
+			if !cmp.Equal(gcsDownloader.gcsDownloadPrefixNvidia, tc.expectedPrefixNvidia) {
+				t.Errorf("The expected NVIDIA prefix is: %s, but the actual NVIDIA prefix is: %s.", tc.expectedPrefixNvidia, gcsDownloader.gcsDownloadPrefixNvidia)
+			}
+
+		})
+	}
+}
+
 // TestDownloadArtifact is testing the functionality to download an artifact from the GCS prefix configured in GCSDownloader.
 func TestDownloadArtifact(t *testing.T) {
 	testDir, err := os.MkdirTemp("", "testdir")
@@ -43,7 +129,7 @@
 		fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, data.artifactName)] = []byte(data.artifactContent)
 	}
 	defer fakeGCS.Close()
-	var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+	var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
 	var testCases = []struct {
 		downloadDir  string
 		artifactName string
@@ -127,7 +213,7 @@
 		fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, data.artifactName)] = []byte(data.artifactContent)
 	}
 	defer fakeGCS.Close()
-	var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+	var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
 
 	var testCases = []struct {
 		artifactName string
@@ -192,7 +278,7 @@
 		fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, data.artifactName)] = []byte(data.artifactContent)
 	}
 	defer fakeGCS.Close()
-	var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+	var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
 
 	var testCases = []struct {
 		prefix   string
@@ -253,7 +339,7 @@
 		fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, data.artifactName)] = []byte(data.artifactContent)
 	}
 	defer fakeGCS.Close()
-	var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+	var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
 
 	var testCases = []struct {
 		url             string
diff --git a/src/pkg/cos/cos_test.go b/src/pkg/cos/cos_test.go
index 2457260..1377d1e 100644
--- a/src/pkg/cos/cos_test.go
+++ b/src/pkg/cos/cos_test.go
@@ -309,3 +309,7 @@
 func (*fakeDownloader) ListArtifacts(ctx context.Context, prefix string) ([]string, error) {
 	return nil, nil
 }
+
+func (*fakeDownloader) DownloadGenericNvidiaDriver(context.Context, string, string) (string, error) {
+	return "", nil
+}
diff --git a/src/pkg/gpuconfig/artifacts.go b/src/pkg/gpuconfig/artifacts.go
index 55e552c..942a2d8 100644
--- a/src/pkg/gpuconfig/artifacts.go
+++ b/src/pkg/gpuconfig/artifacts.go
@@ -89,3 +89,7 @@
 func (d *GPUArtifactsDownloader) ArtifactExists(ctx context.Context, artifactPath string) (bool, error) {
 	return false, fmt.Errorf("not implemented")
 }
+
+func (d *GPUArtifactsDownloader) DownloadGenericNvidiaDriver(ctx context.Context, destDir, driverVersion string) (string, error) {
+	return "", fmt.Errorf("not implemented")
+}