// Package commands implements subcommands of cos_gpu_installer.
package commands

import (
	"context"
	stderrors "errors"
	"fmt"
	"io/ioutil"
	"os"
	"path/filepath"
	"regexp"
	"strconv"
	"strings"
	"syscall"

	"flag"

	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/deviceinfo"
	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/features"
	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/internal/installer"
	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/internal/signing"
	"cos.googlesource.com/cos/tools.git/src/pkg/cos"
	"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig"
	"cos.googlesource.com/cos/tools.git/src/pkg/modules"

	log "github.com/golang/glog"
	"github.com/google/subcommands"
	"github.com/pkg/errors"
)

const (
	grepFound            = 0
	hostRootPath         = "/root"
	scratchDir           = "/tmp"
	kernelSrcDir         = "/build/usr/src/linux"
	toolchainPkgDir      = "/build/cos-tools"
	cosGPUConfigJsonPath = "/etc/cos-gpu-config.json"
)

type Fallback struct {
	minMajorVersion       int
	maxMajorVersion       int
	fallbackDriverVersion string
}

// Checks compatibilty of input driverVersion with Fallback for GPU device type
func (f Fallback) Compatible(driverMajorVersion int) bool {
	if f.maxMajorVersion != 0 && driverMajorVersion > f.maxMajorVersion {
		return false
	}
	if f.minMajorVersion != 0 && driverMajorVersion < f.minMajorVersion {
		return false
	}
	return true
}

var fallbackMap = map[deviceinfo.GPUType]Fallback{
	// R470 is the last driver family supporting K80 GPU devices.
	deviceinfo.K80: {
		maxMajorVersion:       470,
		minMajorVersion:       450,
		fallbackDriverVersion: "R470",
	},
	deviceinfo.L4: {
		minMajorVersion:       525,
		fallbackDriverVersion: "R535",
	},
	deviceinfo.H100: {
		minMajorVersion:       525,
		fallbackDriverVersion: "R535",
	},
}

type FallBackFlag struct {
	isSet bool
	value bool
}

func (f *FallBackFlag) String() string {
	return fmt.Sprintf("%v", f.value)
}

func (f *FallBackFlag) Set(s string) error {
	v, err := strconv.ParseBool(s)
	if err != nil {
		err = errors.New("parse error")
	}
	f.value = v
	f.isSet = true
	return err
}

func (f *FallBackFlag) IsSet() bool {
	return f.isSet
}

func (f *FallBackFlag) Get() bool {
	return f.value
}

// InstallCommand is the subcommand to install GPU drivers.
type InstallCommand struct {
	driverVersion           string
	hostInstallDir          string
	forceFallback           FallBackFlag
	unsignedDriver          bool
	gcsDownloadBucket       string
	gcsDownloadPrefix       string
	nvidiaInstallerURL      string
	signatureURL            string
	targetGPU               string
	debug                   bool
	test                    bool
	prepareBuildTools       bool
	kernelOpen              bool
	noVerify                bool
	kernelModuleParams      modules.ModuleParameters
	nvidiaInstallerURLOpen  string
	gcsDownloadBucketNvidia string
	gcsDownloadPrefixNvidia string
}

// Name implements subcommands.Command.Name.
func (*InstallCommand) Name() string { return "install" }

// Synopsis implements subcommands.Command.Synopsis.
func (*InstallCommand) Synopsis() string { return "Install GPU drivers." }

// Usage implements subcommands.Command.Usage.
func (*InstallCommand) Usage() string { return "install [-dir <filepath>]\n" }

