blob: e94c1422e86b5c0b905221fb6e6c151e7646036c [file] [log] [blame]
// Package deviceinfo provides the devices information for cos-gpu-installer
package deviceinfo
import (
"bytes"
"fmt"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
)
type GPUType int
const (
K80 GPUType = iota
P4
P100
V100
L4
T4
H100
H200
B200
GB200
GB300
A100_40GB
A100_80GB
RTX_PRO_6000
NO_GPU
Others
NvidiaVendorID uint16 = 0x10de
)
// gpuTypeToPciIDs provides the mapping from a GPUType to its PCI vendor and
// device IDs. This is the source of truth for hardware identification.
var gpuTypeToPciIDs = map[GPUType]struct {
VendorID uint16
DeviceID uint16
}{
K80: {NvidiaVendorID, 0x102d},
P100: {NvidiaVendorID, 0x15f8},
V100: {NvidiaVendorID, 0x1db1},
P4: {NvidiaVendorID, 0x1bb3},
L4: {NvidiaVendorID, 0x27b8},
H100: {NvidiaVendorID, 0x2330},
A100_40GB: {NvidiaVendorID, 0x20b0},
A100_80GB: {NvidiaVendorID, 0x20b2},
T4: {NvidiaVendorID, 0x1eb8},
H200: {NvidiaVendorID, 0x2335},
B200: {NvidiaVendorID, 0x2901},
GB200: {NvidiaVendorID, 0x2941},
GB300: {NvidiaVendorID, 0x31c2},
RTX_PRO_6000: {NvidiaVendorID, 0x2bb5},
}
// AvailableGPUTypesList returns a list of GPU devices supported based on the milestone
func AvailableGPUTypesList(milestone string) ([]GPUType, error) {
var availableGpuTypeList = []GPUType{P4, P100, V100, L4, T4, H100, H200, A100_40GB, A100_80GB}
milestoneInt, err := strconv.Atoi(milestone)
if err != nil {
return nil, fmt.Errorf("invalid milestone input: %v", err)
}
if milestoneInt > 105 {
availableGpuTypeList = append(availableGpuTypeList, B200)
availableGpuTypeList = append(availableGpuTypeList, RTX_PRO_6000)
}
if milestoneInt > 113 {
availableGpuTypeList = append(availableGpuTypeList, GB200)
}
if milestoneInt >= 125 {
availableGpuTypeList = append(availableGpuTypeList, GB300)
}
return availableGpuTypeList, nil
}
var AllGPUTypeStrings = []string{
"NVIDIA_TESLA_K80",
"NVIDIA_TESLA_P4",
"NVIDIA_TESLA_P100",
"NVIDIA_TESLA_V100",
"NVIDIA_L4",
"NVIDIA_H100_80GB",
"NVIDIA_H200",
"NVIDIA_B200",
"NVIDIA_TESLA_A100",
"NVIDIA_A100_80GB",
"NVIDIA_TESLA_T4",
"NVIDIA_GB200",
"NVIDIA_GB300",
"NVIDIA_RTX_PRO_6000",
}
// ParseGPUType converts a string to a GPUType enum.
func ParseGPUType(gpu string) (GPUType, error) {
processedGPU := strings.ToUpper(strings.TrimSpace(gpu))
switch processedGPU {
case "NVIDIA_TESLA_K80":
return K80, nil
case "NVIDIA_TESLA_P4":
return P4, nil
case "NVIDIA_TESLA_P100":
return P100, nil
case "NVIDIA_TESLA_V100":
return V100, nil
case "NVIDIA_L4":
return L4, nil
case "NVIDIA_H100_80GB":
return H100, nil
case "NVIDIA_TESLA_A100":
return A100_40GB, nil
case "NVIDIA_A100_80GB":
return A100_80GB, nil
case "NVIDIA_TESLA_T4":
return T4, nil
case "NVIDIA_H200":
return H200, nil
case "NVIDIA_B200":
return B200, nil
case "NVIDIA_GB200":
return GB200, nil
case "NVIDIA_GB300":
return GB300, nil
case "NVIDIA_RTX_PRO_6000":
return RTX_PRO_6000, nil
default:
return 0, fmt.Errorf("invalid GPU type string. Available GPU types are: %s", strings.Join(AllGPUTypeStrings, ", "))
}
}
func (g GPUType) String() string {
switch g {
case K80:
return "NVIDIA_TESLA_K80"
case P4:
return "NVIDIA_TESLA_P4"
case P100:
return "NVIDIA_TESLA_P100"
case V100:
return "NVIDIA_TESLA_V100"
case L4:
return "NVIDIA_L4"
case H100:
return "NVIDIA_H100_80GB"
case A100_40GB:
return "NVIDIA_TESLA_A100"
case A100_80GB:
return "NVIDIA_A100_80GB"
case T4:
return "NVIDIA_TESLA_T4"
case H200:
return "NVIDIA_H200"
case B200:
return "NVIDIA_B200"
case GB200:
return "NVIDIA_GB200"
case GB300:
return "NVIDIA_GB300"
case RTX_PRO_6000:
return "NVIDIA_RTX_PRO_6000"
case NO_GPU:
return "NO_GPU"
case Others:
return "OTHERS"
default:
return "UNKNOWN"
}
}
// TODO(gshuoy): b/331317222 - Add the open source supported in the proto file.
func (g GPUType) OpenSupported() bool {
switch g {
case NO_GPU, K80, P4, P100, V100:
return false
default:
return true
}
}
// TODO(gshuoy): b/331317222 - Add the arch support in the proto file.
// SupportedArches returns the list of supported architectures for the GPU type.
// Typically either "x86_64", "aarch64", or both (if future GPUs support both).
func (g GPUType) SupportedArches() []string {
switch g {
case GB200, GB300:
return []string{"aarch64"}
case NO_GPU, Others:
return []string{"x86_64", "aarch64"}
default:
return []string{"x86_64"}
}
}
func (g GPUType) SupportsArch(arch string) bool {
processedArch := strings.ToLower(strings.TrimSpace(arch))
return slices.Contains(g.SupportedArches(), processedArch)
}
// GetGPUTypeInfo finds the GPU device type by reading sysfs.
// It detects the PCI ID on the system and looks it up in the gpuTypeToPciIDs map.
func GetGPUTypeInfo(pciDevicesPath string) (GPUType, error) {
devices, err := filepath.Glob(filepath.Join(pciDevicesPath, "*"))
if err != nil {
return NO_GPU, fmt.Errorf("failed to glob PCI devices: %w", err)
}
foundNvidiaDevice := false
for _, devicePath := range devices {
vendorID, deviceID, err := parsePciDeviceInfo(devicePath)
if err != nil {
// Skip devices where we can't read vendor/device info
continue
}
if vendorID != NvidiaVendorID {
continue
}
foundNvidiaDevice = true
for gpuType, pciIDs := range gpuTypeToPciIDs {
if pciIDs.VendorID == vendorID && pciIDs.DeviceID == deviceID {
return gpuType, nil // Found a known NVIDIA GPU
}
}
}
if !foundNvidiaDevice {
return NO_GPU, nil
}
// Found an NVIDIA device, but its PCI ID did not match any known types in gpuTypeToPciIDs.
return Others, nil
}
func parsePciIdPart(part []byte) (uint16, error) {
part = bytes.TrimPrefix(part, []byte("0x"))
part = bytes.TrimSuffix(part, []byte("\n"))
result, err := strconv.ParseUint(string(part), 16, 16)
return uint16(result), err
}
func parsePciDeviceInfo(devicePath string) (uint16, uint16, error) {
vendorIDFile := filepath.Join(devicePath, "vendor")
vendorIDBytes, err := os.ReadFile(vendorIDFile)
if err != nil {
return 0, 0, err
}
vendorID, err := parsePciIdPart(vendorIDBytes)
if err != nil {
return 0, 0, err
}
deviceIDFile := filepath.Join(devicePath, "device")
deviceIDBytes, err := os.ReadFile(deviceIDFile)
if err != nil {
return 0, 0, err
}
deviceID, err := parsePciIdPart(deviceIDBytes)
if err != nil {
return 0, 0, err
}
return vendorID, deviceID, nil
}