Merge "cos-gpu-installer-v2: Add commands package"
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
new file mode 100644
index 0000000..f75805f
--- /dev/null
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -0,0 +1,164 @@
+// Package commands implements subcommands of cos_gpu_installer.
+package commands
+
+import (
+ "context"
+ "path/filepath"
+
+ "cmd/cos_gpu_installer/internal/installer"
+ "cmd/cos_gpu_installer/internal/signing"
+ "flag"
+ "pkg/cos"
+ "pkg/modules"
+
+ log "github.com/golang/glog"
+ "github.com/google/subcommands"
+ "github.com/pkg/errors"
+)
+
+const (
+ hostRootPath = "/root"
+ kernelSrcDir = "/build/usr/src/linux"
+ kernelHeaderDir = "/build/usr/src/linux-headers"
+ toolchainPkgDir = "/build/cos-tools"
+)
+
+// InstallCommand is the subcommand to install GPU drivers.
+type InstallCommand struct {
+ driverVersion string
+ hostInstallDir string
+ enforceSigning bool
+ internalDownload bool
+ debug bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*InstallCommand) Name() string { return "install" }
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*InstallCommand) Synopsis() string { return "Install GPU drivers." }
+
+// Usage implements subcommands.Command.Usage.
+func (*InstallCommand) Usage() string { return "install [-dir <filepath>]\n" }
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *InstallCommand) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&c.driverVersion, "version", "",
+ "The GPU driver verion to install. It will install the default GPU driver if the flag is not set explicitly.")
+ f.StringVar(&c.hostInstallDir, "dir", "/var/lib/nvidia",
+ "Host directory that GPU drivers should be installed to")
+ f.BoolVar(&c.enforceSigning, "enforce-signing", true,
+ "Whether to enforce GPU drivers being signed. Setting to false will disable kernel module signing security feature.")
+ // TODO(mikewu): change this flag to a bucket prefix string.
+ f.BoolVar(&c.internalDownload, "internal-download", false,
+ "Whether to try to download files from Google internal server. This is only useful for internal developing.")
+ f.BoolVar(&c.debug, "debug", false,
+ "Enable debug mode.")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *InstallCommand) Execute(ctx context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
+ envReader, err := cos.NewEnvReader(hostRootPath)
+ if err != nil {
+ c.logError(errors.Wrapf(err, "failed to create envReader with host root path %s", hostRootPath))
+ return subcommands.ExitFailure
+ }
+
+ log.Infof("Running on COS build id %s", envReader.BuildNumber())
+
+ downloader := &cos.GCSDownloader{envReader, c.internalDownload}
+ if c.driverVersion == "" {
+ defaultVersion, err := installer.GetDefaultGPUDriverVersion(downloader)
+ if err != nil {
+ c.logError(errors.Wrap(err, "failed to get default driver version"))
+ return subcommands.ExitFailure
+ }
+ c.driverVersion = defaultVersion
+ }
+ log.Infof("Installing GPU driver version %s", c.driverVersion)
+
+ if !c.enforceSigning {
+ log.Info("Doesn't enforce signing. Need to disable module locking.")
+ if err := cos.DisableKernelModuleLocking(); err != nil {
+ c.logError(errors.Wrap(err, "failed to configure kernel module locking"))
+ return subcommands.ExitFailure
+ }
+ }
+
+ hostInstallDir := filepath.Join(hostRootPath, c.hostInstallDir)
+ cacher := installer.NewCacher(hostInstallDir, envReader.BuildNumber(), c.driverVersion)
+ if isCached, err := cacher.IsCached(); isCached && err == nil {
+ log.Info("Found cached version, NOT building the drivers.")
+ if err := installer.ConfigureCachedInstalltion(hostInstallDir, c.enforceSigning); err != nil {
+ c.logError(errors.Wrap(err, "failed to configure cached installation"))
+ return subcommands.ExitFailure
+ }
+ if err := installer.VerifyDriverInstallation(); err != nil {
+ c.logError(errors.Wrap(err, "failed to verify GPU driver installation"))
+ return subcommands.ExitFailure
+ }
+ if err := modules.UpdateHostLdCache(hostRootPath, filepath.Join(c.hostInstallDir, "lib64")); err != nil {
+ c.logError(errors.Wrap(err, "failed to update host ld cache"))
+ return subcommands.ExitFailure
+ }
+ return subcommands.ExitSuccess
+ }
+
+ log.Info("Did not find cached version, installing the drivers...")
+ if err := installDriver(c, cacher, envReader, downloader); err != nil {
+ c.logError(err)
+ return subcommands.ExitFailure
+ }
+
+ return subcommands.ExitSuccess
+}
+
+func installDriver(c *InstallCommand, cacher *installer.Cacher, envReader *cos.EnvReader, downloader *cos.GCSDownloader) error {
+ callback, err := installer.ConfigureDriverInstallationDirs(filepath.Join(hostRootPath, c.hostInstallDir), envReader.KernelRelease())
+ if err != nil {
+ return errors.Wrap(err, "failed to configure GPU driver installation dirs")
+ }
+ defer func() { callback <- 0 }()
+
+ if c.enforceSigning {
+ if err := signing.DownloadDriverSignatures(downloader, c.driverVersion); err != nil {
+ return errors.Wrap(err, "failed to download driver signature")
+ }
+ }
+
+ installerFile, err := installer.DownloadDriverInstaller(
+ c.driverVersion, envReader.Milestone(), envReader.BuildNumber())
+ if err != nil {
+ return errors.Wrap(err, "failed to download GPU driver installer")
+ }
+
+ if err := cos.SetCompilationEnv(downloader); err != nil {
+ return errors.Wrap(err, "failed to set compilation environment variables")
+ }
+ if err := cos.InstallCrossToolchain(downloader, toolchainPkgDir); err != nil {
+ return errors.Wrap(err, "failed to install toolchain")
+ }
+
+ if err := installer.RunDriverInstaller(installerFile, c.enforceSigning); err != nil {
+ return errors.Wrap(err, "failed to run GPU driver installer")
+ }
+ if err := cacher.Cache(); err != nil {
+ return errors.Wrap(err, "failed to cache installation")
+ }
+ if err := installer.VerifyDriverInstallation(); err != nil {
+ return errors.Wrap(err, "failed to verify installation")
+ }
+ if err := modules.UpdateHostLdCache(hostRootPath, filepath.Join(c.hostInstallDir, "lib64")); err != nil {
+ return errors.Wrap(err, "failed to update host ld cache")
+ }
+ log.Info("Finished installing the drivers.")
+ return nil
+}
+
+func (c *InstallCommand) logError(err error) {
+ if c.debug {
+ log.Errorf("%+v", err)
+ } else {
+ log.Errorf("%v", err)
+ }
+}
diff --git a/src/cmd/cos_gpu_installer/internal/commands/list.go b/src/cmd/cos_gpu_installer/internal/commands/list.go
new file mode 100644
index 0000000..30523cf
--- /dev/null
+++ b/src/cmd/cos_gpu_installer/internal/commands/list.go
@@ -0,0 +1,79 @@
+package commands
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "cmd/cos_gpu_installer/internal/installer"
+ "flag"
+ "pkg/cos"
+
+ log "github.com/golang/glog"
+ "github.com/google/subcommands"
+ "github.com/pkg/errors"
+)
+
+// ListCommand is the subcommand to list supported GPU drivers.
+type ListCommand struct {
+ internalDownload bool
+ debug bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*ListCommand) Name() string { return "list" }
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*ListCommand) Synopsis() string { return "List supported GPU drivers for this version." }
+
+// Usage implements subcommands.Command.Usage.
+func (*ListCommand) Usage() string { return "list\n" }
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *ListCommand) SetFlags(f *flag.FlagSet) {
+ // TODO(mikewu): change this flag to a bucket prefix string.
+ f.BoolVar(&c.internalDownload, "internal-download", false,
+ "Whether to try to download files from Google internal server. This is only useful for internal developing.")
+ f.BoolVar(&c.debug, "debug", false,
+ "Enable debug mode.")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *ListCommand) Execute(ctx context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
+ envReader, err := cos.NewEnvReader(hostRootPath)
+ if err != nil {
+ c.logError(errors.Wrap(err, "failed to create envReader"))
+ return subcommands.ExitFailure
+ }
+ log.Infof("Running on COS build id %s", envReader.BuildNumber())
+ downloader := &cos.GCSDownloader{envReader, c.internalDownload}
+ artifacts, err := downloader.ListExtensionArtifacts("gpu")
+ if err != nil {
+ c.logError(errors.Wrap(err, "failed to list gpu extension artifacts"))
+ return subcommands.ExitFailure
+ }
+ defaultVersion, err := installer.GetDefaultGPUDriverVersion(downloader)
+ if err != nil {
+ c.logError(errors.Wrap(err, "failed to get default driver version"))
+ return subcommands.ExitFailure
+ }
+ for _, artifact := range artifacts {
+ if strings.HasSuffix(artifact, ".signature.tar.gz") {
+ driverVersion := strings.TrimSuffix(artifact, ".signature.tar.gz")
+ if defaultVersion == driverVersion {
+ fmt.Printf("%s [default]\n", driverVersion)
+ } else {
+ fmt.Printf("%s\n", driverVersion)
+ }
+ }
+ }
+ return subcommands.ExitSuccess
+}
+
+func (c *ListCommand) logError(err error) {
+ if c.debug {
+ log.Errorf("%+v", err)
+ } else {
+ log.Errorf("%v", err)
+ }
+}