| package gpuconfig |
| |
| import ( |
| "errors" |
| "fmt" |
| "os" |
| "strings" |
| "testing" |
| |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb" |
| "github.com/golang/protobuf/proto" |
| "github.com/google/go-cmp/cmp" |
| "google.golang.org/protobuf/encoding/prototext" |
| ) |
| |
| const testProtoDataPath = "./testdata/gpu_driver_versions_test.textproto" |
| |
| func TestGetGPUDriverVersion(t *testing.T) { |
| textProtoData, err := os.ReadFile(testProtoDataPath) |
| if err != nil { |
| t.Errorf("Cannot read the test textproto data: %v", err) |
| } |
| var gpuDriverVersionInfoList = &pb.GPUDriverVersionInfoList{} |
| err = (prototext.UnmarshalOptions{DiscardUnknown: true, AllowPartial: true}).Unmarshal(textProtoData, gpuDriverVersionInfoList) |
| if err != nil { |
| t.Errorf("failed when parsing the GPU driver version: %v", err) |
| } |
| testData, err := proto.Marshal(gpuDriverVersionInfoList) |
| if err != nil { |
| t.Errorf("fail to encode to the binary proto data: %v", err) |
| } |
| var testCases = []struct { |
| gpuProtoContent []byte |
| gpuType string |
| input string |
| fallback bool |
| expectedDriverVersion string |
| err error |
| }{ |
| { |
| nil, |
| "NVIDIA_L4", |
| "", |
| false, |
| "", |
| errors.New("the gpu proto content must not be empty"), |
| }, |
| { |
| testData, |
| "", |
| "", |
| false, |
| "", |
| errors.New("the GPU type must not be empty"), |
| }, |
| { |
| testData, |
| " nvidia_l4 ", |
| "", |
| false, |
| "535.154.05", |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "560.38", |
| false, |
| "560.38", |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "DEFAULT", |
| false, |
| "535.154.05", |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "default", |
| false, |
| "535.154.05", |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "latest", |
| false, |
| "535.154.05", |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "R535", |
| false, |
| "535.154.05", |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "r535", |
| false, |
| "535.154.05", |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "535.129.03", |
| false, |
| "535.129.03", |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "R470", |
| false, |
| "", |
| errors.New("not supported for GPU type"), |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "R470", |
| true, |
| "535.154.05", |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "R525", |
| true, |
| "535.154.05", |
| nil, |
| }, |
| { |
| testData, |
| "invalidGPU", |
| "R525", |
| true, |
| "", |
| errors.New("no supported driver versions found for gpu"), |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "invalidInput", |
| false, |
| "", |
| errors.New("the input is invalid"), |
| }, |
| { |
| testData, |
| "NVIDIA_L4", |
| "invalidInput", |
| true, |
| "", |
| errors.New("the input is invalid"), |
| }, |
| { |
| testData, |
| "no_gpu", |
| "r525", |
| true, |
| "535.161.08", |
| nil, |
| }, |
| { |
| testData, |
| "no_gpu", |
| "r525", |
| true, |
| "535.161.08", |
| nil, |
| }, |
| { |
| testData, |
| "Sample_GPU_without_default_label", |
| "r525", |
| true, |
| "", |
| errors.New("default label is not supported in current gpu type"), |
| }, |
| } |
| for index, tc := range testCases { |
| t.Run(fmt.Sprintf("Test %d: GetGPUDriverVersion: with gpuType: %s, input: %s, fallback flag: %v", index, tc.gpuType, tc.input, tc.fallback), func(t *testing.T) { |
| driverVersion, err := GetGPUDriverVersion(tc.gpuProtoContent, tc.gpuType, tc.input, tc.fallback) |
| if tc.err == nil { |
| if err != nil { |
| t.Errorf("Failed to GetGPUDriverVersion with error: %v", err) |
| } |
| if !cmp.Equal(tc.expectedDriverVersion, driverVersion) { |
| t.Errorf("Test GetGPUDriverVersion failed: the expected driver version is: %s, but the actual driver version is: %s.", tc.expectedDriverVersion, driverVersion) |
| } |
| } else { |
| if err == nil || !strings.Contains(err.Error(), tc.err.Error()) { |
| t.Errorf("Test GetGPUDriverVersion failed: the error been sent out is not correct, the exected: %s, the actual: %s", tc.err.Error(), err.Error()) |
| } |
| } |
| }) |
| } |
| } |
| |
| func TestGetGPUDriverVersions(t *testing.T) { |
| textProtoData, err := os.ReadFile(testProtoDataPath) |
| if err != nil { |
| t.Errorf("Cannot read the test textproto data: %v", err) |
| } |
| var gpuDriverVersionInfoList = &pb.GPUDriverVersionInfoList{} |
| err = (prototext.UnmarshalOptions{DiscardUnknown: true, AllowPartial: true}).Unmarshal(textProtoData, gpuDriverVersionInfoList) |
| if err != nil { |
| t.Errorf("failed when parsing the GPU driver version: %v", err) |
| } |
| testData, err := proto.Marshal(gpuDriverVersionInfoList) |
| if err != nil { |
| t.Errorf("fail to encode to the binary proto data: %v", err) |
| } |
| var testCases = []struct { |
| gpuProtoContent []byte |
| gpuType string |
| expectedDriverVersions []*pb.DriverVersion |
| err error |
| }{ |
| { |
| nil, |
| "NVIDIA_L4", |
| nil, |
| errors.New("the gpu proto content must not be empty"), |
| }, |
| { |
| testData, |
| "", |
| nil, |
| errors.New("the GPU type must not be empty"), |
| }, |
| { |
| testData, |
| "NVIDIA_TESLA_V100", |
| []*pb.DriverVersion{ |
| {Version: "535.154.05", Label: "DEFAULT"}, |
| {Version: "535.154.05", Label: "LATEST"}, |
| {Version: "535.154.05", Label: "R535"}, |
| {Version: "535.129.03"}, |
| {Version: "535.104.12"}, |
| {Version: "535.104.05"}, |
| {Version: "470.223.02", Label: "R470"}, |
| {Version: "470.199.02"}, |
| }, |
| nil, |
| }, |
| { |
| testData, |
| "NVIDIA_TESLA_P100", |
| []*pb.DriverVersion{ |
| {Version: "535.154.05", Label: "DEFAULT"}, |
| {Version: "535.154.05", Label: "LATEST"}, |
| {Version: "535.154.05", Label: "R535"}, |
| {Version: "535.129.03"}, |
| {Version: "535.104.12"}, |
| {Version: "535.104.05"}, |
| {Version: "470.223.02", Label: "R470"}, |
| {Version: "470.199.02"}, |
| }, |
| nil, |
| }, |
| { |
| testData, |
| "InvalidGPU", |
| nil, |
| errors.New("no supported driver versions found for gpu"), |
| }, |
| } |
| |
| for index, tc := range testCases { |
| t.Run(fmt.Sprintf("Test %d: GetGPUDriverVersions: with gpuType: %s", index, tc.gpuType), func(t *testing.T) { |
| driverVersions, err := GetGPUDriverVersions(tc.gpuProtoContent, tc.gpuType) |
| if tc.err == nil { |
| if err != nil { |
| t.Errorf("Failed to GetGPUDriverVersions with error: %v", err) |
| } |
| if len(driverVersions) != len(tc.expectedDriverVersions) { |
| t.Errorf("Test GetGPUDriverVersions failed: the length of expected gpu drivers is %v, but the actual lenght is %v", len(tc.expectedDriverVersions), len(driverVersions)) |
| } |
| for index, driverVersion := range driverVersions { |
| expectedDriverVersion := tc.expectedDriverVersions[index] |
| if expectedDriverVersion.Label != driverVersion.Label { |
| t.Errorf("Test GetGPUDriverVersions failed: the expected driver version list contains DriverVersion{Label: %s, Version: %s},"+ |
| "but the actual driver version list contains {DriverVersion{Label: %s, Version: %s}}.", expectedDriverVersion.Label, expectedDriverVersion.Version, |
| driverVersion.Label, driverVersion.Version) |
| } |
| if expectedDriverVersion.Version != driverVersion.Version { |
| t.Errorf("Test GetGPUDriverVersions failed: the expected driver version list contains DriverVersion{Label: %s, Version: %s},"+ |
| "but the actual driver version list contains {DriverVersion{Label: %s, Version: %s}}.", expectedDriverVersion.Label, expectedDriverVersion.Version, |
| driverVersion.Label, driverVersion.Version) |
| } |
| } |
| } else { |
| if err == nil || !strings.Contains(err.Error(), tc.err.Error()) { |
| t.Errorf("Test GetGPUDriverVersions failed: the error been sent out is not correct, the exected: %s, the actual: %s", tc.err.Error(), err.Error()) |
| } |
| } |
| }) |
| } |
| } |