Adding `--gcs-download-bucket-nvidia` and `--gcs-download-prefix-nvidia` flags to the cos-gpu-installer.

This would allow users to customize the source location from which the NVIDIA installer run file is downloaded

BUG=b/335661885
TEST=test locally, cloudbuild test

Change-Id: I4200784168ed00e961367ac065fb34029a7f9522
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/69634
Cloud-Build: GCB Service account <228075978874@cloudbuild.gserviceaccount.com>
Reviewed-by: Robert Kolchmeyer <rkolchmeyer@google.com>
Tested-by: Shuo Yang <gshuoy@google.com>
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
index 9b8e343..267a3ae 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/install.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -100,22 +100,24 @@
 
 // InstallCommand is the subcommand to install GPU drivers.
 type InstallCommand struct {
-	driverVersion          string
-	hostInstallDir         string
-	forceFallback          FallBackFlag
-	unsignedDriver         bool
-	gcsDownloadBucket      string
-	gcsDownloadPrefix      string
-	nvidiaInstallerURL     string
-	signatureURL           string
-	targetGPU              string
-	debug                  bool
-	test                   bool
-	prepareBuildTools      bool
-	kernelOpen             bool
-	noVerify               bool
-	kernelModuleParams     modules.ModuleParameters
-	nvidiaInstallerURLOpen string
+	driverVersion           string
+	hostInstallDir          string
+	forceFallback           FallBackFlag
+	unsignedDriver          bool
+	gcsDownloadBucket       string
+	gcsDownloadPrefix       string
+	nvidiaInstallerURL      string
+	signatureURL            string
+	targetGPU               string
+	debug                   bool
+	test                    bool
+	prepareBuildTools       bool
+	kernelOpen              bool
+	noVerify                bool
+	kernelModuleParams      modules.ModuleParameters
+	nvidiaInstallerURLOpen  string
+	gcsDownloadBucketNvidia string
+	gcsDownloadPrefixNvidia string
 }
 
 // Name implements subcommands.Command.Name.
@@ -162,6 +164,12 @@
 	f.StringVar(&c.signatureURL, "signature-url", "",
 		"A URL to the driver signature. This flag can only be used together with `-test` and `-nvidia-installer-url` for for debugging and testing.")
 	f.StringVar(&c.nvidiaInstallerURLOpen, "nvidia-installer-url-open", "", "This can be used to specify the location of the GSP firmware and user-space NVIDIA GPU driver components from a corresponding driver release of the OSS kernel modules. This flag is only for debugging and testing.")
+	f.StringVar(&c.gcsDownloadBucketNvidia, "gcs-download-bucket-nvidia", "", "The GCS bucket containing NVIDIA generic driver packages. "+
+		"Must be used in conjunction with `-gcs-download-prefix-nvidia`."+
+		"If not set, then it would be one of 'nvidia-drivers-us-public', 'nvidia-drivers-asia-public', and 'nvidia-drivers-eu-public' based on where the VM is running.")
+	f.StringVar(&c.gcsDownloadPrefixNvidia, "gcs-download-prefix-nvidia", "", "The GCS prefix where NVIDIA generic driver packages are located within the bucket. "+
+		"Must be used in conjunction with `-gcs-download-bucket-nvidia`."+
+		"If not set, defaults to 'tesla/{driver-version}' (e.g., 'tesla/535.129.03').")
 	f.BoolVar(&c.debug, "debug", false,
 		"Enable debug mode.")
 	f.BoolVar(&c.test, "test", false,
@@ -189,6 +197,9 @@
 	if c.nvidiaInstallerURLOpen != "" && (c.driverVersion == "" || c.test == false) {
 		return stderrors.New("-nvidia-installer-url-open must be used with -test and -version")
 	}
+	if c.gcsDownloadPrefixNvidia != "" && c.gcsDownloadBucketNvidia == "" {
+		return stderrors.New("-gcs-download-prefix-nvidia and -gcs-download-bucket-nvidia must be set together")
+	}
 	return nil
 }
 
@@ -500,7 +511,7 @@
 
 	var installerFile string
 	if c.nvidiaInstallerURLOpen == "" {
-		installerFile, err = installer.DownloadGenericDriverInstaller(ctx, c.driverVersion, downloader)
+		installerFile, err = installer.DownloadGenericDriverInstaller(ctx, downloader, c.driverVersion, c.gcsDownloadBucketNvidia, c.gcsDownloadPrefixNvidia, "")
 	} else {
 		installerFile, err = installer.DownloadToInstallDir(c.nvidiaInstallerURLOpen, "Unofficial GPU driver installer")
 	}
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer.go b/src/cmd/cos_gpu_installer/internal/installer/installer.go
index c74e24f..280ab59 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer.go
@@ -17,6 +17,7 @@
 
 	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/internal/signing"
 	"cos.googlesource.com/cos/tools.git/src/pkg/cos"