// SetFlags implements subcommands.Command.SetFlags.
func (c *InstallCommand) SetFlags(f *flag.FlagSet) {
	f.StringVar(&c.driverVersion, "version", "",
		"The GPU driver verion to install. "+
			"It will install the default GPU driver if the flag is not set explicitly. "+
			"Set the flag to 'latest' to install the latest GPU driver version. "+
			"Please note that R470 is the last driver family supporting K80 GPU devices. "+
			"If a higher version is used with K80 GPU, the installer will automatically "+
			"choose an available R470 driver version."+
			"Supported values are 'default', 'latest', 'R<major-version>' eg. 'R470', 'R525'., '<precise-version>' eg. '535.129.03', '525.147.05'., "+
			"For a list of supported driver version for your COS revision "+
			"please use `$cos-extensions list gpu` or view the versions available at "+
			"https://cloud.google.com/container-optimized-os/docs/release-notes.")
	f.StringVar(&c.targetGPU, "target-gpu", "", fmt.Sprintf("This flag specifies the GPU device for driver installation. "+
		"If specified, it must be one of %s. If not specified, the GPU device will be auto-detected by the installer.", strings.Join(deviceinfo.AllGPUTypeStrings, ", ")))
	f.StringVar(&c.hostInstallDir, "host-dir", "",
		"Host directory that GPU drivers should be installed to. "+
			"It tries to read from the env NVIDIA_INSTALL_DIR_HOST if the flag is not set explicitly.")
	f.BoolVar(&c.unsignedDriver, "allow-unsigned-driver", false,
		"Whether to allow load unsigned GPU drivers. "+
			"If this flag is set to true, module signing security features must be disabled on the host for driver installation to succeed. "+
			"This flag is only for debugging and testing.")
	f.StringVar(&c.gcsDownloadBucket, "gcs-download-bucket", "",
		"The GCS bucket to download COS artifacts from. "+
			"The default bucket is one of 'cos-tools', 'cos-tools-asia' and 'cos-tools-eu' based on where the VM is running. "+
			"Those are the public COS artifacts buckets.")
	f.StringVar(&c.gcsDownloadPrefix, "gcs-download-prefix", "",
		"The GCS path prefix when downloading COS artifacts."+
			"If not set then the COS build number and board (e.g. 13310.1041.38/lakitu) will be used.")
	f.StringVar(&c.nvidiaInstallerURL, "nvidia-installer-url", "",
		"A URL to an nvidia-installer to use for driver installation. This flag is mutually exclusive with `-version`. "+
			"This flag must be used with `-allow-unsigned-driver`. This flag is only for debugging and testing.")
	f.StringVar(&c.signatureURL, "signature-url", "",
		"A URL to the driver signature. This flag can only be used together with `-test` and `-nvidia-installer-url` for for debugging and testing.")
	f.StringVar(&c.nvidiaInstallerURLOpen, "nvidia-installer-url-open", "", "This can be used to specify the location of the GSP firmware and user-space NVIDIA GPU driver components from a corresponding driver release of the OSS kernel modules. This flag is only for debugging and testing.")
	f.StringVar(&c.gcsDownloadBucketNvidia, "gcs-download-bucket-nvidia", "", "The GCS bucket containing NVIDIA generic driver packages. "+
		"Must be used in conjunction with `-gcs-download-prefix-nvidia`."+
		"If not set, then it would be one of 'nvidia-drivers-us-public', 'nvidia-drivers-asia-public', and 'nvidia-drivers-eu-public' based on where the VM is running.")
	f.StringVar(&c.gcsDownloadPrefixNvidia, "gcs-download-prefix-nvidia", "", "The GCS prefix where NVIDIA generic driver packages are located within the bucket. "+
		"Must be used in conjunction with `-gcs-download-bucket-nvidia`."+
		"If not set, defaults to 'tesla/{driver-version}' (e.g., 'tesla/535.129.03').")
	f.BoolVar(&c.debug, "debug", false,
		"Enable debug mode.")
	f.BoolVar(&c.test, "test", false,
		"Enable test mode. "+
			"In test mode, `-nvidia-installer-url` can be used without `-allow-unsigned-driver`.")
	f.BoolVar(&c.prepareBuildTools, "prepare-build-tools", false, "Whether to populate the build tools cache, i.e. to download and install the toolchain and the kernel headers. Drivers are NOT installed when this flag is set and running with this flag does not require GPU attached to the instance.")
	f.BoolVar(&c.noVerify, "no-verify", false, "Skip kernel module loading and installation verification. Useful for preloading drivers without attached GPU.")
	c.kernelModuleParams = modules.NewModuleParameters()
	f.Var(&c.kernelModuleParams, "module-arg", "Kernel module parameters can be specified using this flag. These parameters are used while loading the specific kernel mode drivers into the kernel. Usage: -module-arg <module-x>.<parameter-y>=<value> -module-arg <module-y>.<parameter-z>=<value> ..    For eg: –module-arg nvidia_uvm.uvm_debug_prints=1 –module-arg nvidia.NVreg_EnableGpuFirmware=0.")
	f.Var(&c.forceFallback, "force-fallback", "This flag specify whether to use fallback mechanism when specified GPU driver is not compatible with GPU devices.\n"+
		"If unspecified, it is `false` for --version=R<major-version> eg. 'R470', 'R525' or --version=<precise-version> eg. '535.129.03', '525.147.05', it is `true` for version is not specified or --version=default or --version=latest.\n"+
		"When fallback behavior is active, the installer will find a compatible driver to install for the detected GPU on the VM.")
}

