| package deviceinfo |
| |
| import ( |
| "fmt" |
| "os" |
| "path/filepath" |
| "slices" |
| "strings" |
| "testing" |
| ) |
| |
| func TestArchSupported(t *testing.T) { |
| tests := []struct { |
| gpuType GPUType |
| expectedArch []string |
| }{ |
| {GB200, []string{"aarch64"}}, |
| {GB300, []string{"aarch64"}}, |
| {B200, []string{"x86_64"}}, |
| {H100, []string{"x86_64"}}, |
| {L4, []string{"x86_64"}}, |
| {NO_GPU, []string{"x86_64", "aarch64"}}, |
| {A100_40GB, []string{"x86_64"}}, |
| {A100_80GB, []string{"x86_64"}}, |
| {Others, []string{"x86_64", "aarch64"}}, |
| } |
| |
| for _, tt := range tests { |
| res := tt.gpuType.SupportedArches() |
| if !slices.Equal(res, tt.expectedArch) { |
| t.Errorf("ArchSupported for %s: expected %v, got %v", tt.gpuType, tt.expectedArch, res) |
| } |
| } |
| } |
| |
| func TestGetGPUTypeInfo(t *testing.T) { |
| createMockPCIDevice := func(t *testing.T, base, vendor, device string) { |
| t.Helper() |
| devicePath := filepath.Join(base, fmt.Sprintf("0000:00:%s.0", device[len(device)-1:])) |
| if err := os.MkdirAll(devicePath, 0755); err != nil { |
| t.Fatalf("Failed to create test dir: %v", err) |
| } |
| if vendor != "" { |
| if err := os.WriteFile(filepath.Join(devicePath, "vendor"), []byte(vendor), 0644); err != nil { |
| t.Fatalf("Failed to write vendor file: %v", err) |
| } |
| } |
| if device != "" { |
| if err := os.WriteFile(filepath.Join(devicePath, "device"), []byte(device), 0644); err != nil { |
| t.Fatalf("Failed to write device file: %v", err) |
| } |
| } |
| } |
| |
| tests := []struct { |
| name string |
| setup func(pciPath string) |
| wantGpuType GPUType |
| wantErr bool |
| errContains string |
| }{ |
| { |
| name: "found L4", |
| setup: func(pciPath string) { |
| createMockPCIDevice(t, pciPath, "0x10de", "0x27b8") |
| }, |
| wantGpuType: L4, |
| }, |
| { |
| name: "empty pci path", |
| setup: func(pciPath string) {}, |
| wantGpuType: NO_GPU, |
| }, |
| { |
| name: "Unknown Nvidia GPU", |
| setup: func(pciPath string) { |
| createMockPCIDevice(t, pciPath, "0x10de", "0xffff") // Unknown Device ID |
| }, |
| wantGpuType: Others, |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| pciPath := t.TempDir() |
| tt.setup(pciPath) |
| gotGpuType, err := GetGPUTypeInfo(pciPath) |
| |
| if (err != nil) != tt.wantErr { |
| t.Fatalf("GetGPUTypeInfo() error = %v, wantErr %v", err, tt.wantErr) |
| } |
| |
| if tt.wantErr { |
| if err == nil || !strings.Contains(err.Error(), tt.errContains) { |
| t.Errorf("GetGPUTypeInfo() error = %q, wantErrMessage containing %q", err, tt.errContains) |
| } |
| } |
| |
| if !tt.wantErr && gotGpuType != tt.wantGpuType { |
| t.Errorf("GetGPUTypeInfo() = %v, want %v", gotGpuType, tt.wantGpuType) |
| } |
| }) |
| } |
| } |
| |
| func TestSupportsArch(t *testing.T) { |
| tests := []struct { |
| gpuType GPUType |
| arch string |
| expected bool |
| }{ |
| {GB200, "x86_64", false}, |
| {GB200, "aarch64", true}, |
| {GB200, "AaRCh64", true}, |
| {GB300, "x86_64", false}, |
| {GB300, "aarch64", true}, |
| {GB300, "AaRCh64", true}, |
| {B200, "x86_64", true}, |
| {A100_40GB, "aarch64", false}, |
| {Others, "x86_64", true}, |
| {Others, "aarch64", true}, |
| {L4, "x86_64", true}, |
| {H100, "X86_64", true}, |
| {H200, " x86_64 ", true}, |
| {NO_GPU, "x86_64", true}, |
| {NO_GPU, "aarch64", true}, |
| } |
| for _, tt := range tests { |
| res := tt.gpuType.SupportsArch(tt.arch) |
| if res != tt.expected { |
| t.Errorf("SupportedArch(%s) for %s: expected %v, got %v", tt.arch, tt.gpuType, tt.expected, res) |
| } |
| } |
| } |