cos-gpu-installer: Add gcsDownloadPrefix flag

This flag allows to specify GCS download prefix through flag. It
is more flexsible than the hard-coded GCS buckets. It makes it easier
to test against internal gcs buckets.

BUG=b/155192122
TEST=manually run "cos-extensions" againast the new cos-gpu-installer to
test.

Change-Id: Ia7ef3b5b91c8e4fe4c3dc8f08ff6d8f72c2b6746
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/9840
Reviewed-by: Robert Kolchmeyer <rkolchmeyer@google.com>
Tested-by: Ke Wu <mikewu@google.com>
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
index a787cb9..e3973e7 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/install.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -29,11 +29,12 @@
 
 // InstallCommand is the subcommand to install GPU drivers.
 type InstallCommand struct {
-	driverVersion    string
-	hostInstallDir   string
-	unsignedDriver   bool
-	internalDownload bool
-	debug            bool
+	driverVersion     string
+	hostInstallDir    string
+	unsignedDriver    bool
+	gcsDownloadBucket string
+	gcsDownloadPrefix string
+	debug             bool
 }
 
 // Name implements subcommands.Command.Name.
@@ -51,14 +52,17 @@
 		"The GPU driver verion to install. It will install the default GPU driver if the flag is not set explicitly.")
 	f.StringVar(&c.hostInstallDir, "host-dir", "",
 		"Host directory that GPU drivers should be installed to. "+
-		"It tries to read from the env NVIDIA_INSTALL_DIR_HOST if the flag is not set explicitly.")
+			"It tries to read from the env NVIDIA_INSTALL_DIR_HOST if the flag is not set explicitly.")
 	f.BoolVar(&c.unsignedDriver, "allow-unsigned-driver", false,
 		"Whether to allow load unsigned GPU drivers. "+
 			"If this flag is set to true, module signing security features must be disabled on the host for driver installation to succeed. "+
 			"This flag is only for debugging.")
-	// TODO(mikewu): change this flag to a bucket prefix string.
-	f.BoolVar(&c.internalDownload, "internal-download", false,
-		"Whether to try to download files from Google internal server. This is only useful for internal developing.")
+	f.StringVar(&c.gcsDownloadBucket, "gcs-download-bucket", "cos-tools",
+		"The GCS bucket to download COS artifacts from. "+
+			"For example, the default value is 'cos-tools' which is the public COS artifacts bucket.")
+	f.StringVar(&c.gcsDownloadPrefix, "gcs-download-prefix", "",
+		"The GCS path prefix when downloading COS artifacts."+
+			"If not set then the COS version build number (e.g. 13310.1041.38) will be used.")
 	f.BoolVar(&c.debug, "debug", false,
 		"Enable debug mode.")
 }
@@ -73,7 +77,7 @@
 
 	log.Infof("Running on COS build id %s", envReader.BuildNumber())
 
