| // Package commands implements subcommands of cos_gpu_installer. |
| package commands |
| |
| import ( |
| "testing" |
| |
| "cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/deviceinfo" |
| "cos.googlesource.com/cos/tools.git/src/pkg/modules" |
| ) |
| |
| func TestSetCoherentGPUMemoryMode(t *testing.T) { |
| testCases := []struct { |
| name string |
| gpuType deviceinfo.GPUType |
| driverVersion string |
| initialParams modules.ModuleParameters |
| expectParamSet bool |
| expectedParamValue string |
| userParamSet bool |
| }{ |
| { |
| name: "GB200 with R580 driver should set default", |
| gpuType: deviceinfo.GB200, |
| driverVersion: "580.01.01", |
| initialParams: modules.NewModuleParameters(), |
| expectParamSet: true, |
| expectedParamValue: "NVreg_CoherentGPUMemoryMode=driver", |
| userParamSet: false, |
| }, |
| { |
| name: "GB200 with R585 driver should set default", |
| gpuType: deviceinfo.GB200, |
| driverVersion: "585.01.01", |
| initialParams: modules.NewModuleParameters(), |
| expectParamSet: true, |
| expectedParamValue: "NVreg_CoherentGPUMemoryMode=driver", |
| userParamSet: false, |
| }, |
| { |
| name: "GB200 with pre-R580 driver should not set default", |
| gpuType: deviceinfo.GB200, |
| driverVersion: "535.01.01", |
| initialParams: modules.NewModuleParameters(), |
| expectParamSet: false, |
| userParamSet: false, |
| }, |
| { |
| name: "GB200 with user-provided value should not override", |
| gpuType: deviceinfo.GB200, |
| driverVersion: "580.01.01", |
| initialParams: func() modules.ModuleParameters { |
| p := modules.NewModuleParameters() |
| p.Set("nvidia.NVreg_CoherentGPUMemoryMode=user_value") |
| return p |
| }(), |
| expectParamSet: false, |
| userParamSet: false, |
| }, |
| { |
| name: "Non-GB200 GPU with R580 driver should not set default", |
| gpuType: deviceinfo.H200, |
| driverVersion: "580.01.01", |
| initialParams: modules.NewModuleParameters(), |
| expectParamSet: false, |
| userParamSet: false, |
| }, |
| { |
| name: "GB300 with R580 driver should set default", |
| gpuType: deviceinfo.GB300, |
| driverVersion: "580.01.01", |
| initialParams: modules.NewModuleParameters(), |
| expectParamSet: true, |
| expectedParamValue: "NVreg_CoherentGPUMemoryMode=driver", |
| userParamSet: false, |
| }, |
| { |
| name: "GB300 with R585 driver should set default", |
| gpuType: deviceinfo.GB300, |
| driverVersion: "585.01.01", |
| initialParams: modules.NewModuleParameters(), |
| expectParamSet: true, |
| expectedParamValue: "NVreg_CoherentGPUMemoryMode=driver", |
| userParamSet: false, |
| }, |
| { |
| name: "GB300 with pre-R580 driver should not set default", |
| gpuType: deviceinfo.GB300, |
| driverVersion: "535.01.01", |
| initialParams: modules.NewModuleParameters(), |
| expectParamSet: false, |
| userParamSet: false, |
| }, |
| { |
| name: "GB300 with user-provided value should not override", |
| gpuType: deviceinfo.GB300, |
| driverVersion: "580.01.01", |
| initialParams: func() modules.ModuleParameters { |
| p := modules.NewModuleParameters() |
| p.Set("nvidia.NVreg_CoherentGPUMemoryMode=user_value") |
| return p |
| }(), |
| expectParamSet: false, |
| userParamSet: false, |
| }, |
| { |
| name: "Non-GB300 GPU with R580 driver should not set default", |
| gpuType: deviceinfo.H200, |
| driverVersion: "580.01.01", |
| initialParams: modules.NewModuleParameters(), |
| expectParamSet: false, |
| userParamSet: false, |
| }, |
| } |
| |
| for _, tc := range testCases { |
| t.Run(tc.name, func(t *testing.T) { |
| cmd := &InstallCommand{ |
| driverVersion: tc.driverVersion, |
| kernelModuleParams: tc.initialParams, |
| } |
| |
| if err := cmd.setGracePlatformDefaultParams(tc.gpuType); err != nil { |
| t.Fatalf("setGracePlatformDefaultParams failed: %v", err) |
| } |
| |
| paramWasSet := false |
| if params, ok := cmd.kernelModuleParams["nvidia"]; ok { |
| for _, param := range params { |
| if param == "NVreg_CoherentGPUMemoryMode=driver" { |
| paramWasSet = true |
| break |
| } |
| } |
| } |
| |
| if tc.expectParamSet != paramWasSet { |
| t.Errorf("Expected param set state to be %v, but got %v", tc.expectParamSet, paramWasSet) |
| } |
| |
| // Verify user-set value is preserved |
| if tc.userParamSet { |
| foundUserValue := false |
| if params, ok := cmd.kernelModuleParams["nvidia"]; ok { |
| for _, param := range params { |
| if param == "NVreg_CoherentGPUMemoryMode=user_value" { |
| foundUserValue = true |
| break |
| } |
| } |
| } |
| if !foundUserValue { |
| t.Errorf("User-provided kernel module parameter was not preserved.") |
| } |
| } |
| }) |
| } |
| } |