blob: 57ab93d09fb42455b529590e6cf0a81b8d8b2e27 [file] [log] [blame]
// Package gpuconfig implements routines for manipulating proto based
// GPU build configuration files.
//
// It also implements the construction of these configs for
// the COS Image and the COS Kernel CI.
package gpuconfig
import (
"errors"
"fmt"
"maps"
"net/http"
"regexp"
"slices"
"sort"
"strings"
"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/deviceinfo"
"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb"
log "github.com/golang/glog"
"golang.org/x/mod/semver"
"google.golang.org/protobuf/proto"
)
const (
defaultLabel string = "DEFAULT"
latestLabel string = "LATEST"
gridDriverExtension string = "-grid"
gridGcpDriverExtension string = "-grid-gcp"
)
var (
// regex for checking if user input is a valid percise driver version
PreciseVersionPattern = regexp.MustCompile(`^\d+(\.\d+){1,2}$`)
preciseGridVersionPattern = regexp.MustCompile(`^\d+(\.\d+){1,2}(?i:-grid(-gcp)?)$`)
)
type gpuDriverVersion struct {
gpu deviceinfo.GPU
hostDriverVersion string
labelToDriverVersion map[string]string
driverVersionToLabels map[string][]string
versionToSupportedHostVersions map[string][]string
}
// ParseGPUDriverVersionInfoList will parse the GPU driver version proto content to pb.GPUDriverVersionInfoList.
func ParseGPUDriverVersionInfoList(gpuDriverVersionContent []byte) (*pb.GPUDriverVersionInfoList, error) {
var gpuDriverVersionInfoList = &pb.GPUDriverVersionInfoList{}
if gpuDriverVersionContent == nil {
return nil, errors.New("the gpu proto content must not be empty")
}
err := (proto.UnmarshalOptions{DiscardUnknown: true, AllowPartial: true}).Unmarshal(gpuDriverVersionContent, gpuDriverVersionInfoList)
if err != nil {
return nil, fmt.Errorf("failed when parsing the GPU driver version: %v", err)
}
return gpuDriverVersionInfoList, nil
}
// parseGPUDriverVersion will parse the GPU driver version content to gpuDriverVersion struct.
func parseGPUDriverVersion(gpuDriverVersionContent []byte, gpuTypeString string) (*gpuDriverVersion, error) {
gpuDriverVersionInfoList, err := ParseGPUDriverVersionInfoList(gpuDriverVersionContent)
if err != nil {
return nil, err
}
var labelToDriverVersion = map[string]string{} // {default -> "575.00.11", latest -> "575.11.00"}
var driverVersionToLabels = map[string][]string{} // {"575.00.00" -> ["default", "latest"]}
var versionToSupportedHostVersions = map[string][]string{} // {"575.00.80" -> ["580", "570"]} - map guest version to supported host versions
var supportsVGPU bool
for _, gpuDriverVersionInfo := range gpuDriverVersionInfoList.GetGpuDriverVersionInfo() {
if !strings.EqualFold(strings.TrimSpace(gpuTypeString), strings.TrimSpace(gpuDriverVersionInfo.GetGpuDevice().GetGpuType())) {
continue
}
// found the gpuType we want
supportsVGPU = gpuDriverVersionInfo.GetGpuDevice().GetSupportsVgpu()
for _, driverVersionInfo := range gpuDriverVersionInfo.GetSupportedDriverVersions() { // {label: "default", version: "575.57.08"}
driverVersion := driverVersionInfo.GetVersion()
driverVersionLabel := driverVersionInfo.GetLabel()
// for a version X with no label, it'll remain driverVersionToLabels["X"] = []
if _, ok := driverVersionToLabels[driverVersion]; !ok {
driverVersionToLabels[driverVersion] = []string{}
}
if driverVersionLabel != "" {
labelToDriverVersion[driverVersionLabel] = driverVersion
driverVersionToLabels[driverVersion] = append(driverVersionToLabels[driverVersion], driverVersionLabel)
}
// populate versionToSupportedHostVersions
if _, ok := versionToSupportedHostVersions[driverVersion]; !ok { // If the guest version is not present in the map, initialize with an empty list
versionToSupportedHostVersions[driverVersion] = []string{}
}
for _, hostVersion := range driverVersionInfo.GetSupportedHostVersions() { // maybe move this to a separate method
versionToSupportedHostVersions[driverVersion] = append(versionToSupportedHostVersions[driverVersion], hostVersion.GetVersion())
}
}
// Can we break here? gpu_type should be unique
}
// If the driverVersionToLabels is empty means the input GPU type is not found in the GPU driver version proto content
if len(driverVersionToLabels) == 0 {
return nil, fmt.Errorf("no supported driver versions found for gpu: %s", gpuTypeString)
}
gpuType, err := deviceinfo.ParseGPUType(gpuTypeString)
if err != nil {
return nil, fmt.Errorf("unable to parse gpu_type %s to GPUType", gpuTypeString)
}
return &gpuDriverVersion{
gpu: deviceinfo.GPU{GPUType: gpuType, SupportsVGPU: supportsVGPU},
labelToDriverVersion: labelToDriverVersion,
driverVersionToLabels: driverVersionToLabels,
versionToSupportedHostVersions: versionToSupportedHostVersions},
nil
}
// GetGPUDriverVersion fetch the GPU driver version from the input given the specific gpuType, fallback flag and proto bytes.
// The input here can be:
// 1. Precise GPU driver versions, e.g. 535.125.67, 560.38
// In this case, we will check whether this version is compatible with gpuType.
// 2. Major GPU version label, e.g. R535, or default or latest or empty label.
// In this case, we will find out the driver version which is associated to this label for gpuType.
func GetGPUDriverVersion(gpuProtoContent []byte, gpuType string, input string, fallback bool) (string, error) {
majorVersionPattern := regexp.MustCompile(`^R\d+$`)
processedInput := strings.ToUpper(strings.TrimSpace(input))
if processedInput != "" && processedInput != defaultLabel && processedInput != latestLabel &&
!PreciseVersionPattern.MatchString(processedInput) && !majorVersionPattern.MatchString(processedInput) &&
!preciseGridVersionPattern.MatchString(processedInput) {
return "", fmt.Errorf("the input is invalid: %s", input)
}
if gpuType == "" {
return "", errors.New("the GPU type must not be empty")
}
if gpuProtoContent == nil {
return "", errors.New("the gpu proto content must not be empty")
}
gpuDriverVersions, err := parseGPUDriverVersion(gpuProtoContent, gpuType)
if err != nil {
return "", fmt.Errorf("cannot parse the GPU driver versions info given the proto content and gpuType with error: %v", err)
}
hostDriverVersion, isVGPU, err := VGPUInfo(gpuType)
if err != nil {
return "", err
}
gpuDriverVersions.gpu.IsVGPU = isVGPU
gpuDriverVersions.hostDriverVersion = hostDriverVersion
// If the input label is empty, we will just use default GPU driver versions for this gpu type.
if processedInput == "" {
processedInput = defaultLabel
}
// The input label is the precise GPU driver version.
if gpuDriverVersions.gpu.IsVGPU {
return getVGPUVersion(gpuDriverVersions, processedInput, fallback, gpuDriverVersions.gpu)
}
if PreciseVersionPattern.MatchString(processedInput) {
// If the input driver version is compatible with the current GPUType, we just return as it is.
log.Infof("The input label: %s is a precise GPU driver version, checking whether it's compatible with GPU: %s", processedInput, gpuType)
_, found := gpuDriverVersions.driverVersionToLabels[processedInput]
if found {
log.Infof("%s is compatible with GPU: %s", processedInput, gpuType)
return processedInput, nil
}
}
driverVersion, ok := gpuDriverVersions.labelToDriverVersion[processedInput]
if !ok {
// The input label is not supported in current GPU Type.
if fallback {
//If fallback flag is true, we will try to fetch the default GPU driver version for this gpu type.
defaultDriverVersion, ok := gpuDriverVersions.labelToDriverVersion[defaultLabel]
if !ok {
return "", fmt.Errorf("default label is not supported in current gpu type: %s", gpuType)
}
return defaultDriverVersion, nil
} else {
// If fallback flag is false, an error will be thrown.
return "", fmt.Errorf("failed to check driver compatibility, the input: %s is not supported for GPU type: %s, please use a compatible driver, or use the default/latest flag or consider forcing a fallback using `--force-fallback=true` for the installer to select a compatible driver for the device", processedInput, gpuType)
}
}
return driverVersion, nil
}
// Returns hostDriverVersion string, isVGPU bool, err error
func VGPUInfo(gpuType string) (string, bool, error) {
var hostDriverVersion string
isVGPU := false
// To minimize risk, only make the API calls for RTX_PRO_6000. Other devices do not support vGPU.
if gpuType != deviceinfo.RTX_PRO_6000.String() {
return "", false, nil
}
// This is a temporary safe guard as host-driver-version will not be available from all machine types for a few months.
client := http.Client{}
vmShape, err := VmShapeFromMDS(client)
if err != nil {
return "", false, fmt.Errorf("Could not get vm shape from MDS: %v", err)
}
fractional := slices.Contains(FractionalVMShapes, vmShape)
if !fractional {
return "", false, nil
}
// Once the MDS rollout is done globally we can remove the `fractional` check and only use this method (TODO mirilio - b/490226943)
hostDriverVersion, isVGPU, err = VGpuStatusFromMds(client)
log.Infof("VM shape - %s, isVGPU - %t, host driver version - %s", vmShape, isVGPU, hostDriverVersion)
return hostDriverVersion, isVGPU, err
}
// getVGPUVersion finds the vgpu driver version to install given a gpuDriverVersion struct [populated based on gpu device,
// processedInput (defaul/latest/RXXX/specific version) and the fallback flag]. The returned version will be compatible with the host driver version.
// Returns an error if no copmatible gpu driver is found.
func getVGPUVersion(gpuDriverVersions *gpuDriverVersion, processedInput string, fallback bool, gpu deviceinfo.GPU) (string, error) {
if gpuDriverVersions.hostDriverVersion == "" {
return "", fmt.Errorf("host driver version missing, can't install vgpu driver on guest")
}
hostDriverBranch := strings.Split(gpuDriverVersions.hostDriverVersion, ".")[0] // extract branch number from a xxx.xx.xx version format
if PreciseVersionPattern.MatchString(processedInput) {
// the specific version does not explicitly include "-grid" or "-grid-gcp" in the end
// Prioritizing -grid-gcp over -grid
gridGcpVersion := processedInput + gridGcpDriverExtension
if _, ok := gpuDriverVersions.driverVersionToLabels[gridGcpVersion]; ok {
processedInput = gridGcpVersion
} else {
processedInput += gridDriverExtension
}
}
if preciseGridVersionPattern.MatchString(processedInput) {
// If the input driver version is compatible with the current GPUType and host branch, we return it as is.
log.Infof("The input label: %s is a precise GPU driver version, checking whether it's compatible with GPU: %s and host driver version %s", processedInput, gpu.GPUType.String(), gpuDriverVersions.hostDriverVersion)
if guestHostVersionsCompatible(gpuDriverVersions, processedInput, hostDriverBranch) {
log.Infof("%s is compatible with GPU: %s and with host driver version %s", processedInput, gpu.GPUType.String(), gpuDriverVersions.hostDriverVersion)
return processedInput, nil
}
// version either not available or is not compatible with host - try to fallback if possible
if !fallback {
return "", fmt.Errorf("the requested driver version %s is either not supported for vGPU type: %s or is not compatible with host driver version %s. Please use a compatible driver, or use the default/latest flag or consider forcing a fallback using `--force-fallback=true` for the installer to select a compatible driver for the device", processedInput, gpu.GPUType.String(), gpuDriverVersions.hostDriverVersion)
}
// if fallback is true, treat intallation as though --version=default was specified
processedInput = defaultLabel
}
// "default"/"latest"/"RXXX" labels are all handled the same way
driverVersion, ok := gpuDriverVersions.labelToDriverVersion[processedInput]
if !ok { // The input label is not supported in current GPU Type.
if !fallback {
return "", fmt.Errorf("the requested driver version %s is not supported for vGPU type: %s ", processedInput, gpu.GPUType.String())
}
// Fallback flag is true - we will try to fetch the default GPU driver version for this gpu type. In case processedInput="default" we perform the same check again but that's fine
driverVersion, ok = gpuDriverVersions.labelToDriverVersion[defaultLabel]
if !ok {
return "", fmt.Errorf("default label is not supported in current gpu type: %s", gpu.GPUType.String()) // this should never happen because every device_type has a default driver version
}
}
// add `-grid` to driver_version if not already there
// From this point on the logic is the same whether the processedInput is found in gpuDriverVersions.labelToDriverVersion or if we set it to default
if guestHostVersionsCompatible(gpuDriverVersions, driverVersion, hostDriverBranch) {
return driverVersion, nil
}
if !fallback {
return "", fmt.Errorf("the requested driver version %s is not supported for vGPU type: %s and host driver vesion %s and fallback is disabled. consider forcing a fallback using `--force-fallback=true` for the installer to select a compatible driver for the device", processedInput, gpu.GPUType.String(), gpuDriverVersions.hostDriverVersion)
}
return findVGPUFallbackVersion(gpuDriverVersions, gpu, hostDriverBranch)
}
func guestHostVersionsCompatible(gpuDriverVersions *gpuDriverVersion, guestVersion string, hostVersionBranch string) bool {
return slices.Contains(gpuDriverVersions.versionToSupportedHostVersions[guestVersion], hostVersionBranch)
}
// In case the default version for the vgpu_type is not compatible with the host and fallback is enabled, this function finds a compatible version to
// install on the host. First the exact host driver version is considered, then any compatible version.
func findVGPUFallbackVersion(gpuDriverVersions *gpuDriverVersion, gpu deviceinfo.GPU, hostDriverBranch string) (string, error) {
// We want to check if the exact host driver version is available to install on the VM as a grid driver
// Prioritize -grid-gcp over -grid
hostVersionAsGridGcp := gpuDriverVersions.hostDriverVersion + gridGcpDriverExtension
if _, ok := gpuDriverVersions.driverVersionToLabels[hostVersionAsGridGcp]; ok {
return hostVersionAsGridGcp, nil
}
hostVersionAsGrid := gpuDriverVersions.hostDriverVersion + gridDriverExtension
if _, ok := gpuDriverVersions.driverVersionToLabels[hostVersionAsGrid]; ok {
// Host version N is always compatible with guest version N
return hostVersionAsGrid, nil
}
// Host version isn't availale for installation on guest. As a last resort - find the newest grid version available
var matchingGuestVersions []string
for guestVersion, hostBranchList := range gpuDriverVersions.versionToSupportedHostVersions {
if slices.Contains(hostBranchList, hostDriverBranch) {
// return guestVersion, nil
matchingGuestVersions = append(matchingGuestVersions, guestVersion)
}
}
if len(matchingGuestVersions) > 0 {
sortedMatchingVersions := sortGridVersions(matchingGuestVersions)
latestGridVersion := sortedMatchingVersions[len(sortedMatchingVersions)-1]
log.Infof("Found driver versions that support the host's version. Will try to install the latest one - %s.", latestGridVersion)
return latestGridVersion, nil
}
// We did not find a grid version that matches the host version. Return the latest grid version but warn the user that it might not match the host.
sortedVersions := sortGridVersions(slices.Collect(maps.Keys(gpuDriverVersions.versionToSupportedHostVersions))) // Using Collect for type casting
if len(sortedVersions) == 0 {
return "", fmt.Errorf("Could not find a fallback driver version to install.") // Should never happen
}
latestGridVersion := sortedVersions[len(sortedVersions)-1]
log.Warningf("Could not find a grid driver that is compatible with the host driver versions. It is possible that the host driver was upgraded. Trying to install the latest driver version available (%s), which could result in degraded performance.", latestGridVersion)
return latestGridVersion, nil
}
// Gets a list of versions (grid or non-grid) and returns a sorted list of all the grid versions
func sortGridVersions(versions []string) []string {
var gridVersions []string
// Collect all grid versions and append "v" in the beginning - required by semver package
for _, version := range versions {
if strings.HasSuffix(version, gridDriverExtension) || strings.HasSuffix(version, gridGcpDriverExtension) {
gridVersions = append(gridVersions, "v"+version)
}
}
sort.Sort(semver.ByVersion(gridVersions))
semver.Sort(gridVersions)
// Clean out the prefix "v" in all items
cleanGridVersions := make([]string, len(gridVersions))
for i, version := range gridVersions {
cleanGridVersions[i] = strings.TrimPrefix(version, "v")
}
return cleanGridVersions
}
// GetGPUDriverVersions will fetch a list of GPU driver versions that support the gpuType.
func GetGPUDriverVersions(gpuProtoContent []byte, gpuType string) ([]*pb.DriverVersion, error) {
if gpuType == "" {
return nil, errors.New("the GPU type must not be empty")
}
if gpuProtoContent == nil {
return nil, errors.New("the gpu proto content must not be empty")
}
gpuDriverVersionInfoList, err := ParseGPUDriverVersionInfoList(gpuProtoContent)
if err != nil {
return nil, err
}
for _, gpuDriverVersionInfo := range gpuDriverVersionInfoList.GetGpuDriverVersionInfo() {
if gpuType == gpuDriverVersionInfo.GetGpuDevice().GetGpuType() {
return gpuDriverVersionInfo.GetSupportedDriverVersions(), nil
}
}
return nil, fmt.Errorf("no supported driver versions found for gpu: %s", gpuType)
}