| // Copyright 2026 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // Package main implements a tool to compile Nvidia GPU drivers for a specific COS version. |
| package main |
| |
| import ( |
| "context" |
| "flag" |
| "log" |
| "os" |
| "os/exec" |
| "path" |
| "path/filepath" |
| |
| "cloud.google.com/go/storage" |
| "google.golang.org/protobuf/proto" |
| |
| "cos.googlesource.com/cos/tools.git/src/pkg/cos" |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpubuild" |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig" |
| "cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb" |
| "cos.googlesource.com/cos/tools.git/src/pkg/utils" |
| ) |
| |
| var ( |
| runfile = flag.String("runfile", "", "Path to Nvidia driver runfile (local).") |
| cosVersion = flag.String("cos-version", "", "COS version (build number, e.g., 19506.120.64).") |
| cosBoard = flag.String("cos-board", "", "COS board (e.g., lakitu).") |
| outDir = flag.String("out-dir", ".", "Output directory.") |
| ) |
| |
| func main() { |
| flag.Parse() |
| |
| if *runfile == "" || *cosVersion == "" || *cosBoard == "" { |
| log.Fatal("Missing required flags: -runfile, -cos-version, -cos-board") |
| } |
| |
| ctx := context.Background() |
| gcsClient, err := storage.NewClient(ctx) |
| if err != nil { |
| log.Fatalf("Failed to create GCS client: %v", err) |
| } |
| defer gcsClient.Close() |
| |
| if _, err := os.Stat(*runfile); err != nil { |
| log.Fatalf("Runfile not found: %v", err) |
| } |
| absRunfile, err := filepath.Abs(*runfile) |
| if err != nil { |
| log.Fatalf("Failed to get absolute path for runfile: %v", err) |
| } |
| if err := os.Chmod(absRunfile, 0755); err != nil { |
| log.Fatalf("Failed to make runfile executable: %v", err) |
| } |
| |
| driverVersion, driverArch, err := gpubuild.GetRunfileInfo(absRunfile) |
| if err != nil { |
| log.Fatalf("Failed to get runfile info: %v", err) |
| } |
| log.Printf("Driver Version: %s, Arch: %s", driverVersion, driverArch) |
| |
| workDir, err := os.MkdirTemp("", "gpu-compiler-work") |
| if err != nil { |
| log.Fatalf("Failed to create work dir: %v", err) |
| } |
| defer os.RemoveAll(workDir) |
| |
| toolchainTar := filepath.Join(workDir, "toolchain.tar.xz") |
| headersTar := filepath.Join(workDir, "kernel-headers.tgz") |
| |
| prefix := path.Join(*cosVersion, *cosBoard) |
| downloader := cos.NewGCSDownloader(gcsClient, nil, "", prefix, "", "", false) |
| |
| log.Println("Downloading toolchain...") |
| if err := downloader.DownloadToolchain(ctx, workDir); err != nil { |
| log.Fatalf("Failed to download toolchain: %v", err) |
| } |
| |
| log.Println("Downloading kernel headers...") |
| if err := downloader.DownloadKernelHeaders(ctx, workDir); err != nil { |
| log.Fatalf("Failed to download kernel headers: %v", err) |
| } |
| |
| toolchainDir := filepath.Join(workDir, "toolchain") |
| if err := os.Mkdir(toolchainDir, 0755); err != nil { |
| log.Fatalf("Failed to create toolchain dir: %v", err) |
| } |
| |
| log.Println("Extracting toolchain...") |
| if err := extractTar(toolchainTar, toolchainDir); err != nil { |
| log.Fatalf("Failed to extract toolchain: %v", err) |
| } |
| |
| log.Println("Extracting kernel headers...") |
| if err := extractTar(headersTar, toolchainDir); err != nil { |
| log.Fatalf("Failed to extract kernel headers: %v", err) |
| } |
| |
| log.Println("Compiling driver...") |
| |
| if err := os.MkdirAll(*outDir, 0755); err != nil { |
| log.Fatalf("Failed to create output directory: %v", err) |
| } |
| absOutDir, err := filepath.Abs(*outDir) |
| if err != nil { |
| log.Fatalf("Failed to get absolute path for out-dir: %v", err) |
| } |
| |
| compiledPkg, err := gpubuild.Compile(absRunfile, toolchainDir, absOutDir) |
| if err != nil { |
| log.Fatalf("Compilation failed: %v", err) |
| } |
| log.Printf("Compiled package: %s", compiledPkg.Path) |
| |
| log.Println("Updating gpu_driver_versions.bin...") |
| gpuVersionsBin := filepath.Join(workDir, "gpu_driver_versions.bin") |
| |
| log.Printf("Attempting to download gpu_driver_versions.bin with prefix %s", prefix) |
| if err := downloader.DownloadArtifact(ctx, workDir, "gpu_driver_versions.bin"); err != nil { |
| log.Fatalf("ERROR: gpu_driver_versions.bin not found for version %s (board %s): %v", *cosVersion, *cosBoard, err) |
| } |
| |
| if err := updateGPUVersionsBin(gpuVersionsBin, driverVersion, filepath.Join(absOutDir, "gpu_driver_versions.bin")); err != nil { |
| log.Fatalf("Failed to update gpu_driver_versions.bin: %v", err) |
| } |
| |
| destRunfile := filepath.Join(absOutDir, filepath.Base(absRunfile)) |
| log.Printf("Copying runfile to %s", destRunfile) |
| if err := utils.CopyFile(absRunfile, destRunfile); err != nil { |
| log.Fatalf("Failed to copy runfile to output directory: %v", err) |
| } |
| |
| log.Println("Successfully completed!") |
| } |
| |
| func extractTar(tarPath, destDir string) error { |
| // Use tar command as it is simpler than writing full tar extraction in Go, |
| // and we know we are on Linux. |
| cmd := exec.Command("tar", "xf", tarPath, "-C", destDir) |
| cmd.Stdout = os.Stdout |
| cmd.Stderr = os.Stderr |
| return cmd.Run() |
| } |
| |
| func updateGPUVersionsBin(binPath, driverVersion, outPath string) error { |
| content, err := os.ReadFile(binPath) |
| if err != nil { |
| return err |
| } |
| |
| driverVersionInfo, err := gpuconfig.ParseGPUDriverVersionInfoList(content) |
| if err != nil { |
| return err |
| } |
| |
| // Update logic: append version to all GPUDriverVersionInfo if not present |
| updated := false |
| for _, gpuVersionInfo := range driverVersionInfo.GetGpuDriverVersionInfo() { |
| found := false |
| for _, v := range gpuVersionInfo.SupportedDriverVersions { |
| if v.Version == driverVersion { |
| found = true |
| break |
| } |
| } |
| if !found { |
| gpuVersionInfo.SupportedDriverVersions = append(gpuVersionInfo.SupportedDriverVersions, &pb.DriverVersion{Version: driverVersion}) |
| updated = true |
| } |
| } |
| |
| if !updated { |
| log.Printf("Version %s already present in gpu_driver_versions.bin, no update needed", driverVersion) |
| } else { |
| log.Printf("Added version %s to gpu_driver_versions.bin", driverVersion) |
| } |
| |
| resultBytes, err := proto.Marshal(driverVersionInfo) |
| if err != nil { |
| return err |
| } |
| |
| return os.WriteFile(outPath, resultBytes, 0644) |
| } |