cos/tools: Remove using http client for cos-gpu-installer
OSS GPU driver installer downloading is using http client and we are changing to use GCS go client instead.
BUG=b/330361078
TEST=test locally, cloudbuild test
Change-Id: I6aa2472613f257671225013e950841a712774e43
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/69410
Reviewed-by: Arnav Kansal <rnv@google.com>
Tested-by: Shuo Yang <gshuoy@google.com>
Cloud-Build: GCB Service account <228075978874@cloudbuild.gserviceaccount.com>
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
index 55f17b8..9b8e343 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/install.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -500,7 +500,7 @@
var installerFile string
if c.nvidiaInstallerURLOpen == "" {
- installerFile, err = installer.DownloadGenericDriverInstaller(c.driverVersion)
+ installerFile, err = installer.DownloadGenericDriverInstaller(ctx, c.driverVersion, downloader)
} else {
installerFile, err = installer.DownloadToInstallDir(c.nvidiaInstallerURLOpen, "Unofficial GPU driver installer")
}
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer.go b/src/cmd/cos_gpu_installer/internal/installer/installer.go
index 1773536..c74e24f 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer.go
@@ -739,11 +739,16 @@
}
// DownloadGenericDriverInstaller downloads the generic GPU driver installer given driver version.
-func DownloadGenericDriverInstaller(driverVersion string) (string, error) {
+func DownloadGenericDriverInstaller(ctx context.Context, driverVersion string, downloader *cos.GCSDownloader) (string, error) {
log.Infof("Downloading GPU driver installer version %s", driverVersion)
downloadURL, err := getGenericDriverInstallerURL(driverVersion)
if err != nil {
- return "", errors.Wrap(err, "failed to get driver installer URL")
+ return "", fmt.Errorf("failed to get driver installer URL: %v", err)
}
- return DownloadToInstallDir(downloadURL, "GPU driver installer")
+ outputPath := filepath.Join(gpuInstallDirContainer, path.Base(downloadURL))
+ err = downloader.DownloadArtifactFromURL(ctx, downloadURL, outputPath)
+ if err != nil {
+ return "", fmt.Errorf("failed to download gpu driver installer file from %s with error: %v", downloadURL, err)
+ }
+ return path.Base(downloadURL), nil
}
diff --git a/src/pkg/cos/artifacts.go b/src/pkg/cos/artifacts.go
index 9e3e4e4..fcfb425 100644
--- a/src/pkg/cos/artifacts.go
+++ b/src/pkg/cos/artifacts.go
@@ -179,6 +179,15 @@
return objects, nil
}
+// DownloadArtifactFromURL will download the artifact from url and save it to the destinationPath.
+func (d *GCSDownloader) DownloadArtifactFromURL(ctx context.Context, url string, destinationPath string) error {
+ gcsClient, err := d.getGCSClient(ctx)
+ if err != nil {
+ return err
+ }
+ return gcs.DownloadGCSObjectFromURL(ctx, gcsClient, url, destinationPath)
+}
+
func (d *GCSDownloader) getGCSClient(ctx context.Context) (*storage.Client, error) {
glog.Infof("Start to fetch the GCSClient.")
d.mutex.Lock()
diff --git a/src/pkg/cos/artifacts_test.go b/src/pkg/cos/artifacts_test.go
index 34a684b..646237a 100644
--- a/src/pkg/cos/artifacts_test.go
+++ b/src/pkg/cos/artifacts_test.go
@@ -226,3 +226,60 @@
}
}
+
+// TestDownloadArtifactFromURL is testing the functionality to download the artifacts from the url
+func TestDownloadArtifactFromURL(t *testing.T) {
+ testDir, err := os.MkdirTemp("", "testdir")
+ if err != nil {
+ t.Fatalf("Failed to create tempdir: %v", err)
+ }
+ defer os.RemoveAll(testDir)
+ fakeGCS := fakes.GCSForTest(t)
+ fakeBucket := "cos-tools"
+ fakePrefix := "10000.00.00/lakitu"
+ fakeArtifactName := "NVIDIA-Linux-x86_64-550.run"
+ fakeGCSClient := fakeGCS.Client
+ ctx := context.Background()
+ var fakeData = []struct {
+ artifactName string
+ artifactContent string
+ }{
+ {
+ fakeArtifactName,
+ "Artifact content",
+ },
+ }
+ for _, data := range fakeData {
+ fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, data.artifactName)] = []byte(data.artifactContent)
+ }
+ defer fakeGCS.Close()
+ var FakeGCSDownloader = NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix)
+
+ var testCases = []struct {
+ url string
+ destinationPath string
+ expected string
+ }{
+ {
+ fmt.Sprintf("https://storage.googleapis.com/%[1]s/%[2]s/%[3]s", fakeBucket, fakePrefix, fakeArtifactName),
+ fmt.Sprintf("%[1]s/%[2]s", testDir, fakeArtifactName),
+ "Artifact content",
+ },
+ }
+ for index, tc := range testCases {
+ t.Run(fmt.Sprintf("TestDownloadArtifactFromURL %d:", index), func(t *testing.T) {
+ err := FakeGCSDownloader.DownloadArtifactFromURL(ctx, tc.url, tc.destinationPath)
+ if err != nil {
+ t.Fatalf("TestDownloadArtifactFromURL Failed: %v", err)
+ }
+ actualData, err := os.ReadFile(tc.destinationPath)
+ if err != nil {
+ t.Fatalf("Fail to read the content from the output file: %s with error: %v", tc.destinationPath, err)
+ }
+ if !cmp.Equal(string(actualData), tc.expected) {
+ t.Errorf("TestDownloadArtifactFromURL failed: the expected data is: %s, but the actual data is: %s.", tc.expected, string(actualData))
+ }
+ })
+ }
+
+}