Adding `--gcs-download-bucket-nvidia` and `--gcs-download-prefix-nvidia` flags to the cos-gpu-installer.
This would allow users to customize the source location from which the NVIDIA installer run file is downloaded
BUG=b/335661885
TEST=test locally, cloudbuild test
Change-Id: I4200784168ed00e961367ac065fb34029a7f9522
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/69634
Cloud-Build: GCB Service account <228075978874@cloudbuild.gserviceaccount.com>
Reviewed-by: Robert Kolchmeyer <rkolchmeyer@google.com>
Tested-by: Shuo Yang <gshuoy@google.com>
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
index 9b8e343..267a3ae 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/install.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -100,22 +100,24 @@
// InstallCommand is the subcommand to install GPU drivers.
type InstallCommand struct {
- driverVersion string
- hostInstallDir string
- forceFallback FallBackFlag
- unsignedDriver bool
- gcsDownloadBucket string
- gcsDownloadPrefix string
- nvidiaInstallerURL string
- signatureURL string
- targetGPU string
- debug bool
- test bool
- prepareBuildTools bool
- kernelOpen bool
- noVerify bool
- kernelModuleParams modules.ModuleParameters
- nvidiaInstallerURLOpen string
+ driverVersion string
+ hostInstallDir string
+ forceFallback FallBackFlag
+ unsignedDriver bool
+ gcsDownloadBucket string
+ gcsDownloadPrefix string
+ nvidiaInstallerURL string
+ signatureURL string
+ targetGPU string
+ debug bool
+ test bool
+ prepareBuildTools bool
+ kernelOpen bool
+ noVerify bool
+ kernelModuleParams modules.ModuleParameters
+ nvidiaInstallerURLOpen string
+ gcsDownloadBucketNvidia string
+ gcsDownloadPrefixNvidia string
}
// Name implements subcommands.Command.Name.
@@ -162,6 +164,12 @@
f.StringVar(&c.signatureURL, "signature-url", "",
"A URL to the driver signature. This flag can only be used together with `-test` and `-nvidia-installer-url` for for debugging and testing.")
f.StringVar(&c.nvidiaInstallerURLOpen, "nvidia-installer-url-open", "", "This can be used to specify the location of the GSP firmware and user-space NVIDIA GPU driver components from a corresponding driver release of the OSS kernel modules. This flag is only for debugging and testing.")
+ f.StringVar(&c.gcsDownloadBucketNvidia, "gcs-download-bucket-nvidia", "", "The GCS bucket containing NVIDIA generic driver packages. "+
+ "Must be used in conjunction with `-gcs-download-prefix-nvidia`."+
+ "If not set, then it would be one of 'nvidia-drivers-us-public', 'nvidia-drivers-asia-public', and 'nvidia-drivers-eu-public' based on where the VM is running.")
+ f.StringVar(&c.gcsDownloadPrefixNvidia, "gcs-download-prefix-nvidia", "", "The GCS prefix where NVIDIA generic driver packages are located within the bucket. "+
+ "Must be used in conjunction with `-gcs-download-bucket-nvidia`."+
+ "If not set, defaults to 'tesla/{driver-version}' (e.g., 'tesla/535.129.03').")
f.BoolVar(&c.debug, "debug", false,
"Enable debug mode.")
f.BoolVar(&c.test, "test", false,
@@ -189,6 +197,9 @@
if c.nvidiaInstallerURLOpen != "" && (c.driverVersion == "" || c.test == false) {
return stderrors.New("-nvidia-installer-url-open must be used with -test and -version")
}
+ if c.gcsDownloadPrefixNvidia != "" && c.gcsDownloadBucketNvidia == "" {
+ return stderrors.New("-gcs-download-prefix-nvidia and -gcs-download-bucket-nvidia must be set together")
+ }
return nil
}
@@ -500,7 +511,7 @@
var installerFile string
if c.nvidiaInstallerURLOpen == "" {
- installerFile, err = installer.DownloadGenericDriverInstaller(ctx, c.driverVersion, downloader)
+ installerFile, err = installer.DownloadGenericDriverInstaller(ctx, downloader, c.driverVersion, c.gcsDownloadBucketNvidia, c.gcsDownloadPrefixNvidia, "")
} else {
installerFile, err = installer.DownloadToInstallDir(c.nvidiaInstallerURLOpen, "Unofficial GPU driver installer")
}
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer.go b/src/cmd/cos_gpu_installer/internal/installer/installer.go
index c74e24f..280ab59 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer.go
@@ -17,6 +17,7 @@
"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/internal/signing"
"cos.googlesource.com/cos/tools.git/src/pkg/cos"
+ "cos.googlesource.com/cos/tools.git/src/pkg/gcs"
"cos.googlesource.com/cos/tools.git/src/pkg/modules"
"cos.googlesource.com/cos/tools.git/src/pkg/utils"
@@ -35,11 +36,13 @@
defaultFilePermission = 0755
signedURLKey = "Expires"
prebuiltModuleTemplate = "nvidia-drivers-%s.tgz"
+ gcsBucketNvidiaTemplate = "nvidia-drivers-%s-public"
+ gcsPrefixNvidiaTemplate = "tesla/%s"
+ nvidiaRunFileTemplate = "NVIDIA-Linux-x86_64-%s.run"
DefaultVersion = "default"
LatestVersion = "latest"
MajorGPUDriverArtifactPrefix = "gpu_"
MajorGPUDriverArtifactSuffix = "_version"
- installerURLTemplate = "https://storage.googleapis.com/nvidia-drivers-%[1]s-public/tesla/%[2]s/NVIDIA-Linux-x86_64-%[2]s.run"
)
var (
@@ -728,27 +731,43 @@
return downloader.ArtifactExists(ctx, prebuiltModulesArtifactPath)
}
-func getGenericDriverInstallerURL(driverVersion string) (string, error) {
+// getGenericDriverInstallerGCSVariables returns the GCSBucket, GCSPath, errorMsg for the NVIDIA driver installer.
+func getGenericDriverInstallerGCSVariables(driverVersion, bucketNvidia, prefixNvidia string) (string, string, error) {
metadataZone, err := utils.GetGCEMetadata("zone")
if err != nil {
- return "", errors.Wrap(err, "failed to get GCE metadata zone")
+ return "", "", errors.Wrap(err, "failed to get GCE metadata zone")
}
downloadLocation := getInstallerDownloadLocation(metadataZone)
- return fmt.Sprintf(installerURLTemplate, downloadLocation, driverVersion), nil
+ gcsBucketNvidia := fmt.Sprintf(gcsBucketNvidiaTemplate, downloadLocation)
+ gcsPrefixNvidia := fmt.Sprintf(gcsPrefixNvidiaTemplate, driverVersion)
+ if bucketNvidia != "" {
+ gcsBucketNvidia = bucketNvidia
+ gcsPrefixNvidia = prefixNvidia
+ }
+ runFileName := fmt.Sprintf(nvidiaRunFileTemplate, driverVersion)
+ gcsPathNvidia := filepath.Join(gcsPrefixNvidia, runFileName)
+ return gcsBucketNvidia, gcsPathNvidia, nil
}
// DownloadGenericDriverInstaller downloads the generic GPU driver installer given driver version.
-func DownloadGenericDriverInstaller(ctx context.Context, driverVersion string, downloader *cos.GCSDownloader) (string, error) {
+func DownloadGenericDriverInstaller(ctx context.Context, downloader *cos.GCSDownloader, driverVersion, bucketNvidia, prefixNvidia, installerDir string) (string, error) {
log.Infof("Downloading GPU driver installer version %s", driverVersion)
- downloadURL, err := getGenericDriverInstallerURL(driverVersion)
+ gcsClient, err := downloader.GetGCSClient(ctx)
if err != nil {
- return "", fmt.Errorf("failed to get driver installer URL: %v", err)
+ return "", err
}
- outputPath := filepath.Join(gpuInstallDirContainer, path.Base(downloadURL))
- err = downloader.DownloadArtifactFromURL(ctx, downloadURL, outputPath)
+ gcsBucketNvidia, gcsPathNvidia, err := getGenericDriverInstallerGCSVariables(driverVersion, bucketNvidia, prefixNvidia)
if err != nil {
- return "", fmt.Errorf("failed to download gpu driver installer file from %s with error: %v", downloadURL, err)
+ return "", fmt.Errorf("failed to get generic NVIDIA driver installer GCS variables with error: %v", err)
}
- return path.Base(downloadURL), nil
+ if installerDir == "" {
+ installerDir = gpuInstallDirContainer
+ }
+ outputPath := filepath.Join(installerDir, path.Base(gcsPathNvidia))
+ err = gcs.DownloadGCSObject(ctx, gcsClient, gcsBucketNvidia, gcsPathNvidia, outputPath)
+ if err != nil {
+ return "", fmt.Errorf("failed to download gpu driver installer file from bucket: %s object path: %s with error: %v", gcsBucketNvidia, gcsPathNvidia, err)
+ }
+ return path.Base(gcsPathNvidia), nil
}
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
index ce946d6..93152e5 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
@@ -3,6 +3,7 @@
import (
"bytes"
"context"
+ "fmt"
"os"
"path"
"path/filepath"
@@ -56,15 +57,168 @@
}
}
-func TestGetGenericDriverInstallerURL(t *testing.T) {
- ret, err := getGenericDriverInstallerURL("525.125.06")
+func TestGetGenericDriverInstallerGCSVariables(t *testing.T) {
+
+ var testCases = []struct {
+ driverVersion string
+ bucketNvidia string
+ prefixNvidia string
+ expectedBucket string
+ expectedPath string
+ }{
+ {
+ "535.125.09",
+ "",
+ "",
+ "nvidia-drivers-us-public",
+ "tesla/535.125.09/NVIDIA-Linux-x86_64-535.125.09.run",
+ },
+ {
+ "535.125.09",
+ "cos-nvidia-bucket",
+ "tesla/temp",
+ "cos-nvidia-bucket",
+ "tesla/temp/NVIDIA-Linux-x86_64-535.125.09.run",
+ },
+ {
+ "470.129.23",
+ "",
+ "tesla/temp",
+ "nvidia-drivers-us-public",
+ "tesla/470.129.23/NVIDIA-Linux-x86_64-470.129.23.run",
+ },
+ {
+ "470.129.23",
+ "cos-nvidia-bucket",
+ "",
+ "cos-nvidia-bucket",
+ "NVIDIA-Linux-x86_64-470.129.23.run",
+ },
+ }
+ for index, tc := range testCases {
+ t.Run(fmt.Sprintf("Test %d: TestGetGenericDriverInstallerGCSVariables", index), func(t *testing.T) {
+ actualBucket, actualPath, err := getGenericDriverInstallerGCSVariables(tc.driverVersion, tc.bucketNvidia, tc.prefixNvidia)
+ if err != nil {
+ t.Errorf("Unexpected err, want: nil, got: %v", err)
+ }
+ if actualBucket != tc.expectedBucket {
+ t.Errorf("Unexpected bucket result, want: %s, got: %s", tc.expectedBucket, actualBucket)
+ }
+ if actualPath != tc.expectedPath {
+ t.Errorf("Unexpected aritifact path result, want: %s, got: %s", tc.expectedPath, actualPath)
+ }
+ })
+ }
+}
+
+func TestDownloadGenericDriverInstaller(t *testing.T) {
+ fakeGCS := fakes.GCSForTest(t)
+ fakeBucket := "cos-tools"
+ fakePrefix := "10000.00.00/lakitu"
+ fakeGCSClient := fakeGCS.Client
+ ctx := context.Background()
+ tempDir, err := os.MkdirTemp("", "installerDir")
if err != nil {
- t.Errorf("Unexpected err, want: nil, got: %v", err)
+ t.Fatalf("Failed to create tempdir: %v", err)
}
- expectedRet := "https://storage.googleapis.com/nvidia-drivers-us-public/tesla/525.125.06/NVIDIA-Linux-x86_64-525.125.06.run"
- if ret != expectedRet {
- t.Errorf("Unexpected return, want: %s, got: %s", expectedRet, ret)
+ defer os.RemoveAll(tempDir)
+ var mockData = []struct {
+ bucket string
+ prefix string
+ objectName string
+ content string
+ }{
+ {
+ "nvidia-drivers-us-public",
+ "tesla/550.121.43",
+ "NVIDIA-Linux-x86_64-550.121.43.run",
+ "This is 550.121.43 run file content",
+ },
+ {
+ "cos-tools",
+ "10000.00.00/lakitu",
+ "NVIDIA-Linux-x86_64-535.125.43.run",
+ "This is 535.125.43 run file content",
+ },
+ {
+ "testBucket",
+ "testPrefix",
+ "NVIDIA-Linux-x86_64-470.121.43.run",
+ "This is 470.121.43 run file content",
+ },
+ {
+ "testBucket1",
+ "",
+ "NVIDIA-Linux-x86_64-525.121.43.run",
+ "This is 525.121.43 run file content",
+ },
}
+ for _, testData := range mockData {
+ fakeGCS.Objects[path.Join("/", testData.bucket, testData.prefix, testData.objectName)] = []byte(testData.content)
+ }
+ var testCases = []struct {
+ driverVersion string
+ bucketNvidia string
+ prefixNvidia string
+ expectedContent string
+ expectedPath string
+ }{
+ {
+ "535.125.43",
+ "cos-tools",
+ "10000.00.00/lakitu",
+ "This is 535.125.43 run file content",
+ "NVIDIA-Linux-x86_64-535.125.43.run",
+ },
+ {
+ "470.121.43",
+ "testBucket",
+ "testPrefix",
+ "This is 470.121.43 run file content",
+ "NVIDIA-Linux-x86_64-470.121.43.run",
+ },
+ {
+ "525.121.43",
+ "testBucket1",
+ "",
+ "This is 525.121.43 run file content",
+ "NVIDIA-Linux-x86_64-525.121.43.run",
+ },
+ {
+ "550.121.43",
+ "",
+ "",
+ "This is 550.121.43 run file content",
+ "NVIDIA-Linux-x86_64-550.121.43.run",
+ },
+ {
+ "550.121.43",
+ "",
+ "invalidPrefix",
+ "This is 550.121.43 run file content",
+ "NVIDIA-Linux-x86_64-550.121.43.run",
+ },
+ }
+ var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+ for index, tc := range testCases {
+ t.Run(fmt.Sprintf("Test %d: TestDownloadGenericDriverInstaller", index), func(t *testing.T) {
+ pathName, err := DownloadGenericDriverInstaller(ctx, FakeGCSDownloader, tc.driverVersion, tc.bucketNvidia, tc.prefixNvidia, tempDir)
+ if err != nil {
+ t.Errorf("Unexpected err, want: nil, got: %v", err)
+ }
+ if pathName != tc.expectedPath {
+ t.Errorf("Unexpected path result, want: %s, got: %s", tc.expectedPath, pathName)
+ }
+ runFileContent, err := os.ReadFile(filepath.Join(tempDir, pathName))
+ if err != nil {
+ t.Errorf("Unexpected err, want: nil, got: %v", err)
+ }
+ if string(runFileContent) != tc.expectedContent {
+ t.Errorf("Unexpected content, want: %s, got: %s", tc.expectedContent, string(runFileContent))
+ }
+ })
+ }
+
}
func TestDownloadGPUDriverVersionsProto(t *testing.T) {
diff --git a/src/pkg/cos/artifacts.go b/src/pkg/cos/artifacts.go
index fcfb425..9a055d1 100644
--- a/src/pkg/cos/artifacts.go
+++ b/src/pkg/cos/artifacts.go
@@ -133,7 +133,7 @@
gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
filename := filepath.Base(gcsPath)
outputPath := filepath.Join(destDir, filename)
- gcsClient, err := d.getGCSClient(ctx)
+ gcsClient, err := d.GetGCSClient(ctx)
if err != nil {
return err
}
@@ -151,7 +151,7 @@
func (d *GCSDownloader) ArtifactExists(ctx context.Context, artifactPath string) (bool, error) {
gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
glog.Infof("Start to check whether artifact: %s exists in bucket: %s", gcsPath, d.gcsDownloadBucket)
- gcsClient, err := d.getGCSClient(ctx)
+ gcsClient, err := d.GetGCSClient(ctx)
if err != nil {
return false, err
}
@@ -168,7 +168,7 @@
var objects []string
var err error
gcsPath := path.Join(d.gcsDownloadPrefix, prefix)
- gcsClient, err := d.getGCSClient(ctx)
+ gcsClient, err := d.GetGCSClient(ctx)
if err != nil {
return nil, err
}
@@ -181,14 +181,14 @@
// DownloadArtifactFromURL will download the artifact from url and save it to the destinationPath.
func (d *GCSDownloader) DownloadArtifactFromURL(ctx context.Context, url string, destinationPath string) error {
- gcsClient, err := d.getGCSClient(ctx)
+ gcsClient, err := d.GetGCSClient(ctx)
if err != nil {
return err
}
return gcs.DownloadGCSObjectFromURL(ctx, gcsClient, url, destinationPath)
}
-func (d *GCSDownloader) getGCSClient(ctx context.Context) (*storage.Client, error) {
+func (d *GCSDownloader) GetGCSClient(ctx context.Context) (*storage.Client, error) {
glog.Infof("Start to fetch the GCSClient.")
d.mutex.Lock()
defer d.mutex.Unlock()
diff --git a/src/pkg/gcs/gcs_client.go b/src/pkg/gcs/gcs_client.go
index 4411de9..528d606 100644
--- a/src/pkg/gcs/gcs_client.go
+++ b/src/pkg/gcs/gcs_client.go
@@ -58,6 +58,7 @@
// DownloadGCSObject downloads the object at bucket, objectName and saves it at destinationPath
func DownloadGCSObject(ctx context.Context, gcsClient *storage.Client, gcsBucket, objectName, destinationPath string) error {
+ glog.Infof("Downloading gcs object from bucket: %s, object: %s to %s.", gcsBucket, objectName, destinationPath)
rc, err := gcsClient.Bucket(gcsBucket).Object(objectName).NewReader(ctx)
if err != nil {
return fmt.Errorf("failed to create the reader from GCS client: %w", err)