func (c *InstallCommand) validateFlags() error {
	if c.nvidiaInstallerURL != "" && c.driverVersion != "" {
		return stderrors.New("-nvidia-installer-url and -version are both set; these flags are mutually exclusive")
	}
	if c.nvidiaInstallerURL != "" && c.unsignedDriver == false && c.test == false {
		return stderrors.New("-nvidia-installer-url is set, and -allow-unsigned-driver is not; -nvidia-installer-url must be used with -allow-unsigned-driver if not in test mode")
	}
	if c.signatureURL != "" && (c.nvidiaInstallerURL == "" || c.test == false) {
		return stderrors.New("-signature-url must be used with -nvidia-installer-url and -test")
	}
	if c.nvidiaInstallerURLOpen != "" && (c.driverVersion == "" || c.test == false) {
		return stderrors.New("-nvidia-installer-url-open must be used with -test and -version")
	}
	if c.gcsDownloadPrefixNvidia != "" && c.gcsDownloadBucketNvidia == "" {
		return stderrors.New("-gcs-download-prefix-nvidia and -gcs-download-bucket-nvidia must be set together")
	}
	return nil
}

// Execute implements subcommands.Command.Execute.
func (c *InstallCommand) Execute(ctx context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
	if err := c.validateFlags(); err != nil {
		c.logError(err)
		return subcommands.ExitFailure
	}
	envReader, err := cos.NewEnvReader(hostRootPath)
	if err != nil {
		c.logError(errors.Wrapf(err, "failed to create envReader with host root path %s", hostRootPath))
		return subcommands.ExitFailure
	}
	if c.debug {
		if err := flag.Set("v", "2"); err != nil {
			log.Errorf("Unable to set debug logging: %v", err)
		}
	}

	log.V(2).Infof("Running on COS build id %s", envReader.BuildNumber())

	fullCOSGPUConfigJsonPath := filepath.Join(hostRootPath, cosGPUConfigJsonPath)
	log.Infof("Feature flags initialization")
	featureConfig, err := features.InitFlags(fullCOSGPUConfigJsonPath)
	if err != nil {
		log.Errorf("Unable to init feature flags: %v", err)
		return subcommands.ExitFailure
	}

	// All prerelease builds are in dev-channel. For testing we don't need to check release track.
	// we can preload dependencies for dev-channel images too.
	if releaseTrack := envReader.ReleaseTrack(); !c.prepareBuildTools && !c.test && releaseTrack == "dev-channel" {
		c.logError(fmt.Errorf("GPU installation is not supported on dev images for now; Please use LTS image."))
		return subcommands.ExitFailure
	}

	var gpuType deviceinfo.GPUType = deviceinfo.NO_GPU
	if !c.prepareBuildTools {
		if featureConfig.PerGPULabelFeatureFlag && c.targetGPU != "" {
			gpuType, err = deviceinfo.ParseGPUType(c.targetGPU)
			if err != nil {
				c.logError(fmt.Errorf("failed to parse the target GPU type: %v", err))
				return subcommands.ExitFailure
			}
		} else {
			if gpuType, err = deviceinfo.GetGPUTypeInfo(); err != nil {
				if !c.noVerify {
					c.logError(errors.Wrapf(err, "failed to detect GPU type information"))
					return subcommands.ExitFailure
				}
				log.Infof("No GPU device detected, continue driver preoloading without verification.")
			}
		}
	}
	log.Infof("Install GPU driver for device type: %s", gpuType)

	// Read value from env NVIDIA_INSTALL_DIR_HOST if the flag is not set. This is to be compatible with old interface.
	if c.hostInstallDir == "" {
		c.hostInstallDir = os.Getenv("NVIDIA_INSTALL_DIR_HOST")
	}
	hostInstallDir := filepath.Join(hostRootPath, c.hostInstallDir)

	downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix)
	if c.nvidiaInstallerURL == "" {
		versionInput := c.driverVersion
		useFallback := useFallbackMechanism(c)
		if featureConfig.PerGPULabelFeatureFlag {
			c.driverVersion, err = getDriverVersionFromProto(ctx, downloader, gpuType, hostInstallDir, c.driverVersion, useFallback)
			if err != nil {
				c.logError(fmt.Errorf("failed to get %s driver version with error: %w", versionInput, err))
				return subcommands.ExitFailure
			}
			if gpuType == deviceinfo.Others {
				log.Warningf("Warning: the GPU device is not recognized by the cos-gpu-installer, "+
					"the GPU driver: %s we are installing may not compatible with the current GPU device. "+
					"Please go to https://cloud.google.com/container-optimized-os/docs/how-to/run-gpus to see the supported GPU devices on COS.", c.driverVersion)
			}
			if gpuType == deviceinfo.NO_GPU {
				log.Warningf("Warning: the GPU device is not detected by the cos-gpu-installer, "+
					"the GPU driver: %s we are installing may not compatible on the GPUs this image is being preloaded for.", c.driverVersion)

			}
		} else {
			c.driverVersion, err = getDriverVersion(ctx, downloader, c.driverVersion)
			if err != nil {
				c.logError(fmt.Errorf("failed to get %s driver version with error: %w", versionInput, err))
				return subcommands.ExitFailure
			}
			if err := c.checkDriverCompatibility(ctx, downloader, gpuType, useFallback); err != nil {
				c.logError(errors.Wrap(err, "failed to check driver compatibility"))
				return subcommands.ExitFailure
			}
		}

		log.Infof("Installing GPU driver version %s", c.driverVersion)
	} else {
		log.Infof("Installing GPU driver from %q", c.nvidiaInstallerURL)
	}

	if c.unsignedDriver {
		kernelCmdline, err := ioutil.ReadFile("/proc/cmdline")
		if err != nil {
			c.logError(fmt.Errorf("failed to read kernel command line: %v", err))
		}
		if cos.CheckKernelModuleSigning(string(kernelCmdline)) {
			log.Warning("Current kernel command line does not support unsigned kernel modules. Not enforcing kernel module signing may cause installation fail.")
		}
	}

	var cacher *installer.Cacher
	// We only want to cache drivers installed from official sources.
	if c.nvidiaInstallerURL == "" && c.nvidiaInstallerURLOpen == "" {
		cacher = installer.NewCacher(hostInstallDir, envReader.BuildNumber(), c.driverVersion)
		if isCached, isOpen, err := cacher.IsCached(); isCached && err == nil {
			log.V(2).Info("Found cached version, NOT building the drivers.")
			if err := installer.ConfigureCachedInstallation(hostInstallDir, !c.unsignedDriver, c.test, isOpen, c.noVerify, c.kernelModuleParams); err != nil {
				c.logError(errors.Wrap(err, "failed to configure cached installation"))
				return subcommands.ExitFailure
			}
			if err := installer.VerifyDriverInstallation(c.noVerify, c.debug); err != nil {
				c.logError(errors.Wrap(err, "failed to verify GPU driver installation"))
				return subcommands.ExitFailure
			}
			if err := modules.UpdateHostLdCache(hostRootPath, filepath.Join(c.hostInstallDir, "lib64")); err != nil {
				c.logError(errors.Wrap(err, "failed to update host ld cache"))
				return subcommands.ExitFailure
			}
			return subcommands.ExitSuccess
		}
	}

	log.V(2).Info("Did not find cached version, installing the drivers...")

	// install OSS kernel modules (if available) if device supports
	if !c.unsignedDriver && gpuType.OpenSupported() {
		c.kernelOpen = gpuType.OpenSupported()
	}

	prebuiltModulesAvailable, err := installer.PrebuiltModulesAvailable(ctx, downloader, c.driverVersion, c.kernelOpen)

	if err != nil {
		c.logError(errors.Wrap(err, "failed to find prebuilt modules"))
		return subcommands.ExitFailure
	}

	// skip prebuilt module installation if preparing build tools
	if !c.prepareBuildTools && prebuiltModulesAvailable {
		log.V(2).Info("Found prebuilt kernel modules, installing additional components...")
		if err := installDriverPrebuiltModules(ctx, c, cacher, envReader, downloader); err != nil {
			c.logError(err)
			return subcommands.ExitFailure
		}
		return subcommands.ExitSuccess
	}

	if err := installDriver(ctx, c, cacher, envReader, downloader); err != nil {
		c.logError(err)
		return subcommands.ExitFailure
	}

	return subcommands.ExitSuccess
}

