blob: 7eef2236b927cc43d7e266d0bb092c1aa48cf8b4 [file] [log] [blame] [edit]
package gpu
import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"reflect"
"strings"
"testing"
)
type FakeHttpClient struct {
URL string
Header http.Header
Body []byte
}
func (h *FakeHttpClient) Do(r *http.Request) (*http.Response, error) {
h.URL = r.URL.String()
h.Header = r.Header
resp := &http.Response{
Header: r.Header,
StatusCode: http.StatusOK,
Request: r,
Body: io.NopCloser(bytes.NewReader(h.Body)),
}
return resp, nil
}
type FakeOSUtils struct {
result string
cmds string
}
func (f *FakeOSUtils) run(command string, args []string, hideStderr bool) (string, error) {
if command == "uname" {
return f.result, nil
}
f.cmds = fmt.Sprintf("%s %s", command, strings.Join(args, " "))
return f.cmds, nil
}
func TestRetrieveInstanceZone(t *testing.T) {
h := &FakeHttpClient{Body: []byte(`projects/475556798229/zones/us-central1-a`)}
zone, err := retrieveInstanceZone(h, instanceURL)
if err != nil {
t.Fatalf("retrieveInstanceZone failed: %v", err)
}
wantZone := "us"
wantHeader := "Google"
if h.URL != instanceURL || h.Header["Metadata-Flavor"][0] != wantHeader {
t.Errorf("TestRetrieveInstanceZone failed: Request not valid:\n wantUrl: %s\tgotUrl:%s\n wantHeader:%v\n gotHeader:%v",
h.URL, instanceURL, "Google", h.Header["Metadata-Flavor"][0])
}
if zone != wantZone {
t.Errorf("TestRetrieveInstanceZone failed: Returned unexpected difference:\n wantUrl: %s\tgotUrl:%s\n wantHeader:%v\n gotHeader:%v",
h.URL, instanceURL, "Google", h.Header["Metadata-Flavor"][0])
}
}
func TestCheckArch(t *testing.T) {
tests := []struct {
desc string
wantErr bool
hideStderr bool
arch string
}{
{
desc: "Success when arch is x86_64",
wantErr: false,
hideStderr: false,
arch: "x86_64",
},
{
desc: "Failure when arch is ARM",
wantErr: false,
hideStderr: false,
arch: "aarch64",
},
{
desc: "Failure when arch is other",
wantErr: true,
hideStderr: false,
arch: "ARM",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
f := FakeOSUtils{result: test.arch}
run = f.run
err := checkArch()
if gotErr := err != nil; gotErr != test.wantErr {
t.Errorf("TestCheckArch(%s): Error: %s\n gotErr: %t\n wantErr: %t", test.desc, err, gotErr, test.wantErr)
}
})
}
}
func TestGetCosInstaller(t *testing.T) {
tests := []struct {
desc string
wantErr bool
httpRes []byte
wantInstaller Installers
setEnv bool
setInstaller string
apiDomains string
}{
{
desc: "Vm instance zone in artifact regions",
wantErr: false,
httpRes: []byte("us"),
wantInstaller: Installers{DefaultInstaller: "us.gcr.io/cos-cloud/cos-gpu-installer:v2.4.4", SelectedInstaller: ""},
},
{
desc: "Vm instance zone not in artifact regions",
wantErr: false,
httpRes: []byte("cn"),
wantInstaller: Installers{DefaultInstaller: "gcr.io/cos-cloud/cos-gpu-installer:v2.4.4", SelectedInstaller: ""},
},
{
desc: "Installer set in env variables.",
wantErr: false,
setEnv: true,
setInstaller: "cloud-installer-101",
httpRes: []byte("us"),
wantInstaller: Installers{DefaultInstaller: "gcr.io/cos-cloud/cos-gpu-installer:v2.4.4", SelectedInstaller: "cloud-installer-101"},
},
{
desc: "Running in TPC",
wantErr: false,
httpRes: []byte("u"),
wantInstaller: Installers{DefaultInstaller: "docker.fake-ar-domain.goog/fake-project-prefix/cos-cloud/cos-gpu-installer/cos-gpu-installer:v2.4.4", SelectedInstaller: ""},
apiDomains: "API_DOMAIN=fake-api-domain.goog\nARTIFACT_REGISTRY_DOMAIN=fake-ar-domain.goog\nPROJECT_PREFIX=fake-project-prefix",
},
{
desc: "Running in GDU with domain file",
wantErr: false,
httpRes: []byte("us"),
wantInstaller: Installers{DefaultInstaller: "us.gcr.io/cos-cloud/cos-gpu-installer:v2.4.4", SelectedInstaller: ""},
apiDomains: "API_DOMAIN=googleapis.com\nARTIFACT_REGISTRY_DOMAIN=pkg.dev\nPROJECT_PREFIX=",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
c := &FakeHttpClient{Body: test.httpRes}
if test.apiDomains != "" {
tmpFile, err := os.CreateTemp("", "")
if err != nil {
t.Fatalf("failed to create tmpFile: %v", err)
}
defer os.Remove(tmpFile.Name())
_, err = tmpFile.WriteString(test.apiDomains)
if err != nil {
t.Fatalf("failed to write to tmpFile: %v", err)
}
realAPIDomainsPath := apiDomainsPath
apiDomainsPath = tmpFile.Name()
defer func() { apiDomainsPath = realAPIDomainsPath }()
}
if test.setEnv == true {
t.Setenv("COS_GPU_INSTALLER", test.setInstaller)
}
cosInstaller, err := GetCosInstaller(c)
if gotErr := err != nil; gotErr != test.wantErr {
t.Errorf("TestGetCosInstaller(%s): Error: %s\n gotErr: %t\n wantErr: %t", test.desc, err, gotErr, test.wantErr)
}
if !reflect.DeepEqual(cosInstaller, test.wantInstaller) {
t.Errorf("TestGetCosInstaller(%s): returned unexpected difference;\n wantInstaller:%s\t gotInstaller:%s", test.desc, test.wantInstaller, cosInstaller)
}
})
}
}
func TestList(t *testing.T) {
tests := []struct {
desc string
args []string
wantErr bool
listInstaller bool
installer Installers
want string
}{
{
desc: "Flag --gpu-installer passed in",
args: []string{},
listInstaller: true,
installer: Installers{DefaultInstaller: "installer", SelectedInstaller: "selected installer"},
want: "installer",
},
{
desc: "Flag --gpu-installer not passed in",
args: []string{},
installer: Installers{DefaultInstaller: "installer"},
want: "Available",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
var got bytes.Buffer
writer = &got
f := FakeOSUtils{result: "x86_64"}
run = f.run
err := List(test.installer, test.listInstaller, test.args)
if gotErr := err != nil; gotErr != test.wantErr {
t.Errorf("TestList(%s): Error: %s\n gotErr: %t\n wantErr: %t", test.desc, err, gotErr, test.wantErr)
}
if !strings.HasPrefix(got.String(), test.want) {
t.Errorf("TestGetCosInstaller(%s): returned unexpected difference;\n want:%s\t got:%s", test.desc, test.want, got.String())
}
})
}
}
func TestRunInstaller(t *testing.T) {
tests := []struct {
desc string
wantArgs []string
cacheExists bool
wantErr bool
installer Installers
cmd string
args []string
hideStderr bool
}{
{
desc: "Uses set installer",
wantArgs: []string{
"/usr/bin/docker",
"run",
"--rm",
"--name=cos-gpu-installer",
"--privileged",
"--net=host",
"--pid=host",
"--volume", "/dev:/dev",
"--volume", "/:/root",
"--log-driver", "journald",
"installer",
"install",
},
wantErr: false,
cmd: "install",
installer: Installers{SelectedInstaller: "installer", DefaultInstaller: "default installer"},
hideStderr: false,
},
{
desc: "Uses extension cache: flag --prepare-build-tools passed in",
wantErr: false,
args: []string{"--prepare-build-tools"},
installer: Installers{DefaultInstaller: "installer"},
cmd: "install",
wantArgs: []string{
"/usr/bin/docker",
"run",
"--rm",
"--name=cos-gpu-installer",
"--privileged",
"--net=host",
"--pid=host",
"--volume", "/dev:/dev",
"--volume", "/:/root",
"--log-driver", "journald",
"--volume", "/var/lib/cos-extensions/:/build/",
"installer",
"install",
"--prepare-build-tools",
},
hideStderr: false,
},
{
desc: "Does not use extension cache",
wantArgs: []string{
"/usr/bin/docker",
"run",
"--rm",
"--name=cos-gpu-installer",
"--privileged",
"--net=host",
"--pid=host",
"--volume", "/dev:/dev",
"--volume", "/:/root",
"--log-driver", "journald",
"installer",
"install",
},
wantErr: false,
cmd: "install",
installer: Installers{DefaultInstaller: "installer"},
hideStderr: false,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
var got bytes.Buffer
writer = &got
f := FakeOSUtils{result: "x86_64"}
run = f.run
err := runInstaller(test.cmd, test.hideStderr, extensionCache, test.installer, test.args...)
if gotErr := err != nil; gotErr != test.wantErr {
t.Errorf("TestList(%s): Error: %s\n gotErr: %t\n wantErr: %t", test.desc, err, gotErr, test.wantErr)
}
want := strings.Join(test.wantArgs, " ")
if f.cmds != want {
t.Errorf("TestRunInstaller(%s): returned unexpected difference;\n want:%s\t got:%v", test.desc, want, f.cmds)
}
})
}
}
func TestRunInstallerCacheExists(t *testing.T) {
tmpFile, err := os.CreateTemp(".", "cos-extensions")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
cacheFile := tmpFile.Name()
defer os.Remove(tmpFile.Name())
f := FakeOSUtils{result: "x86_64"}
run = f.run
err = runInstaller("install", true, cacheFile, Installers{DefaultInstaller: "installer"})
if gotErr := err != nil; gotErr {
t.Errorf("TestRunInstallerCacheExists: Error: %s\n", err)
}
wantArgs := []string{
"/usr/bin/docker",
"run",
"--rm",
"--name=cos-gpu-installer",
"--privileged",
"--net=host",
"--pid=host",
"--volume", "/dev:/dev",
"--volume", "/:/root",
"--log-driver", "journald",
fmt.Sprintf("--volume %s/:/build/", tmpFile.Name()),
"installer",
"install",
}
want := strings.Join(wantArgs, " ")
if f.cmds != want {
t.Errorf("TestRunInstallerCacheExists: returned unexpected difference;\n want:%s\t got:%v", want, f.cmds)
}
}
func TestRunInstallerCleanCache(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "tmpDir")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
args := []string{"--clean-build-tools"}
// Ensuring temp dir is deleted if test fails.
defer os.RemoveAll(tmpDir)
f := FakeOSUtils{result: "x86_64"}
run = f.run
err = runInstaller("install", false, tmpDir, Installers{DefaultInstaller: "installer"}, args...)
if err != nil {
t.Errorf("TestRunInstallerCleanCache: Error: %s\n", err)
}
if _, err := os.Stat(tmpDir); !os.IsNotExist(err) {
t.Errorf("TestRunInstallerCleanCache: did not remove cache file")
}
}