| // Package gpuconfig implements routines for manipulating proto based |
| // GPU build configuration files. |
| // |
| // It also implements the construction of these configs for |
| // the COS Image and the COS Kernel CI. |
| package gpuconfig |
| |
| import ( |
| "errors" |
| "fmt" |
| "maps" |
| "net/http" |
| "regexp" |
| "slices" |
| "sort" |
| "strings" |
| |
| "cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/deviceinfo" |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb" |
| log "github.com/golang/glog" |
| "golang.org/x/mod/semver" |
| "google.golang.org/protobuf/proto" |
| ) |
| |
| const ( |
| defaultLabel string = "DEFAULT" |
| latestLabel string = "LATEST" |
| gridDriverExtension string = "-grid" |
| gridGcpDriverExtension string = "-grid-gcp" |
| ) |
| |
| var ( |
| // regex for checking if user input is a valid percise driver version |
| PreciseVersionPattern = regexp.MustCompile(`^\d+(\.\d+){1,2}$`) |
| preciseGridVersionPattern = regexp.MustCompile(`^\d+(\.\d+){1,2}(?i:-grid(-gcp)?)$`) |
| ) |
| |
| type gpuDriverVersion struct { |
| gpu deviceinfo.GPU |
| hostDriverVersion string |
| labelToDriverVersion map[string]string |
| driverVersionToLabels map[string][]string |
| versionToSupportedHostVersions map[string][]string |
| } |
| |
| // ParseGPUDriverVersionInfoList will parse the GPU driver version proto content to pb.GPUDriverVersionInfoList. |
| func ParseGPUDriverVersionInfoList(gpuDriverVersionContent []byte) (*pb.GPUDriverVersionInfoList, error) { |
| var gpuDriverVersionInfoList = &pb.GPUDriverVersionInfoList{} |
| if gpuDriverVersionContent == nil { |
| return nil, errors.New("the gpu proto content must not be empty") |
| } |
| err := (proto.UnmarshalOptions{DiscardUnknown: true, AllowPartial: true}).Unmarshal(gpuDriverVersionContent, gpuDriverVersionInfoList) |
| if err != nil { |
| return nil, fmt.Errorf("failed when parsing the GPU driver version: %v", err) |
| } |
| return gpuDriverVersionInfoList, nil |
| } |
| |
| // parseGPUDriverVersion will parse the GPU driver version content to gpuDriverVersion struct. |
| func parseGPUDriverVersion(gpuDriverVersionContent []byte, gpuTypeString string) (*gpuDriverVersion, error) { |
| gpuDriverVersionInfoList, err := ParseGPUDriverVersionInfoList(gpuDriverVersionContent) |
| if err != nil { |
| return nil, err |
| } |
| var labelToDriverVersion = map[string]string{} // {default -> "575.00.11", latest -> "575.11.00"} |
| var driverVersionToLabels = map[string][]string{} // {"575.00.00" -> ["default", "latest"]} |
| var versionToSupportedHostVersions = map[string][]string{} // {"575.00.80" -> ["580", "570"]} - map guest version to supported host versions |
| var supportsVGPU bool |
| for _, gpuDriverVersionInfo := range gpuDriverVersionInfoList.GetGpuDriverVersionInfo() { |
| if !strings.EqualFold(strings.TrimSpace(gpuTypeString), strings.TrimSpace(gpuDriverVersionInfo.GetGpuDevice().GetGpuType())) { |
| continue |
| } |
| // found the gpuType we want |
| supportsVGPU = gpuDriverVersionInfo.GetGpuDevice().GetSupportsVgpu() |
| for _, driverVersionInfo := range gpuDriverVersionInfo.GetSupportedDriverVersions() { // {label: "default", version: "575.57.08"} |
| driverVersion := driverVersionInfo.GetVersion() |
| driverVersionLabel := driverVersionInfo.GetLabel() |
| // for a version X with no label, it'll remain driverVersionToLabels["X"] = [] |
| if _, ok := driverVersionToLabels[driverVersion]; !ok { |
| driverVersionToLabels[driverVersion] = []string{} |
| } |
| if driverVersionLabel != "" { |
| labelToDriverVersion[driverVersionLabel] = driverVersion |
| driverVersionToLabels[driverVersion] = append(driverVersionToLabels[driverVersion], driverVersionLabel) |
| } |
| // populate versionToSupportedHostVersions |
| if _, ok := versionToSupportedHostVersions[driverVersion]; !ok { // If the guest version is not present in the map, initialize with an empty list |
| versionToSupportedHostVersions[driverVersion] = []string{} |
| } |
| for _, hostVersion := range driverVersionInfo.GetSupportedHostVersions() { // maybe move this to a separate method |
| versionToSupportedHostVersions[driverVersion] = append(versionToSupportedHostVersions[driverVersion], hostVersion.GetVersion()) |
| } |
| } |
| // Can we break here? gpu_type should be unique |
| } |
| // If the driverVersionToLabels is empty means the input GPU type is not found in the GPU driver version proto content |
| if len(driverVersionToLabels) == 0 { |
| return nil, fmt.Errorf("no supported driver versions found for gpu: %s", gpuTypeString) |
| } |
| gpuType, err := deviceinfo.ParseGPUType(gpuTypeString) |
| if err != nil { |
| return nil, fmt.Errorf("unable to parse gpu_type %s to GPUType", gpuTypeString) |
| } |
| return &gpuDriverVersion{ |
| gpu: deviceinfo.GPU{GPUType: gpuType, SupportsVGPU: supportsVGPU}, |
| labelToDriverVersion: labelToDriverVersion, |
| driverVersionToLabels: driverVersionToLabels, |
| versionToSupportedHostVersions: versionToSupportedHostVersions}, |
| nil |
| } |
| |
| // GetGPUDriverVersion fetch the GPU driver version from the input given the specific gpuType, fallback flag and proto bytes. |
| // The input here can be: |
| // 1. Precise GPU driver versions, e.g. 535.125.67, 560.38 |
| // In this case, we will check whether this version is compatible with gpuType. |
| // 2. Major GPU version label, e.g. R535, or default or latest or empty label. |
| // In this case, we will find out the driver version which is associated to this label for gpuType. |
| func GetGPUDriverVersion(gpuProtoContent []byte, gpuType string, input string, fallback bool) (string, error) { |
| majorVersionPattern := regexp.MustCompile(`^R\d+$`) |
| processedInput := strings.ToUpper(strings.TrimSpace(input)) |
| if processedInput != "" && processedInput != defaultLabel && processedInput != latestLabel && |
| !PreciseVersionPattern.MatchString(processedInput) && !majorVersionPattern.MatchString(processedInput) && |
| !preciseGridVersionPattern.MatchString(processedInput) { |
| return "", fmt.Errorf("the input is invalid: %s", input) |
| } |
| if gpuType == "" { |
| return "", errors.New("the GPU type must not be empty") |
| } |
| if gpuProtoContent == nil { |
| return "", errors.New("the gpu proto content must not be empty") |
| } |
| |
| gpuDriverVersions, err := parseGPUDriverVersion(gpuProtoContent, gpuType) |
| if err != nil { |
| return "", fmt.Errorf("cannot parse the GPU driver versions info given the proto content and gpuType with error: %v", err) |
| } |
| |
| hostDriverVersion, isVGPU, err := VGPUInfo(gpuType) |
| if err != nil { |
| return "", err |
| } |
| gpuDriverVersions.gpu.IsVGPU = isVGPU |
| gpuDriverVersions.hostDriverVersion = hostDriverVersion |
| |
| // If the input label is empty, we will just use default GPU driver versions for this gpu type. |
| if processedInput == "" { |
| processedInput = defaultLabel |
| } |
| // The input label is the precise GPU driver version. |
| if gpuDriverVersions.gpu.IsVGPU { |
| return getVGPUVersion(gpuDriverVersions, processedInput, fallback, gpuDriverVersions.gpu) |
| } |
| if PreciseVersionPattern.MatchString(processedInput) { |
| // If the input driver version is compatible with the current GPUType, we just return as it is. |
| log.Infof("The input label: %s is a precise GPU driver version, checking whether it's compatible with GPU: %s", processedInput, gpuType) |
| _, found := gpuDriverVersions.driverVersionToLabels[processedInput] |
| if found { |
| log.Infof("%s is compatible with GPU: %s", processedInput, gpuType) |
| return processedInput, nil |
| } |
| } |
| |
| driverVersion, ok := gpuDriverVersions.labelToDriverVersion[processedInput] |
| if !ok { |
| // The input label is not supported in current GPU Type. |
| if fallback { |
| //If fallback flag is true, we will try to fetch the default GPU driver version for this gpu type. |
| defaultDriverVersion, ok := gpuDriverVersions.labelToDriverVersion[defaultLabel] |
| if !ok { |
| return "", fmt.Errorf("default label is not supported in current gpu type: %s", gpuType) |
| } |
| return defaultDriverVersion, nil |
| } else { |
| // If fallback flag is false, an error will be thrown. |
| return "", fmt.Errorf("failed to check driver compatibility, the input: %s is not supported for GPU type: %s, please use a compatible driver, or use the default/latest flag or consider forcing a fallback using `--force-fallback=true` for the installer to select a compatible driver for the device", processedInput, gpuType) |
| } |
| } |
| return driverVersion, nil |
| } |
| |
| // Returns hostDriverVersion string, isVGPU bool, err error |
| func VGPUInfo(gpuType string) (string, bool, error) { |
| var hostDriverVersion string |
| isVGPU := false |
| // To minimize risk, only make the API calls for RTX_PRO_6000. Other devices do not support vGPU. |
| if gpuType != deviceinfo.RTX_PRO_6000.String() { |
| return "", false, nil |
| } |
| // This is a temporary safe guard as host-driver-version will not be available from all machine types for a few months. |
| client := http.Client{} |
| vmShape, err := VmShapeFromMDS(client) |
| if err != nil { |
| return "", false, fmt.Errorf("Could not get vm shape from MDS: %v", err) |
| } |
| fractional := slices.Contains(FractionalVMShapes, vmShape) |
| if !fractional { |
| return "", false, nil |
| } |
| |
| // Once the MDS rollout is done globally we can remove the `fractional` check and only use this method (TODO mirilio - b/490226943) |
| hostDriverVersion, isVGPU, err = VGpuStatusFromMds(client) |
| log.Infof("VM shape - %s, isVGPU - %t, host driver version - %s", vmShape, isVGPU, hostDriverVersion) |
| return hostDriverVersion, isVGPU, err |
| } |
| |
| // getVGPUVersion finds the vgpu driver version to install given a gpuDriverVersion struct [populated based on gpu device, |
| // processedInput (defaul/latest/RXXX/specific version) and the fallback flag]. The returned version will be compatible with the host driver version. |
| // Returns an error if no copmatible gpu driver is found. |
| func getVGPUVersion(gpuDriverVersions *gpuDriverVersion, processedInput string, fallback bool, gpu deviceinfo.GPU) (string, error) { |
| if gpuDriverVersions.hostDriverVersion == "" { |
| return "", fmt.Errorf("host driver version missing, can't install vgpu driver on guest") |
| } |
| hostDriverBranch := strings.Split(gpuDriverVersions.hostDriverVersion, ".")[0] // extract branch number from a xxx.xx.xx version format |
| |
| if PreciseVersionPattern.MatchString(processedInput) { |
| // the specific version does not explicitly include "-grid" or "-grid-gcp" in the end |
| // Prioritizing -grid-gcp over -grid |
| gridGcpVersion := processedInput + gridGcpDriverExtension |
| if _, ok := gpuDriverVersions.driverVersionToLabels[gridGcpVersion]; ok { |
| processedInput = gridGcpVersion |
| } else { |
| processedInput += gridDriverExtension |
| } |
| } |
| if preciseGridVersionPattern.MatchString(processedInput) { |
| // If the input driver version is compatible with the current GPUType and host branch, we return it as is. |
| log.Infof("The input label: %s is a precise GPU driver version, checking whether it's compatible with GPU: %s and host driver version %s", processedInput, gpu.GPUType.String(), gpuDriverVersions.hostDriverVersion) |
| |
| if guestHostVersionsCompatible(gpuDriverVersions, processedInput, hostDriverBranch) { |
| log.Infof("%s is compatible with GPU: %s and with host driver version %s", processedInput, gpu.GPUType.String(), gpuDriverVersions.hostDriverVersion) |
| return processedInput, nil |
| } |
| // version either not available or is not compatible with host - try to fallback if possible |
| if !fallback { |
| return "", fmt.Errorf("the requested driver version %s is either not supported for vGPU type: %s or is not compatible with host driver version %s. Please use a compatible driver, or use the default/latest flag or consider forcing a fallback using `--force-fallback=true` for the installer to select a compatible driver for the device", processedInput, gpu.GPUType.String(), gpuDriverVersions.hostDriverVersion) |
| } |
| // if fallback is true, treat intallation as though --version=default was specified |
| processedInput = defaultLabel |
| } |
| // "default"/"latest"/"RXXX" labels are all handled the same way |
| driverVersion, ok := gpuDriverVersions.labelToDriverVersion[processedInput] |
| if !ok { // The input label is not supported in current GPU Type. |
| if !fallback { |
| return "", fmt.Errorf("the requested driver version %s is not supported for vGPU type: %s ", processedInput, gpu.GPUType.String()) |
| } |
| // Fallback flag is true - we will try to fetch the default GPU driver version for this gpu type. In case processedInput="default" we perform the same check again but that's fine |
| driverVersion, ok = gpuDriverVersions.labelToDriverVersion[defaultLabel] |
| if !ok { |
| return "", fmt.Errorf("default label is not supported in current gpu type: %s", gpu.GPUType.String()) // this should never happen because every device_type has a default driver version |
| } |
| } |
| // add `-grid` to driver_version if not already there |
| // From this point on the logic is the same whether the processedInput is found in gpuDriverVersions.labelToDriverVersion or if we set it to default |
| if guestHostVersionsCompatible(gpuDriverVersions, driverVersion, hostDriverBranch) { |
| return driverVersion, nil |
| } |
| if !fallback { |
| return "", fmt.Errorf("the requested driver version %s is not supported for vGPU type: %s and host driver vesion %s and fallback is disabled. consider forcing a fallback using `--force-fallback=true` for the installer to select a compatible driver for the device", processedInput, gpu.GPUType.String(), gpuDriverVersions.hostDriverVersion) |
| } |
| return findVGPUFallbackVersion(gpuDriverVersions, gpu, hostDriverBranch) |
| } |
| |
| func guestHostVersionsCompatible(gpuDriverVersions *gpuDriverVersion, guestVersion string, hostVersionBranch string) bool { |
| return slices.Contains(gpuDriverVersions.versionToSupportedHostVersions[guestVersion], hostVersionBranch) |
| } |
| |
| // In case the default version for the vgpu_type is not compatible with the host and fallback is enabled, this function finds a compatible version to |
| // install on the host. First the exact host driver version is considered, then any compatible version. |
| func findVGPUFallbackVersion(gpuDriverVersions *gpuDriverVersion, gpu deviceinfo.GPU, hostDriverBranch string) (string, error) { |
| // We want to check if the exact host driver version is available to install on the VM as a grid driver |
| // Prioritize -grid-gcp over -grid |
| hostVersionAsGridGcp := gpuDriverVersions.hostDriverVersion + gridGcpDriverExtension |
| if _, ok := gpuDriverVersions.driverVersionToLabels[hostVersionAsGridGcp]; ok { |
| return hostVersionAsGridGcp, nil |
| } |
| hostVersionAsGrid := gpuDriverVersions.hostDriverVersion + gridDriverExtension |
| if _, ok := gpuDriverVersions.driverVersionToLabels[hostVersionAsGrid]; ok { |
| // Host version N is always compatible with guest version N |
| return hostVersionAsGrid, nil |
| } |
| |
| // Host version isn't availale for installation on guest. As a last resort - find the newest grid version available |
| var matchingGuestVersions []string |
| for guestVersion, hostBranchList := range gpuDriverVersions.versionToSupportedHostVersions { |
| if slices.Contains(hostBranchList, hostDriverBranch) { |
| // return guestVersion, nil |
| matchingGuestVersions = append(matchingGuestVersions, guestVersion) |
| } |
| } |
| if len(matchingGuestVersions) > 0 { |
| sortedMatchingVersions := sortGridVersions(matchingGuestVersions) |
| latestGridVersion := sortedMatchingVersions[len(sortedMatchingVersions)-1] |
| log.Infof("Found driver versions that support the host's version. Will try to install the latest one - %s.", latestGridVersion) |
| return latestGridVersion, nil |
| |
| } |
| // We did not find a grid version that matches the host version. Return the latest grid version but warn the user that it might not match the host. |
| sortedVersions := sortGridVersions(slices.Collect(maps.Keys(gpuDriverVersions.versionToSupportedHostVersions))) // Using Collect for type casting |
| if len(sortedVersions) == 0 { |
| return "", fmt.Errorf("Could not find a fallback driver version to install.") // Should never happen |
| } |
| latestGridVersion := sortedVersions[len(sortedVersions)-1] |
| log.Warningf("Could not find a grid driver that is compatible with the host driver versions. It is possible that the host driver was upgraded. Trying to install the latest driver version available (%s), which could result in degraded performance.", latestGridVersion) |
| return latestGridVersion, nil |
| } |
| |
| // Gets a list of versions (grid or non-grid) and returns a sorted list of all the grid versions |
| func sortGridVersions(versions []string) []string { |
| var gridVersions []string |
| // Collect all grid versions and append "v" in the beginning - required by semver package |
| for _, version := range versions { |
| if strings.HasSuffix(version, gridDriverExtension) || strings.HasSuffix(version, gridGcpDriverExtension) { |
| gridVersions = append(gridVersions, "v"+version) |
| } |
| } |
| sort.Sort(semver.ByVersion(gridVersions)) |
| semver.Sort(gridVersions) |
| // Clean out the prefix "v" in all items |
| cleanGridVersions := make([]string, len(gridVersions)) |
| for i, version := range gridVersions { |
| cleanGridVersions[i] = strings.TrimPrefix(version, "v") |
| } |
| return cleanGridVersions |
| } |
| |
| // GetGPUDriverVersions will fetch a list of GPU driver versions that support the gpuType. |
| func GetGPUDriverVersions(gpuProtoContent []byte, gpuType string) ([]*pb.DriverVersion, error) { |
| if gpuType == "" { |
| return nil, errors.New("the GPU type must not be empty") |
| } |
| if gpuProtoContent == nil { |
| return nil, errors.New("the gpu proto content must not be empty") |
| } |
| gpuDriverVersionInfoList, err := ParseGPUDriverVersionInfoList(gpuProtoContent) |
| if err != nil { |
| return nil, err |
| } |
| for _, gpuDriverVersionInfo := range gpuDriverVersionInfoList.GetGpuDriverVersionInfo() { |
| if gpuType == gpuDriverVersionInfo.GetGpuDevice().GetGpuType() { |
| return gpuDriverVersionInfo.GetSupportedDriverVersions(), nil |
| } |
| } |
| return nil, fmt.Errorf("no supported driver versions found for gpu: %s", gpuType) |
| } |