func getDriverVersionFromProto(ctx context.Context, downloader cos.ArtifactsDownloader, gpuType deviceinfo.GPUType, gpuInstallDir string, argVersion string, fallback bool) (string, error) {
	gpuProtoContent, err := installer.DownloadGPUDriverVersionsProto(ctx, downloader, gpuInstallDir)
	if err != nil {
		return "", fmt.Errorf("Failed to download and read GPU driver versions proto with error: %v", err)
	}
	return gpuconfig.GetGPUDriverVersion(gpuProtoContent, gpuType.String(), argVersion, fallback)
}

func getDriverVersion(ctx context.Context, downloader *cos.GCSDownloader, argVersion string) (string, error) {
	processedArgVersion := strings.ToLower(strings.TrimSpace(argVersion))
	if processedArgVersion == "" || processedArgVersion == installer.DefaultVersion {
		// install the default version.
		return installer.GetGPUDriverVersion(ctx, downloader, installer.DefaultVersion)
	}

	if processedArgVersion == installer.LatestVersion {
		// install the latest version.
		return installer.GetGPUDriverVersion(ctx, downloader, installer.LatestVersion)
	}
	majorVersionPattern := regexp.MustCompile(`^r\d+$`)
	if majorVersionPattern.MatchString(processedArgVersion) {
		// argVersion is an major version, getting the precise gpu driver version.
		driverVerion, err := installer.GetGPUDriverVersion(ctx, downloader, strings.ToUpper(processedArgVersion))
		if err != nil {
			return "", errors.Errorf("The input GPU driver major version: %s is invalid, please use `$cos-extensions list gpu` to get the available major driver versions.", argVersion)
		}
		return driverVerion, nil
	}
	preciseVersionPattern := regexp.MustCompile(`^\d+(\.\d+){2}$`)
	if !preciseVersionPattern.MatchString(processedArgVersion) {
		//precise version is invalid.
		return "", errors.Errorf("The input GPU driver version: %s is invalid, please use `$cos-extensions list gpu` to get the available GPU driver versions.", argVersion)
	}
	// argVersion is an acutal verson, return it as-is.
	return argVersion, nil
}

