cos-gpu-installer: Switch to using cos-nvidia-gpu-drivers
Use cos-nvidia-gpu-drivers(-asia/eu) for sourcing generic datacenter
Linux NVIDIA drivers used for installing matching userspace components
for OSS kernel modules.
BUG=b/335514880
Change-Id: I752cdcd5bdd8abc5421f91add46f56b3415da80f
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/71059
Tested-by: Arnav Kansal <rnv@google.com>
Reviewed-by: Robert Kolchmeyer <rkolchmeyer@google.com>
Cloud-Build: GCB Service account <228075978874@cloudbuild.gserviceaccount.com>
Reviewed-by: Shuo Yang <gshuoy@google.com>
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
index 267a3ae..f9ad8f0 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/install.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -166,10 +166,10 @@
f.StringVar(&c.nvidiaInstallerURLOpen, "nvidia-installer-url-open", "", "This can be used to specify the location of the GSP firmware and user-space NVIDIA GPU driver components from a corresponding driver release of the OSS kernel modules. This flag is only for debugging and testing.")
f.StringVar(&c.gcsDownloadBucketNvidia, "gcs-download-bucket-nvidia", "", "The GCS bucket containing NVIDIA generic driver packages. "+
"Must be used in conjunction with `-gcs-download-prefix-nvidia`."+
- "If not set, then it would be one of 'nvidia-drivers-us-public', 'nvidia-drivers-asia-public', and 'nvidia-drivers-eu-public' based on where the VM is running.")
+ "If not set, then it would be one of 'cos-nvidia-gpu-drivers', 'cos-nvidia-gpu-drivers-asia', and 'cos-nvidia-gpu-drivers-eu' based on where the VM is running.")
f.StringVar(&c.gcsDownloadPrefixNvidia, "gcs-download-prefix-nvidia", "", "The GCS prefix where NVIDIA generic driver packages are located within the bucket. "+
"Must be used in conjunction with `-gcs-download-bucket-nvidia`."+
- "If not set, defaults to 'tesla/{driver-version}' (e.g., 'tesla/535.129.03').")
+ "If not set, defaults to ''.")
f.BoolVar(&c.debug, "debug", false,
"Enable debug mode.")
f.BoolVar(&c.test, "test", false,
@@ -263,7 +263,7 @@
}
hostInstallDir := filepath.Join(hostRootPath, c.hostInstallDir)
- downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix)
+ downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix, c.gcsDownloadBucketNvidia, c.gcsDownloadPrefixNvidia)
if c.nvidiaInstallerURL == "" {
versionInput := c.driverVersion
useFallback := useFallbackMechanism(c)
@@ -511,7 +511,7 @@
var installerFile string
if c.nvidiaInstallerURLOpen == "" {
- installerFile, err = installer.DownloadGenericDriverInstaller(ctx, downloader, c.driverVersion, c.gcsDownloadBucketNvidia, c.gcsDownloadPrefixNvidia, "")
+ installerFile, err = installer.DownloadGenericDriverInstaller(ctx, downloader, c.driverVersion)
} else {
installerFile, err = installer.DownloadToInstallDir(c.nvidiaInstallerURLOpen, "Unofficial GPU driver installer")
}
diff --git a/src/cmd/cos_gpu_installer/internal/commands/list.go b/src/cmd/cos_gpu_installer/internal/commands/list.go
index 665053b..c7bdb78 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/list.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/list.go
@@ -68,7 +68,7 @@
log.Errorf("Unable to init feature flags: %v", err)
return subcommands.ExitFailure
}
- downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix)
+ downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix, "", "")
if featureConfig.PerGPULabelFeatureFlag {
var gpuProtoCacheDir = filepath.Join(hostRootPath, c.gpuProtoCacheDir)
if c.gpuProtoCacheDir == "" {
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer.go b/src/cmd/cos_gpu_installer/internal/installer/installer.go
index 02c1416..8d7f42d 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer.go
@@ -17,7 +17,6 @@
"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/internal/signing"
"cos.googlesource.com/cos/tools.git/src/pkg/cos"
- "cos.googlesource.com/cos/tools.git/src/pkg/gcs"
"cos.googlesource.com/cos/tools.git/src/pkg/modules"
"cos.googlesource.com/cos/tools.git/src/pkg/utils"
@@ -35,9 +34,6 @@
defaultFilePermission = 0755
signedURLKey = "Expires"
prebuiltModuleTemplate = "nvidia-drivers-%s.tgz"
- gcsBucketNvidiaTemplate = "nvidia-drivers-%s-public"
- gcsPrefixNvidiaTemplate = "tesla/%s"
- nvidiaRunFileTemplate = "NVIDIA-Linux-x86_64-%s.run"
DefaultVersion = "default"
LatestVersion = "latest"
MajorGPUDriverArtifactPrefix = "gpu_"
@@ -506,21 +502,6 @@
return nil
}
-func getInstallerDownloadLocation(metadataZone string) string {
- fields := strings.Split(metadataZone, "/")
- zone := fields[len(fields)-1]
- locationMapping := map[string]string{
- "us": "us",
- "asia": "asia",
- "europe": "eu",
- }
- location, ok := locationMapping[strings.Split(zone, "-")[0]]
- if !ok {
- location = "us"
- }
- return location
-}
-
func createHostDirBindMount(hostDir, bindMountPath string) error {
if err := os.MkdirAll(hostDir, defaultFilePermission); err != nil {
return errors.Wrapf(err, "failed to create dir %s", hostDir)
@@ -700,43 +681,8 @@
return downloader.ArtifactExists(ctx, prebuiltModulesArtifactPath)
}
-// getGenericDriverInstallerGCSVariables returns the GCSBucket, GCSPath, errorMsg for the NVIDIA driver installer.
-func getGenericDriverInstallerGCSVariables(driverVersion, bucketNvidia, prefixNvidia string) (string, string, error) {
- metadataZone, err := utils.GetGCEMetadata("zone")
- if err != nil {
- return "", "", errors.Wrap(err, "failed to get GCE metadata zone")
- }
- downloadLocation := getInstallerDownloadLocation(metadataZone)
-
- gcsBucketNvidia := fmt.Sprintf(gcsBucketNvidiaTemplate, downloadLocation)
- gcsPrefixNvidia := fmt.Sprintf(gcsPrefixNvidiaTemplate, driverVersion)
- if bucketNvidia != "" {
- gcsBucketNvidia = bucketNvidia
- gcsPrefixNvidia = prefixNvidia
- }
- runFileName := fmt.Sprintf(nvidiaRunFileTemplate, driverVersion)
- gcsPathNvidia := filepath.Join(gcsPrefixNvidia, runFileName)
- return gcsBucketNvidia, gcsPathNvidia, nil
-}
-
// DownloadGenericDriverInstaller downloads the generic GPU driver installer given driver version.
-func DownloadGenericDriverInstaller(ctx context.Context, downloader *cos.GCSDownloader, driverVersion, bucketNvidia, prefixNvidia, installerDir string) (string, error) {
+func DownloadGenericDriverInstaller(ctx context.Context, downloader *cos.GCSDownloader, driverVersion string) (string, error) {
log.Infof("Downloading GPU driver installer version %s", driverVersion)
- gcsClient, err := downloader.GetGCSClient(ctx)
- if err != nil {
- return "", err
- }
- gcsBucketNvidia, gcsPathNvidia, err := getGenericDriverInstallerGCSVariables(driverVersion, bucketNvidia, prefixNvidia)
- if err != nil {
- return "", fmt.Errorf("failed to get generic NVIDIA driver installer GCS variables with error: %v", err)
- }
- if installerDir == "" {
- installerDir = gpuInstallDirContainer
- }
- outputPath := filepath.Join(installerDir, path.Base(gcsPathNvidia))
- err = gcs.DownloadGCSObject(ctx, gcsClient, gcsBucketNvidia, gcsPathNvidia, outputPath)
- if err != nil {
- return "", fmt.Errorf("failed to download gpu driver installer file from bucket: %s object path: %s with error: %v", gcsBucketNvidia, gcsPathNvidia, err)
- }
- return path.Base(gcsPathNvidia), nil
+ return downloader.DownloadGenericNvidiaDriver(ctx, gpuInstallDirContainer, driverVersion)
}
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
index 64d511d..597faa3 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
@@ -3,7 +3,6 @@
import (
"bytes"
"context"
- "fmt"
"os"
"path"
"path/filepath"
@@ -15,204 +14,6 @@
"github.com/golang/protobuf/proto"
)
-func TestGetInstallerDownloadLocation(t *testing.T) {
- for _, tc := range []struct {
- testName string
- metadataZone string
- expectedLocation string
- }{
- {
- "us-west1-b",
- "projects/123456789/zones/us-west1-b",
- "us",
- },
- {
- "asia-east1-a",
- "projects/123456789/zones/asia-east1-a",
- "asia",
- },
- {
- "europe-west1-b",
- "projects/123456789/zones/europe-west1-b",
- "eu",
- },
- {
- "australia-southeast1-a",
- "projects/123456789/zones/australia-southeast1-a",
- "us",
- },
- } {
- location := getInstallerDownloadLocation(tc.metadataZone)
- if location != tc.expectedLocation {
- t.Errorf("%s: expect location: %s, got: %s", tc.testName, tc.expectedLocation, location)
- }
- }
-}
-
-func TestGetGenericDriverInstallerGCSVariables(t *testing.T) {
-
- var testCases = []struct {
- driverVersion string
- bucketNvidia string
- prefixNvidia string
- expectedBucket string
- expectedPath string
- }{
- {
- "535.125.09",
- "",
- "",
- "nvidia-drivers-us-public",
- "tesla/535.125.09/NVIDIA-Linux-x86_64-535.125.09.run",
- },
- {
- "535.125.09",
- "cos-nvidia-bucket",
- "tesla/temp",
- "cos-nvidia-bucket",
- "tesla/temp/NVIDIA-Linux-x86_64-535.125.09.run",
- },
- {
- "470.129.23",
- "",
- "tesla/temp",
- "nvidia-drivers-us-public",
- "tesla/470.129.23/NVIDIA-Linux-x86_64-470.129.23.run",
- },
- {
- "470.129.23",
- "cos-nvidia-bucket",
- "",
- "cos-nvidia-bucket",
- "NVIDIA-Linux-x86_64-470.129.23.run",
- },
- }
- for index, tc := range testCases {
- t.Run(fmt.Sprintf("Test %d: TestGetGenericDriverInstallerGCSVariables", index), func(t *testing.T) {
- actualBucket, actualPath, err := getGenericDriverInstallerGCSVariables(tc.driverVersion, tc.bucketNvidia, tc.prefixNvidia)
- if err != nil {
- t.Errorf("Unexpected err, want: nil, got: %v", err)
- }
- if actualBucket != tc.expectedBucket {
- t.Errorf("Unexpected bucket result, want: %s, got: %s", tc.expectedBucket, actualBucket)
- }
- if actualPath != tc.expectedPath {
- t.Errorf("Unexpected aritifact path result, want: %s, got: %s", tc.expectedPath, actualPath)
- }
- })
- }
-}
-
-func TestDownloadGenericDriverInstaller(t *testing.T) {
- fakeGCS := fakes.GCSForTest(t)
- fakeBucket := "cos-tools"
- fakePrefix := "10000.00.00/lakitu"
- fakeGCSClient := fakeGCS.Client
- ctx := context.Background()
- tempDir, err := os.MkdirTemp("", "installerDir")
- if err != nil {
- t.Fatalf("Failed to create tempdir: %v", err)
- }
- defer os.RemoveAll(tempDir)
- var mockData = []struct {
- bucket string
- prefix string
- objectName string
- content string
- }{
- {
- "nvidia-drivers-us-public",
- "tesla/550.121.43",
- "NVIDIA-Linux-x86_64-550.121.43.run",
- "This is 550.121.43 run file content",
- },
- {
- "cos-tools",
- "10000.00.00/lakitu",
- "NVIDIA-Linux-x86_64-535.125.43.run",
- "This is 535.125.43 run file content",
- },
- {
- "testBucket",
- "testPrefix",
- "NVIDIA-Linux-x86_64-470.121.43.run",
- "This is 470.121.43 run file content",
- },
- {
- "testBucket1",
- "",
- "NVIDIA-Linux-x86_64-525.121.43.run",
- "This is 525.121.43 run file content",
- },
- }
- for _, testData := range mockData {
- fakeGCS.Objects[path.Join("/", testData.bucket, testData.prefix, testData.objectName)] = []byte(testData.content)
- }
- var testCases = []struct {
- driverVersion string
- bucketNvidia string
- prefixNvidia string
- expectedContent string
- expectedPath string
- }{
- {
- "535.125.43",
- "cos-tools",
- "10000.00.00/lakitu",
- "This is 535.125.43 run file content",
- "NVIDIA-Linux-x86_64-535.125.43.run",
- },
- {
- "470.121.43",
- "testBucket",
- "testPrefix",
- "This is 470.121.43 run file content",
- "NVIDIA-Linux-x86_64-470.121.43.run",
- },
- {
- "525.121.43",
- "testBucket1",
- "",
- "This is 525.121.43 run file content",
- "NVIDIA-Linux-x86_64-525.121.43.run",
- },
- {
- "550.121.43",
- "",
- "",
- "This is 550.121.43 run file content",
- "NVIDIA-Linux-x86_64-550.121.43.run",
- },
- {
- "550.121.43",
- "",
- "invalidPrefix",
- "This is 550.121.43 run file content",
- "NVIDIA-Linux-x86_64-550.121.43.run",
- },
- }
- var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
- for index, tc := range testCases {
- t.Run(fmt.Sprintf("Test %d: TestDownloadGenericDriverInstaller", index), func(t *testing.T) {
- pathName, err := DownloadGenericDriverInstaller(ctx, FakeGCSDownloader, tc.driverVersion, tc.bucketNvidia, tc.prefixNvidia, tempDir)
- if err != nil {
- t.Errorf("Unexpected err, want: nil, got: %v", err)
- }
- if pathName != tc.expectedPath {
- t.Errorf("Unexpected path result, want: %s, got: %s", tc.expectedPath, pathName)
- }
- runFileContent, err := os.ReadFile(filepath.Join(tempDir, pathName))
- if err != nil {
- t.Errorf("Unexpected err, want: nil, got: %v", err)
- }
- if string(runFileContent) != tc.expectedContent {
- t.Errorf("Unexpected content, want: %s, got: %s", tc.expectedContent, string(runFileContent))
- }
- })
- }
-
-}
-
func TestDownloadGPUDriverVersionsProto(t *testing.T) {
fakeGCS := fakes.GCSForTest(t)
fakeBucket := "cos-tools"
@@ -268,7 +69,7 @@
}
defer os.RemoveAll(tempDir)
expectedFilePath := filepath.Join(tempDir, gpuDriverProtoBin)
- var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+ var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
actualData, err := DownloadGPUDriverVersionsProto(ctx, FakeGCSDownloader, tempDir)
if err != nil {
t.Fatalf("DownloadGPUDriverVersionsProto returned an error: %v", err)
diff --git a/src/pkg/cos/artifacts.go b/src/pkg/cos/artifacts.go
index 9a055d1..f1d46cb 100644
--- a/src/pkg/cos/artifacts.go
+++ b/src/pkg/cos/artifacts.go
@@ -18,26 +18,50 @@
)
const (
- cosToolsGCS = "cos-tools"
- cosToolsGCSAsia = "cos-tools-asia"
- cosToolsGCSEU = "cos-tools-eu"
- kernelInfo = "kernel_info"
- kernelSrcArchive = "kernel-src.tar.gz"
- kernelHeaders = "kernel-headers.tgz"
- toolchainURL = "toolchain_url"
- toolchainArchive = "toolchain.tar.xz"
- toolchainEnv = "toolchain_env"
- crosKernelRepo = "https://chromium.googlesource.com/chromiumos/third_party/kernel"
+ cosArtifacts = "cos"
+ nvidiaArtifacts = "nvidia"
+ cosToolsGCS = "cos-tools"
+ cosToolsGCSAsia = "cos-tools-asia"
+ cosToolsGCSEU = "cos-tools-eu"
+ nvidiaDriversGCS = "cos-nvidia-gpu-drivers"
+ nvidiaDriversGCSAsia = "cos-nvidia-gpu-drivers-asia"
+ nvidiaDriversGCSEU = "cos-nvidia-gpu-drivers-eu"
+ kernelInfo = "kernel_info"
+ kernelSrcArchive = "kernel-src.tar.gz"
+ kernelHeaders = "kernel-headers.tgz"
+ toolchainURL = "toolchain_url"
+ toolchainArchive = "toolchain.tar.xz"
+ toolchainEnv = "toolchain_env"
+ nvidiaGenericDriverTemplate = "NVIDIA-Linux-x86_64-%s.run"
+ crosKernelRepo = "https://chromium.googlesource.com/chromiumos/third_party/kernel"
)
-// Map VM zone prefix to specific cos-tools bucket for geo-redundancy.
-var cosToolsPrefixMap = map[string]string{
- "us": cosToolsGCS,
- "northamerica": cosToolsGCS,
- "southamerica": cosToolsGCS,
- "europe": cosToolsGCSEU,
- "asia": cosToolsGCSAsia,
- "australia": cosToolsGCSAsia,
+// Map VM zone prefix to GCS buckets for cos-tools and NVIDIA drivers for geo-redundancy.
+var geoBucketMap = map[string]map[string]string{
+ "us": {
+ cosArtifacts: cosToolsGCS,
+ nvidiaArtifacts: nvidiaDriversGCS,
+ },
+ "northamerica": {
+ cosArtifacts: cosToolsGCS,
+ nvidiaArtifacts: nvidiaDriversGCS,
+ },
+ "southamerica": {
+ cosArtifacts: cosToolsGCS,
+ nvidiaArtifacts: nvidiaDriversGCS,
+ },
+ "europe": {
+ cosArtifacts: cosToolsGCSEU,
+ nvidiaArtifacts: nvidiaDriversGCSEU,
+ },
+ "asia": {
+ cosArtifacts: cosToolsGCSAsia,
+ nvidiaArtifacts: nvidiaDriversGCSAsia,
+ },
+ "australia": {
+ cosArtifacts: cosToolsGCSAsia,
+ nvidiaArtifacts: nvidiaDriversGCSAsia,
+ },
}
// ArtifactsDownloader defines the interface to download COS artifacts.
@@ -45,8 +69,9 @@
DownloadKernelSrc(ctx context.Context, destDir string) error
DownloadToolchainEnv(ctx context.Context, destDir string) error
DownloadToolchain(ctx context.Context, destDir string) error
- DownloadKernelHeaders(dctx context.Context, estDir string) error
+ DownloadKernelHeaders(ctx context.Context, destDir string) error
DownloadArtifact(ctx context.Context, destDir, artifact string) error
+ DownloadGenericNvidiaDriver(ctx context.Context, destDir, driverVersion string) (string, error)
GetArtifact(ctx context.Context, artifact string) ([]byte, error)
ArtifactExists(ctx context.Context, artifact string) (bool, error)
ListArtifacts(ctx context.Context, prefix string) ([]string, error)
@@ -54,38 +79,49 @@
// GCSDownloader is the struct downloading COS artifacts from GCS bucket.
type GCSDownloader struct {
- gcsClient *storage.Client
- envReader *EnvReader
- gcsDownloadBucket string
- gcsDownloadPrefix string
- mutex sync.Mutex
+ gcsClient *storage.Client
+ envReader *EnvReader
+ gcsDownloadBucket string
+ gcsDownloadPrefix string
+ gcsDownloadBucketNvidia string
+ gcsDownloadPrefixNvidia string
+ mutex sync.Mutex
}
// NewGCSDownloader creates a GCSDownloader instance.
-func NewGCSDownloader(gcsClient *storage.Client, e *EnvReader, bucket, prefix string) *GCSDownloader {
+func NewGCSDownloader(gcsClient *storage.Client, e *EnvReader, bucket, prefix, nvidiaBucket, nvidiaPrefix string) *GCSDownloader {
+ zone, err := metadata.Zone()
+ if err != nil {
+ glog.Warningf("failed to get zone from metadata, will use defaults, err: %v", err)
+ }
+
+ zonePrefix := "us" // Default to 'us' if zone is not available
+ if err == nil {
+ zonePrefix = strings.Split(zone, "-")[0]
+ }
+
+ // Get buckets based on zone, using defaults if necessary
+ buckets := geoBucketMap["us"] // Default to 'us' buckets
+ if zoneBuckets, found := geoBucketMap[zonePrefix]; found {
+ buckets = zoneBuckets
+ }
+
// If bucket is not set, use cos-tools, cos-tools-asia or cos-tools-eu
// according to the zone the VM is running in for geo-redundancy.
// If cannot fetch zone from metadata or get an unknown zone prefix,
// use cos-tools as the default GCS bucket.
if bucket == "" {
- zone, err := metadata.Zone()
- if err != nil {
- glog.Warningf("failed to get zone from metadata, will use 'gs://cos-tools' as artifact bucket, err: %v", err)
- bucket = cosToolsGCS
- } else {
- zonePrefix := strings.Split(zone, "-")[0]
- if geoBucket, found := cosToolsPrefixMap[zonePrefix]; found {
- bucket = geoBucket
- } else {
- bucket = cosToolsGCS
- }
- }
+ bucket = buckets[cosArtifacts]
}
// Use {build number}/{board} as the default GCS download prefix.
if prefix == "" {
prefix = path.Join(e.BuildNumber(), e.Board())
}
- return &GCSDownloader{gcsClient: gcsClient, envReader: e, gcsDownloadBucket: bucket, gcsDownloadPrefix: prefix}
+
+ if nvidiaBucket == "" {
+ nvidiaBucket = buckets[nvidiaArtifacts]
+ }
+ return &GCSDownloader{gcsClient: gcsClient, envReader: e, gcsDownloadBucket: bucket, gcsDownloadPrefix: prefix, gcsDownloadBucketNvidia: nvidiaBucket, gcsDownloadPrefixNvidia: nvidiaPrefix}
}
// DownloadKernelSrc downloads COS kernel sources to destination directory.
@@ -130,23 +166,33 @@
// DownloadArtifact downloads an artifact from the GCS prefix configured in GCSDownloader.
func (d *GCSDownloader) DownloadArtifact(ctx context.Context, destDir, artifactPath string) error {
- gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
+ return d.downloadArtifact(ctx, destDir, d.gcsDownloadBucket, d.gcsDownloadPrefix, artifactPath)
+}
+
+func (d *GCSDownloader) downloadArtifact(ctx context.Context, destDir, bucket, prefix, artifactPath string) error {
+ gcsPath := path.Join(prefix, artifactPath)
filename := filepath.Base(gcsPath)
outputPath := filepath.Join(destDir, filename)
gcsClient, err := d.GetGCSClient(ctx)
if err != nil {
return err
}
- glog.Infof("Start to download %s artifact from bucket: %s, object: %s to %s.", filename, d.gcsDownloadBucket, gcsPath, outputPath)
- if err := gcs.DownloadGCSObject(ctx, gcsClient, d.gcsDownloadBucket, gcsPath, outputPath); err != nil {
- glog.Errorf("Failed to download %s artifact from bucket: %s, object: %s to %s with error: %v.", filename, d.gcsDownloadBucket, gcsPath, outputPath, err)
- return fmt.Errorf("failed to download %s artifact from bucket: %s, object: %s to %s with error: %w", filename, d.gcsDownloadBucket, gcsPath, outputPath, err)
+ glog.Infof("Start to download %s artifact from bucket: %s, object: %s to %s.", filename, bucket, gcsPath, outputPath)
+ if err := gcs.DownloadGCSObject(ctx, gcsClient, bucket, gcsPath, outputPath); err != nil {
+ glog.Errorf("Failed to download %s artifact from bucket: %s, object: %s to %s with error: %v.", filename, bucket, gcsPath, outputPath, err)
+ return fmt.Errorf("failed to download %s artifact from bucket: %s, object: %s to %s with error: %w", filename, bucket, gcsPath, outputPath, err)
}
- glog.Infof("Sucessfully downloaded %s artifact from bucket: %s, object: %s to %s.", filename, d.gcsDownloadBucket, gcsPath, outputPath)
+ glog.Infof("Sucessfully downloaded %s artifact from bucket: %s, object: %s to %s.", filename, bucket, gcsPath, outputPath)
return nil
}
+// DownloadNvidiaArtifact downloads a NVIDIA drivers from the NVIDIA GCS bucket and prefix configured in GCSDownloader.
+func (d *GCSDownloader) DownloadGenericNvidiaDriver(ctx context.Context, destDir, driverVersion string) (string, error) {
+ nvidiaArtifactPath := fmt.Sprintf(nvidiaGenericDriverTemplate, driverVersion)
+ return nvidiaArtifactPath, d.downloadArtifact(ctx, destDir, d.gcsDownloadBucketNvidia, d.gcsDownloadPrefixNvidia, nvidiaArtifactPath)
+}
+
// ArtifactExists check whether the artifactpath exists.
func (d *GCSDownloader) ArtifactExists(ctx context.Context, artifactPath string) (bool, error) {
gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
diff --git a/src/pkg/cos/artifacts_test.go b/src/pkg/cos/artifacts_test.go
index 646237a..e624834 100644
--- a/src/pkg/cos/artifacts_test.go
+++ b/src/pkg/cos/artifacts_test.go
@@ -14,6 +14,92 @@
"github.com/google/go-cmp/cmp"
)
+// TestNewGCSDownloader is testing the functionality to create a GCS downloader with appropriate buckets.
+func TestNewGCSDownloader(t *testing.T) {
+ fakeGCS := fakes.GCSForTest(t)
+ fakeGCSClient := fakeGCS.Client
+ fakeEnvReader := &EnvReader{}
+ fakeEnvReader.osRelease = map[string]string{"BUILD_ID": "10000.0.0"}
+ fakeEnvReader.lsbRelease = map[string]string{"CHROMEOS_RELEASE_BOARD": "lakitu"}
+ var testCases = []struct {
+ name string
+ bucket string
+ prefix string
+ bucketNvidia string
+ prefixNvidia string
+ expectedBucket string
+ expectedPrefix string
+ expectedBucketNvidia string
+ expectedPrefixNvidia string
+ }{
+ {
+ "all parameters specified",
+ "bucket",
+ "prefix",
+ "nvBucket",
+ "nvPrefix",
+ "bucket",
+ "prefix",
+ "nvBucket",
+ "nvPrefix",
+ },
+ {
+ "NVIDIA bucket default",
+ "bucket",
+ "prefix",
+ "",
+ "nvPrefix",
+ "bucket",
+ "prefix",
+ "cos-nvidia-gpu-drivers",
+ "nvPrefix",
+ },
+ {
+ "COS Tools prefix default",
+ "bucket",
+ "",
+ "nvBucket",
+ "nvPrefix",
+ "bucket",
+ "10000.0.0/lakitu",
+ "nvBucket",
+ "nvPrefix",
+ },
+ {
+ "COS Tools bucket default",
+ "",
+ "prefix",
+ "nvBucket",
+ "nvPrefix",
+ "cos-tools",
+ "prefix",
+ "nvBucket",
+ "nvPrefix",
+ },
+ }
+ for index, tc := range testCases {
+ t.Run(fmt.Sprintf("Test %d: NewGCSDownloader: %s", index, tc.name), func(t *testing.T) {
+ gcsDownloader := NewGCSDownloader(fakeGCSClient, fakeEnvReader, tc.bucket, tc.prefix, tc.bucketNvidia, tc.prefixNvidia)
+ if gcsDownloader == nil {
+ t.Fatalf("Failed to create GCSDownloader with %v", tc)
+ }
+ if !cmp.Equal(gcsDownloader.gcsDownloadBucket, tc.expectedBucket) {
+ t.Errorf("The expected bucket is: %s, but the actual bucket is: %s.", tc.expectedBucket, gcsDownloader.gcsDownloadBucket)
+ }
+ if !cmp.Equal(gcsDownloader.gcsDownloadPrefix, tc.expectedPrefix) {
+ t.Errorf("The expected prefix is: %s, but the actual prefix is: %s.", tc.expectedPrefix, gcsDownloader.gcsDownloadPrefix)
+ }
+ if !cmp.Equal(gcsDownloader.gcsDownloadBucketNvidia, tc.expectedBucketNvidia) {
+ t.Errorf("The expected NVIDIA bucket is: %s, but the actual NVIDIA bucket is: %s.", tc.expectedBucketNvidia, gcsDownloader.gcsDownloadBucketNvidia)
+ }
+ if !cmp.Equal(gcsDownloader.gcsDownloadPrefixNvidia, tc.expectedPrefixNvidia) {
+ t.Errorf("The expected NVIDIA prefix is: %s, but the actual NVIDIA prefix is: %s.", tc.expectedPrefixNvidia, gcsDownloader.gcsDownloadPrefixNvidia)
+ }
+
+ })
+ }
+}
+
// TestDownloadArtifact is testing the functionality to download an artifact from the GCS prefix configured in GCSDownloader.
func TestDownloadArtifact(t *testing.T) {
testDir, err := os.MkdirTemp("", "testdir")
@@ -43,7 +129,7 @@
fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, data.artifactName)] = []byte(data.artifactContent)
}
defer fakeGCS.Close()
- var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+ var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
var testCases = []struct {
downloadDir string
artifactName string
@@ -127,7 +213,7 @@
fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, data.artifactName)] = []byte(data.artifactContent)
}
defer fakeGCS.Close()
- var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+ var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
var testCases = []struct {
artifactName string
@@ -192,7 +278,7 @@
fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, data.artifactName)] = []byte(data.artifactContent)
}
defer fakeGCS.Close()
- var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+ var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
var testCases = []struct {
prefix string
@@ -253,7 +339,7 @@
fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, data.artifactName)] = []byte(data.artifactContent)
}
defer fakeGCS.Close()
- var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+ var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
var testCases = []struct {
url string
diff --git a/src/pkg/cos/cos_test.go b/src/pkg/cos/cos_test.go
index 2457260..1377d1e 100644
--- a/src/pkg/cos/cos_test.go
+++ b/src/pkg/cos/cos_test.go
@@ -309,3 +309,7 @@
func (*fakeDownloader) ListArtifacts(ctx context.Context, prefix string) ([]string, error) {
return nil, nil
}
+
+func (*fakeDownloader) DownloadGenericNvidiaDriver(context.Context, string, string) (string, error) {
+ return "", nil
+}
diff --git a/src/pkg/gpuconfig/artifacts.go b/src/pkg/gpuconfig/artifacts.go
index 55e552c..942a2d8 100644
--- a/src/pkg/gpuconfig/artifacts.go
+++ b/src/pkg/gpuconfig/artifacts.go
@@ -89,3 +89,7 @@
func (d *GPUArtifactsDownloader) ArtifactExists(ctx context.Context, artifactPath string) (bool, error) {
return false, fmt.Errorf("not implemented")
}
+
+func (d *GPUArtifactsDownloader) DownloadGenericNvidiaDriver(ctx context.Context, destDir, driverVersion string) (string, error) {
+ return "", fmt.Errorf("not implemented")
+}