blob: 9d101cfbc3f3c91dd1592d2e9163ca5cca2afd8a [file] [log] [blame]
// Package commands implements subcommands of cos_gpu_installer.
package commands
import (
"testing"
"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/deviceinfo"
"cos.googlesource.com/cos/tools.git/src/pkg/modules"
)
func TestSetCoherentGPUMemoryMode(t *testing.T) {
testCases := []struct {
name string
gpuType deviceinfo.GPUType
driverVersion string
initialParams modules.ModuleParameters
expectParamSet bool
expectedParamValue string
userParamSet bool
}{
{
name: "GB200 with R580 driver should set default",
gpuType: deviceinfo.GB200,
driverVersion: "580.01.01",
initialParams: modules.NewModuleParameters(),
expectParamSet: true,
expectedParamValue: "NVreg_CoherentGPUMemoryMode=driver",
userParamSet: false,
},
{
name: "GB200 with R585 driver should set default",
gpuType: deviceinfo.GB200,
driverVersion: "585.01.01",
initialParams: modules.NewModuleParameters(),
expectParamSet: true,
expectedParamValue: "NVreg_CoherentGPUMemoryMode=driver",
userParamSet: false,
},
{
name: "GB200 with pre-R580 driver should not set default",
gpuType: deviceinfo.GB200,
driverVersion: "535.01.01",
initialParams: modules.NewModuleParameters(),
expectParamSet: false,
userParamSet: false,
},
{
name: "GB200 with user-provided value should not override",
gpuType: deviceinfo.GB200,
driverVersion: "580.01.01",
initialParams: func() modules.ModuleParameters {
p := modules.NewModuleParameters()
p.Set("nvidia.NVreg_CoherentGPUMemoryMode=user_value")
return p
}(),
expectParamSet: false,
userParamSet: false,
},
{
name: "Non-GB200 GPU with R580 driver should not set default",
gpuType: deviceinfo.H200,
driverVersion: "580.01.01",
initialParams: modules.NewModuleParameters(),
expectParamSet: false,
userParamSet: false,
},
{
name: "GB300 with R580 driver should set default",
gpuType: deviceinfo.GB300,
driverVersion: "580.01.01",
initialParams: modules.NewModuleParameters(),
expectParamSet: true,
expectedParamValue: "NVreg_CoherentGPUMemoryMode=driver",
userParamSet: false,
},
{
name: "GB300 with R585 driver should set default",
gpuType: deviceinfo.GB300,
driverVersion: "585.01.01",
initialParams: modules.NewModuleParameters(),
expectParamSet: true,
expectedParamValue: "NVreg_CoherentGPUMemoryMode=driver",
userParamSet: false,
},
{
name: "GB300 with pre-R580 driver should not set default",
gpuType: deviceinfo.GB300,
driverVersion: "535.01.01",
initialParams: modules.NewModuleParameters(),
expectParamSet: false,
userParamSet: false,
},
{
name: "GB300 with user-provided value should not override",
gpuType: deviceinfo.GB300,
driverVersion: "580.01.01",
initialParams: func() modules.ModuleParameters {
p := modules.NewModuleParameters()
p.Set("nvidia.NVreg_CoherentGPUMemoryMode=user_value")
return p
}(),
expectParamSet: false,
userParamSet: false,
},
{
name: "Non-GB300 GPU with R580 driver should not set default",
gpuType: deviceinfo.H200,
driverVersion: "580.01.01",
initialParams: modules.NewModuleParameters(),
expectParamSet: false,
userParamSet: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cmd := &InstallCommand{
driverVersion: tc.driverVersion,
kernelModuleParams: tc.initialParams,
}
if err := cmd.setGracePlatformDefaultParams(tc.gpuType); err != nil {
t.Fatalf("setGracePlatformDefaultParams failed: %v", err)
}
paramWasSet := false
if params, ok := cmd.kernelModuleParams["nvidia"]; ok {
for _, param := range params {
if param == "NVreg_CoherentGPUMemoryMode=driver" {
paramWasSet = true
break
}
}
}
if tc.expectParamSet != paramWasSet {
t.Errorf("Expected param set state to be %v, but got %v", tc.expectParamSet, paramWasSet)
}
// Verify user-set value is preserved
if tc.userParamSet {
foundUserValue := false
if params, ok := cmd.kernelModuleParams["nvidia"]; ok {
for _, param := range params {
if param == "NVreg_CoherentGPUMemoryMode=user_value" {
foundUserValue = true
break
}
}
}
if !foundUserValue {
t.Errorf("User-provided kernel module parameter was not preserved.")
}
}
})
}
}