func remountExecutable(dir string) error {
	if err := os.MkdirAll(dir, 0755); err != nil {
		return fmt.Errorf("failed to create dir %q: %v", dir, err)
	}
	if err := syscall.Mount(dir, dir, "", syscall.MS_BIND, ""); err != nil {
		return fmt.Errorf("failed to create bind mount at %q: %v", dir, err)
	}
	if err := syscall.Mount("", dir, "", syscall.MS_REMOUNT|syscall.MS_NOSUID|syscall.MS_NODEV|syscall.MS_RELATIME, ""); err != nil {
		return fmt.Errorf("failed to remount %q: %v", dir, err)
	}
	return nil
}

func useFallbackMechanism(c *InstallCommand) bool {
	// If user specify the fallback flag, we don't infer the whether using the fallback mechanism.
	if c.forceFallback.IsSet() {
		return c.forceFallback.Get()
	}
	processedArgVersion := strings.ToLower(strings.TrimSpace(c.driverVersion))
	// If user doesn't specify the fallback flag,
	// return `true` for --version=default or --version=latest or version not sepcified.
	// return  `false` if --version=R<major-version> eg. 'R470', 'R525' or --version=<precise-version> eg. '535.129.03', '525.147.05'
	if processedArgVersion == "default" || processedArgVersion == "latest" || processedArgVersion == "" {
		return true
	}
	return false
}

