blob: dd7111119ba04bd851b13ef92d2006ad219edadb [file] [log] [blame]
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package preloader
import (
"context"
"crypto/md5"
"fmt"
"io"
"io/fs"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"cos.googlesource.com/cos/tools.git/src/pkg/config"
"cos.googlesource.com/cos/tools.git/src/pkg/cos"
cosfs "cos.googlesource.com/cos/tools.git/src/pkg/fs"
"cos.googlesource.com/cos/tools.git/src/pkg/gpubuild"
"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig"
"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb"
"cos.googlesource.com/cos/tools.git/src/pkg/utils"
"cloud.google.com/go/storage"
"google.golang.org/protobuf/proto"
)
func extractTarball(srcPath, destDir string) error {
cmd := exec.Command("tar", "xf", srcPath, "-C", destDir)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
type envReaderCreator func(context.Context) (*cos.EnvReader, error)
// cachedEnvReader returns a closure that creates an EnvReader from a VMImage.
// Since creating an EnvReader is very expensive in this context (it boots a VM
// from the image to read data from it), we use a closure to cache the resulting
// envReader.
//
// We could create the EnvReader once and pass it around everywhere. However,
// under specific conditions, the EnvReader won't need to be created at all,
// so this closure allows us to lazily create the EnvReader when it is needed.
func cachedEnvReader(input *config.Image, gcs *gcsManager, files *cosfs.Files, buildSpec *config.Build) envReaderCreator {
var envReader *cos.EnvReader
return func(ctx context.Context) (*cos.EnvReader, error) {
var err error
if envReader != nil {
return envReader, nil
}
envReader, err = cos.NewEnvReaderFromVMImage(input.Name, input.Project, cos.EnvReaderVMImageConfig{
DaisyBin: files.DaisyBin,
DaisyWorkflow: "/data/read_cos_env.wf.json",
Project: buildSpec.Project,
Zone: buildSpec.Zone,
GCSDir: gcs.managedDirURL(),
DiskType: buildSpec.DiskType,
MachineType: buildSpec.MachineType,
Network: buildSpec.Network,
SubNet: buildSpec.Subnet,
ServiceAccount: buildSpec.ServiceAccount,
})
if err != nil {
return nil, fmt.Errorf("failed to read env from VM image: %v", err)
}
return envReader, nil
}
}
func installToolchain(ctx context.Context, envReaderFunc envReaderCreator, gcsClient *storage.Client, buildSpec *config.Build, toolchainDir string) error {
if buildSpec.Toolchain != nil {
log.Println("Extracting toolchain package")
if err := extractTarball(buildSpec.Toolchain.ToolchainTarPath, toolchainDir); err != nil {
return err
}
if err := extractTarball(buildSpec.Toolchain.KernelHeadersPath, toolchainDir); err != nil {
return err
}
return nil
}
log.Println("No toolchain package provided; getting toolchain from input image")
envReader, err := envReaderFunc(ctx)
if err != nil {
return err
}
downloader := cos.NewGCSDownloader(gcsClient, envReader, "", "", "", "")
if err := cos.InstallCrossToolchain(ctx, downloader, toolchainDir); err != nil {
return err
}
// InstallCrossToolchain has an undesirable side effect of setting the SYSROOT
// environment variable. It's unclear if removing this behavior from
// InstallCrossToolchain is safe for all usages.
if err := os.Unsetenv("SYSROOT"); err != nil {
log.Printf("WARNING: unable to unset SYSROOT: %v", err)
}
return nil
}
func getArchFromToolchain(toolchainDir string) (string, error) {
x86Matches, err := filepath.Glob(filepath.Join(toolchainDir, "bin/x86_64-cros-linux-gnu-*"))
if err != nil {
return "", err
}
if len(x86Matches) != 0 {
return "x86_64", nil
}
aarch64Matches, err := filepath.Glob(filepath.Join(toolchainDir, "bin/aarch64-cros-linux-gnu-*"))
if err != nil {
return "", err
}
if len(aarch64Matches) != 0 {
return "aarch64", nil
}
return "", fmt.Errorf("could not identify arch from toolchain at %q; are programs missing from the toolchain?")
}
func downloadRunfile(version, arch, md5Sum, outDir string) (_ string, err error) {
outName := fmt.Sprintf("NVIDIA-Linux-%s-%s.run", arch, version)
url := fmt.Sprintf("https://us.download.nvidia.com/tesla/%s/%s", version, outName)
log.Printf("Downloading runfile from %q", url)
outPath := filepath.Join(outDir, outName)
outFile, err := os.Create(outPath)
if err != nil {
return "", err
}
defer utils.CheckClose(outFile, "error closing downloaded runfile", &err)
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer utils.CheckClose(resp.Body, "error closing runfile HTTP body", &err)
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unhealthy status downloading Nvidia runfile: %v", resp.Status)
}
if md5Sum != "" {
h := md5.New()
mw := io.MultiWriter(outFile, h)
if _, err := io.Copy(mw, resp.Body); err != nil {
return "", err
}
got := fmt.Sprintf("%x", h.Sum(nil))
if got != md5Sum {
return "", fmt.Errorf("md5sum mismatch when downloading runfile from %q; got %q, want %q", url, got, md5Sum)
}
} else {
if _, err := io.Copy(outFile, resp.Body); err != nil {
return "", err
}
}
return outPath, nil
}
func listModules(dir string) ([]string, error) {
var result []string
if err := filepath.WalkDir(dir, func(path string, _ fs.DirEntry, err error) error {
if err != nil {
return err
}
if filepath.Ext(path) == ".ko" {
result = append(result, filepath.Base(path))
}
return nil
}); err != nil {
return nil, err
}
return result, nil
}
func setDiff(a, b []string) []string {
aMap := map[string]bool{}
for _, e := range a {
aMap[e] = true
}
for _, e := range b {
if _, ok := aMap[e]; ok {
delete(aMap, e)
}
}
var result []string
for k, _ := range aMap {
result = append(result, k)
}
return result
}
func mergeDriverPackages(generatedDriver *gpubuild.DriverPackage, preBuiltPath, outDir string) (*gpubuild.DriverPackage, error) {
log.Println("Merging any missing precompiled modules from precompiled package into custom package")
nvidiaGenerated, err := os.MkdirTemp("", "nvidia-generated")
if err != nil {
return nil, err
}
defer os.RemoveAll(nvidiaGenerated)
if err := extractTarball(generatedDriver.Path, nvidiaGenerated); err != nil {
return nil, err
}
nvidiaPreBuilt, err := os.MkdirTemp("", "prebuilt-generated")
if err != nil {
return nil, err
}
defer os.RemoveAll(nvidiaPreBuilt)
if err := extractTarball(preBuiltPath, nvidiaPreBuilt); err != nil {
return nil, err
}
generatedModules, err := listModules(nvidiaGenerated)
if err != nil {
return nil, err
}
log.Printf("Generated package includes: %v", generatedModules)
preBuiltModules, err := listModules(nvidiaPreBuilt)
if err != nil {
return nil, err
}
log.Printf("Prebuilt package includes: %v", preBuiltModules)
notGeneratedModules := setDiff(preBuiltModules, generatedModules)
log.Printf("Adding the following to generated package: %v", notGeneratedModules)
for _, module := range notGeneratedModules {
src := filepath.Join(nvidiaPreBuilt, "drivers", module)
dst := filepath.Join(nvidiaGenerated, "drivers", module)
if err := utils.CopyFile(src, dst); err != nil {
return nil, err
}
}
result := &gpubuild.DriverPackage{
Path: filepath.Join(outDir, fmt.Sprintf("nvidia-drivers-%s.tgz", generatedDriver.Version)),
Version: generatedDriver.Version,
}
if err := cosfs.TarDir(nvidiaGenerated, result.Path); err != nil {
return nil, err
}
return result, nil
}
// installNonRunfileModules searches for kernel modules in any existing
// precompiled packages that are not present in the runfile we just compiled. If
// it finds anything, it adds it to the generated package we just compiled. This
// helps with supporting gdrcopy.
func installNonRunfileModules(ctx context.Context, envReaderFunc envReaderCreator, gcsClient *storage.Client, gpuConfig *config.GPUConfig, generatedDriver *gpubuild.DriverPackage, outDir string) (*gpubuild.DriverPackage, error) {
if gpuConfig.PrePackagedDriversPath != "" {
return mergeDriverPackages(generatedDriver, gpuConfig.PrePackagedDriversPath, outDir)
}
log.Println("No precompiled driver package provided, looking for an existing driver package for the given VM image")
envReader, err := envReaderFunc(ctx)
if err != nil {
return nil, err
}
downloader := cos.NewGCSDownloader(gcsClient, envReader, "", "", "", "")
precompiled := fmt.Sprintf("nvidia-drivers-%s.tgz", generatedDriver.Version)
if exists, err := downloader.ArtifactExists(ctx, precompiled); err != nil || !exists {
if err != nil {
log.Printf("WARNING: error searching for non-runfile modules for version %q: %v", generatedDriver.Version, err)
}
log.Printf("WARNING: could not find non-runfile modules for version %q, will not install them", generatedDriver.Version)
return generatedDriver, nil
}
tmpDir, err := os.MkdirTemp("", "non-runfile-modules")
if err != nil {
return nil, err
}
defer os.RemoveAll(tmpDir)
if err := downloader.DownloadArtifact(ctx, tmpDir, precompiled); err != nil {
return nil, err
}
precompiledPath := filepath.Join(tmpDir, precompiled)
return mergeDriverPackages(generatedDriver, precompiledPath, outDir)
}
func downloadImexDriver(ctx context.Context, envReaderFunc envReaderCreator, gcsClient *storage.Client, generatedDriver *gpubuild.DriverPackage, arch, outDir string) (string, error) {
if arch != "aarch64" {
log.Printf("Arch is %s, not searching for IMEX drivers", arch)
return "", nil
}
log.Println("Searching for IMEX driver to install")
envReader, err := envReaderFunc(ctx)
if err != nil {
return "", err
}
downloader := cos.NewGCSDownloader(gcsClient, envReader, "", "", "", "")
result, err := downloader.DownloadImexDriver(ctx, outDir, generatedDriver.Version)
if err != nil {
log.Printf("WARNING: could not download IMEX driver for %s", generatedDriver.Version)
return "", nil
}
return result, nil
}
func updateVersionInfo(versionsProto []byte, driverVersion, outDir string) (string, error) {
driverVersionInfo, err := gpuconfig.ParseGPUDriverVersionInfoList(versionsProto)
if err != nil {
return "", err
}
for _, gpuVersionInfo := range driverVersionInfo.GetGpuDriverVersionInfo() {
found := false
for _, v := range gpuVersionInfo.SupportedDriverVersions {
if v.Version == driverVersion {
found = true
break
}
}
if !found {
gpuVersionInfo.SupportedDriverVersions = append(gpuVersionInfo.SupportedDriverVersions, &pb.DriverVersion{Version: driverVersion})
}
}
resultBytes, err := proto.Marshal(driverVersionInfo)
if err != nil {
return "", err
}
result := filepath.Join(outDir, "gpu_driver_versions.bin")
if err := os.WriteFile(result, resultBytes, 0644); err != nil {
return "", err
}
return result, nil
}
// installDriverVersionsProto adds the version we just compiled to the
// gpu_driver_versions.bin associated with the input image. For each GPU type,
// add the version to it, if the version is not present at all for that GPU
// type.
//
// Some older COS images do not have a gpu_driver_versions.bin, and those
// versions do not require one. Do not install anything in this case and return
// an empty string.
func installDriverVersionsProto(ctx context.Context, envReaderFunc envReaderCreator, gcsClient *storage.Client, gpuConfig *config.GPUConfig, driverVersion, outDir string) (string, error) {
if gpuConfig.VersionsProtoPath != "" {
log.Println("Using provided gpu_driver_versions.bin")
versionsProto, err := os.ReadFile(gpuConfig.VersionsProtoPath)
if err != nil {
return "", err
}
return updateVersionInfo(versionsProto, driverVersion, outDir)
}
log.Println("No provided gpu_driver_versions.bin, looking for the one associated with the input image")
envReader, err := envReaderFunc(ctx)
if err != nil {
return "", err
}
downloader := cos.NewGCSDownloader(gcsClient, envReader, "", "", "", "")
if exists, err := downloader.ArtifactExists(ctx, "gpu_driver_versions.bin"); err != nil || !exists {
if err != nil {
log.Printf("WARNING: error searching for gpu_driver_versions.bin for input COS image: %v", err)
}
log.Println("WARNING: could not find gpu_driver_versions.bin for input image, will not provide one to cos-gpu-installer")
return "", nil
}
tmpDir, err := os.MkdirTemp("", "image-driver-versions")
if err != nil {
return "", err
}
defer os.RemoveAll(tmpDir)
if err := downloader.DownloadArtifact(ctx, tmpDir, "gpu_driver_versions.bin"); err != nil {
return "", err
}
versionsProto, err := os.ReadFile(filepath.Join(tmpDir, "gpu_driver_versions.bin"))
if err != nil {
return "", err
}
return updateVersionInfo(versionsProto, driverVersion, outDir)
}
func buildGPUPackage(ctx context.Context, input *config.Image, gcs *gcsManager, files *cosfs.Files, buildSpec *config.Build, workDir string) error {
log.Println("Building GPU package")
envReaderFunc := cachedEnvReader(input, gcs, files, buildSpec)
toolchainDir := filepath.Join(workDir, "toolchain")
if err := os.Mkdir(toolchainDir, 0755); err != nil {
return err
}
if err := installToolchain(ctx, envReaderFunc, gcs.gcsClient, buildSpec, toolchainDir); err != nil {
return fmt.Errorf("unable to install toolchain: %v", err)
}
arch, err := getArchFromToolchain(toolchainDir)
if err != nil {
return err
}
var runfile string
if filepath.Ext(buildSpec.GPUConfig.Version) == ".run" {
runfile = buildSpec.GPUConfig.Version
} else {
runfile, err = downloadRunfile(buildSpec.GPUConfig.Version, arch, buildSpec.GPUConfig.MD5Sum, workDir)
if err != nil {
return err
}
}
if err := os.Chmod(runfile, 0755); err != nil {
return err
}
driverPackage, err := gpubuild.Compile(runfile, toolchainDir, workDir)
if err != nil {
return err
}
mergedDir := filepath.Join(workDir, "merged")
if err := os.Mkdir(mergedDir, 0755); err != nil {
return err
}
mergedPackage, err := installNonRunfileModules(ctx, envReaderFunc, gcs.gcsClient, buildSpec.GPUConfig, driverPackage, mergedDir)
if err != nil {
return err
}
imexPath, err := downloadImexDriver(ctx, envReaderFunc, gcs.gcsClient, driverPackage, arch, mergedDir)
if err != nil {
return err
}
generatedProtoPath, err := installDriverVersionsProto(ctx, envReaderFunc, gcs.gcsClient, buildSpec.GPUConfig, driverPackage.Version, mergedDir)
if err != nil {
return err
}
buildSpec.GCSFiles = append(buildSpec.GCSFiles, runfile, mergedPackage.Path)
if imexPath != "" {
buildSpec.GCSFiles = append(buildSpec.GCSFiles, imexPath)
}
if generatedProtoPath != "" {
buildSpec.GCSFiles = append(buildSpec.GCSFiles, generatedProtoPath)
}
buildSpec.GPUConfig.Version = driverPackage.Version
return nil
}