blob: a37b4a4dcfd642ba874016e62127f5f193023b46 [file] [log] [blame]
package dkms
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"os"
"path"
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
func TestParsePrivateKey(t *testing.T) {
ecdsaBytes, err := generateECDSAPrivateKey()
if err != nil {
t.Fatalf("failed to generate ecdsa private key: %v", err)
}
_, rsaBytes, err := generateRSAPrivateKey()
if err != nil {
t.Fatalf("failed to generate rsa private key: %v", err)
}
testCases := []struct {
desc string
contents []byte
wantErr bool
expectedKey crypto.PrivateKey
}{
{
"Unsupported key",
ecdsaBytes,
true,
nil,
},
{
"Supported rsa key",
rsaBytes,
false,
&rsa.PrivateKey{},
},
{
"Empty bytes",
[]byte{},
true,
nil,
},
{
"Nonsense bytes",
[]byte("this is not a key"),
true,
nil,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
privateKey, err := parsePrivateKey(test.contents)
if gotErr := err != nil; gotErr != test.wantErr {
t.Fatalf("got error: %v, want error: %v", err, test.wantErr)
}
if reflect.TypeOf(privateKey) != reflect.TypeOf(test.expectedKey) {
t.Fatalf("private key type did not match; expected: %T, got: %T", test.expectedKey, privateKey)
}
})
}
}
func TestGenerateModuleSignature(t *testing.T) {
privateKey, _, err := generateRSAPrivateKey()
if err != nil {
t.Fatalf("failed to generate rsa key: %v", err)
}
content := []byte("module")
hash := crypto.SHA256
signature, err := generateModuleSignature(privateKey, content, hash)
if err != nil {
t.Fatalf("%v", err)
}
hashedDigest := sha256.Sum256(content)
// Ensure signature is RSA PKCS #1 v1.5 signature
publicKey := privateKey.PublicKey
if err := rsa.VerifyPKCS1v15(&publicKey, hash, hashedDigest[:], signature); err != nil {
t.Fatalf("failed to verify signature: %v", err)
}
}
func TestAppendSignature(t *testing.T) {
signature := []byte("signature")
contents := []byte("module")
signedBytes, err := appendSignature(signature, contents)
if err != nil {
t.Fatalf("%v", err)
}
expectedBytes := [...]byte{
// The following line is the bytes of the original module: "module"
0x6D, 0x6F, 0x64, 0x75, 0x6c, 0x65,
// The following line is the bytes of the signature: "signature"
0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65,
// The following lines are the bytes of module_signature struct
0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x09,
// The following lines are the bytes of PKCS7 message: "~Module signature appended~\n"
0x7e, 0x4d, 0x6f, 0x64, 0x75, 0x6c, 0x65, 0x20, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x20, 0x61,
0x70, 0x70, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x7e, 0xa,
}
if diff := cmp.Diff(expectedBytes[:], signedBytes); diff != "" {
t.Errorf("signed module content doesn't match,\nwant: %v\ngot: %v\ndiff: %v",
expectedBytes, signedBytes, diff)
}
}
func TestSignModules(t *testing.T) {
tmpDir := t.TempDir()
keyPath := path.Join(tmpDir, "rsaKey.pem")
key, keyBytes, err := generateRSAPrivateKey()
if err != nil {
t.Fatalf("failed to generate rsa key: %v", err)
}
if err = os.WriteFile(keyPath, keyBytes, 0600); err != nil {
t.Fatalf("failed to write rsa key to file (%s): %v", keyPath, err)
}
certPath := path.Join(tmpDir, "cert.der")
certBytes, err := generateCertificate(key)
if err != nil {
t.Fatalf("failed to generate certificate: %v", err)
}
if err = os.WriteFile(certPath, certBytes, 0600); err != nil {
t.Fatalf("failed to write certificate content to file (%s): %v", keyPath, err)
}
pkg := &Package{
Name: "module",
Version: "1.0",
KernelVersion: "6.1.100",
Arch: "x86_64",
BuildId: "18244.151.14",
Trees: &Trees{Dkms: tmpDir},
}
module := Module{
BuiltName: "module",
Package: pkg,
}
modules := []Module{module}
modulePath := module.BuiltPath()
err = os.MkdirAll(path.Dir(modulePath), 0755)
if err != nil {
t.Fatalf("failed to create module directory (%s): %v", modulePath, err)
}
content := []byte("module")
if err := os.WriteFile(modulePath, content, 0600); err != nil {
t.Fatalf("failed to write temp module file (%s): %v", modulePath, err)
}
if err := SignModules(modules, keyPath, certPath, "sha256"); err != nil {
t.Fatalf("%v", err)
}
signedModule, err := os.ReadFile(modulePath)
if err != nil {
t.Fatalf("failed to read from module file after signing: %v", err)
}
if !bytes.HasSuffix(signedModule, []byte(magicNumber)) {
t.Errorf("signed module does not end with magic number:\nexpected:%s\tgot:%s",
[]byte(magicNumber), signedModule[len(signedModule)-len(magicNumber):])
}
if !bytes.HasPrefix(signedModule, content) {
t.Errorf("signed module does not start with original content:\nexpected:%s\tgot:%s",
content, signedModule[len(content):])
}
}
// generateRSAPrivateKey creates a new RSA private key
func generateRSAPrivateKey() (*rsa.PrivateKey, []byte, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %v", err)
}
// Convert the private key to PEM format.
pemBytes := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
return privateKey, pemBytes, nil
}
// generateECDSAPrivateKey creates a new ecdsa private key
func generateECDSAPrivateKey() ([]byte, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %v", err)
}
derBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal ecdsa private key: %v", err)
}
// Retrieve the PEM format of the private key
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
return pemBytes, nil
}
// generateCertificate creates a new certificate with an associated private key
func generateCertificate(privateKey *rsa.PrivateKey) ([]byte, error) {
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("failed to generate serial number: %v", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Country: []string{"US"},
CommonName: "secure-boot-cert",
},
NotBefore: time.Now(),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"secure-boot.com"},
}
derBytes, err := x509.CreateCertificate(
rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, fmt.Errorf("failed to create certificate: %v", err)
}
return derBytes, nil
}