func installDriver(ctx context.Context, c *InstallCommand, cacher *installer.Cacher, envReader *cos.EnvReader, downloader *cos.GCSDownloader) error {
	callback, err := installer.ConfigureDriverInstallationDirs(filepath.Join(hostRootPath, c.hostInstallDir), envReader.KernelRelease())
	if err != nil {
		return errors.Wrap(err, "failed to configure GPU driver installation dirs")
	}
	defer func() { callback <- 0 }()

	if err := cos.SetCompilationEnv(ctx, downloader); err != nil {
		return errors.Wrap(err, "failed to set compilation environment variables")
	}
	if err := remountExecutable(toolchainPkgDir); err != nil {
		return fmt.Errorf("failed to remount %q as executable: %v", filepath.Dir(toolchainPkgDir), err)
	}
	if err := cos.InstallCrossToolchain(ctx, downloader, toolchainPkgDir); err != nil {
		return errors.Wrap(err, "failed to install toolchain")
	}

	// Skip driver installation if we are only populating build tools cache
	if c.prepareBuildTools {
		return nil
	}

	var installerFile string
	if c.nvidiaInstallerURL == "" {
		installerFile, err = installer.DownloadDriverInstallerV2(ctx, downloader, c.driverVersion)
		if err != nil {
			return errors.Wrap(err, "failed to download GPU driver installer")
		}
	} else {
		installerFile, err = installer.DownloadToInstallDir(c.nvidiaInstallerURL, "Unofficial GPU driver installer")
		if err != nil {
			return err
		}
	}

	if !c.unsignedDriver {
		if c.signatureURL != "" {
			if err := signing.DownloadDriverSignaturesFromURL(c.signatureURL); err != nil {
				return errors.Wrap(err, "failed to download driver signature")
			}
		} else {
			if err = signing.DownloadDriverSignaturesV2(ctx, downloader, c.driverVersion); err != nil {
				return errors.Wrap(err, "failed to download driver signature")
			}
		}
	}

	if err := installer.RunDriverInstaller(toolchainPkgDir, installerFile, c.driverVersion, !c.unsignedDriver, c.test, false, c.noVerify, c.kernelModuleParams); err != nil {
		if errors.Is(err, installer.ErrDriverLoad) {
			// Drivers were linked, but couldn't load; try again with legacy linking
			log.Infof("Failed to load kernel module, err: %v. Retrying driver installation with legacy linking", err)
			if err := installer.RunDriverInstaller(toolchainPkgDir, installerFile, c.driverVersion, !c.unsignedDriver, c.test, true, c.noVerify, c.kernelModuleParams); err != nil {
				return fmt.Errorf("failed to run GPU driver installer: %v", err)
			}
		} else {
			return errors.Wrap(err, "failed to run GPU driver installer")
		}
	}
	if cacher != nil {
		if err := cacher.Cache(false); err != nil {
			return errors.Wrap(err, "failed to cache installation")
		}
	}
	if err := installer.VerifyDriverInstallation(c.noVerify, c.debug); err != nil {
		return errors.Wrap(err, "failed to verify installation")
	}
	if err := modules.UpdateHostLdCache(hostRootPath, filepath.Join(c.hostInstallDir, "lib64")); err != nil {
		return errors.Wrap(err, "failed to update host ld cache")
	}
	log.Info("Finished installing the drivers.")
	return nil
}

