blob: 049bafdc5fc4d4da6952c16fc44f7d05a422aad1 [file] [log] [blame]
// Package gpuconfig implements routines for manipulating proto based
// GPU build configuration files.
//
// It also implements the construction of these configs for
// the COS Image and the COS Kernel CI.
package gpuconfig
import (
"errors"
"fmt"
"regexp"
"strings"
"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb"
log "github.com/golang/glog"
"google.golang.org/protobuf/proto"
)
const (
defaultLabel string = "DEFAULT"
latestLabel string = "LATEST"
)
type gpuDriverVersion struct {
gpuType string
labelToDriverVersion map[string]string
driverVersionToLabels map[string][]string
}
// parseGPUDriverVersionInfoLists will parse the GPU driver version proto content to pb.GPUDriverVersionInfoList.
func parseGPUDriverVersionInfoList(gpuDriverVersionContent []byte) (*pb.GPUDriverVersionInfoList, error) {
var gpuDriverVersionInfoList = &pb.GPUDriverVersionInfoList{}
if gpuDriverVersionContent == nil {
return nil, errors.New("the gpu proto content must not be empty")
}
err := (proto.UnmarshalOptions{DiscardUnknown: true, AllowPartial: true}).Unmarshal(gpuDriverVersionContent, gpuDriverVersionInfoList)
if err != nil {
return nil, fmt.Errorf("failed when parsing the GPU driver version: %v", err)
}
return gpuDriverVersionInfoList, nil
}
// parseGPUDriverVersion will parse the GPU driver version content to GPUDriverVersionInfoList.
func parseGPUDriverVersion(gpuDriverVersionContent []byte, gpuType string) (*gpuDriverVersion, error) {
gpuDriverVersionInfoList, err := parseGPUDriverVersionInfoList(gpuDriverVersionContent)
if err != nil {
return nil, err
}
var labelToDriverVersion = map[string]string{}
var driverVersionToLabels = map[string][]string{}
for _, gpuDriverVersionInfo := range gpuDriverVersionInfoList.GetGpuDriverVersionInfo() {
if strings.ToUpper(strings.TrimSpace(gpuType)) == strings.ToUpper(strings.TrimSpace(gpuDriverVersionInfo.GetGpuType())) {
for _, driverVersionInfo := range gpuDriverVersionInfo.GetSupportedDriverVersions() {
driverVersion := driverVersionInfo.GetVersion()
driverVersionLabel := driverVersionInfo.GetLabel()
if _, ok := driverVersionToLabels[driverVersion]; !ok {
driverVersionToLabels[driverVersion] = []string{}
}
if driverVersionLabel != "" {
labelToDriverVersion[driverVersionLabel] = driverVersion
driverVersionToLabels[driverVersion] = append(driverVersionToLabels[driverVersion], driverVersionLabel)
}
}
}
}
// If the driverVersionToLabels is empty means the input GPU type is not found in the GPU driver version proto content
if len(driverVersionToLabels) == 0 {
return nil, fmt.Errorf("no supported driver versions found for gpu: %s", gpuType)
}
return &gpuDriverVersion{gpuType, labelToDriverVersion, driverVersionToLabels}, nil
}
// GetGPUDriverVersion fetch the GPU driver version from the input given the specific gpuType, fallback flag and proto bytes.
// The input here can be:
// 1. Precise GPU driver versions, e.g. 535.125.67,
// In this case, we will check whether this version is compatible with gpuType.
// 2. Major GPU version label, e.g. R535, or default or latest or empty label.
// In this case, we will find out the driver version which is associated to this label for gpuType.
func GetGPUDriverVersion(gpuProtoContent []byte, gpuType string, input string, fallback bool) (string, error) {
preciseVersionPattern := regexp.MustCompile(`^\d+(\.\d+){2}$`)
majorVersionPattern := regexp.MustCompile(`^R\d+$`)
processedInput := strings.ToUpper(strings.TrimSpace(input))
if processedInput != "" && processedInput != defaultLabel && processedInput != latestLabel &&
!preciseVersionPattern.MatchString(processedInput) && !majorVersionPattern.MatchString(processedInput) {
return "", fmt.Errorf("the input is invalid: %s", input)
}
if gpuType == "" {
return "", errors.New("the GPU type must not be empty")
}
if gpuProtoContent == nil {
return "", errors.New("the gpu proto content must not be empty")
}
gpuDriverVersions, err := parseGPUDriverVersion(gpuProtoContent, gpuType)
if err != nil {
return "", fmt.Errorf("cannot parse the GPU driver versions info given the proto content and gpuType with error: %v", err)
}
// If the input label is empty, we will just use default GPU driver versions for this gpu type.
if processedInput == "" {
processedInput = defaultLabel
}
// The input label is the precise GPU driver version.
if preciseVersionPattern.MatchString(processedInput) {
// If the input driver version is compatible with the current GPUType, we just return as it is.
log.Infof("The input label: %s is a precise GPU driver version, checking whether it's compatible with GPU: %s", processedInput, gpuType)
_, found := gpuDriverVersions.driverVersionToLabels[processedInput]
if found {
log.Infof("%s is compatible with GPU: %s", processedInput, gpuType)
return processedInput, nil
}
}
driverVersion, ok := gpuDriverVersions.labelToDriverVersion[processedInput]
if !ok {
// The input label is not supported in current GPU Type.
if fallback {
//If fallback flag is true, we will try to fetch the default GPU driver version for this gpu type.
defaultDriverVersion, ok := gpuDriverVersions.labelToDriverVersion[defaultLabel]
if !ok {
return "", fmt.Errorf("default label is not supported in current gpu type: %s", gpuType)
}
return defaultDriverVersion, nil
} else {
// If fallback flag is false, an error will be thrown.
return "", fmt.Errorf("the input: %s is not supported for GPU type: %s", processedInput, gpuType)
}
}
return driverVersion, nil
}
// GetGPUDriverVersions will fetch a list of GPU driver versions that support the gpuType.
func GetGPUDriverVersions(gpuProtoContent []byte, gpuType string) ([]*pb.DriverVersion, error) {
if gpuType == "" {
return nil, errors.New("the GPU type must not be empty")
}
if gpuProtoContent == nil {
return nil, errors.New("the gpu proto content must not be empty")
}
gpuDriverVersionInfoList, err := parseGPUDriverVersionInfoList(gpuProtoContent)
if err != nil {
return nil, err
}
for _, gpuDriverVersionInfo := range gpuDriverVersionInfoList.GetGpuDriverVersionInfo() {
if gpuType == gpuDriverVersionInfo.GetGpuType() {
return gpuDriverVersionInfo.GetSupportedDriverVersions(), nil
}
}
return nil, fmt.Errorf("no supported driver versions found for gpu: %s", gpuType)
}