cos-gpu-installer-v2: Add utils package
Change-Id: I24ace0a2f3704e5afe93e332deb1c0c1bf2fef9c
diff --git a/src/pkg/utils/utils.go b/src/pkg/utils/utils.go
new file mode 100644
index 0000000..7022039
--- /dev/null
+++ b/src/pkg/utils/utils.go
@@ -0,0 +1,341 @@
+// Package utils provides utility functions.
+package utils
+
+import (
+ "archive/tar"
+ "bufio"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "syscall"
+ "time"
+
+ log "github.com/golang/glog"
+ "github.com/pkg/errors"
+)
+
+var (
+ downloadRetries = 3
+ lockFile = "/root/tmp/cos_gpu_installer_lock"
+)
+
+type serviceAccountToken struct {
+ Token string `json:"access_token"`
+ Expire int `json:"expires_in"`
+ TokenType string `json:"token_type"`
+}
+
+type listStorageObjectsResponse struct {
+ Kind string `json:"kind"`
+ Items []struct {
+ Kind string `json:"kind"`
+ ID string `json:"id"`
+ SelfLink string `json:"selfLink"`
+ MediaLink string `json:"mediaLink"`
+ Name string `json:"name"`
+ Bucket string `json:"bucket"`
+ Generation string `json:"generation"`
+ Metageneration string `json:"metageneration"`
+ ContentType string `json:"contentType"`
+ StorageClass string `json:"storageClass"`
+ Size string `json:"size"`
+ Md5Hash string `json:"md5Hash"`
+ Crc32c string `json:"crc32c"`
+ Etag string `json:"etag"`
+ TimeCreated string `json:"timeCreated"`
+ Updated string `json:"updated"`
+ TimeStorageClassUpdated string `json:"timeStorageClassUpdated"`
+ } `json:"items"`
+}
+
+// Flock exclusively locks a special file on the host to make sure only one calling process is running at any time.
+func Flock() {
+ // TODO(mikewu): generalize Flock to make it useful for other use cases.
+ f, err := os.OpenFile(lockFile, os.O_RDONLY|os.O_CREATE, 0666)
+ if err != nil {
+ log.Exitf("Failed to open lock file: %v", err)
+ }
+ if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
+ log.Exitf("File %s is locked. Other process might be running.", lockFile)
+ }
+}
+
+// DownloadContentFromURL downloads file from a given URL.
+func DownloadContentFromURL(url, outputPath, infoStr string) error {
+ url = strings.TrimSpace(url)
+ log.Infof("Downloading %s from %s", infoStr, url)
+
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return errors.Wrapf(err, "failed to download %s from %s", infoStr, url)
+ }
+ // TODO(mikewu): Consider using GCS GO package.
+ if strings.HasPrefix(url, "https://storage.googleapis.com") {
+ // TODO(mikewu): Consider using sgauth (https://github.com/google/oauth2l/tree/master/sgauth).
+ token, err := GetDefaultVMToken()
+ if err != nil {
+ return errors.Wrap(err, "failed to get VM token")
+ }
+ req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
+ }
+
+ outputFile, err := os.Create(outputPath)
+ if err != nil {
+ return errors.Wrapf(err, "failed to create file %s", outputPath)
+ }
+ defer outputFile.Close()
+
+ client := &http.Client{}
+
+ var response *http.Response
+ retries := downloadRetries
+ for retries > 0 {
+ response, err = client.Do(req)
+ if err != nil {
+ log.Errorf("Failed to download %s: %v", infoStr, err)
+ retries--
+ time.Sleep(time.Second)
+ log.Info("Retry...")
+ } else {
+ break
+ }
+ }
+ if response == nil {
+ return errors.Wrapf(err, "failed to download %s", infoStr)
+ }
+ defer response.Body.Close()
+ if response.StatusCode != 200 {
+ return errors.Errorf("failed to download %s, status: %s", infoStr, response.Status)
+ }
+ if _, err := io.Copy(outputFile, response.Body); err != nil {
+ return errors.Wrapf(err, "failed to download %s", infoStr)
+ }
+
+ log.Infof("Successfully downloaded %s from %s", infoStr, url)
+ return nil
+}
+
+// DownloadFromGCS downloads an object from the given GCS path.
+func DownloadFromGCS(destDir, gcsBucket, gcsPath string) error {
+ downloadURL := fmt.Sprintf("https://storage.googleapis.com/%s/%s", gcsBucket, gcsPath)
+ filename := filepath.Base(gcsPath)
+ outputPath := filepath.Join(destDir, filename)
+ return DownloadContentFromURL(downloadURL, outputPath, filename)
+}
+
+// ListGCSBucket lists the objects whose names begin with the given prefix in the given GCS bucekt.
+func ListGCSBucket(bucket, prefix string) ([]string, error) {
+ log.Infof("Listing objects from GCS bucekt %s with prefix %s", bucket, prefix)
+
+ url := fmt.Sprintf("https://storage.googleapis.com/storage/v1/b/%s/o?prefix=%s", bucket, prefix)
+ dir, err := ioutil.TempDir("", "bucketlist")
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to create tempdir")
+ }
+ defer os.RemoveAll(dir)
+ tmpfile := filepath.Join(dir, "bucketlist")
+ if err := DownloadContentFromURL(url, tmpfile, "bucketlist"); err != nil {
+ return nil, errors.Wrapf(err, "failed to downoad url %s", url)
+ }
+
+ content, err := ioutil.ReadFile(tmpfile)
+ if err != nil {
+ return nil, errors.Wrapf(err, "failed to read file %s", tmpfile)
+ }
+ var jsonContent listStorageObjectsResponse
+ if err := json.Unmarshal(content, &jsonContent); err != nil {
+ return nil, errors.Wrapf(err, "failed to parse json string %s", string(content))
+ }
+
+ var objects []string
+ for _, item := range jsonContent.Items {
+ objects = append(objects, item.Name)
+ }
+ return objects, nil
+}
+
+// GetDefaultVMToken returns the default GCE service account of the COS VM the program is running on.
+func GetDefaultVMToken() (string, error) {
+ tokenStr, err := GetGCEMetadata("service-accounts/default/token")
+ if err != nil {
+ return "", errors.Wrap(err, "failed to get default VM token")
+ }
+ token, err := parseVMToken(tokenStr)
+ if err != nil {
+ return "", errors.Wrap(err, "failed to parse VM token")
+ }
+ return token.Token, nil
+}
+
+// GetGCEMetadata queries GCE metadata server to get the value of a given metadata key.
+func GetGCEMetadata(metadataPath string) (string, error) {
+ req, err := http.NewRequest("GET", "http://metadata.google.internal/computeMetadata/v1/instance/"+metadataPath, nil)
+ if err != nil {
+ return "", errors.Wrap(err, "failed to get GCE metadata")
+ }
+ req.Header.Add("Metadata-Flavor", "Google")
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", errors.Wrap(err, "failed to get GCE metadata")
+ }
+ defer resp.Body.Close()
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return "", errors.Wrap(err, "failed to get GCE metadata")
+ }
+ return string(body), nil
+}
+
+// IsDirEmpty returns whether a given directory is empty.
+func IsDirEmpty(dirName string) (bool, error) {
+ dir, err := os.Open(dirName)
+ if err != nil {
+ return false, err
+ }
+ defer dir.Close()
+ _, err = dir.Readdir(1)
+ if err == io.EOF {
+ return true, nil
+ }
+ return false, err
+}
+
+// LoadEnvFromFile reads an env file from fs into memory as a map.
+func LoadEnvFromFile(prefix, filePath string) (map[string]string, error) {
+ path := filepath.Join(prefix, filePath)
+ envs := make(map[string]string)
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, errors.Wrapf(err, "failed to read file %s", path)
+ }
+ defer f.Close()
+ rd := bufio.NewReader(f)
+ // TODO(mikewu): Consider using https://golang.org/pkg/bufio/#Scanner.
+ for {
+ line, err := rd.ReadString('\n')
+ if err != nil && err != io.EOF {
+ return nil, errors.Wrapf(err, "failed to read file %s", path)
+ }
+ trimmedLine := strings.TrimSpace(line)
+ if trimmedLine != "" {
+ parts := strings.SplitN(trimmedLine, "=", 2)
+ if len(parts) != 2 {
+ return nil, errors.Wrapf(err, "Unrecognized format: %s", trimmedLine)
+ }
+ envs[parts[0]] = strings.Trim(parts[1], `"'`)
+ }
+ if err == io.EOF {
+ break
+ }
+ }
+ return envs, nil
+}
+
+// CreateTarFile creates a tar archive file given a map of {filename: content}.
+func CreateTarFile(tarFilename string, files map[string][]byte) error {
+ tarFile, err := os.Create(tarFilename)
+ if err != nil {
+ return err
+ }
+ defer tarFile.Close()
+
+ tw := tar.NewWriter(tarFile)
+ defer tw.Close()
+
+ for name, body := range files {
+ header := &tar.Header{
+ Name: name,
+ Mode: 0644,
+ Size: int64(len(body)),
+ }
+ if err := tw.WriteHeader(header); err != nil {
+ return err
+ }
+ if _, err := tw.Write(body); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// RunCommandAndLogOutput runs the given command and logs the stdout and stderr in parallel.
+func RunCommandAndLogOutput(cmd *exec.Cmd, expectError bool) error {
+ cmd.Stdout = &loggingWriter{logger: log.Info}
+ cmd.Stderr = &loggingWriter{logger: log.Error}
+
+ err := cmd.Run()
+ if _, ok := err.(*exec.ExitError); ok && expectError {
+ log.Warningf("command %s didn't complete successfully: %v", cmd.Path, err)
+ return nil
+ }
+ return err
+}
+
+// CopyFile copies a file from src to dest.
+func CopyFile(src, dest string) error {
+ srcfile, err := os.Open(src)
+ if err != nil {
+ return errors.Wrapf(err, "failed to open file %s", src)
+ }
+ defer srcfile.Close()
+ destfile, err := os.Create(dest)
+ if err != nil {
+ return errors.Wrapf(err, "failed to create file %s", dest)
+ }
+ defer destfile.Close()
+ if _, err := io.Copy(destfile, srcfile); err != nil {
+ return errors.Wrapf(err, "failed to copy file from %s to %s", dest, src)
+ }
+ if err := srcfile.Close(); err != nil {
+ return errors.Wrapf(err, "failed to close source file %s", src)
+ }
+ if err := destfile.Close(); err != nil {
+ return errors.Wrapf(err, "failed to close destination file %s", dest)
+ }
+ return nil
+}
+
+// MoveFile moves a file from src to dest.
+// Avoid to use os.Rename as the src and dst may on different filesystems,
+// e.g. (container temp fs -> host mounted volume).
+func MoveFile(src, dest string) error {
+ if err := CopyFile(src, dest); err != nil {
+ return errors.Wrapf(err, "failed to move file from %s to %s", src, dest)
+ }
+ if err := os.Remove(src); err != nil {
+ return errors.Wrapf(err, "failed to remove file %s", src)
+ }
+ return nil
+}
+
+func parseVMToken(tokenStr string) (*serviceAccountToken, error) {
+ var token serviceAccountToken
+ if err := json.Unmarshal([]byte(tokenStr), &token); err != nil {
+ return nil, err
+ }
+ return &token, nil
+}
+
+type loggingWriter struct {
+ logger func(...interface{})
+ buf []byte
+}
+
+func (l *loggingWriter) Write(p []byte) (int, error) {
+ for _, b := range p {
+ if b == '\n' {
+ l.logger(string(l.buf[:]))
+ l.buf = nil
+ continue
+ }
+ l.buf = append(l.buf, b)
+ }
+ return len(p), nil
+}
diff --git a/src/pkg/utils/utils_test.go b/src/pkg/utils/utils_test.go
new file mode 100644
index 0000000..05d2b9f
--- /dev/null
+++ b/src/pkg/utils/utils_test.go
@@ -0,0 +1,153 @@
+package utils
+
+import (
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestLock(t *testing.T) {
+ if os.Getenv("TEST_LOCK") == "1" {
+ origLockFile := lockFile
+ lockFile = os.Getenv("TEST_DIR")
+ defer func(origLockFile string) { lockFile = origLockFile }(origLockFile)
+ Flock()
+ // forever so that the filelock won't be released.
+ for {
+ }
+ }
+
+ tmpfile, err := ioutil.TempFile("", "testing")
+ if err != nil {
+ t.Fatalf("Failed to create tempfile: %v", err)
+ }
+ defer os.Remove(tmpfile.Name())
+
+ // First time to call Lock(), expect to wait forever
+ cmd1 := exec.Command(os.Args[0], "-test.run=TestLock")
+ cmd1.Env = append(os.Environ(), "TEST_LOCK=1", "TEST_DIR="+tmpfile.Name())
+ if err := cmd1.Start(); err != nil {
+ t.Fatalf("Failed to start command: %v", err)
+ }
+
+ // Wait 1 sec for the first process to lock file.
+ time.Sleep(time.Second)
+
+ // Second time to call Lock(), expect to exit with status 1
+ cmd2 := exec.Command(os.Args[0], "-test.run=TestLock")
+ cmd2.Env = append(os.Environ(), "TEST_LOCK=1", "TEST_DIR="+tmpfile.Name())
+ if err := cmd2.Start(); err != nil {
+ t.Fatalf("Failed to start command: %v", err)
+ }
+
+ waitWithTimeout(t, cmd1, 3, false)
+ waitWithTimeout(t, cmd2, 3, true)
+}
+
+func waitWithTimeout(t *testing.T, cmd *exec.Cmd, timeout int, expectError bool) {
+ done := make(chan error, 1)
+ go func() {
+ done <- cmd.Wait()
+ }()
+
+ select {
+ case <-time.After(time.Duration(timeout) * time.Second):
+ if err := cmd.Process.Kill(); err != nil {
+ t.Fatalf("Failed to kill process: %v", err)
+ }
+ if expectError {
+ t.Errorf("Process %s didn't exit while expecting to exit with error", cmd.Path)
+ }
+ case err := <-done:
+ e, ok := err.(*exec.ExitError)
+ if !ok {
+ t.Fatal("Failed to convert error to exec.ExitError")
+ }
+ if e.Success() == expectError {
+ t.Errorf("Process %s exited with unexpected status, want error: %v, got error: %v",
+ cmd.Path, expectError, !e.Success())
+ }
+ }
+}
+
+func TestParseVMToken(t *testing.T) {
+ token, err := parseVMToken(
+ `{"access_token":"ya29.c.Kmi8B89nrn2Esf2e4WEk2MlZp7G8EpMatfxD36UuG3QJpwqePPxLAMvlb-WEi-nnZ7WmFsxyTAhzFMlxBV4AEYfs1tdJqolDay_3BXkwv0cwFe6OO86_dSUWDbiK9gIYQ6bAE_oR9SdVdw","expires_in":3248,"token_type":"Bearer"}`)
+ if err != nil {
+ t.Fatalf("Failed to run parseVMToken: %v", err)
+ }
+ expectedToken := serviceAccountToken{
+ Token: "ya29.c.Kmi8B89nrn2Esf2e4WEk2MlZp7G8EpMatfxD36UuG3QJpwqePPxLAMvlb-WEi-nnZ7WmFsxyTAhzFMlxBV4AEYfs1tdJqolDay_3BXkwv0cwFe6OO86_dSUWDbiK9gIYQ6bAE_oR9SdVdw",
+ Expire: 3248,
+ TokenType: "Bearer",
+ }
+ if diff := cmp.Diff(*token, expectedToken); diff != "" {
+ t.Errorf("Unexpected return\nwant: %v\ngot: %v\ndiff: %v", expectedToken, *token, diff)
+ }
+}
+
+func TestIsDirEmpty(t *testing.T) {
+ emptyDir, err := ioutil.TempDir("", "testing")
+ if err != nil {
+ t.Fatalf("Failed to create tmp dir: %v", err)
+ }
+ defer os.RemoveAll(emptyDir)
+
+ nonEmptyDir, err := ioutil.TempDir("", "testing")
+ if err != nil {
+ t.Fatalf("Failed to create tmp dir: %v", err)
+ }
+ defer os.RemoveAll(nonEmptyDir)
+
+ tmpfile, err := ioutil.TempFile(nonEmptyDir, "testing")
+ if err != nil {
+ t.Fatalf("Failed to create tmp file: %v", err)
+ }
+
+ defer os.Remove(tmpfile.Name())
+
+ for _, tc := range []struct {
+ testName string
+ dir string
+ expectEmpty bool
+ }{
+ {"EmptyDir", emptyDir, true},
+ {"NonEmptyDir", nonEmptyDir, false},
+ } {
+ ret, _ := IsDirEmpty(tc.dir)
+ if ret != tc.expectEmpty {
+ t.Errorf("%v: Unexpected return, want: %v, got: %v", tc.testName, tc.expectEmpty, ret)
+ }
+ }
+}
+
+func TestLoadEnvFromFile(t *testing.T) {
+ testDir, err := ioutil.TempDir("", "testing")
+ if err != nil {
+ t.Fatalf("Failed to create test dir: %v", err)
+ }
+ defer os.RemoveAll(testDir)
+ envStr := `key1=value1
+key2=value2`
+ if err := ioutil.WriteFile(filepath.Join(testDir, "env"), []byte(envStr), 0644); err != nil {
+ t.Fatalf("Failed to write to env file: %v", err)
+ }
+
+ envs, err := LoadEnvFromFile(testDir, "env")
+ if err != nil {
+ t.Fatalf("Failed to read from env file: %v", err)
+ }
+
+ expectedEnvs := map[string]string{
+ "key1": "value1",
+ "key2": "value2",
+ }
+ if diff := cmp.Diff(envs, expectedEnvs); diff != "" {
+ t.Errorf("Unexpected envs, want: %v, got: %v, diff: %v", expectedEnvs, envs, diff)
+ }
+}