Add support for GDRCopy
Tested log:
```
sudo COS_GPU_INSTALLER=us.gcr.io/cloud-kernel-build/cos-gpu-installer:latest cos-extensions install gpu -- -debug -test -gdr -gcs-download-prefix=lakitu-release-tryjob/R129-19275.0.0-90a6ad3f -gcs-download-bucket=cos-infra-prod-artifacts-presubmit
I0910 05:26:43.969611 1423 installer.go:934] Applying default module parameter: use_persistent_mapping=1
I0910 05:26:43.969635 1423 installer.go:934] Applying default module parameter: dbg_enabled=0
I0910 05:26:43.969642 1423 installer.go:934] Applying default module parameter: info_enabled=0
I0910 05:26:43.969647 1423 installer.go:939] Loading GDRCopy kernel module with dependencies.
I0910 05:26:43.973776 1423 modules.go:190] loading module: /usr/sbin/insmod /usr/local/nvidia/drivers/gdrdrv.ko use_persistent_mapping=1 dbg_enabled=0 info_enabled=0
I0910 05:26:44.024876 1423 installer.go:949] GDRCopy driver major is 241
I0910 05:26:44.024929 1423 installer.go:958] Creating device node /dev/gdrdrv
I0910 05:26:44.024955 1423 install.go:606] GDRCopy driver and device node created successfully.
```
BUG=b/428981220
TEST=Tested in a presubmit GPU VM with GDRCopy kernel module installed.
RELEASE_NOTE=None
Change-Id: Iab353c605ddf3d11643f883391eb30932a0ac911
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/110701
Tested-by: Chenglong Tang <chenglongtang@google.com>
Cloud-Build: GCB Service account <228075978874@cloudbuild.gserviceaccount.com>
Reviewed-by: Kevin Berry <kpberry@google.com>
Reviewed-by: Shuo Yang <gshuoy@google.com>
diff --git a/src/cmd/cos_gpu_installer/internal/commands/install.go b/src/cmd/cos_gpu_installer/internal/commands/install.go
index eb532ee..1ed27c4 100644
--- a/src/cmd/cos_gpu_installer/internal/commands/install.go
+++ b/src/cmd/cos_gpu_installer/internal/commands/install.go
@@ -104,6 +104,7 @@
hostInstallDir string
forceFallback FallBackFlag
unsignedDriver bool
+ gdrCopy bool
gcsDownloadBucket string
gcsDownloadPrefix string
nvidiaInstallerURL string
@@ -184,11 +185,12 @@
f.BoolVar(&c.noVerify, "no-verify", false, "Skip kernel module loading and installation verification. Useful for preloading drivers without attached GPU.")
f.BoolVar(&c.skipNvidiaSmi, "skip-nvidia-smi", false, "This flag disables the execution of nvidia-smi verification.")
c.kernelModuleParams = modules.NewModuleParameters()
- f.Var(&c.kernelModuleParams, "module-arg", "Kernel module parameters can be specified using this flag. These parameters are used while loading the specific kernel mode drivers into the kernel. Usage: -module-arg <module-x>.<parameter-y>=<value> -module-arg <module-y>.<parameter-z>=<value> .. For eg: –module-arg nvidia_uvm.uvm_debug_prints=1 –module-arg nvidia.NVreg_EnableGpuFirmware=0.")
+ f.Var(&c.kernelModuleParams, "module-arg", "Kernel module parameters can be specified using this flag. These parameters are used while loading the specific kernel mode drivers into the kernel. Usage: -module-arg <module-x>.<parameter-y>=<value> -module-arg <module-y>.<parameter-z>=<value> .. For eg: –module-arg nvidia_uvm.uvm_debug_prints=1 –module-arg nvidia.NVreg_EnableGpuFirmware=0.")
f.Var(&c.forceFallback, "force-fallback", "This flag specify whether to use fallback mechanism when specified GPU driver is not compatible with GPU devices.\n"+
"If unspecified, it is `false` for --version=R<major-version> eg. 'R470', 'R525' or --version=<precise-version> eg. '535.129.03', '525.147.05', it is `true` for version is not specified or --version=default or --version=latest.\n"+
"When fallback behavior is active, the installer will find a compatible driver to install for the detected GPU on the VM.")
f.StringVar(&c.localArtifactsDir, "local-artifacts-dir", "", "Local directory where NVIDIA driver artifacts are stored. If set, artifacts will be copied from this directory instead of downloaded from GCS.")
+ f.BoolVar(&c.gdrCopy, "gdr", false, "Install GDRCopy driver.")
}
func (c *InstallCommand) validateFlags() error {
@@ -353,6 +355,15 @@
c.logError(errors.Wrap(err, "failed to verify GPU driver installation"))
return subcommands.ExitFailure
}
+
+ if c.gdrCopy {
+ if err := installer.InstallGDRCopy(c.noVerify, c.kernelModuleParams); err != nil {
+ c.logError(err)
+ return subcommands.ExitFailure
+ }
+ log.V(1).Info("GDRCopy driver and device node created successfully.")
+ }
+
if err := modules.UpdateHostLdCache(hostRootPath, filepath.Join(c.hostInstallDir, "lib64")); err != nil {
c.logError(errors.Wrap(err, "failed to update host ld cache"))
return subcommands.ExitFailure
@@ -532,6 +543,14 @@
if err := installer.VerifyDriverInstallation(c.noVerify, c.debug, c.skipNvidiaSmi); err != nil {
return errors.Wrap(err, "failed to verify installation")
}
+
+ if c.gdrCopy {
+ if err := installer.InstallGDRCopy(c.noVerify, c.kernelModuleParams); err != nil {
+ return err
+ }
+ log.V(1).Info("GDRCopy driver and device node created successfully.")
+ }
+
if err := modules.UpdateHostLdCache(hostRootPath, filepath.Join(c.hostInstallDir, "lib64")); err != nil {
return errors.Wrap(err, "failed to update host ld cache")
}
@@ -580,6 +599,12 @@
if err := installer.VerifyDriverInstallation(c.noVerify, c.debug, c.skipNvidiaSmi); err != nil {
return errors.Wrap(err, "failed to verify installation")
}
+ if c.gdrCopy {
+ if err := installer.InstallGDRCopy(c.noVerify, c.kernelModuleParams); err != nil {
+ return err
+ }
+ log.V(1).Info("GDRCopy driver and device node created successfully.")
+ }
if err := modules.UpdateHostLdCache(hostRootPath, filepath.Join(c.hostInstallDir, "lib64")); err != nil {
return errors.Wrap(err, "failed to update host ld cache")
}
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer.go b/src/cmd/cos_gpu_installer/internal/installer/installer.go
index 43e97ee..58f55e8 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer.go
@@ -13,6 +13,7 @@
"path/filepath"
"regexp"
"sort"
+ "strconv"
"strings"
"syscall"
@@ -42,6 +43,9 @@
LatestVersion = "latest"
MajorGPUDriverArtifactPrefix = "gpu_"
MajorGPUDriverArtifactSuffix = "_version"
+ gdrdrvDevicePath = "/dev/gdrdrv"
+ gdrdrvModuleName = "gdrdrv"
+ procDevicesPath = "/proc/devices"
)
var (
@@ -861,3 +865,114 @@
return nil
}
+
+// isDeviceRegistered checks if a device is registered in /proc/devices
+// and returns true and its major number if it is.
+func isDeviceRegistered(deviceName string) (bool, int) {
+ content, err := ioutil.ReadFile(procDevicesPath)
+ if err != nil {
+ log.Errorf("Failed to read %s: %v", procDevicesPath, err)
+ return false, 0
+ }
+
+ lines := strings.Split(string(content), "\n")
+ for _, line := range lines {
+ fields := strings.Fields(line)
+ if len(fields) == 2 && fields[1] == deviceName {
+ major, err := strconv.Atoi(fields[0])
+ if err != nil {
+ log.Errorf("Failed to parse major number for %s: %v", deviceName, err)
+ return false, 0
+ }
+ return true, major
+ }
+ }
+ return false, 0
+}
+
+// mergeModuleParams takes a list of user-provided parameters (as "key=val" strings)
+// and a map of default parameters. It returns a final slice of parameters, ensuring
+// that any key provided by the user is not overridden by a default.
+func mergeModuleParams(userParams []string, defaults map[string]string) []string {
+ userSetKeys := make(map[string]bool)
+ finalParamsList := []string{}
+
+ // 1. Add all user params first and record which keys they set.
+ for _, userParam := range userParams {
+ finalParamsList = append(finalParamsList, userParam)
+ if key, _, found := strings.Cut(userParam, "="); found {
+ userSetKeys[key] = true
+ }
+ }
+
+ // 2. Add defaults ONLY if the key wasn't already set by the user.
+ for key, value := range defaults {
+ if !userSetKeys[key] {
+ paramString := fmt.Sprintf("%s=%s", key, value)
+ finalParamsList = append(finalParamsList, paramString)
+ }
+ }
+ return finalParamsList
+}
+
+// InstallGDRCopy loads the GDRCopy kernel module and creates its device node.
+// This should be run after the main NVIDIA kernel modules are loaded.
+// It follows https://github.com/NVIDIA/gdrcopy/blob/master/insmod.sh.
+func InstallGDRCopy(noVerify bool, moduleParams modules.ModuleParameters) error {
+ if noVerify {
+ log.Info("Flag --no-verify is set, skipping GDRCopy installation.")
+ return nil
+ }
+
+ kernelModulePath := filepath.Join(gpuInstallDirContainer, "drivers")
+
+ // 1. Define the gdrdrv module.
+ gdrModule := &modules.Module{
+ Name: gdrdrvModuleName,
+ Path: filepath.Join(kernelModulePath, "gdrdrv.ko"),
+ }
+
+ // Set default module parameters if the user did not provide them.
+ // Flags are defined here: https://github.com/NVIDIA/gdrcopy/blob/master/insmod.sh#L28.
+ defaults := map[string]string{
+ "dbg_enabled": "0",
+ "info_enabled": "0",
+ "use_persistent_mapping": "1",
+ }
+
+ // Call our tested helper function to get the final parameter list.
+ userGDRParams := moduleParams[gdrdrvModuleName]
+ finalGDRParams := mergeModuleParams(userGDRParams, defaults)
+
+ // Assign the merged list back to the global map to be passed to LoadModule.
+ moduleParams[gdrdrvModuleName] = finalGDRParams
+ log.V(1).Infof("Applying final parameters for %s: %v", gdrdrvModuleName, finalGDRParams)
+
+ // 2. Load the module.
+ log.V(1).Info("Loading GDRCopy kernel module with dependencies.")
+ if err := modules.LoadModule(gdrModule, moduleParams); err != nil {
+ return errors.Wrap(err, "failed to load gdrdrv kernel module")
+ }
+
+ // 3. Create the device node
+ isLoaded, major := isDeviceRegistered(gdrdrvModuleName)
+ if !isLoaded {
+ return stderrors.New("gdrdrv module loaded but device not found in /proc/devices")
+ }
+ log.Infof("GDRCopy driver major is %d", major)
+
+ if _, err := os.Stat(gdrdrvDevicePath); err == nil {
+ log.Infof("Removing old inode %s", gdrdrvDevicePath)
+ if err := os.Remove(gdrdrvDevicePath); err != nil {
+ return errors.Wrapf(err, "failed to remove existing device node %s", gdrdrvDevicePath)
+ }
+ }
+
+ log.Infof("Creating device node %s", gdrdrvDevicePath)
+ dev := unix.Mkdev(uint32(major), 0)
+ if err := unix.Mknod(gdrdrvDevicePath, unix.S_IFCHR|0666, int(dev)); err != nil {
+ return errors.Wrapf(err, "failed to create device node for %s", gdrdrvDevicePath)
+ }
+
+ return nil
+}
diff --git a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
index 4c015ce..a06bc7c 100644
--- a/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
+++ b/src/cmd/cos_gpu_installer/internal/installer/installer_test.go
@@ -8,12 +8,14 @@
"os"
"path"
"path/filepath"
+ "sort"
"testing"
"cos.googlesource.com/cos/tools.git/src/pkg/cos"
"cos.googlesource.com/cos/tools.git/src/pkg/fakes"
"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb"
"github.com/golang/protobuf/proto"
+ "github.com/google/go-cmp/cmp"
)
func TestDownloadGPUDriverVersionsProto(t *testing.T) {
@@ -204,3 +206,58 @@
return tarPath
}
+
+func TestMergeModuleParams(t *testing.T) {
+ // These are the defaults from InstallGDRCopy
+ defaults := map[string]string{
+ "dbg_enabled": "0",
+ "info_enabled": "0",
+ "use_persistent_mapping": "1",
+ }
+
+ testCases := []struct {
+ name string
+ userParams []string
+ want []string
+ }{
+ {
+ name: "No user params, all defaults applied",
+ userParams: []string{},
+ want: []string{"dbg_enabled=0", "info_enabled=0", "use_persistent_mapping=1"},
+ },
+ {
+ name: "User overrides one default",
+ userParams: []string{"dbg_enabled=1"},
+ want: []string{"dbg_enabled=1", "info_enabled=0", "use_persistent_mapping=1"},
+ },
+ {
+ name: "User adds a custom param",
+ userParams: []string{"foo=bar"},
+ want: []string{"foo=bar", "dbg_enabled=0", "info_enabled=0", "use_persistent_mapping=1"},
+ },
+ {
+ name: "User overrides one and adds one",
+ userParams: []string{"use_persistent_mapping=0", "custom=true"},
+ want: []string{"use_persistent_mapping=0", "custom=true", "dbg_enabled=0", "info_enabled=0"},
+ },
+ {
+ name: "User overrides all defaults",
+ userParams: []string{"dbg_enabled=1", "info_enabled=1", "use_persistent_mapping=0"},
+ want: []string{"dbg_enabled=1", "info_enabled=1", "use_persistent_mapping=0"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ got := mergeModuleParams(tc.userParams, defaults)
+
+ // Sort slices before comparing since parameter order does not matter.
+ sort.Strings(got)
+ sort.Strings(tc.want)
+
+ if diff := cmp.Diff(tc.want, got); diff != "" {
+ t.Errorf("mergeModuleParams() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}