cos-gpu-installer: fall back to R470 on K80 GPU

R470 is the last driver family supporting K80 GPU devices.
If a newer version is used with K80, the installer will automatically
fall back to R470 version.

Logs when fallback happens: https://paste.googleplex.com/6419100317253632

BUG=b/240587402
TEST=run on VMs with K80

Change-Id: I044b147c77179d6a893f11e155c675755d9c5beb
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/36594
Reviewed-by: Robert Kolchmeyer <rkolchmeyer@google.com>
Tested-by: He Gao <hegao@google.com>
Cloud-Build: GCB Service account <228075978874@cloudbuild.gserviceaccount.com>
Reviewed-by: Arnav Kansal <rnv@google.com>
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
index ca193c5..35e11d5 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/install.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -7,6 +7,7 @@
 	"fmt"
 	"io/ioutil"
 	"os"
+	"os/exec"
 	"path/filepath"
 	"strconv"
 	"strings"
@@ -18,7 +19,6 @@
 	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/internal/signing"
 	"cos.googlesource.com/cos/tools.git/src/pkg/cos"
 	"cos.googlesource.com/cos/tools.git/src/pkg/modules"
-	"cos.googlesource.com/cos/tools.git/src/pkg/utils"
 
 	log "github.com/golang/glog"
 	"github.com/google/subcommands"
@@ -32,6 +32,38 @@
 	toolchainPkgDir = "/build/cos-tools"
 )
 
+type GPUType int
+
+const (
+	// Currently we only need to know if K80 is used.
+	K80 GPUType = iota
+	Others
+)
+
+func (g GPUType) String() string {
+	switch g {
+	case K80:
+		return "K80"
+	case Others:
+		return "Others"
+	default:
+		return "Unknown"
+	}
+}
+
+type Fallback struct {
+	maxMajorVersion          int
+	getFallbackDriverVersion func(cos.ArtifactsDownloader) (string, error)
+}
+
+var fallbackMap = map[GPUType]Fallback{
+	// R470 is the last driver family supporting K80 GPU devices.
+	K80: {
+		maxMajorVersion:          470,
+		getFallbackDriverVersion: installer.GetR470GPUDriverVersion,
+	},
+}
+
 // InstallCommand is the subcommand to install GPU drivers.
 type InstallCommand struct {
 	driverVersion      string
@@ -60,7 +92,10 @@
 	f.StringVar(&c.driverVersion, "version", "",
 		"The GPU driver verion to install. "+
 			"It will install the default GPU driver if the flag is not set explicitly. "+
-			"Set the flag to 'latest' to install the latest GPU driver version.")
+			"Set the flag to 'latest' to install the latest GPU driver version. "+
+			"Please note that R470 is the last driver family supporting K80 GPU devices. "+
+			"If a higher version is used with K80 GPU, the installer will automatically "+
+			"choose an available R470 driver version.")
 	f.StringVar(&c.hostInstallDir, "host-dir", "",
 		"Host directory that GPU drivers should be installed to. "+
 			"It tries to read from the env NVIDIA_INSTALL_DIR_HOST if the flag is not set explicitly.")
@@ -128,10 +163,12 @@
 		return subcommands.ExitFailure
 	}
 
+	var gpuType GPUType
+
 	if !c.prepareBuildTools {
 		var isGpuConfigured bool
-		if isGpuConfigured, err = c.isGpuConfigured(); err != nil {
-			c.logError(errors.Wrapf(err, "failed to check if GPU is configured"))
+		if isGpuConfigured, gpuType, err = c.getGPUTypeInfo(); err != nil {
+			c.logError(errors.Wrapf(err, "failed to get GPU type information"))
 			return subcommands.ExitFailure
 		}
 
@@ -159,6 +196,10 @@
 				return subcommands.ExitFailure
 			}
 		}
+		if err := c.checkDriverCompatibility(downloader, gpuType); err != nil {
+			c.logError(errors.Wrap(err, "failed to check driver compatibility"))
+			return subcommands.ExitFailure
+		}
 		log.Infof("Installing GPU driver version %s", c.driverVersion)
 	} else {
 		log.Infof("Installing GPU driver from %q", c.nvidiaInstallerURL)
@@ -316,12 +357,37 @@
 	}
 }
 
-func (c *InstallCommand) isGpuConfigured() (bool, error) {
+func (c *InstallCommand) getGPUTypeInfo() (bool, GPUType, error) {
 	cmd := "lspci | grep -i \"nvidia\""
-	returnCode, err := utils.RunCommandWithExitCode([]string{"/bin/bash", "-c", cmd}, "", nil)
+	outBytes, err := exec.Command("/bin/bash", "-c", cmd).Output()
 	if err != nil {
-		return false, err
+		return false, Others, err
 	}
-	isConfigured := returnCode == grepFound
-	return isConfigured, nil
+	out := string(outBytes)
+	switch {
+	case strings.Contains(out, "[Tesla K80]"):
+		return true, K80, nil
+	default:
+		return true, Others, nil
+	}
+}
+
+func (c *InstallCommand) checkDriverCompatibility(downloader *cos.GCSDownloader, gpuType GPUType) error {
+	driverMajorVersion, err := strconv.Atoi(strings.Split(c.driverVersion, ".")[0])
+	if err != nil {
+		return errors.Wrap(err, "failed to get driver major version")
+	}
+
+	fallback, found := fallbackMap[gpuType]
+	if found && driverMajorVersion > fallback.maxMajorVersion {
+		log.Warningf("\n\nDriver version %s doesn't support %s GPU devices.\n\n", c.driverVersion, gpuType)
+		fallbackVersion, err := fallback.getFallbackDriverVersion(downloader)
+		if err != nil {
+			return errors.Wrap(err, "failed to get fallback driver")
+		}
+		log.Warningf("\n\nUsing driver version %s for %s GPU compatibility.\n\n", fallbackVersion, gpuType)
+		c.driverVersion = fallbackVersion
+		return nil
+	}
+	return nil
 }
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer.go b/src/cmd/cos_gpu_installer/internal/installer/installer.go
index d02b959..d83c294 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer.go
@@ -28,6 +28,7 @@
 	gpuInstallDirContainer        = "/usr/local/nvidia"
 	defaultGPUDriverFile          = "gpu_default_version"
 	latestGPUDriverFile           = "gpu_latest_version"
+	r470GPUDriverFile             = "gpu_R470_version"
 	precompiledInstallerURLFormat = "https://storage.googleapis.com/nvidia-drivers-%s-public/nvidia-cos-project/%s/tesla/%s_00/%s/NVIDIA-Linux-x86_64-%s_%s-%s.cos"
 	defaultFilePermission         = 0755
 	signedURLKey                  = "Expires"
@@ -414,6 +415,16 @@
 	return strings.Trim(string(content), "\n "), nil
 }
 
+// GetR470GPUDriverVersion gets the R470 GPU driver version.
+func GetR470GPUDriverVersion(downloader cos.ArtifactsDownloader) (string, error) {
+	log.Info("Getting the R470 GPU driver version")
+	content, err := downloader.GetArtifact(r470GPUDriverFile)
+	if err != nil {
+		return "", errors.Wrapf(err, "failed to get R470 GPU driver version")
+	}
+	return strings.Trim(string(content), "\n "), nil
+}
+
 func updateContainerLdCache() error {
 	log.V(2).Info("Updating container's ld cache")