blob: 3f9a9b48f708331dd6ba2ffad7dba33644056c45 [file] [log] [blame] [edit]
package gpuconfig
import (
"fmt"
"io"
"net/http"
"strings"
log "github.com/golang/glog"
)
const (
vmShapeEndpoint = "http://metadata.google.internal/computeMetadata/v1/instance/machine-type"
hostDriverVersionEndpoint = "http://metadata.google.internal/computeMetadata/v1/instance/host/nvidia-host-driver-version"
)
// All vm shapes that have <1 GPU
var FractionalVMShapes = []string{"g4-standard-6", "g4-standard-12", "g4-standard-24"}
// queryMDS handles the HTTP GET request logic. Returns the response, status code, err.
// We can't use cloud.google.com/go/compute/metadata because we need to know the status code of the response
func queryMDS(client http.Client, endpoint string) (string, int, error) {
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return "", -1, fmt.Errorf("error creating request for endpoint %s: %v", endpoint, err)
}
req.Header.Add("Metadata-Flavor", "Google")
resp, err := client.Do(req)
if err != nil {
return "", -1, fmt.Errorf("error executing API request to %s: %v", endpoint, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", -1, fmt.Errorf("error reading response from %s: %v", endpoint, err)
}
log.Infof("API response for endpoint %s is: %s, %d, %v", endpoint, string(body), resp.StatusCode, nil)
return string(body), resp.StatusCode, nil
}
// Returns the VM shape. This is needed for us to know if we shuold make a second
// API call to get the host's driver version.
// Once the host version's API endpoint is available from all machine types, we can remove this logic.
func VmShapeFromMDS(client http.Client) (string, error) {
// expected API response is something like "projects/1234567890/machineTypes/g4-standard-24" - we care about the last part
fullShapeString, statusCode, err := queryMDS(client, vmShapeEndpoint)
if err != nil || statusCode != 200 {
return "", fmt.Errorf("Error getting vm shape. Got status code: %d, err: %v", statusCode, err)
}
shapeParts := strings.Split(fullShapeString, "/")
return shapeParts[len(shapeParts)-1], nil
}
// Returns host_driver_version, isVGPU, err. isVGP is true iff MDS endpoint returns http 200 and the host version.
func VGpuStatusFromMds(client http.Client) (string, bool, error) {
hostVersion, statusCode, err := queryMDS(client, hostDriverVersionEndpoint)
// MDS returns version parts separated by '_'. We handle all versions with ".", so swap that
hostVersion = strings.ReplaceAll(hostVersion, "_", ".")
if err != nil {
return "", false, err
}
if statusCode == 404 { // not vGPU
return "", false, nil
} else if statusCode == 200 {
if PreciseVersionPattern.MatchString(hostVersion) {
// http 200 and a valid version - this is a vGPU
return hostVersion, true, nil
}
// http 200 and not a valid version - this is a vGPU but MDS could not get host's version
return "", true, fmt.Errorf("MDS: got 200 status code but host driver version unrecognized - %s", hostVersion)
}
// Any status code other than 200 and 404 is unexpected and is seen as an error
return "", false, fmt.Errorf("MDS: for unexpected http status code for host driver version - %d", statusCode)
}