| // Package gpuconfig implements routines for manipulating proto based |
| // GPU build configuration files. |
| // |
| // It also implements the construction of these configs for |
| // the COS Image and the COS Kernel CI. |
| package gpuconfig |
| |
| import ( |
| "errors" |
| "fmt" |
| "regexp" |
| "strings" |
| |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb" |
| log "github.com/golang/glog" |
| "google.golang.org/protobuf/proto" |
| ) |
| |
| const ( |
| defaultLabel string = "DEFAULT" |
| latestLabel string = "LATEST" |
| ) |
| |
| type gpuDriverVersion struct { |
| gpuType string |
| labelToDriverVersion map[string]string |
| driverVersionToLabels map[string][]string |
| } |
| |
| // parseGPUDriverVersionInfoLists will parse the GPU driver version proto content to pb.GPUDriverVersionInfoList. |
| func parseGPUDriverVersionInfoList(gpuDriverVersionContent []byte) (*pb.GPUDriverVersionInfoList, error) { |
| var gpuDriverVersionInfoList = &pb.GPUDriverVersionInfoList{} |
| if gpuDriverVersionContent == nil { |
| return nil, errors.New("the gpu proto content must not be empty") |
| } |
| err := (proto.UnmarshalOptions{DiscardUnknown: true, AllowPartial: true}).Unmarshal(gpuDriverVersionContent, gpuDriverVersionInfoList) |
| if err != nil { |
| return nil, fmt.Errorf("failed when parsing the GPU driver version: %v", err) |
| } |
| return gpuDriverVersionInfoList, nil |
| } |
| |
| // parseGPUDriverVersion will parse the GPU driver version content to gpuDriverVersion struct. |
| func parseGPUDriverVersion(gpuDriverVersionContent []byte, gpuType string) (*gpuDriverVersion, error) { |
| gpuDriverVersionInfoList, err := parseGPUDriverVersionInfoList(gpuDriverVersionContent) |
| if err != nil { |
| return nil, err |
| } |
| var labelToDriverVersion = map[string]string{} |
| var driverVersionToLabels = map[string][]string{} |
| for _, gpuDriverVersionInfo := range gpuDriverVersionInfoList.GetGpuDriverVersionInfo() { |
| if strings.EqualFold(strings.TrimSpace(gpuType), strings.TrimSpace(gpuDriverVersionInfo.GetGpuDevice().GetGpuType())) { |
| for _, driverVersionInfo := range gpuDriverVersionInfo.GetSupportedDriverVersions() { |
| driverVersion := driverVersionInfo.GetVersion() |
| driverVersionLabel := driverVersionInfo.GetLabel() |
| if _, ok := driverVersionToLabels[driverVersion]; !ok { |
| driverVersionToLabels[driverVersion] = []string{} |
| } |
| if driverVersionLabel != "" { |
| labelToDriverVersion[driverVersionLabel] = driverVersion |
| driverVersionToLabels[driverVersion] = append(driverVersionToLabels[driverVersion], driverVersionLabel) |
| } |
| } |
| } |
| } |
| // If the driverVersionToLabels is empty means the input GPU type is not found in the GPU driver version proto content |
| if len(driverVersionToLabels) == 0 { |
| return nil, fmt.Errorf("no supported driver versions found for gpu: %s", gpuType) |
| } |
| return &gpuDriverVersion{gpuType, labelToDriverVersion, driverVersionToLabels}, nil |
| |
| } |
| |
| // GetGPUDriverVersion fetch the GPU driver version from the input given the specific gpuType, fallback flag and proto bytes. |
| // The input here can be: |
| // 1. Precise GPU driver versions, e.g. 535.125.67, 560.38 |
| // In this case, we will check whether this version is compatible with gpuType. |
| // 2. Major GPU version label, e.g. R535, or default or latest or empty label. |
| // In this case, we will find out the driver version which is associated to this label for gpuType. |
| func GetGPUDriverVersion(gpuProtoContent []byte, gpuType string, input string, fallback bool) (string, error) { |
| preciseVersionPattern := regexp.MustCompile(`^\d+(\.\d+){1,2}$`) |
| majorVersionPattern := regexp.MustCompile(`^R\d+$`) |
| processedInput := strings.ToUpper(strings.TrimSpace(input)) |
| if processedInput != "" && processedInput != defaultLabel && processedInput != latestLabel && |
| !preciseVersionPattern.MatchString(processedInput) && !majorVersionPattern.MatchString(processedInput) { |
| return "", fmt.Errorf("the input is invalid: %s", input) |
| } |
| if gpuType == "" { |
| return "", errors.New("the GPU type must not be empty") |
| } |
| if gpuProtoContent == nil { |
| return "", errors.New("the gpu proto content must not be empty") |
| } |
| gpuDriverVersions, err := parseGPUDriverVersion(gpuProtoContent, gpuType) |
| if err != nil { |
| return "", fmt.Errorf("cannot parse the GPU driver versions info given the proto content and gpuType with error: %v", err) |
| } |
| // If the input label is empty, we will just use default GPU driver versions for this gpu type. |
| if processedInput == "" { |
| processedInput = defaultLabel |
| } |
| // The input label is the precise GPU driver version. |
| if preciseVersionPattern.MatchString(processedInput) { |
| // If the input driver version is compatible with the current GPUType, we just return as it is. |
| log.Infof("The input label: %s is a precise GPU driver version, checking whether it's compatible with GPU: %s", processedInput, gpuType) |
| _, found := gpuDriverVersions.driverVersionToLabels[processedInput] |
| if found { |
| log.Infof("%s is compatible with GPU: %s", processedInput, gpuType) |
| return processedInput, nil |
| } |
| } |
| |
| driverVersion, ok := gpuDriverVersions.labelToDriverVersion[processedInput] |
| if !ok { |
| // The input label is not supported in current GPU Type. |
| if fallback { |
| //If fallback flag is true, we will try to fetch the default GPU driver version for this gpu type. |
| defaultDriverVersion, ok := gpuDriverVersions.labelToDriverVersion[defaultLabel] |
| if !ok { |
| return "", fmt.Errorf("default label is not supported in current gpu type: %s", gpuType) |
| } |
| return defaultDriverVersion, nil |
| } else { |
| // If fallback flag is false, an error will be thrown. |
| return "", fmt.Errorf("failed to check driver compatibility, the input: %s is not supported for GPU type: %s, please use a compatible driver, or use the default/latest flag or consider forcing a fallback using `--force-fallback=true` for the installer to select a compatible driver for the device", processedInput, gpuType) |
| } |
| } |
| return driverVersion, nil |
| } |
| |
| // GetGPUDriverVersions will fetch a list of GPU driver versions that support the gpuType. |
| func GetGPUDriverVersions(gpuProtoContent []byte, gpuType string) ([]*pb.DriverVersion, error) { |
| if gpuType == "" { |
| return nil, errors.New("the GPU type must not be empty") |
| } |
| if gpuProtoContent == nil { |
| return nil, errors.New("the gpu proto content must not be empty") |
| } |
| gpuDriverVersionInfoList, err := parseGPUDriverVersionInfoList(gpuProtoContent) |
| if err != nil { |
| return nil, err |
| } |
| for _, gpuDriverVersionInfo := range gpuDriverVersionInfoList.GetGpuDriverVersionInfo() { |
| if gpuType == gpuDriverVersionInfo.GetGpuDevice().GetGpuType() { |
| return gpuDriverVersionInfo.GetSupportedDriverVersions(), nil |
| } |
| } |
| return nil, fmt.Errorf("no supported driver versions found for gpu: %s", gpuType) |
| } |