| package gpuconfig |
| |
| import ( |
| "fmt" |
| "io" |
| "net/http" |
| "strings" |
| |
| log "github.com/golang/glog" |
| ) |
| |
| const ( |
| vmShapeEndpoint = "http://metadata.google.internal/computeMetadata/v1/instance/machine-type" |
| hostDriverVersionEndpoint = "http://metadata.google.internal/computeMetadata/v1/instance/host/nvidia-host-driver-version" |
| ) |
| |
| // All vm shapes that have <1 GPU |
| var FractionalVMShapes = []string{"g4-standard-6", "g4-standard-12", "g4-standard-24"} |
| |
| // queryMDS handles the HTTP GET request logic. Returns the response, status code, err. |
| // We can't use cloud.google.com/go/compute/metadata because we need to know the status code of the response |
| func queryMDS(client http.Client, endpoint string) (string, int, error) { |
| req, err := http.NewRequest("GET", endpoint, nil) |
| if err != nil { |
| return "", -1, fmt.Errorf("error creating request for endpoint %s: %v", endpoint, err) |
| } |
| |
| req.Header.Add("Metadata-Flavor", "Google") |
| |
| resp, err := client.Do(req) |
| if err != nil { |
| return "", -1, fmt.Errorf("error executing API request to %s: %v", endpoint, err) |
| } |
| defer resp.Body.Close() |
| |
| body, err := io.ReadAll(resp.Body) |
| if err != nil { |
| return "", -1, fmt.Errorf("error reading response from %s: %v", endpoint, err) |
| } |
| log.Infof("API response for endpoint %s is: %s, %d, %v", endpoint, string(body), resp.StatusCode, nil) |
| return string(body), resp.StatusCode, nil |
| } |
| |
| // Returns the VM shape. This is needed for us to know if we shuold make a second |
| // API call to get the host's driver version. |
| // Once the host version's API endpoint is available from all machine types, we can remove this logic. |
| func VmShapeFromMDS(client http.Client) (string, error) { |
| // expected API response is something like "projects/1234567890/machineTypes/g4-standard-24" - we care about the last part |
| fullShapeString, statusCode, err := queryMDS(client, vmShapeEndpoint) |
| if err != nil || statusCode != 200 { |
| return "", fmt.Errorf("Error getting vm shape. Got status code: %d, err: %v", statusCode, err) |
| } |
| shapeParts := strings.Split(fullShapeString, "/") |
| return shapeParts[len(shapeParts)-1], nil |
| } |
| |
| // Returns host_driver_version, isVGPU, err. isVGP is true iff MDS endpoint returns http 200 and the host version. |
| func VGpuStatusFromMds(client http.Client) (string, bool, error) { |
| hostVersion, statusCode, err := queryMDS(client, hostDriverVersionEndpoint) |
| // MDS returns version parts separated by '_'. We handle all versions with ".", so swap that |
| hostVersion = strings.ReplaceAll(hostVersion, "_", ".") |
| if err != nil { |
| return "", false, err |
| } |
| if statusCode == 404 { // not vGPU |
| return "", false, nil |
| } else if statusCode == 200 { |
| if PreciseVersionPattern.MatchString(hostVersion) { |
| // http 200 and a valid version - this is a vGPU |
| return hostVersion, true, nil |
| } |
| // http 200 and not a valid version - this is a vGPU but MDS could not get host's version |
| return "", true, fmt.Errorf("MDS: got 200 status code but host driver version unrecognized - %s", hostVersion) |
| } |
| // Any status code other than 200 and 404 is unexpected and is seen as an error |
| return "", false, fmt.Errorf("MDS: for unexpected http status code for host driver version - %d", statusCode) |
| } |