blob: ffc1af06ba063a1a89e6a6176ba8ee2d0ff527ba [file] [log] [blame]
package deviceinfo
import (
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"testing"
)
func TestArchSupported(t *testing.T) {
tests := []struct {
gpuType GPUType
expectedArch []string
}{
{GB200, []string{"aarch64"}},
{GB300, []string{"aarch64"}},
{B200, []string{"x86_64"}},
{H100, []string{"x86_64"}},
{L4, []string{"x86_64"}},
{NO_GPU, []string{"x86_64", "aarch64"}},
{A100_40GB, []string{"x86_64"}},
{A100_80GB, []string{"x86_64"}},
{Others, []string{"x86_64", "aarch64"}},
}
for _, tt := range tests {
res := tt.gpuType.SupportedArches()
if !slices.Equal(res, tt.expectedArch) {
t.Errorf("ArchSupported for %s: expected %v, got %v", tt.gpuType, tt.expectedArch, res)
}
}
}
func TestGetGPUTypeInfo(t *testing.T) {
createMockPCIDevice := func(t *testing.T, base, vendor, device string) {
t.Helper()
devicePath := filepath.Join(base, fmt.Sprintf("0000:00:%s.0", device[len(device)-1:]))
if err := os.MkdirAll(devicePath, 0755); err != nil {
t.Fatalf("Failed to create test dir: %v", err)
}
if vendor != "" {
if err := os.WriteFile(filepath.Join(devicePath, "vendor"), []byte(vendor), 0644); err != nil {
t.Fatalf("Failed to write vendor file: %v", err)
}
}
if device != "" {
if err := os.WriteFile(filepath.Join(devicePath, "device"), []byte(device), 0644); err != nil {
t.Fatalf("Failed to write device file: %v", err)
}
}
}
tests := []struct {
name string
setup func(pciPath string)
wantGpuType GPUType
wantErr bool
errContains string
}{
{
name: "found L4",
setup: func(pciPath string) {
createMockPCIDevice(t, pciPath, "0x10de", "0x27b8")
},
wantGpuType: L4,
},
{
name: "empty pci path",
setup: func(pciPath string) {},
wantGpuType: NO_GPU,
},
{
name: "Unknown Nvidia GPU",
setup: func(pciPath string) {
createMockPCIDevice(t, pciPath, "0x10de", "0xffff") // Unknown Device ID
},
wantGpuType: Others,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pciPath := t.TempDir()
tt.setup(pciPath)
gotGpuType, err := GetGPUTypeInfo(pciPath)
if (err != nil) != tt.wantErr {
t.Fatalf("GetGPUTypeInfo() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
if err == nil || !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("GetGPUTypeInfo() error = %q, wantErrMessage containing %q", err, tt.errContains)
}
}
if !tt.wantErr && gotGpuType != tt.wantGpuType {
t.Errorf("GetGPUTypeInfo() = %v, want %v", gotGpuType, tt.wantGpuType)
}
})
}
}
func TestSupportsArch(t *testing.T) {
tests := []struct {
gpuType GPUType
arch string
expected bool
}{
{GB200, "x86_64", false},
{GB200, "aarch64", true},
{GB200, "AaRCh64", true},
{GB300, "x86_64", false},
{GB300, "aarch64", true},
{GB300, "AaRCh64", true},
{B200, "x86_64", true},
{A100_40GB, "aarch64", false},
{Others, "x86_64", true},
{Others, "aarch64", true},
{L4, "x86_64", true},
{H100, "X86_64", true},
{H200, " x86_64 ", true},
{NO_GPU, "x86_64", true},
{NO_GPU, "aarch64", true},
}
for _, tt := range tests {
res := tt.gpuType.SupportsArch(tt.arch)
if res != tt.expected {
t.Errorf("SupportedArch(%s) for %s: expected %v, got %v", tt.arch, tt.gpuType, tt.expected, res)
}
}
}