-	downloader := cos.NewGCSDownloader(envReader, c.internalDownload)
+	downloader := cos.NewGCSDownloader(envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix)
 	if c.driverVersion == "" {
 		defaultVersion, err := installer.GetDefaultGPUDriverVersion(downloader)
 		if err != nil {
diff --git a/src/cmd/cos_gpu_installer/internal/commands/list.go b/src/cmd/cos_gpu_installer/internal/commands/list.go
index da76756..ae8e155 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/list.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/list.go
@@ -17,8 +17,9 @@
 
 // ListCommand is the subcommand to list supported GPU drivers.
 type ListCommand struct {
-	internalDownload bool
-	debug            bool
+	gcsDownloadBucket string
+	gcsDownloadPrefix string
+	debug             bool
 }
 
 // Name implements subcommands.Command.Name.
@@ -32,9 +33,12 @@
 
 // SetFlags implements subcommands.Command.SetFlags.
 func (c *ListCommand) SetFlags(f *flag.FlagSet) {
-	// TODO(mikewu): change this flag to a bucket prefix string.
-	f.BoolVar(&c.internalDownload, "internal-download", false,
-		"Whether to try to download files from Google internal server. This is only useful for internal developing.")
+	f.StringVar(&c.gcsDownloadBucket, "gcs-download-bucket", "cos-tools",
+		"The GCS bucket to download COS artifacts from. "+
+			"For example, the default value is 'cos-tools' which is the public COS artifacts bucket.")
+	f.StringVar(&c.gcsDownloadPrefix, "gcs-download-prefix", "",
+		"The GCS path prefix when downloading COS artifacts."+
+			"If not set then the COS version build number (e.g. 13310.1041.38) will be used.")
 	f.BoolVar(&c.debug, "debug", false,
 		"Enable debug mode.")
 }
@@ -47,8 +51,8 @@
 		return subcommands.ExitFailure
 	}
 	log.Infof("Running on COS build id %s", envReader.BuildNumber())
-	downloader := cos.NewGCSDownloader(envReader, c.internalDownload)
-	artifacts, err := downloader.ListExtensionArtifacts("gpu")
+	downloader := cos.NewGCSDownloader(envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix)
+	artifacts, err := downloader.ListGPUExtensionArtifacts()
 	if err != nil {
 		c.logError(errors.Wrap(err, "failed to list gpu extension artifacts"))
 		return subcommands.ExitFailure
diff --git a/src/pkg/cos/artifacts.go b/src/pkg/cos/artifacts.go
index 1d720d7..76fb55c 100644
--- a/src/pkg/cos/artifacts.go
+++ b/src/pkg/cos/artifacts.go
@@ -4,6 +4,7 @@
 	"fmt"
 	"io/ioutil"
 	"os"
+	"path"
 	"path/filepath"
 
 	log "github.com/golang/glog"
@@ -13,9 +14,7 @@
 )
 
 const (
-	// TODO(mikewu): consider making GCS buckets as flags.
 	cosToolsGCS      = "cos-tools"
-	internalGCS      = "container-vm-image-staging"
 	chromiumOSSDKGCS = "chromiumos-sdk"
 	kernelInfo       = "kernel_info"
 	kernelSrcArchive = "kernel-src.tar.gz"
@@ -38,13 +37,22 @@
 
 // GCSDownloader is the struct downloading COS artifacts from GCS bucket.
 type GCSDownloader struct {
-	envReader *EnvReader
-	Internal  bool
+	envReader         *EnvReader
+	gcsDownloadBucket string
+	gcsDownloadPrefix string
 }
 
 // NewGCSDownloader creates a GCSDownloader instance.
-func NewGCSDownloader(e *EnvReader, i bool) *GCSDownloader {
-	return &GCSDownloader{e, i}
+func NewGCSDownloader(e *EnvReader, bucket, prefix string) *GCSDownloader {
+	// Use cos-tools as the default GCS bucket.
+	if bucket == "" {
+		bucket = cosToolsGCS
+	}
+	// Use build number as the default GCS download prefix.
+	if prefix == "" {
+		prefix = e.BuildNumber()
+	}
+	return &GCSDownloader{e, bucket, prefix}
 }
 
 // DownloadKernelSrc downloads COS kernel sources to destination directory.
@@ -95,32 +103,13 @@
 	return content, nil
 }
 
-// DownloadArtifact downloads an artifact from GCS buckets, including public bucket and internal bucket.
-// TODO(mikewu): consider allow users to pass in GCS directories in arguments.
+// DownloadArtifact downloads an artifact from the GCS prefix configured in GCSDownloader.
 func (d *GCSDownloader) DownloadArtifact(destDir, artifactPath string) error {
-	var err error
-
-	if err = utils.DownloadFromGCS(destDir, cosToolsGCS, d.artifactPublicPath(artifactPath)); err == nil {
-		return nil
+	gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
+	if err := utils.DownloadFromGCS(destDir, d.gcsDownloadBucket, gcsPath); err != nil {
+		return errors.Errorf("failed to download %s from gs://%s/%s", artifactPath, d.gcsDownloadBucket, gcsPath)
 	}
-	log.Errorf("Failed to download %s from public GCS: %v", artifactPath, err)
-
-	if d.Internal {
-		if err = utils.DownloadFromGCS(destDir, internalGCS, d.artifactInternalPath(artifactPath)); err == nil {
-			return nil
-		}
-		log.Errorf("Failed to download %s from internal GCS: %v", artifactPath, err)
-	}
-
-	return errors.Errorf("failed to download %s", artifactPath)
-}
-
-func (d *GCSDownloader) artifactPublicPath(artifactPath string) string {
-	return fmt.Sprintf("%s/%s", d.envReader.BuildNumber(), artifactPath)
-}
-
-func (d *GCSDownloader) artifactInternalPath(artifactPath string) string {
-	return fmt.Sprintf("lakitu-release/R%s-%s/%s", d.envReader.Milestone(), d.envReader.BuildNumber(), artifactPath)
+	return nil
 }
 
 func (d *GCSDownloader) getToolchainURL() (string, error) {
diff --git a/src/pkg/cos/extensions.go b/src/pkg/cos/extensions.go
index c8968d7..151ed80 100644
--- a/src/pkg/cos/extensions.go
+++ b/src/pkg/cos/extensions.go
@@ -2,12 +2,11 @@
 
 import (
 	"fmt"
-	"path/filepath"
+	"path"
 	"regexp"
 
 	"cos.googlesource.com/cos/tools.git/src/pkg/utils"
 
-	log "github.com/golang/glog"
 	"github.com/pkg/errors"
 )
 
@@ -28,18 +27,10 @@
 func (d *GCSDownloader) ListExtensions() ([]string, error) {
 	var objects []string
 	var err error
-	if objects, err = utils.ListGCSBucket(cosToolsGCS, d.artifactPublicPath("extensions")); err != nil || len(objects) == 0 {
-		log.Errorf("Failed to list extensions from public GCS: %v", err)
-		if d.Internal {
-			if objects, err = utils.ListGCSBucket(internalGCS, d.artifactInternalPath("extensions")); err != nil {
-				log.Errorf("Failed to list extensions from internal GCS: %v", err)
-			}
-		}
-	}
-	if err != nil {
+	gcsPath := path.Join(d.gcsDownloadPrefix, "extensions")
+	if objects, err = utils.ListGCSBucket(d.gcsDownloadBucket, gcsPath); err != nil {
 		return nil, errors.Wrap(err, "failed to list extensions")
 	}
-
 	var extensions []string
 	re := regexp.MustCompile(`extensions/(\w+)$`)
 	for _, object := range objects {
@@ -51,21 +42,11 @@
 }
 
 // ListExtensionArtifacts lists all artifacts of a given extension.
-// TODO(mikewu): make this extension specific.
 func (d *GCSDownloader) ListExtensionArtifacts(extension string) ([]string, error) {
 	var objects []string
 	var err error
-	extensionPath := filepath.Join("extensions", extension)
-	if objects, err = utils.ListGCSBucket(cosToolsGCS, d.artifactPublicPath(extensionPath)); err != nil || len(objects) == 0 {
-		log.Errorf("Failed to list extension artifacts from public GCS: %v", err)
-		// TODO(mikewu): use flags to specify GCS directories.
-		if d.Internal {
-			if objects, err = utils.ListGCSBucket(internalGCS, d.artifactInternalPath(extensionPath)); err != nil {
-				log.Errorf("Failed to list extension artifacts from internal GCS: %v", err)
-			}
-		}
-	}
-	if err != nil {
+	gcsPath := path.Join(d.gcsDownloadPrefix, "extensions", extension)
+	if objects, err = utils.ListGCSBucket(d.gcsDownloadBucket, gcsPath); err != nil {
 		return nil, errors.Wrap(err, "failed to list extensions")
 	}
 
@@ -79,14 +60,19 @@
 	return artifacts, nil
 }
 
+// ListGPUExtensionArtifacts lists all artifacts of GPU extension.
+func (d *GCSDownloader) ListGPUExtensionArtifacts() ([]string, error) {
+	return d.ListExtensionArtifacts(GPUExtension)
+}
+
 // DownloadExtensionArtifact downloads an artifact of the given extension.
 func (d *GCSDownloader) DownloadExtensionArtifact(destDir, extension, artifact string) error {
-	artifactPath := filepath.Join("extensions", extension, artifact)
+	artifactPath := path.Join("extensions", extension, artifact)
 	return d.DownloadArtifact(destDir, artifactPath)
 }
 
 // GetExtensionArtifact reads the content of an artifact of the given extension.
 func (d *GCSDownloader) GetExtensionArtifact(extension, artifact string) ([]byte, error) {
-	artifactPath := filepath.Join("extensions", extension, artifact)
+	artifactPath := path.Join("extensions", extension, artifact)
 	return d.GetArtifact(artifactPath)
 }