| package commands |
| |
| import ( |
| "context" |
| "flag" |
| "fmt" |
| "os" |
| "path/filepath" |
| "strings" |
| |
| "cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/deviceinfo" |
| "cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/features" |
| "cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/internal/installer" |
| "cos.googlesource.com/cos/tools.git/src/pkg/cos" |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig" |
| log "github.com/golang/glog" |
| "github.com/google/subcommands" |
| "github.com/pkg/errors" |
| ) |
| |
| // ListCommand is the subcommand to list supported GPU drivers. |
| type ListCommand struct { |
| gcsDownloadBucket string |
| gcsDownloadPrefix string |
| gpuProtoCacheDir string |
| targetGPU string |
| debug bool |
| } |
| |
| // Name implements subcommands.Command.Name. |
| func (*ListCommand) Name() string { return "list" } |
| |
| // Synopsis implements subcommands.Command.Synopsis. |
| func (*ListCommand) Synopsis() string { return "List supported GPU drivers for this version." } |
| |
| // Usage implements subcommands.Command.Usage. |
| func (*ListCommand) Usage() string { return "list\n" } |
| |
| // SetFlags implements subcommands.Command.SetFlags. |
| func (c *ListCommand) SetFlags(f *flag.FlagSet) { |
| f.StringVar(&c.gcsDownloadBucket, "gcs-download-bucket", "", |
| "The GCS bucket to download COS artifacts from. "+ |
| "The default bucket is one of 'cos-tools', 'cos-tools-asia' and 'cos-tools-eu' based on where the VM is running. "+ |
| "Those are the public COS artifacts buckets.") |
| f.StringVar(&c.gcsDownloadPrefix, "gcs-download-prefix", "", |
| "The GCS path prefix when downloading COS artifacts."+ |
| "If not set then the COS build number and board (e.g. 13310.1041.38/lakitu) will be used.") |
| f.StringVar(&c.targetGPU, "target-gpu", "", fmt.Sprintf("This flag specifies the GPU device to display its compatible drivers. "+ |
| "If specified, it must be one of %s. If not specified, the GPU device will be auto-detected by the installer.", strings.Join(deviceinfo.AllGPUTypeStrings, ", "))) |
| f.StringVar(&c.gpuProtoCacheDir, "gpu-proto-cache-dir", "", |
| "The GPU proto cache directory that GPU driver versions proto file is stored into. If unspecified, the GPU driver versions proto file will not be cached.") |
| f.BoolVar(&c.debug, "debug", false, |
| "Enable debug mode.") |
| } |
| |
| // Execute implements subcommands.Command.Execute. |
| func (c *ListCommand) Execute(ctx context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { |
| envReader, err := cos.NewEnvReader(hostRootPath) |
| if err != nil { |
| c.logError(errors.Wrap(err, "failed to create envReader")) |
| return subcommands.ExitFailure |
| } |
| log.Infof("Running on COS build id %s", envReader.BuildNumber()) |
| fullCOSGPUConfigJsonPath := filepath.Join(hostRootPath, cosGPUConfigJsonPath) |
| log.Infof("Feature flags initialization") |
| featureConfig, err := features.InitFlags(fullCOSGPUConfigJsonPath) |
| if err != nil { |
| log.Errorf("Unable to init feature flags: %v", err) |
| return subcommands.ExitFailure |
| } |
| downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix) |
| if featureConfig.PerGPULabelFeatureFlag { |
| var gpuProtoCacheDir = filepath.Join(hostRootPath, c.gpuProtoCacheDir) |
| if c.gpuProtoCacheDir == "" { |
| gpuProtoCacheDir, err = os.MkdirTemp(scratchDir, "gpuProtoCacheDir") |
| if err != nil { |
| log.Errorf("Failed to create tmp gpu proto cache dir in: %s", scratchDir) |
| return subcommands.ExitFailure |
| } |
| log.Infof("Created the tmp gpu proto cache directory: %s", gpuProtoCacheDir) |
| } |
| var gpuType deviceinfo.GPUType = deviceinfo.NO_GPU |
| if featureConfig.PerGPULabelFeatureFlag && c.targetGPU != "" { |
| gpuType, err = deviceinfo.ParseGPUType(c.targetGPU) |
| if err != nil { |
| c.logError(fmt.Errorf("failed to parse target GPU type: %v", err)) |
| return subcommands.ExitFailure |
| } |
| } else { |
| gpuType, err = deviceinfo.GetGPUTypeInfo() |
| if err != nil { |
| log.Infof("No GPU device detected!") |
| } |
| } |
| log.Infof("List GPU driver for device type: %s", gpuType) |
| |
| gpuVersionLabelMappings, err := downloadGPUDriverLabelMapping(ctx, downloader, gpuProtoCacheDir, gpuType) |
| if err != nil { |
| c.logError(fmt.Errorf("failed to fetch the available GPU driver versions and GPU driver labels for GPU type: %s with error: %v", gpuType.String(), err)) |
| return subcommands.ExitFailure |
| } |
| for gpuType, gpuDriverMap := range gpuVersionLabelMappings { |
| fmt.Printf("\nGPU Device: %s\n", gpuType.String()) |
| for driverVersion, gpuDriverLabel := range gpuDriverMap { |
| promptMsg := driverVersion |
| if gpuDriverLabel != "" { |
| promptMsg += " " + gpuDriverLabel |
| } |
| fmt.Printf("%s\n", promptMsg) |
| } |
| } |
| } else { |
| artifacts, err := downloader.ListGPUExtensionArtifacts() |
| if err != nil { |
| c.logError(errors.Wrap(err, "failed to list gpu extension artifacts")) |
| return subcommands.ExitFailure |
| } |
| gpuDriverToLabelMap, err := getGPUDriverToMajorLabelMap(ctx, downloader) |
| if err != nil { |
| c.logError(fmt.Errorf("error appears when fetching the GPU driver versions to GPU driver major labels map: %w", err)) |
| } |
| for _, artifact := range artifacts { |
| driverVersion := "" |
| if strings.HasSuffix(artifact, ".signature.tar.gz") { |
| driverVersion = strings.TrimSuffix(artifact, ".signature.tar.gz") |
| } else if strings.HasPrefix(artifact, "nvidia-drivers-") && strings.HasSuffix(artifact, "-signature.tar.gz") { |
| driverVersion = strings.TrimPrefix(artifact, "nvidia-drivers-") |
| driverVersion = strings.TrimSuffix(driverVersion, "-signature.tar.gz") |
| } |
| if driverVersion != "" { |
| promptMsg := driverVersion |
| if gpuDriverLabel, ok := gpuDriverToLabelMap[driverVersion]; ok { |
| promptMsg += " " + gpuDriverLabel |
| } |
| fmt.Printf("%s\n", promptMsg) |
| } |
| } |
| } |
| return subcommands.ExitSuccess |
| } |
| |
| // downloadGPUDriverLabelMapping retrieves mapping from GPUType to corresponding driver versions and their labels. |
| // This function will download the GPU driver version data to gpuInstallDir and parse it and |
| // constructs a hierarchical map where the first level keys are GPUTypes and the second level keys are driver versions |
| // with their associated labels as values. |
| // Example of the expected mapping: |
| // |
| // V100 -> { |
| // 535.129.03 -> [R535][latest][default], |
| // 470.223.02 -> [R470] |
| // 525.147.05 -> [R525] |
| // } |
| // |
| // Note: |
| // For NO_GPU case, this function will return mappings for all known GPU and their corresponding driver versions and label. |
| // For OTHERS case, this function will log the information that the GPU device is not recognized, and the GPU drivers in the OTHERS category from proto will be listed. |
| func downloadGPUDriverLabelMapping(ctx context.Context, downloader cos.ArtifactsDownloader, gpuInstallDir string, gpuType deviceinfo.GPUType) (map[deviceinfo.GPUType]map[string]string, error) { |
| gpuProtoContent, err := installer.DownloadGPUDriverVersionsProto(ctx, downloader, gpuInstallDir) |
| if err != nil { |
| return nil, fmt.Errorf("failed to download and read GPU driver versions proto with error: %v", err) |
| } |
| var gpuTypeToDriverLabelMap = map[deviceinfo.GPUType]map[string]string{} |
| if gpuType == deviceinfo.NO_GPU { |
| log.Infof("GPU driver versions for all available GPU types supported in COS will be listed.") |
| for _, availableGpuType := range deviceinfo.AvailableGPUTypesList { |
| driverLabelMap, err := buildGPUDriverLabelMap(gpuProtoContent, availableGpuType) |
| if err != nil { |
| return nil, err |
| } |
| gpuTypeToDriverLabelMap[availableGpuType] = driverLabelMap |
| } |
| return gpuTypeToDriverLabelMap, nil |
| } |
| if gpuType == deviceinfo.Others { |
| fmt.Printf("\nNote: This GPU device is not recognized by the gpu installer, the listed GPU drivers below may not compatible with current GPU device.\n" + |
| "Please go to https://cloud.google.com/compute/docs/gpus/install-drivers-gpu for more details\n") |
| } |
| driverLabelMap, err := buildGPUDriverLabelMap(gpuProtoContent, gpuType) |
| if err != nil { |
| return nil, err |
| } |
| gpuTypeToDriverLabelMap[gpuType] = driverLabelMap |
| return gpuTypeToDriverLabelMap, nil |
| } |
| func buildGPUDriverLabelMap(gpuProtoContent []byte, gpuType deviceinfo.GPUType) (map[string]string, error) { |
| driverVersions, err := gpuconfig.GetGPUDriverVersions(gpuProtoContent, gpuType.String()) |
| var gpuDriverMap = map[string]string{} |
| if err != nil { |
| return nil, fmt.Errorf("failed to get the supported GPU driver versions for GPU type: %s with error: %v", gpuType.String(), err) |
| } |
| for _, driverVersion := range driverVersions { |
| gpuDriverLabel := driverVersion.Label |
| gpuDriverVersion := driverVersion.Version |
| if _, ok := gpuDriverMap[gpuDriverVersion]; !ok { |
| gpuDriverMap[gpuDriverVersion] = "" |
| } |
| if gpuDriverLabel != "" { |
| gpuDriverMap[gpuDriverVersion] += "[" + gpuDriverLabel + "]" |
| } |
| } |
| return gpuDriverMap, nil |
| } |
| |
| // getGPUDriverToMajorLabelMap gets the GPU Driver version and associated labels |
| // E.g. |
| // 535.129.03 -> [R535][latest][default], |
| // 470.223.02 -> [R470] |
| // 525.147.05 -> [R525] |
| func getGPUDriverToMajorLabelMap(ctx context.Context, downloader cos.ArtifactsDownloader) (map[string]string, error) { |
| GPUDriverMajorVersionArtifactsMap, err := installer.DownloadGPUDriverVersionArtifacts(ctx, downloader) |
| if err != nil { |
| return nil, fmt.Errorf("error while fetching available gpu driver version files - %w", err) |
| } |
| var gpuDriverMap = map[string]string{} |
| for gpuArtifact, GPUDriverVersion := range GPUDriverMajorVersionArtifactsMap { |
| majorGPUDriverLabel := strings.TrimPrefix(gpuArtifact, installer.MajorGPUDriverArtifactPrefix) |
| majorGPUDriverLabel = strings.TrimSuffix(majorGPUDriverLabel, installer.MajorGPUDriverArtifactSuffix) |
| gpuDriverMap[GPUDriverVersion] += "[" + majorGPUDriverLabel + "]" |
| } |
| return gpuDriverMap, nil |
| } |
| func (c *ListCommand) logError(err error) { |
| if c.debug { |
| log.Errorf("%+v", err) |
| } else { |
| log.Errorf("%v", err) |
| } |
| } |
| func (c *ListCommand) logWarning(err error) { |
| if c.debug { |
| log.Warningf("%+v", err) |
| } else { |
| log.Warningf("%v", err) |
| } |
| } |