| package installer |
| |
| import ( |
| "archive/tar" |
| "bytes" |
| "context" |
| "fmt" |
| "os" |
| "path" |
| "path/filepath" |
| "sort" |
| "testing" |
| |
| "cos.googlesource.com/cos/tools.git/src/pkg/cos" |
| "cos.googlesource.com/cos/tools.git/src/pkg/fakes" |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb" |
| "github.com/golang/protobuf/proto" |
| "github.com/google/go-cmp/cmp" |
| ) |
| |
| func TestDownloadGPUDriverVersionsProto(t *testing.T) { |
| fakeGCS := fakes.GCSForTest(t) |
| fakeBucket := "cos-tools" |
| fakePrefix := "10000.00.00/lakitu" |
| fakeGCSClient := fakeGCS.Client |
| ctx := context.Background() |
| var gpuDriverProtoBin = "gpu_driver_versions.bin" |
| var mockData = &pb.GPUDriverVersionInfoList{ |
| GpuDriverVersionInfo: []*pb.GPUDriverVersionInfo{ |
| { |
| GpuDevice: &pb.GPUDevice{ |
| GpuType: "NVIDIA_TESLA_V100", |
| }, |
| SupportedDriverVersions: []*pb.DriverVersion{ |
| { |
| Label: "DEFAULT", |
| Version: "535.154.05", |
| }, |
| { |
| Label: "LATEST", |
| Version: "535.154.05", |
| }, |
| { |
| Label: "R535", |
| Version: "535.154.05", |
| }, |
| { |
| Version: "535.129.03", |
| }, |
| { |
| Version: "535.104.12", |
| }, |
| { |
| Version: "535.104.05", |
| }, |
| { |
| Label: "R470", |
| Version: "470.223.02", |
| }, |
| { |
| Version: "470.199.02", |
| }, |
| }, |
| }, |
| }, |
| } |
| binaryMockData, err := proto.Marshal(mockData) |
| if err != nil { |
| t.Fatalf("Failed to marshal mockdata to binary array: %v", err) |
| } |
| fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, gpuDriverProtoBin)] = binaryMockData |
| tempDir, err := os.MkdirTemp("", "mockGpuInstallDir") |
| if err != nil { |
| t.Fatalf("Failed to create tempdir: %v", err) |
| } |
| defer os.RemoveAll(tempDir) |
| expectedFilePath := filepath.Join(tempDir, gpuDriverProtoBin) |
| var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "") |
| actualData, err := DownloadGPUDriverVersionsProto(ctx, FakeGCSDownloader, tempDir) |
| if err != nil { |
| t.Fatalf("DownloadGPUDriverVersionsProto returned an error: %v", err) |
| } |
| _, err = os.Stat(expectedFilePath) |
| if os.IsNotExist(err) { |
| t.Errorf("Expected file %s does not exist.", expectedFilePath) |
| } else if err != nil { |
| t.Fatalf("Error occurs when checking the existence of the file: %s", expectedFilePath) |
| } |
| |
| if !bytes.Equal(binaryMockData, actualData) { |
| t.Errorf("Data mismatch. Expected %v, got %v", binaryMockData, actualData) |
| } |
| |
| } |
| |
| func TestGetLoadedNVIDIAKernelModuleVersion(t *testing.T) { |
| tempDir, err := os.MkdirTemp("", "mockNVIDIAKernelModuleVersionDir") |
| if err != nil { |
| t.Fatalf("Failed to create tempdir: %v", err) |
| } |
| defer os.RemoveAll(tempDir) |
| |
| for index, tc := range []struct { |
| testName string |
| fileName string |
| contentStr string |
| expectOutput string |
| }{ |
| {"550DriverVersionFile", "550versionFile", "NVRM version: NVIDIA UNIX Open Kernel Module for x86_64 550.90.07 \n Release Build " + |
| "(builder@1f48d05e873d) Fri Jul 12 16:29:04 UTC 2024 GCC version: Selected multilib: .;@m64", "550.90.07"}, |
| {"535DriverVersionFile", "535versionFile", "NVRM version: NVIDIA UNIX x86_64 Kernel Module 535.129.03 Thu Oct 19 18:56:32 UTC 2023" + |
| "GCC version: Selected multilib: .;@m64", "535.129.03"}, |
| {"EmptyFile", "emptyVersionFile", "Test empty", ""}, |
| {"FileNotExist", "nonExistFile", "", ""}, |
| } { |
| t.Run(fmt.Sprintf("Test %v: %s", index, tc.testName), func(t *testing.T) { |
| testVersionFilePath := filepath.Join(tempDir, tc.fileName) |
| if tc.fileName != "nonExistFile" { |
| file, err := os.Create(testVersionFilePath) |
| if err != nil { |
| t.Fatalf("Error creating %s with error: %v", testVersionFilePath, err) |
| } |
| defer file.Close() |
| _, err = file.WriteString(tc.contentStr) |
| if err != nil { |
| t.Fatalf("Error writing to %s with error: %v", testVersionFilePath, err) |
| } |
| } |
| actualResult := GetLoadedNVIDIAKernelModuleVersion(testVersionFilePath) |
| if actualResult != tc.expectOutput { |
| t.Errorf("Unexpected return value, want %v, got %v", tc.expectOutput, actualResult) |
| } |
| }) |
| } |
| } |
| |
| func TestInstallImexDriver(t *testing.T) { |
| driverVersion := "570.124.06" |
| tarballPath := createFakeImexTarball(t, driverVersion) |
| |
| // Test destination dir |
| containerDir := t.TempDir() |
| hostDir := t.TempDir() |
| imexConfigDir := "imex-config" |
| |
| binDir := filepath.Join(containerDir, "bin") |
| if err := os.MkdirAll(binDir, 0755); err != nil { |
| t.Fatalf("failed to create bin directory: %v", err) |
| } |
| |
| err := InstallImexDriver(tarballPath, driverVersion, hostDir, containerDir) |
| if err != nil { |
| t.Fatalf("expected success but got error: %v", err) |
| } |
| |
| expectedFiles := map[string]string{ |
| fmt.Sprintf("%s/bin/nvidia-imex", containerDir): "test1", |
| fmt.Sprintf("%s/bin/nvidia-imex-ctl", containerDir): "test2", |
| fmt.Sprintf("%s/imex-config/config.cfg", containerDir): fmt.Sprintf("IMEX_NODE_CONFIG_FILE=%s", filepath.Join(hostDir, imexConfigDir, "nodes_config.cfg")), |
| } |
| for dstpath, expectedContent := range expectedFiles { |
| data, err := os.ReadFile(dstpath) |
| if err != nil { |
| t.Errorf("expected file %s not found: %v", dstpath, err) |
| continue |
| } |
| if string(data) != expectedContent { |
| t.Errorf("file content mismatch for %s: got %q, want %q", dstpath, string(data), expectedContent) |
| } |
| } |
| } |
| |
| func createFakeImexTarball(t *testing.T, version string) string { |
| tarDir := t.TempDir() |
| tarPath := filepath.Join(tarDir, "fake-imex.tar") |
| file, err := os.Create(tarPath) |
| if err != nil { |
| t.Fatalf("failed to create tar file: %v", err) |
| } |
| defer file.Close() |
| |
| tw := tar.NewWriter(file) |
| defer tw.Close() |
| |
| archiveRoot := fmt.Sprintf("nvidia-imex-linux-sbsa-%s-archive", version) |
| files := []struct { |
| path, content string |
| }{ |
| {fmt.Sprintf("%s/usr/bin/nvidia-imex", archiveRoot), "test1"}, |
| {fmt.Sprintf("%s/usr/bin/nvidia-imex-ctl", archiveRoot), "test2"}, |
| {fmt.Sprintf("%s/etc/nvidia-imex/config.cfg", archiveRoot), "IMEX_NODE_CONFIG_FILE=/etc/nvidia-imex/nodes_config.cfg"}, |
| } |
| for _, f := range files { |
| header := &tar.Header{ |
| Name: f.path, |
| Mode: 0755, |
| Size: int64(len(f.content)), |
| } |
| if err := tw.WriteHeader(header); err != nil { |
| t.Fatalf("failed to write tar header: %v", err) |
| } |
| if _, err := tw.Write([]byte(f.content)); err != nil { |
| t.Fatalf("failed to write file content to tar: %v", err) |
| } |
| } |
| return tarPath |
| |
| } |
| |
| func TestMergeModuleParams(t *testing.T) { |
| // These are the defaults from InstallGDRCopy |
| defaults := map[string]string{ |
| "dbg_enabled": "0", |
| "info_enabled": "0", |
| "use_persistent_mapping": "1", |
| } |
| |
| testCases := []struct { |
| name string |
| userParams []string |
| want []string |
| }{ |
| { |
| name: "No user params, all defaults applied", |
| userParams: []string{}, |
| want: []string{"dbg_enabled=0", "info_enabled=0", "use_persistent_mapping=1"}, |
| }, |
| { |
| name: "User overrides one default", |
| userParams: []string{"dbg_enabled=1"}, |
| want: []string{"dbg_enabled=1", "info_enabled=0", "use_persistent_mapping=1"}, |
| }, |
| { |
| name: "User adds a custom param", |
| userParams: []string{"foo=bar"}, |
| want: []string{"foo=bar", "dbg_enabled=0", "info_enabled=0", "use_persistent_mapping=1"}, |
| }, |
| { |
| name: "User overrides one and adds one", |
| userParams: []string{"use_persistent_mapping=0", "custom=true"}, |
| want: []string{"use_persistent_mapping=0", "custom=true", "dbg_enabled=0", "info_enabled=0"}, |
| }, |
| { |
| name: "User overrides all defaults", |
| userParams: []string{"dbg_enabled=1", "info_enabled=1", "use_persistent_mapping=0"}, |
| want: []string{"dbg_enabled=1", "info_enabled=1", "use_persistent_mapping=0"}, |
| }, |
| } |
| |
| for _, tc := range testCases { |
| t.Run(tc.name, func(t *testing.T) { |
| got := mergeModuleParams(tc.userParams, defaults) |
| |
| // Sort slices before comparing since parameter order does not matter. |
| sort.Strings(got) |
| sort.Strings(tc.want) |
| |
| if diff := cmp.Diff(tc.want, got); diff != "" { |
| t.Errorf("mergeModuleParams() mismatch (-want +got):\n%s", diff) |
| } |
| }) |
| } |
| } |