blob: 412bfd2a062d88e3126e6c1a659d0bb75d6387e8 [file] [edit]
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package main implements a tool to compile Nvidia GPU drivers for a specific COS version.
package main
import (
"context"
"flag"
"log"
"os"
"os/exec"
"path"
"path/filepath"
"cloud.google.com/go/storage"
"google.golang.org/protobuf/proto"
"cos.googlesource.com/cos/tools.git/src/pkg/cos"
"cos.googlesource.com/cos/tools.git/src/pkg/gpubuild"
"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig"
"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig/pb"
"cos.googlesource.com/cos/tools.git/src/pkg/utils"
)
var (
runfile = flag.String("runfile", "", "Path to Nvidia driver runfile (local).")
cosVersion = flag.String("cos-version", "", "COS version (build number, e.g., 19506.120.64).")
cosBoard = flag.String("cos-board", "", "COS board (e.g., lakitu).")
outDir = flag.String("out-dir", ".", "Output directory.")
)
func main() {
flag.Parse()
if *runfile == "" || *cosVersion == "" || *cosBoard == "" {
log.Fatal("Missing required flags: -runfile, -cos-version, -cos-board")
}
ctx := context.Background()
gcsClient, err := storage.NewClient(ctx)
if err != nil {
log.Fatalf("Failed to create GCS client: %v", err)
}
defer gcsClient.Close()
if _, err := os.Stat(*runfile); err != nil {
log.Fatalf("Runfile not found: %v", err)
}
absRunfile, err := filepath.Abs(*runfile)
if err != nil {
log.Fatalf("Failed to get absolute path for runfile: %v", err)
}
if err := os.Chmod(absRunfile, 0755); err != nil {
log.Fatalf("Failed to make runfile executable: %v", err)
}
driverVersion, driverArch, err := gpubuild.GetRunfileInfo(absRunfile)
if err != nil {
log.Fatalf("Failed to get runfile info: %v", err)
}
log.Printf("Driver Version: %s, Arch: %s", driverVersion, driverArch)
workDir, err := os.MkdirTemp("", "gpu-compiler-work")
if err != nil {
log.Fatalf("Failed to create work dir: %v", err)
}
defer os.RemoveAll(workDir)
toolchainTar := filepath.Join(workDir, "toolchain.tar.xz")
headersTar := filepath.Join(workDir, "kernel-headers.tgz")
prefix := path.Join(*cosVersion, *cosBoard)
downloader := cos.NewGCSDownloader(gcsClient, nil, "", prefix, "", "", false)
log.Println("Downloading toolchain...")
if err := downloader.DownloadToolchain(ctx, workDir); err != nil {
log.Fatalf("Failed to download toolchain: %v", err)
}
log.Println("Downloading kernel headers...")
if err := downloader.DownloadKernelHeaders(ctx, workDir); err != nil {
log.Fatalf("Failed to download kernel headers: %v", err)
}
toolchainDir := filepath.Join(workDir, "toolchain")
if err := os.Mkdir(toolchainDir, 0755); err != nil {
log.Fatalf("Failed to create toolchain dir: %v", err)
}
log.Println("Extracting toolchain...")
if err := extractTar(toolchainTar, toolchainDir); err != nil {
log.Fatalf("Failed to extract toolchain: %v", err)
}
log.Println("Extracting kernel headers...")
if err := extractTar(headersTar, toolchainDir); err != nil {
log.Fatalf("Failed to extract kernel headers: %v", err)
}
log.Println("Compiling driver...")
if err := os.MkdirAll(*outDir, 0755); err != nil {
log.Fatalf("Failed to create output directory: %v", err)
}
absOutDir, err := filepath.Abs(*outDir)
if err != nil {
log.Fatalf("Failed to get absolute path for out-dir: %v", err)
}
compiledPkg, err := gpubuild.Compile(absRunfile, toolchainDir, absOutDir)
if err != nil {
log.Fatalf("Compilation failed: %v", err)
}
log.Printf("Compiled package: %s", compiledPkg.Path)
log.Println("Updating gpu_driver_versions.bin...")
gpuVersionsBin := filepath.Join(workDir, "gpu_driver_versions.bin")
log.Printf("Attempting to download gpu_driver_versions.bin with prefix %s", prefix)
if err := downloader.DownloadArtifact(ctx, workDir, "gpu_driver_versions.bin"); err != nil {
log.Fatalf("ERROR: gpu_driver_versions.bin not found for version %s (board %s): %v", *cosVersion, *cosBoard, err)
}
if err := updateGPUVersionsBin(gpuVersionsBin, driverVersion, filepath.Join(absOutDir, "gpu_driver_versions.bin")); err != nil {
log.Fatalf("Failed to update gpu_driver_versions.bin: %v", err)
}
destRunfile := filepath.Join(absOutDir, filepath.Base(absRunfile))
log.Printf("Copying runfile to %s", destRunfile)
if err := utils.CopyFile(absRunfile, destRunfile); err != nil {
log.Fatalf("Failed to copy runfile to output directory: %v", err)
}
log.Println("Successfully completed!")
}
func extractTar(tarPath, destDir string) error {
// Use tar command as it is simpler than writing full tar extraction in Go,
// and we know we are on Linux.
cmd := exec.Command("tar", "xf", tarPath, "-C", destDir)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func updateGPUVersionsBin(binPath, driverVersion, outPath string) error {
content, err := os.ReadFile(binPath)
if err != nil {
return err
}
driverVersionInfo, err := gpuconfig.ParseGPUDriverVersionInfoList(content)
if err != nil {
return err
}
// Update logic: append version to all GPUDriverVersionInfo if not present
updated := false
for _, gpuVersionInfo := range driverVersionInfo.GetGpuDriverVersionInfo() {
found := false
for _, v := range gpuVersionInfo.SupportedDriverVersions {
if v.Version == driverVersion {
found = true
break
}
}
if !found {
gpuVersionInfo.SupportedDriverVersions = append(gpuVersionInfo.SupportedDriverVersions, &pb.DriverVersion{Version: driverVersion})
updated = true
}
}
if !updated {
log.Printf("Version %s already present in gpu_driver_versions.bin, no update needed", driverVersion)
} else {
log.Printf("Added version %s to gpu_driver_versions.bin", driverVersion)
}
resultBytes, err := proto.Marshal(driverVersionInfo)
if err != nil {
return err
}
return os.WriteFile(outPath, resultBytes, 0644)
}