blob: a63969f35421e05fdef4874727ab2368bc73efd1 [file] [log] [blame]
package gpuconfig
import (
"errors"
"fmt"
"os"
"strings"
"testing"
"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb"
"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/encoding/prototext"
)
const testProtoDataPath = "./testdata/gpu_driver_versions_test.textproto"
func TestGetGPUDriverVersion(t *testing.T) {
textProtoData, err := os.ReadFile(testProtoDataPath)
if err != nil {
t.Errorf("Cannot read the test textproto data: %v", err)
}
var gpuDriverVersionInfoList = &pb.GPUDriverVersionInfoList{}
err = (prototext.UnmarshalOptions{DiscardUnknown: true, AllowPartial: true}).Unmarshal(textProtoData, gpuDriverVersionInfoList)
if err != nil {
t.Errorf("failed when parsing the GPU driver version: %v", err)
}
testData, err := proto.Marshal(gpuDriverVersionInfoList)
if err != nil {
t.Errorf("fail to encode to the binary proto data: %v", err)
}
var testCases = []struct {
gpuProtoContent []byte
gpuType string
input string
fallback bool
expectedDriverVersion string
err error
}{
{
nil,
"NVIDIA_L4",
"",
false,
"",
errors.New("the gpu proto content must not be empty"),
},
{
testData,
"",
"",
false,
"",
errors.New("the GPU type must not be empty"),
},
{
testData,
" nvidia_l4 ",
"",
false,
"535.154.05",
nil,
},
{
testData,
"NVIDIA_L4",
"560.38",
false,
"560.38",
nil,
},
{
testData,
"NVIDIA_L4",
"DEFAULT",
false,
"535.154.05",
nil,
},
{
testData,
"NVIDIA_L4",
"default",
false,
"535.154.05",
nil,
},
{
testData,
"NVIDIA_L4",
"latest",
false,
"535.154.05",
nil,
},
{
testData,
"NVIDIA_L4",
"R535",
false,
"535.154.05",
nil,
},
{
testData,
"NVIDIA_L4",
"r535",
false,
"535.154.05",
nil,
},
{
testData,
"NVIDIA_L4",
"535.129.03",
false,
"535.129.03",
nil,
},
{
testData,
"NVIDIA_L4",
"R470",
false,
"",
errors.New("not supported for GPU type"),
},
{
testData,
"NVIDIA_L4",
"R470",
true,
"535.154.05",
nil,
},
{
testData,
"NVIDIA_L4",
"R525",
true,
"535.154.05",
nil,
},
{
testData,
"invalidGPU",
"R525",
true,
"",
errors.New("no supported driver versions found for gpu"),
},
{
testData,
"NVIDIA_L4",
"invalidInput",
false,
"",
errors.New("the input is invalid"),
},
{
testData,
"NVIDIA_L4",
"invalidInput",
true,
"",
errors.New("the input is invalid"),
},
{
testData,
"no_gpu",
"r525",
true,
"535.161.08",
nil,
},
{
testData,
"no_gpu",
"r525",
true,
"535.161.08",
nil,
},
{
testData,
"Sample_GPU_without_default_label",
"r525",
true,
"",
errors.New("default label is not supported in current gpu type"),
},
}
for index, tc := range testCases {
t.Run(fmt.Sprintf("Test %d: GetGPUDriverVersion: with gpuType: %s, input: %s, fallback flag: %v", index, tc.gpuType, tc.input, tc.fallback), func(t *testing.T) {
driverVersion, err := GetGPUDriverVersion(tc.gpuProtoContent, tc.gpuType, tc.input, tc.fallback)
if tc.err == nil {
if err != nil {
t.Errorf("Failed to GetGPUDriverVersion with error: %v", err)
}
if !cmp.Equal(tc.expectedDriverVersion, driverVersion) {
t.Errorf("Test GetGPUDriverVersion failed: the expected driver version is: %s, but the actual driver version is: %s.", tc.expectedDriverVersion, driverVersion)
}
} else {
if err == nil || !strings.Contains(err.Error(), tc.err.Error()) {
t.Errorf("Test GetGPUDriverVersion failed: the error been sent out is not correct, the exected: %s, the actual: %s", tc.err.Error(), err.Error())
}
}
})
}
}
func TestGetGPUDriverVersions(t *testing.T) {
textProtoData, err := os.ReadFile(testProtoDataPath)
if err != nil {
t.Errorf("Cannot read the test textproto data: %v", err)
}
var gpuDriverVersionInfoList = &pb.GPUDriverVersionInfoList{}
err = (prototext.UnmarshalOptions{DiscardUnknown: true, AllowPartial: true}).Unmarshal(textProtoData, gpuDriverVersionInfoList)
if err != nil {
t.Errorf("failed when parsing the GPU driver version: %v", err)
}
testData, err := proto.Marshal(gpuDriverVersionInfoList)
if err != nil {
t.Errorf("fail to encode to the binary proto data: %v", err)
}
var testCases = []struct {
gpuProtoContent []byte
gpuType string
expectedDriverVersions []*pb.DriverVersion
err error
}{
{
nil,
"NVIDIA_L4",
nil,
errors.New("the gpu proto content must not be empty"),
},
{
testData,
"",
nil,
errors.New("the GPU type must not be empty"),
},
{
testData,
"NVIDIA_TESLA_V100",
[]*pb.DriverVersion{
{Version: "535.154.05", Label: "DEFAULT"},
{Version: "535.154.05", Label: "LATEST"},
{Version: "535.154.05", Label: "R535"},
{Version: "535.129.03"},
{Version: "535.104.12"},
{Version: "535.104.05"},
{Version: "470.223.02", Label: "R470"},
{Version: "470.199.02"},
},
nil,
},
{
testData,
"NVIDIA_TESLA_P100",
[]*pb.DriverVersion{
{Version: "535.154.05", Label: "DEFAULT"},
{Version: "535.154.05", Label: "LATEST"},
{Version: "535.154.05", Label: "R535"},
{Version: "535.129.03"},
{Version: "535.104.12"},
{Version: "535.104.05"},
{Version: "470.223.02", Label: "R470"},
{Version: "470.199.02"},
},
nil,
},
{
testData,
"InvalidGPU",
nil,
errors.New("no supported driver versions found for gpu"),
},
}
for index, tc := range testCases {
t.Run(fmt.Sprintf("Test %d: GetGPUDriverVersions: with gpuType: %s", index, tc.gpuType), func(t *testing.T) {
driverVersions, err := GetGPUDriverVersions(tc.gpuProtoContent, tc.gpuType)
if tc.err == nil {
if err != nil {
t.Errorf("Failed to GetGPUDriverVersions with error: %v", err)
}
if len(driverVersions) != len(tc.expectedDriverVersions) {
t.Errorf("Test GetGPUDriverVersions failed: the length of expected gpu drivers is %v, but the actual lenght is %v", len(tc.expectedDriverVersions), len(driverVersions))
}
for index, driverVersion := range driverVersions {
expectedDriverVersion := tc.expectedDriverVersions[index]
if expectedDriverVersion.Label != driverVersion.Label {
t.Errorf("Test GetGPUDriverVersions failed: the expected driver version list contains DriverVersion{Label: %s, Version: %s},"+
"but the actual driver version list contains {DriverVersion{Label: %s, Version: %s}}.", expectedDriverVersion.Label, expectedDriverVersion.Version,
driverVersion.Label, driverVersion.Version)
}
if expectedDriverVersion.Version != driverVersion.Version {
t.Errorf("Test GetGPUDriverVersions failed: the expected driver version list contains DriverVersion{Label: %s, Version: %s},"+
"but the actual driver version list contains {DriverVersion{Label: %s, Version: %s}}.", expectedDriverVersion.Label, expectedDriverVersion.Version,
driverVersion.Label, driverVersion.Version)
}
}
} else {
if err == nil || !strings.Contains(err.Error(), tc.err.Error()) {
t.Errorf("Test GetGPUDriverVersions failed: the error been sent out is not correct, the exected: %s, the actual: %s", tc.err.Error(), err.Error())
}
}
})
}
}