cos-gpu-installer: Add gcsDownloadPrefix flag
This flag allows to specify GCS download prefix through flag. It
is more flexsible than the hard-coded GCS buckets. It makes it easier
to test against internal gcs buckets.
BUG=b/155192122
TEST=manually run "cos-extensions" againast the new cos-gpu-installer to
test.
Change-Id: Ia7ef3b5b91c8e4fe4c3dc8f08ff6d8f72c2b6746
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/9840
Reviewed-by: Robert Kolchmeyer <rkolchmeyer@google.com>
Tested-by: Ke Wu <mikewu@google.com>
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
index a787cb9..e3973e7 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/install.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -29,11 +29,12 @@
// InstallCommand is the subcommand to install GPU drivers.
type InstallCommand struct {
- driverVersion string
- hostInstallDir string
- unsignedDriver bool
- internalDownload bool
- debug bool
+ driverVersion string
+ hostInstallDir string
+ unsignedDriver bool
+ gcsDownloadBucket string
+ gcsDownloadPrefix string
+ debug bool
}
// Name implements subcommands.Command.Name.
@@ -51,14 +52,17 @@
"The GPU driver verion to install. It will install the default GPU driver if the flag is not set explicitly.")
f.StringVar(&c.hostInstallDir, "host-dir", "",
"Host directory that GPU drivers should be installed to. "+
- "It tries to read from the env NVIDIA_INSTALL_DIR_HOST if the flag is not set explicitly.")
+ "It tries to read from the env NVIDIA_INSTALL_DIR_HOST if the flag is not set explicitly.")
f.BoolVar(&c.unsignedDriver, "allow-unsigned-driver", false,
"Whether to allow load unsigned GPU drivers. "+
"If this flag is set to true, module signing security features must be disabled on the host for driver installation to succeed. "+
"This flag is only for debugging.")
- // TODO(mikewu): change this flag to a bucket prefix string.
- f.BoolVar(&c.internalDownload, "internal-download", false,
- "Whether to try to download files from Google internal server. This is only useful for internal developing.")
+ f.StringVar(&c.gcsDownloadBucket, "gcs-download-bucket", "cos-tools",
+ "The GCS bucket to download COS artifacts from. "+
+ "For example, the default value is 'cos-tools' which is the public COS artifacts bucket.")
+ f.StringVar(&c.gcsDownloadPrefix, "gcs-download-prefix", "",
+ "The GCS path prefix when downloading COS artifacts."+
+ "If not set then the COS version build number (e.g. 13310.1041.38) will be used.")
f.BoolVar(&c.debug, "debug", false,
"Enable debug mode.")
}
@@ -73,7 +77,7 @@
log.Infof("Running on COS build id %s", envReader.BuildNumber())
- downloader := cos.NewGCSDownloader(envReader, c.internalDownload)
+ downloader := cos.NewGCSDownloader(envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix)
if c.driverVersion == "" {
defaultVersion, err := installer.GetDefaultGPUDriverVersion(downloader)
if err != nil {
diff --git a/src/cmd/cos_gpu_installer/internal/commands/list.go b/src/cmd/cos_gpu_installer/internal/commands/list.go
index da76756..ae8e155 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/list.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/list.go
@@ -17,8 +17,9 @@
// ListCommand is the subcommand to list supported GPU drivers.
type ListCommand struct {
- internalDownload bool
- debug bool
+ gcsDownloadBucket string
+ gcsDownloadPrefix string
+ debug bool
}
// Name implements subcommands.Command.Name.
@@ -32,9 +33,12 @@
// SetFlags implements subcommands.Command.SetFlags.
func (c *ListCommand) SetFlags(f *flag.FlagSet) {
- // TODO(mikewu): change this flag to a bucket prefix string.
- f.BoolVar(&c.internalDownload, "internal-download", false,
- "Whether to try to download files from Google internal server. This is only useful for internal developing.")
+ f.StringVar(&c.gcsDownloadBucket, "gcs-download-bucket", "cos-tools",
+ "The GCS bucket to download COS artifacts from. "+
+ "For example, the default value is 'cos-tools' which is the public COS artifacts bucket.")
+ f.StringVar(&c.gcsDownloadPrefix, "gcs-download-prefix", "",
+ "The GCS path prefix when downloading COS artifacts."+
+ "If not set then the COS version build number (e.g. 13310.1041.38) will be used.")
f.BoolVar(&c.debug, "debug", false,
"Enable debug mode.")
}
@@ -47,8 +51,8 @@
return subcommands.ExitFailure
}
log.Infof("Running on COS build id %s", envReader.BuildNumber())
- downloader := cos.NewGCSDownloader(envReader, c.internalDownload)
- artifacts, err := downloader.ListExtensionArtifacts("gpu")
+ downloader := cos.NewGCSDownloader(envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix)
+ artifacts, err := downloader.ListGPUExtensionArtifacts()
if err != nil {
c.logError(errors.Wrap(err, "failed to list gpu extension artifacts"))
return subcommands.ExitFailure
diff --git a/src/pkg/cos/artifacts.go b/src/pkg/cos/artifacts.go
index 1d720d7..76fb55c 100644
--- a/src/pkg/cos/artifacts.go
+++ b/src/pkg/cos/artifacts.go
@@ -4,6 +4,7 @@
"fmt"
"io/ioutil"
"os"
+ "path"
"path/filepath"
log "github.com/golang/glog"
@@ -13,9 +14,7 @@
)
const (
- // TODO(mikewu): consider making GCS buckets as flags.
cosToolsGCS = "cos-tools"
- internalGCS = "container-vm-image-staging"
chromiumOSSDKGCS = "chromiumos-sdk"
kernelInfo = "kernel_info"
kernelSrcArchive = "kernel-src.tar.gz"
@@ -38,13 +37,22 @@
// GCSDownloader is the struct downloading COS artifacts from GCS bucket.
type GCSDownloader struct {
- envReader *EnvReader
- Internal bool
+ envReader *EnvReader
+ gcsDownloadBucket string
+ gcsDownloadPrefix string
}
// NewGCSDownloader creates a GCSDownloader instance.
-func NewGCSDownloader(e *EnvReader, i bool) *GCSDownloader {
- return &GCSDownloader{e, i}
+func NewGCSDownloader(e *EnvReader, bucket, prefix string) *GCSDownloader {
+ // Use cos-tools as the default GCS bucket.
+ if bucket == "" {
+ bucket = cosToolsGCS
+ }
+ // Use build number as the default GCS download prefix.
+ if prefix == "" {
+ prefix = e.BuildNumber()
+ }
+ return &GCSDownloader{e, bucket, prefix}
}
// DownloadKernelSrc downloads COS kernel sources to destination directory.
@@ -95,32 +103,13 @@
return content, nil
}
-// DownloadArtifact downloads an artifact from GCS buckets, including public bucket and internal bucket.
-// TODO(mikewu): consider allow users to pass in GCS directories in arguments.
+// DownloadArtifact downloads an artifact from the GCS prefix configured in GCSDownloader.
func (d *GCSDownloader) DownloadArtifact(destDir, artifactPath string) error {
- var err error
-
- if err = utils.DownloadFromGCS(destDir, cosToolsGCS, d.artifactPublicPath(artifactPath)); err == nil {
- return nil
+ gcsPath := path.Join(d.gcsDownloadPrefix, artifactPath)
+ if err := utils.DownloadFromGCS(destDir, d.gcsDownloadBucket, gcsPath); err != nil {
+ return errors.Errorf("failed to download %s from gs://%s/%s", artifactPath, d.gcsDownloadBucket, gcsPath)
}
- log.Errorf("Failed to download %s from public GCS: %v", artifactPath, err)
-
- if d.Internal {
- if err = utils.DownloadFromGCS(destDir, internalGCS, d.artifactInternalPath(artifactPath)); err == nil {
- return nil
- }
- log.Errorf("Failed to download %s from internal GCS: %v", artifactPath, err)
- }
-
- return errors.Errorf("failed to download %s", artifactPath)
-}
-
-func (d *GCSDownloader) artifactPublicPath(artifactPath string) string {
- return fmt.Sprintf("%s/%s", d.envReader.BuildNumber(), artifactPath)
-}
-
-func (d *GCSDownloader) artifactInternalPath(artifactPath string) string {
- return fmt.Sprintf("lakitu-release/R%s-%s/%s", d.envReader.Milestone(), d.envReader.BuildNumber(), artifactPath)
+ return nil
}
func (d *GCSDownloader) getToolchainURL() (string, error) {
diff --git a/src/pkg/cos/extensions.go b/src/pkg/cos/extensions.go
index c8968d7..151ed80 100644
--- a/src/pkg/cos/extensions.go
+++ b/src/pkg/cos/extensions.go
@@ -2,12 +2,11 @@
import (
"fmt"
- "path/filepath"
+ "path"
"regexp"
"cos.googlesource.com/cos/tools.git/src/pkg/utils"
- log "github.com/golang/glog"
"github.com/pkg/errors"
)
@@ -28,18 +27,10 @@
func (d *GCSDownloader) ListExtensions() ([]string, error) {
var objects []string
var err error
- if objects, err = utils.ListGCSBucket(cosToolsGCS, d.artifactPublicPath("extensions")); err != nil || len(objects) == 0 {
- log.Errorf("Failed to list extensions from public GCS: %v", err)
- if d.Internal {
- if objects, err = utils.ListGCSBucket(internalGCS, d.artifactInternalPath("extensions")); err != nil {
- log.Errorf("Failed to list extensions from internal GCS: %v", err)
- }
- }
- }
- if err != nil {
+ gcsPath := path.Join(d.gcsDownloadPrefix, "extensions")
+ if objects, err = utils.ListGCSBucket(d.gcsDownloadBucket, gcsPath); err != nil {
return nil, errors.Wrap(err, "failed to list extensions")
}
-
var extensions []string
re := regexp.MustCompile(`extensions/(\w+)$`)
for _, object := range objects {
@@ -51,21 +42,11 @@
}
// ListExtensionArtifacts lists all artifacts of a given extension.
-// TODO(mikewu): make this extension specific.
func (d *GCSDownloader) ListExtensionArtifacts(extension string) ([]string, error) {
var objects []string
var err error
- extensionPath := filepath.Join("extensions", extension)
- if objects, err = utils.ListGCSBucket(cosToolsGCS, d.artifactPublicPath(extensionPath)); err != nil || len(objects) == 0 {
- log.Errorf("Failed to list extension artifacts from public GCS: %v", err)
- // TODO(mikewu): use flags to specify GCS directories.
- if d.Internal {
- if objects, err = utils.ListGCSBucket(internalGCS, d.artifactInternalPath(extensionPath)); err != nil {
- log.Errorf("Failed to list extension artifacts from internal GCS: %v", err)
- }
- }
- }
- if err != nil {
+ gcsPath := path.Join(d.gcsDownloadPrefix, "extensions", extension)
+ if objects, err = utils.ListGCSBucket(d.gcsDownloadBucket, gcsPath); err != nil {
return nil, errors.Wrap(err, "failed to list extensions")
}
@@ -79,14 +60,19 @@
return artifacts, nil
}
+// ListGPUExtensionArtifacts lists all artifacts of GPU extension.
+func (d *GCSDownloader) ListGPUExtensionArtifacts() ([]string, error) {
+ return d.ListExtensionArtifacts(GPUExtension)
+}
+
// DownloadExtensionArtifact downloads an artifact of the given extension.
func (d *GCSDownloader) DownloadExtensionArtifact(destDir, extension, artifact string) error {
- artifactPath := filepath.Join("extensions", extension, artifact)
+ artifactPath := path.Join("extensions", extension, artifact)
return d.DownloadArtifact(destDir, artifactPath)
}
// GetExtensionArtifact reads the content of an artifact of the given extension.
func (d *GCSDownloader) GetExtensionArtifact(extension, artifact string) ([]byte, error) {
- artifactPath := filepath.Join("extensions", extension, artifact)
+ artifactPath := path.Join("extensions", extension, artifact)
return d.GetArtifact(artifactPath)
}