+	"cos.googlesource.com/cos/tools.git/src/pkg/gcs"
 	"cos.googlesource.com/cos/tools.git/src/pkg/modules"
 	"cos.googlesource.com/cos/tools.git/src/pkg/utils"
 
@@ -35,11 +36,13 @@
 	defaultFilePermission         = 0755
 	signedURLKey                  = "Expires"
 	prebuiltModuleTemplate        = "nvidia-drivers-%s.tgz"
+	gcsBucketNvidiaTemplate       = "nvidia-drivers-%s-public"
+	gcsPrefixNvidiaTemplate       = "tesla/%s"
+	nvidiaRunFileTemplate         = "NVIDIA-Linux-x86_64-%s.run"
 	DefaultVersion                = "default"
 	LatestVersion                 = "latest"
 	MajorGPUDriverArtifactPrefix  = "gpu_"
 	MajorGPUDriverArtifactSuffix  = "_version"
-	installerURLTemplate          = "https://storage.googleapis.com/nvidia-drivers-%[1]s-public/tesla/%[2]s/NVIDIA-Linux-x86_64-%[2]s.run"
 )
 
 var (
@@ -728,27 +731,43 @@
 	return downloader.ArtifactExists(ctx, prebuiltModulesArtifactPath)
 }
 
-func getGenericDriverInstallerURL(driverVersion string) (string, error) {
+// getGenericDriverInstallerGCSVariables returns the GCSBucket, GCSPath, errorMsg for the NVIDIA driver installer.
+func getGenericDriverInstallerGCSVariables(driverVersion, bucketNvidia, prefixNvidia string) (string, string, error) {
 	metadataZone, err := utils.GetGCEMetadata("zone")
 	if err != nil {
-		return "", errors.Wrap(err, "failed to get GCE metadata zone")
+		return "", "", errors.Wrap(err, "failed to get GCE metadata zone")
 	}
 	downloadLocation := getInstallerDownloadLocation(metadataZone)
 
-	return fmt.Sprintf(installerURLTemplate, downloadLocation, driverVersion), nil
+	gcsBucketNvidia := fmt.Sprintf(gcsBucketNvidiaTemplate, downloadLocation)
+	gcsPrefixNvidia := fmt.Sprintf(gcsPrefixNvidiaTemplate, driverVersion)
+	if bucketNvidia != "" {
+		gcsBucketNvidia = bucketNvidia
+		gcsPrefixNvidia = prefixNvidia
+	}
+	runFileName := fmt.Sprintf(nvidiaRunFileTemplate, driverVersion)
+	gcsPathNvidia := filepath.Join(gcsPrefixNvidia, runFileName)
+	return gcsBucketNvidia, gcsPathNvidia, nil
 }
 
 // DownloadGenericDriverInstaller downloads the generic GPU driver installer given driver version.
-func DownloadGenericDriverInstaller(ctx context.Context, driverVersion string, downloader *cos.GCSDownloader) (string, error) {
+func DownloadGenericDriverInstaller(ctx context.Context, downloader *cos.GCSDownloader, driverVersion, bucketNvidia, prefixNvidia, installerDir string) (string, error) {
 	log.Infof("Downloading GPU driver installer version %s", driverVersion)
-	downloadURL, err := getGenericDriverInstallerURL(driverVersion)
+	gcsClient, err := downloader.GetGCSClient(ctx)
 	if err != nil {
-		return "", fmt.Errorf("failed to get driver installer URL: %v", err)
+		return "", err
 	}
-	outputPath := filepath.Join(gpuInstallDirContainer, path.Base(downloadURL))
-	err = downloader.DownloadArtifactFromURL(ctx, downloadURL, outputPath)
+	gcsBucketNvidia, gcsPathNvidia, err := getGenericDriverInstallerGCSVariables(driverVersion, bucketNvidia, prefixNvidia)
 	if err != nil {
-		return "", fmt.Errorf("failed to download gpu driver installer file from %s with error: %v", downloadURL, err)
+		return "", fmt.Errorf("failed to get generic NVIDIA driver installer GCS variables with error: %v", err)
 	}
-	return path.Base(downloadURL), nil
+	if installerDir == "" {
+		installerDir = gpuInstallDirContainer
+	}
+	outputPath := filepath.Join(installerDir, path.Base(gcsPathNvidia))
+	err = gcs.DownloadGCSObject(ctx, gcsClient, gcsBucketNvidia, gcsPathNvidia, outputPath)
+	if err != nil {
+		return "", fmt.Errorf("failed to download gpu driver installer file from bucket: %s object path: %s with error: %v", gcsBucketNvidia, gcsPathNvidia, err)
+	}
+	return path.Base(gcsPathNvidia), nil
 }
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
index ce946d6..93152e5 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
@@ -3,6 +3,7 @@
 import (
 	"bytes"
 	"context"
+	"fmt"
 	"os"
 	"path"
 	"path/filepath"
@@ -56,15 +57,168 @@
 	}
 }
 
-func TestGetGenericDriverInstallerURL(t *testing.T) {
-	ret, err := getGenericDriverInstallerURL("525.125.06")
+func TestGetGenericDriverInstallerGCSVariables(t *testing.T) {
+
+	var testCases = []struct {
+		driverVersion  string
+		bucketNvidia   string
+		prefixNvidia   string
+		expectedBucket string
+		expectedPath   string
+	}{
+		{
+			"535.125.09",
+			"",
+			"",
+			"nvidia-drivers-us-public",
+			"tesla/535.125.09/NVIDIA-Linux-x86_64-535.125.09.run",
+		},
+		{
+			"535.125.09",
+			"cos-nvidia-bucket",
+			"tesla/temp",
+			"cos-nvidia-bucket",
+			"tesla/temp/NVIDIA-Linux-x86_64-535.125.09.run",
+		},
+		{
+			"470.129.23",
+			"",
+			"tesla/temp",
+			"nvidia-drivers-us-public",
+			"tesla/470.129.23/NVIDIA-Linux-x86_64-470.129.23.run",
+		},
+		{
+			"470.129.23",
+			"cos-nvidia-bucket",
+			"",
+			"cos-nvidia-bucket",
+			"NVIDIA-Linux-x86_64-470.129.23.run",
+		},
+	}
+	for index, tc := range testCases {
+		t.Run(fmt.Sprintf("Test %d: TestGetGenericDriverInstallerGCSVariables", index), func(t *testing.T) {
+			actualBucket, actualPath, err := getGenericDriverInstallerGCSVariables(tc.driverVersion, tc.bucketNvidia, tc.prefixNvidia)
+			if err != nil {
+				t.Errorf("Unexpected err, want: nil, got: %v", err)
+			}
+			if actualBucket != tc.expectedBucket {
+				t.Errorf("Unexpected bucket result, want: %s, got: %s", tc.expectedBucket, actualBucket)
+			}
+			if actualPath != tc.expectedPath {
+				t.Errorf("Unexpected aritifact path result, want: %s, got: %s", tc.expectedPath, actualPath)
+			}
+		})
+	}
+}
+
+func TestDownloadGenericDriverInstaller(t *testing.T) {
+	fakeGCS := fakes.GCSForTest(t)
+	fakeBucket := "cos-tools"
+	fakePrefix := "10000.00.00/lakitu"
+	fakeGCSClient := fakeGCS.Client
+	ctx := context.Background()
+	tempDir, err := os.MkdirTemp("", "installerDir")
 	if err != nil {
-		t.Errorf("Unexpected err, want: nil, got: %v", err)
+		t.Fatalf("Failed to create tempdir: %v", err)
 	}
-	expectedRet := "https://storage.googleapis.com/nvidia-drivers-us-public/tesla/525.125.06/NVIDIA-Linux-x86_64-525.125.06.run"
-	if ret != expectedRet {
-		t.Errorf("Unexpected return, want: %s, got: %s", expectedRet, ret)
+	defer os.RemoveAll(tempDir)
+	var mockData = []struct {
+		bucket     string
+		prefix     string
+		objectName string
+		content    string
+	}{
+		{
+			"nvidia-drivers-us-public",
+			"tesla/550.121.43",
+			"NVIDIA-Linux-x86_64-550.121.43.run",
+			"This is 550.121.43 run file content",
+		},
+		{
+			"cos-tools",
+			"10000.00.00/lakitu",
+			"NVIDIA-Linux-x86_64-535.125.43.run",
+			"This is 535.125.43 run file content",
+		},
+		{
+			"testBucket",
+			"testPrefix",
+			"NVIDIA-Linux-x86_64-470.121.43.run",
+			"This is 470.121.43 run file content",
+		},
+		{
+			"testBucket1",
+			"",
+			"NVIDIA-Linux-x86_64-525.121.43.run",
+			"This is 525.121.43 run file content",
+		},
 	}
+	for _, testData := range mockData {
+		fakeGCS.Objects[path.Join("/", testData.bucket, testData.prefix, testData.objectName)] = []byte(testData.content)
+	}
+	var testCases = []struct {
+		driverVersion   string
+		bucketNvidia    string
+		prefixNvidia    string
+		expectedContent string
+		expectedPath    string
+	}{
+		{
+			"535.125.43",
+			"cos-tools",
+			"10000.00.00/lakitu",
+			"This is 535.125.43 run file content",
+			"NVIDIA-Linux-x86_64-535.125.43.run",
+		},
+		{
+			"470.121.43",
+			"testBucket",
+			"testPrefix",
+			"This is 470.121.43 run file content",
+			"NVIDIA-Linux-x86_64-470.121.43.run",
+		},
+		{
+			"525.121.43",
+			"testBucket1",
+			"",
+			"This is 525.121.43 run file content",
+			"NVIDIA-Linux-x86_64-525.121.43.run",
+		},
+		{
+			"550.121.43",
+			"",
+			"",
+			"This is 550.121.43 run file content",
+			"NVIDIA-Linux-x86_64-550.121.43.run",
+		},
+		{
+			"550.121.43",
+			"",
+			"invalidPrefix",
+			"This is 550.121.43 run file content",
+			"NVIDIA-Linux-x86_64-550.121.43.run",
+		},
+	}
+	var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+	for index, tc := range testCases {
+		t.Run(fmt.Sprintf("Test %d: TestDownloadGenericDriverInstaller", index), func(t *testing.T) {
+			pathName, err := DownloadGenericDriverInstaller(ctx, FakeGCSDownloader, tc.driverVersion, tc.bucketNvidia, tc.prefixNvidia, tempDir)
+			if err != nil {
+				t.Errorf("Unexpected err, want: nil, got: %v", err)
+			}
+			if pathName != tc.expectedPath {
+				t.Errorf("Unexpected path result, want: %s, got: %s", tc.expectedPath, pathName)
+			}
+			runFileContent, err := os.ReadFile(filepath.Join(tempDir, pathName))
+			if err != nil {
+				t.Errorf("Unexpected err, want: nil, got: %v", err)
+			}
+			if string(runFileContent) != tc.expectedContent {
+				t.Errorf("Unexpected content, want: %s, got: %s", tc.expectedContent, string(runFileContent))
+			}
+		})
+	}
+
 }
 
 func TestDownloadGPUDriverVersionsProto(t *testing.T) {
diff --git a/src/pkg/cos/artifacts.go b/src/pkg/cos/artifacts.go
index fcfb425..9a055d1 100644
--- a/src/pkg/cos/artifacts.go
+++ b/src/pkg/cos/artifacts.go
@@ -133,7 +133,7 @@
 	gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
 	filename := filepath.Base(gcsPath)
 	outputPath := filepath.Join(destDir, filename)
-	gcsClient, err := d.getGCSClient(ctx)
+	gcsClient, err := d.GetGCSClient(ctx)
 	if err != nil {
 		return err
 	}
@@ -151,7 +151,7 @@
 func (d *GCSDownloader) ArtifactExists(ctx context.Context, artifactPath string) (bool, error) {
 	gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
 	glog.Infof("Start to check whether artifact: %s exists in bucket: %s", gcsPath, d.gcsDownloadBucket)
-	gcsClient, err := d.getGCSClient(ctx)
+	gcsClient, err := d.GetGCSClient(ctx)
 	if err != nil {
 		return false, err
 	}
@@ -168,7 +168,7 @@
 	var objects []string
 	var err error
 	gcsPath := path.Join(d.gcsDownloadPrefix, prefix)
-	gcsClient, err := d.getGCSClient(ctx)
+	gcsClient, err := d.GetGCSClient(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -181,14 +181,14 @@
 
 // DownloadArtifactFromURL will download the artifact from url and save it to the destinationPath.
 func (d *GCSDownloader) DownloadArtifactFromURL(ctx context.Context, url string, destinationPath string) error {
-	gcsClient, err := d.getGCSClient(ctx)
+	gcsClient, err := d.GetGCSClient(ctx)
 	if err != nil {
 		return err
 	}
 	return gcs.DownloadGCSObjectFromURL(ctx, gcsClient, url, destinationPath)
 }
 
-func (d *GCSDownloader) getGCSClient(ctx context.Context) (*storage.Client, error) {
+func (d *GCSDownloader) GetGCSClient(ctx context.Context) (*storage.Client, error) {
 	glog.Infof("Start to fetch the GCSClient.")
 	d.mutex.Lock()
 	defer d.mutex.Unlock()
diff --git a/src/pkg/gcs/gcs_client.go b/src/pkg/gcs/gcs_client.go
index 4411de9..528d606 100644
--- a/src/pkg/gcs/gcs_client.go
+++ b/src/pkg/gcs/gcs_client.go
@@ -58,6 +58,7 @@
 
 // DownloadGCSObject downloads the object at bucket, objectName and saves it at destinationPath
 func DownloadGCSObject(ctx context.Context, gcsClient *storage.Client, gcsBucket, objectName, destinationPath string) error {
+	glog.Infof("Downloading gcs object from bucket: %s, object: %s to %s.", gcsBucket, objectName, destinationPath)
 	rc, err := gcsClient.Bucket(gcsBucket).Object(objectName).NewReader(ctx)
 	if err != nil {
 		return fmt.Errorf("failed to create the reader from GCS client: %w", err)