| // Copyright 2025 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package preloader |
| |
| import ( |
| "context" |
| "crypto/md5" |
| "fmt" |
| "io" |
| "io/fs" |
| "log" |
| "net/http" |
| "os" |
| "os/exec" |
| "path/filepath" |
| |
| "cos.googlesource.com/cos/tools.git/src/pkg/config" |
| "cos.googlesource.com/cos/tools.git/src/pkg/cos" |
| cosfs "cos.googlesource.com/cos/tools.git/src/pkg/fs" |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpubuild" |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig" |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb" |
| "cos.googlesource.com/cos/tools.git/src/pkg/utils" |
| |
| "cloud.google.com/go/storage" |
| "google.golang.org/protobuf/proto" |
| ) |
| |
| func extractTarball(srcPath, destDir string) error { |
| cmd := exec.Command("tar", "xf", srcPath, "-C", destDir) |
| cmd.Stdout = os.Stdout |
| cmd.Stderr = os.Stderr |
| return cmd.Run() |
| } |
| |
| type envReaderCreator func(context.Context) (*cos.EnvReader, error) |
| |
| // cachedEnvReader returns a closure that creates an EnvReader from a VMImage. |
| // Since creating an EnvReader is very expensive in this context (it boots a VM |
| // from the image to read data from it), we use a closure to cache the resulting |
| // envReader. |
| // |
| // We could create the EnvReader once and pass it around everywhere. However, |
| // under specific conditions, the EnvReader won't need to be created at all, |
| // so this closure allows us to lazily create the EnvReader when it is needed. |
| func cachedEnvReader(input *config.Image, gcs *gcsManager, files *cosfs.Files, buildSpec *config.Build) envReaderCreator { |
| var envReader *cos.EnvReader |
| return func(ctx context.Context) (*cos.EnvReader, error) { |
| var err error |
| if envReader != nil { |
| return envReader, nil |
| } |
| envReader, err = cos.NewEnvReaderFromVMImage(input.Name, input.Project, cos.EnvReaderVMImageConfig{ |
| DaisyBin: files.DaisyBin, |
| DaisyWorkflow: "/data/read_cos_env.wf.json", |
| Project: buildSpec.Project, |
| Zone: buildSpec.Zone, |
| GCSDir: gcs.managedDirURL(), |
| DiskType: buildSpec.DiskType, |
| MachineType: buildSpec.MachineType, |
| Network: buildSpec.Network, |
| SubNet: buildSpec.Subnet, |
| ServiceAccount: buildSpec.ServiceAccount, |
| }) |
| if err != nil { |
| return nil, fmt.Errorf("failed to read env from VM image: %v", err) |
| } |
| return envReader, nil |
| } |
| } |
| |
| func installToolchain(ctx context.Context, envReaderFunc envReaderCreator, gcsClient *storage.Client, buildSpec *config.Build, toolchainDir string) error { |
| if buildSpec.Toolchain != nil { |
| log.Println("Extracting toolchain package") |
| if err := extractTarball(buildSpec.Toolchain.ToolchainTarPath, toolchainDir); err != nil { |
| return err |
| } |
| if err := extractTarball(buildSpec.Toolchain.KernelHeadersPath, toolchainDir); err != nil { |
| return err |
| } |
| return nil |
| } |
| log.Println("No toolchain package provided; getting toolchain from input image") |
| envReader, err := envReaderFunc(ctx) |
| if err != nil { |
| return err |
| } |
| downloader := cos.NewGCSDownloader(gcsClient, envReader, "", "", "", "") |
| if err := cos.InstallCrossToolchain(ctx, downloader, toolchainDir); err != nil { |
| return err |
| } |
| // InstallCrossToolchain has an undesirable side effect of setting the SYSROOT |
| // environment variable. It's unclear if removing this behavior from |
| // InstallCrossToolchain is safe for all usages. |
| if err := os.Unsetenv("SYSROOT"); err != nil { |
| log.Printf("WARNING: unable to unset SYSROOT: %v", err) |
| } |
| return nil |
| } |
| |
| func getArchFromToolchain(toolchainDir string) (string, error) { |
| x86Matches, err := filepath.Glob(filepath.Join(toolchainDir, "bin/x86_64-cros-linux-gnu-*")) |
| if err != nil { |
| return "", err |
| } |
| if len(x86Matches) != 0 { |
| return "x86_64", nil |
| } |
| aarch64Matches, err := filepath.Glob(filepath.Join(toolchainDir, "bin/aarch64-cros-linux-gnu-*")) |
| if err != nil { |
| return "", err |
| } |
| if len(aarch64Matches) != 0 { |
| return "aarch64", nil |
| } |
| return "", fmt.Errorf("could not identify arch from toolchain at %q; are programs missing from the toolchain?") |
| } |
| |
| func downloadRunfile(version, arch, md5Sum, outDir string) (_ string, err error) { |
| outName := fmt.Sprintf("NVIDIA-Linux-%s-%s.run", arch, version) |
| url := fmt.Sprintf("https://us.download.nvidia.com/tesla/%s/%s", version, outName) |
| log.Printf("Downloading runfile from %q", url) |
| outPath := filepath.Join(outDir, outName) |
| outFile, err := os.Create(outPath) |
| if err != nil { |
| return "", err |
| } |
| defer utils.CheckClose(outFile, "error closing downloaded runfile", &err) |
| resp, err := http.Get(url) |
| if err != nil { |
| return "", err |
| } |
| defer utils.CheckClose(resp.Body, "error closing runfile HTTP body", &err) |
| if resp.StatusCode != http.StatusOK { |
| return "", fmt.Errorf("unhealthy status downloading Nvidia runfile: %v", resp.Status) |
| } |
| if md5Sum != "" { |
| h := md5.New() |
| mw := io.MultiWriter(outFile, h) |
| if _, err := io.Copy(mw, resp.Body); err != nil { |
| return "", err |
| } |
| got := fmt.Sprintf("%x", h.Sum(nil)) |
| if got != md5Sum { |
| return "", fmt.Errorf("md5sum mismatch when downloading runfile from %q; got %q, want %q", url, got, md5Sum) |
| } |
| } else { |
| if _, err := io.Copy(outFile, resp.Body); err != nil { |
| return "", err |
| } |
| } |
| return outPath, nil |
| } |
| |
| func listModules(dir string) ([]string, error) { |
| var result []string |
| if err := filepath.WalkDir(dir, func(path string, _ fs.DirEntry, err error) error { |
| if err != nil { |
| return err |
| } |
| if filepath.Ext(path) == ".ko" { |
| result = append(result, filepath.Base(path)) |
| } |
| return nil |
| }); err != nil { |
| return nil, err |
| } |
| return result, nil |
| } |
| |
| func setDiff(a, b []string) []string { |
| aMap := map[string]bool{} |
| for _, e := range a { |
| aMap[e] = true |
| } |
| for _, e := range b { |
| if _, ok := aMap[e]; ok { |
| delete(aMap, e) |
| } |
| } |
| var result []string |
| for k, _ := range aMap { |
| result = append(result, k) |
| } |
| return result |
| } |
| |
| func mergeDriverPackages(generatedDriver *gpubuild.DriverPackage, preBuiltPath, outDir string) (*gpubuild.DriverPackage, error) { |
| log.Println("Merging any missing precompiled modules from precompiled package into custom package") |
| nvidiaGenerated, err := os.MkdirTemp("", "nvidia-generated") |
| if err != nil { |
| return nil, err |
| } |
| defer os.RemoveAll(nvidiaGenerated) |
| if err := extractTarball(generatedDriver.Path, nvidiaGenerated); err != nil { |
| return nil, err |
| } |
| nvidiaPreBuilt, err := os.MkdirTemp("", "prebuilt-generated") |
| if err != nil { |
| return nil, err |
| } |
| defer os.RemoveAll(nvidiaPreBuilt) |
| if err := extractTarball(preBuiltPath, nvidiaPreBuilt); err != nil { |
| return nil, err |
| } |
| generatedModules, err := listModules(nvidiaGenerated) |
| if err != nil { |
| return nil, err |
| } |
| log.Printf("Generated package includes: %v", generatedModules) |
| preBuiltModules, err := listModules(nvidiaPreBuilt) |
| if err != nil { |
| return nil, err |
| } |
| log.Printf("Prebuilt package includes: %v", preBuiltModules) |
| notGeneratedModules := setDiff(preBuiltModules, generatedModules) |
| log.Printf("Adding the following to generated package: %v", notGeneratedModules) |
| for _, module := range notGeneratedModules { |
| src := filepath.Join(nvidiaPreBuilt, "drivers", module) |
| dst := filepath.Join(nvidiaGenerated, "drivers", module) |
| if err := utils.CopyFile(src, dst); err != nil { |
| return nil, err |
| } |
| } |
| result := &gpubuild.DriverPackage{ |
| Path: filepath.Join(outDir, fmt.Sprintf("nvidia-drivers-%s.tgz", generatedDriver.Version)), |
| Version: generatedDriver.Version, |
| } |
| if err := cosfs.TarDir(nvidiaGenerated, result.Path); err != nil { |
| return nil, err |
| } |
| return result, nil |
| } |
| |
| // installNonRunfileModules searches for kernel modules in any existing |
| // precompiled packages that are not present in the runfile we just compiled. If |
| // it finds anything, it adds it to the generated package we just compiled. This |
| // helps with supporting gdrcopy. |
| func installNonRunfileModules(ctx context.Context, envReaderFunc envReaderCreator, gcsClient *storage.Client, gpuConfig *config.GPUConfig, generatedDriver *gpubuild.DriverPackage, outDir string) (*gpubuild.DriverPackage, error) { |
| if gpuConfig.PrePackagedDriversPath != "" { |
| return mergeDriverPackages(generatedDriver, gpuConfig.PrePackagedDriversPath, outDir) |
| } |
| log.Println("No precompiled driver package provided, looking for an existing driver package for the given VM image") |
| envReader, err := envReaderFunc(ctx) |
| if err != nil { |
| return nil, err |
| } |
| downloader := cos.NewGCSDownloader(gcsClient, envReader, "", "", "", "") |
| precompiled := fmt.Sprintf("nvidia-drivers-%s.tgz", generatedDriver.Version) |
| if exists, err := downloader.ArtifactExists(ctx, precompiled); err != nil || !exists { |
| if err != nil { |
| log.Printf("WARNING: error searching for non-runfile modules for version %q: %v", generatedDriver.Version, err) |
| } |
| log.Printf("WARNING: could not find non-runfile modules for version %q, will not install them", generatedDriver.Version) |
| return generatedDriver, nil |
| } |
| tmpDir, err := os.MkdirTemp("", "non-runfile-modules") |
| if err != nil { |
| return nil, err |
| } |
| defer os.RemoveAll(tmpDir) |
| if err := downloader.DownloadArtifact(ctx, tmpDir, precompiled); err != nil { |
| return nil, err |
| } |
| precompiledPath := filepath.Join(tmpDir, precompiled) |
| return mergeDriverPackages(generatedDriver, precompiledPath, outDir) |
| } |
| |
| func downloadImexDriver(ctx context.Context, envReaderFunc envReaderCreator, gcsClient *storage.Client, generatedDriver *gpubuild.DriverPackage, arch, outDir string) (string, error) { |
| if arch != "aarch64" { |
| log.Printf("Arch is %s, not searching for IMEX drivers", arch) |
| return "", nil |
| } |
| log.Println("Searching for IMEX driver to install") |
| envReader, err := envReaderFunc(ctx) |
| if err != nil { |
| return "", err |
| } |
| downloader := cos.NewGCSDownloader(gcsClient, envReader, "", "", "", "") |
| result, err := downloader.DownloadImexDriver(ctx, outDir, generatedDriver.Version) |
| if err != nil { |
| log.Printf("WARNING: could not download IMEX driver for %s", generatedDriver.Version) |
| return "", nil |
| } |
| return result, nil |
| } |
| |
| func updateVersionInfo(versionsProto []byte, driverVersion, outDir string) (string, error) { |
| driverVersionInfo, err := gpuconfig.ParseGPUDriverVersionInfoList(versionsProto) |
| if err != nil { |
| return "", err |
| } |
| for _, gpuVersionInfo := range driverVersionInfo.GetGpuDriverVersionInfo() { |
| found := false |
| for _, v := range gpuVersionInfo.SupportedDriverVersions { |
| if v.Version == driverVersion { |
| found = true |
| break |
| } |
| } |
| if !found { |
| gpuVersionInfo.SupportedDriverVersions = append(gpuVersionInfo.SupportedDriverVersions, &pb.DriverVersion{Version: driverVersion}) |
| } |
| } |
| resultBytes, err := proto.Marshal(driverVersionInfo) |
| if err != nil { |
| return "", err |
| } |
| result := filepath.Join(outDir, "gpu_driver_versions.bin") |
| if err := os.WriteFile(result, resultBytes, 0644); err != nil { |
| return "", err |
| } |
| return result, nil |
| } |
| |
| // installDriverVersionsProto adds the version we just compiled to the |
| // gpu_driver_versions.bin associated with the input image. For each GPU type, |
| // add the version to it, if the version is not present at all for that GPU |
| // type. |
| // |
| // Some older COS images do not have a gpu_driver_versions.bin, and those |
| // versions do not require one. Do not install anything in this case and return |
| // an empty string. |
| func installDriverVersionsProto(ctx context.Context, envReaderFunc envReaderCreator, gcsClient *storage.Client, gpuConfig *config.GPUConfig, driverVersion, outDir string) (string, error) { |
| if gpuConfig.VersionsProtoPath != "" { |
| log.Println("Using provided gpu_driver_versions.bin") |
| versionsProto, err := os.ReadFile(gpuConfig.VersionsProtoPath) |
| if err != nil { |
| return "", err |
| } |
| return updateVersionInfo(versionsProto, driverVersion, outDir) |
| } |
| log.Println("No provided gpu_driver_versions.bin, looking for the one associated with the input image") |
| envReader, err := envReaderFunc(ctx) |
| if err != nil { |
| return "", err |
| } |
| downloader := cos.NewGCSDownloader(gcsClient, envReader, "", "", "", "") |
| if exists, err := downloader.ArtifactExists(ctx, "gpu_driver_versions.bin"); err != nil || !exists { |
| if err != nil { |
| log.Printf("WARNING: error searching for gpu_driver_versions.bin for input COS image: %v", err) |
| } |
| log.Println("WARNING: could not find gpu_driver_versions.bin for input image, will not provide one to cos-gpu-installer") |
| return "", nil |
| } |
| tmpDir, err := os.MkdirTemp("", "image-driver-versions") |
| if err != nil { |
| return "", err |
| } |
| defer os.RemoveAll(tmpDir) |
| if err := downloader.DownloadArtifact(ctx, tmpDir, "gpu_driver_versions.bin"); err != nil { |
| return "", err |
| } |
| versionsProto, err := os.ReadFile(filepath.Join(tmpDir, "gpu_driver_versions.bin")) |
| if err != nil { |
| return "", err |
| } |
| return updateVersionInfo(versionsProto, driverVersion, outDir) |
| } |
| |
| func buildGPUPackage(ctx context.Context, input *config.Image, gcs *gcsManager, files *cosfs.Files, buildSpec *config.Build, workDir string) error { |
| log.Println("Building GPU package") |
| envReaderFunc := cachedEnvReader(input, gcs, files, buildSpec) |
| toolchainDir := filepath.Join(workDir, "toolchain") |
| if err := os.Mkdir(toolchainDir, 0755); err != nil { |
| return err |
| } |
| if err := installToolchain(ctx, envReaderFunc, gcs.gcsClient, buildSpec, toolchainDir); err != nil { |
| return fmt.Errorf("unable to install toolchain: %v", err) |
| } |
| arch, err := getArchFromToolchain(toolchainDir) |
| if err != nil { |
| return err |
| } |
| var runfile string |
| if filepath.Ext(buildSpec.GPUConfig.Version) == ".run" { |
| runfile = buildSpec.GPUConfig.Version |
| } else { |
| runfile, err = downloadRunfile(buildSpec.GPUConfig.Version, arch, buildSpec.GPUConfig.MD5Sum, workDir) |
| if err != nil { |
| return err |
| } |
| } |
| if err := os.Chmod(runfile, 0755); err != nil { |
| return err |
| } |
| driverPackage, err := gpubuild.Compile(runfile, toolchainDir, workDir) |
| if err != nil { |
| return err |
| } |
| mergedDir := filepath.Join(workDir, "merged") |
| if err := os.Mkdir(mergedDir, 0755); err != nil { |
| return err |
| } |
| mergedPackage, err := installNonRunfileModules(ctx, envReaderFunc, gcs.gcsClient, buildSpec.GPUConfig, driverPackage, mergedDir) |
| if err != nil { |
| return err |
| } |
| imexPath, err := downloadImexDriver(ctx, envReaderFunc, gcs.gcsClient, driverPackage, arch, mergedDir) |
| if err != nil { |
| return err |
| } |
| generatedProtoPath, err := installDriverVersionsProto(ctx, envReaderFunc, gcs.gcsClient, buildSpec.GPUConfig, driverPackage.Version, mergedDir) |
| if err != nil { |
| return err |
| } |
| buildSpec.GCSFiles = append(buildSpec.GCSFiles, runfile, mergedPackage.Path) |
| if imexPath != "" { |
| buildSpec.GCSFiles = append(buildSpec.GCSFiles, imexPath) |
| } |
| if generatedProtoPath != "" { |
| buildSpec.GCSFiles = append(buildSpec.GCSFiles, generatedProtoPath) |
| } |
| buildSpec.GPUConfig.Version = driverPackage.Version |
| return nil |
| } |