Merge "cos-gpu-installer-v2: Add modules package"
diff --git a/src/pkg/modules/modules.go b/src/pkg/modules/modules.go
new file mode 100644
index 0000000..bb7b9d0
--- /dev/null
+++ b/src/pkg/modules/modules.go
@@ -0,0 +1,171 @@
+// Package modules provides fucntionality to install and sign Linux kernel modules.
+package modules
+
+import (
+	"bytes"
+	"encoding/binary"
+	"io"
+	"io/ioutil"
+	"os"
+	"os/exec"
+	"path/filepath"
+	"strings"
+
+	log "github.com/golang/glog"
+	"github.com/pkg/errors"
+
+	"pkg/utils"
+)
+
+const (
+	// PKEYIDPKCS7 is a constant defined in https://github.com/torvalds/linux/blob/master/scripts/sign-file.c
+	PKEYIDPKCS7 = byte(2)
+	// magicNumber is a constant defined in https://github.com/torvalds/linux/blob/master/scripts/sign-file.c
+	magicNumber = "~Module signature appended~\n"
+)
+
+var (
+	execCommand = exec.Command
+)
+
+// LoadModule loads a given kernel module to kernel.
+func LoadModule(moduleName, modulePath string) error {
+	loaded, err := isModuleLoaded(moduleName)
+	if err != nil {
+		return errors.Wrapf(err, "failed to load module %s (%s)", moduleName, modulePath)
+	}
+	if loaded {
+		return nil
+	}
+	if err := loadModule(modulePath); err != nil {
+		return errors.Wrapf(err, "failed to load module %s (%s)", moduleName, modulePath)
+	}
+	return nil
+}
+
+// UpdateHostLdCache updates the ld cache on host.
+func UpdateHostLdCache(hostRootDir, moduleLibDir string) error {
+	log.Info("Updating host's ld cache")
+	ldPath := filepath.Join(hostRootDir, "/etc/ld.so.conf")
+	f, err := os.OpenFile(ldPath, os.O_APPEND|os.O_WRONLY, 0644)
+	if err != nil {
+		return errors.Wrapf(err, "failed to open %s", ldPath)
+	}
+	defer f.Close()
+
+	if _, err := f.WriteString(moduleLibDir); err != nil {
+		return errors.Wrapf(err, "failed to write \"%s\" to %s", moduleLibDir, ldPath)
+	}
+
+	if err := execCommand("ldconfig", "-r", hostRootDir).Run(); err != nil {
+		return errors.Wrapf(err, "failed to run `ldconfig -r %s`", hostRootDir)
+	}
+
+	return nil
+}
+
+// LoadPublicKey loads the given public key to the secondary keyring.
+func LoadPublicKey(keyName, keyPath string) error {
+	log.Infof("Loading %s to secondary system keyring", keyName)
+
+	keyBytes, err := ioutil.ReadFile(keyPath)
+	if err != nil {
+		return errors.Wrapf(err, "failed to read key %s", keyPath)
+	}
+
+	cmd := execCommand("/bin/keyctl", "padd", "asymmetric", keyName, "%keyring:.secondary_trusted_keys")
+	cmd.Stdin = bytes.NewBuffer(keyBytes)
+	if err := cmd.Run(); err != nil {
+		return errors.Wrapf(err, "failed to load %s to system keyring", keyName)
+	}
+	log.Infof("Successfully load key %s into secondary system keyring.", keyName)
+	return nil
+}
+
+// AppendSignature appends a raw PKCS#7 signature to the end of a given kernel module.
+// This is basically the Go implementation of `scripts/sign-file -s` in Linux upstream.
+func AppendSignature(outfilePath, modulefilePath, sigfilePath string) error {
+	tempFile, err := ioutil.TempFile("", "tempFile")
+	if err != nil {
+		return errors.Wrap(err, "failed to create temp file")
+	}
+	defer os.Remove(tempFile.Name())
+	defer tempFile.Close()
+
+	// Copy bytes of kernel module into the temp file.
+	modulefile, err := os.Open(modulefilePath)
+	if err != nil {
+		return errors.Wrapf(err, "failed to open file %s", modulefilePath)
+	}
+	defer modulefile.Close()
+
+	_, err = io.Copy(tempFile, modulefile)
+	if err != nil {
+		return errors.Wrap(err, "failed to copy file")
+	}
+
+	// Append bytes of module signature into the temp file.
+	sigfile, err := os.Open(sigfilePath)
+	if err != nil {
+		return errors.Wrapf(err, "failed to open file %s", sigfilePath)
+	}
+	defer sigfile.Close()
+
+	sigSize, err := io.Copy(tempFile, sigfile)
+	if err != nil {
+		return errors.Wrap(err, "failed to copy file")
+	}
+
+	// Append the marker and the PKCS#7 message.
+	// moduleSignature is the struct module_signature defined in
+	// https://github.com/torvalds/linux/blob/master/scripts/sign-file.c
+	moduleSignature := [12]byte{}
+	// moduleSignature[2] is the id_type of struct module_signature
+	moduleSignature[2] = PKEYIDPKCS7
+	// moduleSignature[8:12] is the sig_len of struct module_signature.
+	// Using BigEndian as the sig_len should be in network byte order.
+	binary.BigEndian.PutUint32(moduleSignature[8:12], uint32(sigSize))
+	_, err = tempFile.Write(moduleSignature[:])
+	if err != nil {
+		return errors.Wrapf(err, "failed to write to file %s", tempFile.Name())
+	}
+
+	_, err = tempFile.Write([]byte(magicNumber))
+	if err != nil {
+		return errors.Wrapf(err, "failed to write to file %s", tempFile.Name())
+	}
+
+	if err := tempFile.Close(); err != nil {
+		return errors.Wrapf(err, "failed to close file %s", tempFile.Name())
+	}
+
+	// Finally, move the outfile to specified location.
+	// It overwrites the original module file if we are appending in place.
+	if err := utils.MoveFile(tempFile.Name(), outfilePath); err != nil {
+		return errors.Wrapf(err, "failed to rename file from %s to %s", tempFile.Name(), outfilePath)
+	}
+
+	return nil
+}
+
+func isModuleLoaded(moduleName string) (bool, error) {
+	out, err := execCommand("lsmod").Output()
+	if err != nil {
+		return false, errors.Wrap(err, "failed to run command `lsmod`")
+	}
+
+	for _, line := range strings.Split(string(out), "\n") {
+		fields := strings.Fields(line)
+		if len(fields) > 0 && fields[0] == moduleName {
+			return true, nil
+		}
+	}
+	return false, nil
+}
+
+func loadModule(modulePath string) error {
+	if err := execCommand("insmod", modulePath).Run(); err != nil {
+		return errors.Wrapf(err, "failed to run command `insmod %s`", modulePath)
+	}
+	return nil
+}
diff --git a/src/pkg/modules/modules_test.go b/src/pkg/modules/modules_test.go
new file mode 100644
index 0000000..bb50d38
--- /dev/null
+++ b/src/pkg/modules/modules_test.go
@@ -0,0 +1,132 @@
+package modules
+
+import (
+	"fmt"
+	"io/ioutil"
+	"os"
+	"os/exec"
+	"strconv"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+var (
+	mockCmdStdout     string
+	mockCmdExitStatus = 0
+)
+
+func fakeExecCommand(command string, args ...string) *exec.Cmd {
+	cs := []string{"-test.run=TestHelperProcess", "--", command}
+	cs = append(cs, args...)
+	cmd := exec.Command(os.Args[0], cs...)
+	es := strconv.Itoa(mockCmdExitStatus)
+	cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1",
+		"STDOUT=" + mockCmdStdout,
+		"EXIT_STATUS=" + es}
+	return cmd
+}
+
+// TestHelperProcess is not a real test. It is a helper process for faking exec.Command.
+func TestHelperProcess(t *testing.T) {
+	if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
+		return
+	}
+	fmt.Fprintf(os.Stdout, os.Getenv("STDOUT"))
+	es, err := strconv.Atoi(os.Getenv("EXIT_STATUS"))
+	if err != nil {
+		t.Fatalf("Failed to convert EXIT_STATUS to int: %v", err)
+	}
+	os.Exit(es)
+}
+
+func TestHasInstalled(t *testing.T) {
+	execCommand = fakeExecCommand
+	defer func() {
+		execCommand = exec.Command
+		mockCmdExitStatus = 0
+	}()
+
+	for _, tc := range []struct {
+		testName      string
+		moduleName    string
+		cmdStdout     string
+		cmdExitStatus int
+		expectOutput  bool
+	}{
+		{"TestModuleInstalled", "nf_nat",
+			"Module\tSize\tUsed by\nnf_nat_ipv4\t16384\t2 ipt_MASQUERADE,iptable_nat\nnf_nat\t53248\t1 nf_nat_ipv4\n",
+			0, true,
+		},
+		{"TestModuleNotInstalled", "fat",
+			"Module\tSize\tUsed by\nnf_nat_ipv4\t16384\t2 ipt_MASQUERADE,iptable_nat\nnf_nat\t53248\t1 nf_nat_ipv4\n",
+			0, false,
+		},
+	} {
+		t.Run(tc.testName, func(t *testing.T) {
+			mockCmdStdout = tc.cmdStdout
+			mockCmdExitStatus = tc.cmdExitStatus
+			out, err := isModuleLoaded(tc.moduleName)
+			if err != nil {
+				t.Errorf("Unexpected error: %v", err)
+			}
+			if out != tc.expectOutput {
+				t.Errorf("Unexpected return value, want %v, got %v", tc.expectOutput, out)
+			}
+		})
+	}
+}
+
+func TestAppendSignature(t *testing.T) {
+	modulefile, err := ioutil.TempFile("", "modulefile")
+	if err != nil {
+		t.Fatalf("AppendSignature: failed to create temp file: %v", err)
+	}
+	defer os.Remove(modulefile.Name())
+	sigfile, err := ioutil.TempFile("", "sigfile")
+	if err != nil {
+		t.Fatalf("AppendSignature: failed to create temp file: %v", err)
+	}
+	defer os.Remove(sigfile.Name())
+
+	_, err = modulefile.Write([]byte("module"))
+	if err != nil {
+		t.Fatalf("AppendSignature: failed to write to file %s: %v", modulefile.Name(), err)
+	}
+	if err := modulefile.Close(); err != nil {
+		t.Fatalf("AppendSignature: failed to close file %s: %v", modulefile.Name(), err)
+	}
+
+	_, err = sigfile.Write([]byte("signature"))
+	if err != nil {
+		t.Fatalf("AppendSignature: failed to write to file %s: %v", sigfile.Name(), err)
+	}
+	if err := sigfile.Close(); err != nil {
+		t.Fatalf("AppendSignature: failed to close file %s: %v", sigfile.Name(), err)
+	}
+
+	if err := AppendSignature(modulefile.Name(), modulefile.Name(), sigfile.Name()); err != nil {
+		t.Fatalf("AppendSignature: failed to run with error: %v", err)
+	}
+	signedModuleBytes, err := ioutil.ReadFile(modulefile.Name())
+	if err != nil {
+		t.Fatalf("AppendSignature: failed to read signed module file: %v", err)
+	}
+	expectedBytes := [...]byte{
+		// The following line is the bytes of the original module: "module"
+		0x6D, 0x6F, 0x64, 0x75, 0x6c, 0x65,
+		// The following line is the bytes of the signature: "signature"
+		0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65,
+		// The following lines are the bytes of module_signature struct
+		0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x09,
+		// The following lines are the bytes of PKCS7 message: "~Module signature appended~\n"
+		0x7e, 0x4d, 0x6f, 0x64, 0x75, 0x6c, 0x65, 0x20, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x20, 0x61,
+		0x70, 0x70, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x7e, 0xa,
+	}
+
+	if diff := cmp.Diff(expectedBytes[:], signedModuleBytes); diff != "" {
+		t.Errorf("AppendSignature: signedModuleBytes doesn't match,\nwant: %v\ngot: %v\ndiff: %v",
+			expectedBytes, signedModuleBytes, diff)
+	}
+}