blob: 55e552c0897691a4f8bb6c48d68d5b3eb026113d [file] [log] [blame]
package gpuconfig
import (
"context"
"fmt"
"net/url"
"path"
"path/filepath"
"log"
"cloud.google.com/go/storage"
"cos.googlesource.com/cos/tools.git/src/pkg/gcs"
"cos.googlesource.com/cos/tools.git/src/pkg/utils"
)
type GPUArtifactsDownloader struct {
client *storage.Client
config GPUPrecompilationConfig
}
// NewGPUArtifactsDownloader creates a GPUArtifactsDownloader instance.
func NewGPUArtifactsDownloader(client *storage.Client, config GPUPrecompilationConfig) *GPUArtifactsDownloader {
return &GPUArtifactsDownloader{client, config}
}
// DownloadKernelSrc downloads COS kernel sources to destination directory.
func (d *GPUArtifactsDownloader) DownloadKernelSrc(ctx context.Context, destDir string) error {
return d.downloadArtifact(ctx, destDir, d.config.ProtoConfig.GetKernelSrcTarballGcs(), "kernel-src.tar.gz")
}
// DownloadToolchainEnv downloads toolchain compilation environment variables to destination directory.
func (d *GPUArtifactsDownloader) DownloadToolchainEnv(ctx context.Context, destDir string) error {
return d.downloadArtifact(ctx, destDir, d.config.ProtoConfig.GetToolchainEnvGcs(), "toolchain_env")
}
// DownloadToolchain downloads toolchain package to destination directory.
func (d *GPUArtifactsDownloader) DownloadToolchain(ctx context.Context, destDir string) error {
return d.downloadArtifact(ctx, destDir, d.config.ProtoConfig.GetToolchainTarballGcs(), "toolchain.tar.xz")
}
// DownloadKernelHeaders downloads COS kernel headers to destination directory.
func (d *GPUArtifactsDownloader) DownloadKernelHeaders(ctx context.Context, destDir string) error {
return d.downloadArtifact(ctx, destDir, d.config.ProtoConfig.GetKernelHeadersTarballGcs(), "kernel-headers.tgz")
}
func (d *GPUArtifactsDownloader) GetArtifact(ctx context.Context, artifact string) ([]byte, error) {
return nil, nil
}
func (d *GPUArtifactsDownloader) ListArtifacts(ctx context.Context, prefix string) ([]string, error) {
return nil, nil
}
func (d *GPUArtifactsDownloader) DownloadNVIDIARunfile(ctx context.Context, destDir string) (string, error) {
url, err := url.Parse(d.config.ProtoConfig.GetNvidiaRunfileAddress())
if err != nil {
return "", fmt.Errorf("error parsing the artifact path: %v", err)
}
nvidiaInstaller := path.Base(url.Path)
if err := d.downloadArtifact(ctx, destDir, url.String(), nvidiaInstaller); err != nil {
return "", err
}
return nvidiaInstaller, nil
}
// DownloadArtifact downloads an artifact from the GCS prefix configured in GPUArtifactsDownloader.
func (d *GPUArtifactsDownloader) DownloadArtifact(ctx context.Context, destDir, artifactPath string) error {
return nil
}
func (d *GPUArtifactsDownloader) downloadArtifact(ctx context.Context, destDir, artifactPath, fileName string) error {
log.Printf("downloading artifact from:%s\n", artifactPath)
url, err := url.Parse(artifactPath)
if err != nil {
return fmt.Errorf("error parsing the artifact path: %v", err)
}
switch url.Scheme {
case "gs":
return gcs.DownloadGCSObjectFromURL(ctx, d.client, artifactPath, filepath.Join(destDir, fileName))
case "https":
return utils.DownloadContentFromURL(artifactPath, filepath.Join(destDir, fileName), fileName)
default:
return fmt.Errorf("only https:// or gs:// urls supported: %s", url)
}
}
func (d *GPUArtifactsDownloader) ArtifactExists(ctx context.Context, artifactPath string) (bool, error) {
return false, fmt.Errorf("not implemented")
}