blob: a06bc7c430bc0e334302446a59f42eae64731a37 [file] [log] [blame]
package installer
import (
"archive/tar"
"bytes"
"context"
"fmt"
"os"
"path"
"path/filepath"
"sort"
"testing"
"cos.googlesource.com/cos/tools.git/src/pkg/cos"
"cos.googlesource.com/cos/tools.git/src/pkg/fakes"
"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb"
"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
)
func TestDownloadGPUDriverVersionsProto(t *testing.T) {
fakeGCS := fakes.GCSForTest(t)
fakeBucket := "cos-tools"
fakePrefix := "10000.00.00/lakitu"
fakeGCSClient := fakeGCS.Client
ctx := context.Background()
var gpuDriverProtoBin = "gpu_driver_versions.bin"
var mockData = &pb.GPUDriverVersionInfoList{
GpuDriverVersionInfo: []*pb.GPUDriverVersionInfo{
{
GpuDevice: &pb.GPUDevice{
GpuType: "NVIDIA_TESLA_V100",
},
SupportedDriverVersions: []*pb.DriverVersion{
{
Label: "DEFAULT",
Version: "535.154.05",
},
{
Label: "LATEST",
Version: "535.154.05",
},
{
Label: "R535",
Version: "535.154.05",
},
{
Version: "535.129.03",
},
{
Version: "535.104.12",
},
{
Version: "535.104.05",
},
{
Label: "R470",
Version: "470.223.02",
},
{
Version: "470.199.02",
},
},
},
},
}
binaryMockData, err := proto.Marshal(mockData)
if err != nil {
t.Fatalf("Failed to marshal mockdata to binary array: %v", err)
}
fakeGCS.Objects[path.Join("/", fakeBucket, fakePrefix, gpuDriverProtoBin)] = binaryMockData
tempDir, err := os.MkdirTemp("", "mockGpuInstallDir")
if err != nil {
t.Fatalf("Failed to create tempdir: %v", err)
}
defer os.RemoveAll(tempDir)
expectedFilePath := filepath.Join(tempDir, gpuDriverProtoBin)
var FakeGCSDownloader = cos.NewGCSDownloader(fakeGCSClient, nil, fakeBucket, fakePrefix, "", "")
actualData, err := DownloadGPUDriverVersionsProto(ctx, FakeGCSDownloader, tempDir)
if err != nil {
t.Fatalf("DownloadGPUDriverVersionsProto returned an error: %v", err)
}
_, err = os.Stat(expectedFilePath)
if os.IsNotExist(err) {
t.Errorf("Expected file %s does not exist.", expectedFilePath)
} else if err != nil {
t.Fatalf("Error occurs when checking the existence of the file: %s", expectedFilePath)
}
if !bytes.Equal(binaryMockData, actualData) {
t.Errorf("Data mismatch. Expected %v, got %v", binaryMockData, actualData)
}
}
func TestGetLoadedNVIDIAKernelModuleVersion(t *testing.T) {
tempDir, err := os.MkdirTemp("", "mockNVIDIAKernelModuleVersionDir")
if err != nil {
t.Fatalf("Failed to create tempdir: %v", err)
}
defer os.RemoveAll(tempDir)
for index, tc := range []struct {
testName string
fileName string
contentStr string
expectOutput string
}{
{"550DriverVersionFile", "550versionFile", "NVRM version: NVIDIA UNIX Open Kernel Module for x86_64 550.90.07 \n Release Build " +
"(builder@1f48d05e873d) Fri Jul 12 16:29:04 UTC 2024 GCC version: Selected multilib: .;@m64", "550.90.07"},
{"535DriverVersionFile", "535versionFile", "NVRM version: NVIDIA UNIX x86_64 Kernel Module 535.129.03 Thu Oct 19 18:56:32 UTC 2023" +
"GCC version: Selected multilib: .;@m64", "535.129.03"},
{"EmptyFile", "emptyVersionFile", "Test empty", ""},
{"FileNotExist", "nonExistFile", "", ""},
} {
t.Run(fmt.Sprintf("Test %v: %s", index, tc.testName), func(t *testing.T) {
testVersionFilePath := filepath.Join(tempDir, tc.fileName)
if tc.fileName != "nonExistFile" {
file, err := os.Create(testVersionFilePath)
if err != nil {
t.Fatalf("Error creating %s with error: %v", testVersionFilePath, err)
}
defer file.Close()
_, err = file.WriteString(tc.contentStr)
if err != nil {
t.Fatalf("Error writing to %s with error: %v", testVersionFilePath, err)
}
}
actualResult := GetLoadedNVIDIAKernelModuleVersion(testVersionFilePath)
if actualResult != tc.expectOutput {
t.Errorf("Unexpected return value, want %v, got %v", tc.expectOutput, actualResult)
}
})
}
}
func TestInstallImexDriver(t *testing.T) {
driverVersion := "570.124.06"
tarballPath := createFakeImexTarball(t, driverVersion)
// Test destination dir
containerDir := t.TempDir()
hostDir := t.TempDir()
imexConfigDir := "imex-config"
binDir := filepath.Join(containerDir, "bin")
if err := os.MkdirAll(binDir, 0755); err != nil {
t.Fatalf("failed to create bin directory: %v", err)
}
err := InstallImexDriver(tarballPath, driverVersion, hostDir, containerDir)
if err != nil {
t.Fatalf("expected success but got error: %v", err)
}
expectedFiles := map[string]string{
fmt.Sprintf("%s/bin/nvidia-imex", containerDir): "test1",
fmt.Sprintf("%s/bin/nvidia-imex-ctl", containerDir): "test2",
fmt.Sprintf("%s/imex-config/config.cfg", containerDir): fmt.Sprintf("IMEX_NODE_CONFIG_FILE=%s", filepath.Join(hostDir, imexConfigDir, "nodes_config.cfg")),
}
for dstpath, expectedContent := range expectedFiles {
data, err := os.ReadFile(dstpath)
if err != nil {
t.Errorf("expected file %s not found: %v", dstpath, err)
continue
}
if string(data) != expectedContent {
t.Errorf("file content mismatch for %s: got %q, want %q", dstpath, string(data), expectedContent)
}
}
}
func createFakeImexTarball(t *testing.T, version string) string {
tarDir := t.TempDir()
tarPath := filepath.Join(tarDir, "fake-imex.tar")
file, err := os.Create(tarPath)
if err != nil {
t.Fatalf("failed to create tar file: %v", err)
}
defer file.Close()
tw := tar.NewWriter(file)
defer tw.Close()
archiveRoot := fmt.Sprintf("nvidia-imex-linux-sbsa-%s-archive", version)
files := []struct {
path, content string
}{
{fmt.Sprintf("%s/usr/bin/nvidia-imex", archiveRoot), "test1"},
{fmt.Sprintf("%s/usr/bin/nvidia-imex-ctl", archiveRoot), "test2"},
{fmt.Sprintf("%s/etc/nvidia-imex/config.cfg", archiveRoot), "IMEX_NODE_CONFIG_FILE=/etc/nvidia-imex/nodes_config.cfg"},
}
for _, f := range files {
header := &tar.Header{
Name: f.path,
Mode: 0755,
Size: int64(len(f.content)),
}
if err := tw.WriteHeader(header); err != nil {
t.Fatalf("failed to write tar header: %v", err)
}
if _, err := tw.Write([]byte(f.content)); err != nil {
t.Fatalf("failed to write file content to tar: %v", err)
}
}
return tarPath
}
func TestMergeModuleParams(t *testing.T) {
// These are the defaults from InstallGDRCopy
defaults := map[string]string{
"dbg_enabled": "0",
"info_enabled": "0",
"use_persistent_mapping": "1",
}
testCases := []struct {
name string
userParams []string
want []string
}{
{
name: "No user params, all defaults applied",
userParams: []string{},
want: []string{"dbg_enabled=0", "info_enabled=0", "use_persistent_mapping=1"},
},
{
name: "User overrides one default",
userParams: []string{"dbg_enabled=1"},
want: []string{"dbg_enabled=1", "info_enabled=0", "use_persistent_mapping=1"},
},
{
name: "User adds a custom param",
userParams: []string{"foo=bar"},
want: []string{"foo=bar", "dbg_enabled=0", "info_enabled=0", "use_persistent_mapping=1"},
},
{
name: "User overrides one and adds one",
userParams: []string{"use_persistent_mapping=0", "custom=true"},
want: []string{"use_persistent_mapping=0", "custom=true", "dbg_enabled=0", "info_enabled=0"},
},
{
name: "User overrides all defaults",
userParams: []string{"dbg_enabled=1", "info_enabled=1", "use_persistent_mapping=0"},
want: []string{"dbg_enabled=1", "info_enabled=1", "use_persistent_mapping=0"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := mergeModuleParams(tc.userParams, defaults)
// Sort slices before comparing since parameter order does not matter.
sort.Strings(got)
sort.Strings(tc.want)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("mergeModuleParams() mismatch (-want +got):\n%s", diff)
}
})
}
}