func installDriverPrebuiltModules(ctx context.Context, c *InstallCommand, cacher *installer.Cacher, envReader *cos.EnvReader, downloader *cos.GCSDownloader) error {
	callback, err := installer.ConfigureDriverInstallationDirs(filepath.Join(hostRootPath, c.hostInstallDir), envReader.KernelRelease())
	if err != nil {
		return errors.Wrap(err, "failed to configure GPU driver installation dirs")
	}
	defer func() { callback <- 0 }()

	var installerFile string
	if c.nvidiaInstallerURLOpen == "" {
		installerFile, err = installer.DownloadGenericDriverInstaller(ctx, downloader, c.driverVersion, c.gcsDownloadBucketNvidia, c.gcsDownloadPrefixNvidia, "")
	} else {
		installerFile, err = installer.DownloadToInstallDir(c.nvidiaInstallerURLOpen, "Unofficial GPU driver installer")
	}
	if err != nil {
		return err
	}

	if err := installer.RunDriverInstallerPrebuiltModules(ctx, downloader, installerFile, c.driverVersion, c.noVerify, c.kernelModuleParams); err != nil {
		return err
	}

	if cacher != nil {
		if err := cacher.Cache(true); err != nil {
			return errors.Wrap(err, "failed to cache installation")
		}
	}
	if err := installer.VerifyDriverInstallation(c.noVerify, c.debug); err != nil {
		return errors.Wrap(err, "failed to verify installation")
	}
	if err := modules.UpdateHostLdCache(hostRootPath, filepath.Join(c.hostInstallDir, "lib64")); err != nil {
		return errors.Wrap(err, "failed to update host ld cache")
	}
	log.Info("Finished installing the drivers.")
	return nil
}

func (c *InstallCommand) logError(err error) {
	if c.debug {
		log.Errorf("%+v", err)
	} else {
		log.Errorf("%v", err)
	}
}

func (c *InstallCommand) checkDriverCompatibility(ctx context.Context, downloader *cos.GCSDownloader, gpuType deviceinfo.GPUType, useFallback bool) error {
	driverMajorVersion, err := strconv.Atoi(strings.Split(c.driverVersion, ".")[0])
	if err != nil {
		return errors.Wrap(err, "failed to get driver major version")
	}
	fallback, found := fallbackMap[gpuType]
	if found && !fallback.Compatible(driverMajorVersion) {
		if !useFallback {
			return fmt.Errorf("driver version %s is not compatible with %s GPU device, "+
				"please use a compatible driver, or consider forcing a fallback using `--force-fallback=true` for the installer to select a compatible driver for the device", c.driverVersion, gpuType)
		}
		log.Warningf("\n\nDriver version %s doesn't support %s GPU devices.\n\n", c.driverVersion, gpuType)
		fallbackVersion, err := installer.GetGPUDriverVersion(ctx, downloader, fallback.fallbackDriverVersion)
		if err != nil {
			return errors.Wrap(err, "failed to get fallback driver")
		}
		log.Warningf("\n\nUsing driver version %s for %s GPU compatibility.\n\n", fallbackVersion, gpuType)
		c.driverVersion = fallbackVersion
		return nil
	}
	return nil
}
