blob: a4dd36e9e2f2a260721835ef6bd3767da0e0c2c4 [file] [log] [blame]
package deviceinfo
import (
"slices"
"testing"
)
func TestArchSupported(t *testing.T) {
tests := []struct {
gpuType GPUType
expectedArch []string
}{
{GB200, []string{"aarch64"}},
{B200, []string{"x86_64"}},
{H100, []string{"x86_64"}},
{L4, []string{"x86_64"}},
{NO_GPU, []string{"x86_64", "aarch64"}},
{A100_40GB, []string{"x86_64"}},
{A100_80GB, []string{"x86_64"}},
{Others, []string{"x86_64", "aarch64"}},
}
for _, tt := range tests {
res := tt.gpuType.SupportedArches()
if !slices.Equal(res, tt.expectedArch) {
t.Errorf("ArchSupported for %s: expected %v, got %v", tt.gpuType, tt.expectedArch, res)
}
}
}
func TestSupportsArch(t *testing.T) {
tests := []struct {
gpuType GPUType
arch string
expected bool
}{
{GB200, "x86_64", false},
{GB200, "aarch64", true},
{GB200, "AaRCh64", true},
{B200, "x86_64", true},
{A100_40GB, "aarch64", false},
{Others, "x86_64", true},
{Others, "aarch64", true},
{L4, "x86_64", true},
{H100, "X86_64", true},
{H200, " x86_64 ", true},
{NO_GPU, "x86_64", true},
{NO_GPU, "aarch64", true},
}
for _, tt := range tests {
res := tt.gpuType.SupportsArch(tt.arch)
if res != tt.expected {
t.Errorf("SupportedArch(%s) for %s: expected %v, got %v", tt.arch, tt.gpuType, tt.expected, res)
}
}
}