blob: 13603edfaf016cbab2b84f12989217e646bee603 [file] [log] [blame] [edit]
package dkms
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"fmt"
"hash/crc32"
"math/big"
"net"
"os"
"path"
"reflect"
"testing"
"time"
kms "cloud.google.com/go/kms/apiv1"
"cloud.google.com/go/kms/apiv1/kmspb"
"github.com/google/go-cmp/cmp"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/types/known/wrapperspb"
)
type fakeKeyManagementServer struct {
kmspb.UnimplementedKeyManagementServiceServer
}
func (f *fakeKeyManagementServer) GetCryptoKeyVersion(ctx context.Context, req *kmspb.GetCryptoKeyVersionRequest) (*kmspb.CryptoKeyVersion, error) {
resp := &kmspb.CryptoKeyVersion{
Algorithm: kmspb.CryptoKeyVersion_RSA_DECRYPT_OAEP_2048_SHA1,
}
return resp, nil
}
func (f *fakeKeyManagementServer) AsymmetricSign(ctx context.Context, req *kmspb.AsymmetricSignRequest) (*kmspb.AsymmetricSignResponse, error) {
t := crc32.MakeTable(crc32.Castagnoli)
signature := []byte("asymmetric sign")
resp := &kmspb.AsymmetricSignResponse{
Signature: signature,
Name: "key",
SignatureCrc32C: &wrapperspb.Int64Value{
Value: int64(crc32.Checksum(signature, t)),
},
VerifiedDigestCrc32C: true,
}
return resp, nil
}
func TestParsePrivateKey(t *testing.T) {
ecdsaBytes, err := generateECDSAPrivateKey()
if err != nil {
t.Fatalf("failed to generate ecdsa private key: %v", err)
}
_, rsaBytes, err := generateRSAPrivateKey()
if err != nil {
t.Fatalf("failed to generate rsa private key: %v", err)
}
testCases := []struct {
desc string
contents []byte
wantErr bool
expectedKey crypto.PrivateKey
}{
{
"Unsupported key",
ecdsaBytes,
true,
nil,
},
{
"Supported rsa key",
rsaBytes,
false,
&rsa.PrivateKey{},
},
{
"Empty bytes",
[]byte{},
true,
nil,
},
{
"Nonsense bytes",
[]byte("this is not a key"),
true,
nil,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
privateKey, err := parsePrivateKey(test.contents)
if gotErr := err != nil; gotErr != test.wantErr {
t.Fatalf("got error: %v, want error: %v", err, test.wantErr)
}
if reflect.TypeOf(privateKey) != reflect.TypeOf(test.expectedKey) {
t.Fatalf("private key type did not match; expected: %T, got: %T", test.expectedKey, privateKey)
}
})
}
}
func TestGenerateModuleSignatureLocal(t *testing.T) {
privateKey, _, err := generateRSAPrivateKey()
if err != nil {
t.Fatalf("failed to generate rsa key: %v", err)
}
content := []byte("module")
signer := &LocalSigner{key: privateKey}
hash := crypto.SHA256
signature, err := signer.Sign(content, hash)
if err != nil {
t.Fatalf("%v", err)
}
hashedDigest := sha256.Sum256(content)
// Ensure signature is RSA PKCS #1 v1.5 signature
publicKey := privateKey.PublicKey
if err := rsa.VerifyPKCS1v15(&publicKey, hash, hashedDigest[:], signature); err != nil {
t.Fatalf("failed to verify signature: %v", err)
}
}
func TestGenerateModuleSignatureKMS(t *testing.T) {
ctx := context.Background()
fakeKmsClient, err := createFakeKMSClient()
if err != nil {
t.Fatalf("failed to create fake kms client: %v", err)
}
content := []byte("module")
signer := &KmsSigner{ctx: ctx, client: fakeKmsClient, keyName: "key"}
hash := crypto.SHA256
signature, err := signer.Sign(content, hash)
if err != nil {
t.Fatalf("%v", err)
}
expectedSignature := []byte("asymmetric sign")
// Ensure signature is RSA PKCS #1 v1.5 signature
if diff := cmp.Diff(expectedSignature, signature); diff != "" {
t.Errorf("signature did not match expected value \nwant: %v\ngot: %v\ndiff: %v",
expectedSignature, signature, diff)
}
}
func TestAppendSignature(t *testing.T) {
signature := []byte("signature")
contents := []byte("module")
signedBytes, err := appendSignature(signature, contents)
if err != nil {
t.Fatalf("%v", err)
}
expectedBytes := [...]byte{
// The following line is the bytes of the original module: "module"
0x6D, 0x6F, 0x64, 0x75, 0x6c, 0x65,
// The following line is the bytes of the signature: "signature"
0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65,
// The following lines are the bytes of module_signature struct
0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x09,
// The following lines are the bytes of PKCS7 message: "~Module signature appended~\n"
0x7e, 0x4d, 0x6f, 0x64, 0x75, 0x6c, 0x65, 0x20, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x20, 0x61,
0x70, 0x70, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x7e, 0xa,
}
if diff := cmp.Diff(expectedBytes[:], signedBytes); diff != "" {
t.Errorf("signed module content doesn't match,\nwant: %v\ngot: %v\ndiff: %v",
expectedBytes, signedBytes, diff)
}
}
func TestSignModule(t *testing.T) {
privateKey, _, err := generateRSAPrivateKey()
if err != nil {
t.Fatalf("failed to generate rsa key: %v", err)
}
certBytes, err := generateCertificate(privateKey)
if err != nil {
t.Fatalf("failed to generate certificate: %v", err)
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
t.Fatalf("failed to parse certificate: %v", err)
}
fakeKmsClient, err := createFakeKMSClient()
if err != nil {
t.Fatalf("failed to create fake kms client: %v", err)
}
testCases := []struct {
desc string
content []byte
signer ModuleSigner
}{
{
"Local signing",
[]byte("Signed locally"),
&LocalSigner{
key: privateKey,
},
},
{
"KMS signing",
[]byte("Signed by KMS"),
&KmsSigner{
ctx: context.Background(),
client: fakeKmsClient,
keyName: "key",
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
digest := digests["sha256"]
encryptAlg := &pkix.AlgorithmIdentifier{Algorithm: oidRsaEncryption, Parameters: asn1.NullRawValue}
signedModule, err := signModule(test.signer, test.content, digest, cert, encryptAlg)
if err != nil {
t.Fatalf("%v", err)
}
if !bytes.HasPrefix(signedModule, test.content) {
t.Errorf("signed data is missing module content:\nexpected:%s\tgot:%s", test.content, signedModule[:len(test.content)])
}
if !bytes.HasSuffix(signedModule, []byte(magicNumber)) {
t.Errorf("signed module does not end with magic number:\nexpected:%s\tgot:%s",
[]byte(magicNumber), signedModule[len(signedModule)-len(magicNumber):])
}
})
}
}
func TestSignModules(t *testing.T) {
tmpDir := t.TempDir()
keyPath := path.Join(tmpDir, "rsaKey.pem")
key, keyBytes, err := generateRSAPrivateKey()
if err != nil {
t.Fatalf("failed to generate rsa key: %v", err)
}
if err = os.WriteFile(keyPath, keyBytes, 0600); err != nil {
t.Fatalf("failed to write rsa key to file (%s): %v", keyPath, err)
}
certPath := path.Join(tmpDir, "cert.der")
certBytes, err := generateCertificate(key)
if err != nil {
t.Fatalf("failed to generate certificate: %v", err)
}
if err = os.WriteFile(certPath, certBytes, 0600); err != nil {
t.Fatalf("failed to write certificate content to file (%s): %v", keyPath, err)
}
pkg := &Package{
Name: "module",
Version: "1.0",
KernelVersion: "6.1.100",
Arch: "x86_64",
BuildId: "18244.151.14",
Trees: &Trees{Dkms: tmpDir},
}
module := Module{
BuiltName: "module",
Package: pkg,
}
modules := []Module{module}
modulePath := module.BuiltPath()
err = os.MkdirAll(path.Dir(modulePath), 0755)
if err != nil {
t.Fatalf("failed to create module directory (%s): %v", modulePath, err)
}
content := []byte("module")
if err := os.WriteFile(modulePath, content, 0600); err != nil {
t.Fatalf("failed to write temp module file (%s): %v", modulePath, err)
}
signer := &LocalSigner{
key: key,
}
if err := SignModules(modules, certPath, "sha256", signer); err != nil {
t.Fatalf("%v", err)
}
signedModule, err := os.ReadFile(modulePath)
if err != nil {
t.Fatalf("failed to read from module file after signing: %v", err)
}
if !bytes.HasSuffix(signedModule, []byte(magicNumber)) {
t.Errorf("signed module does not end with magic number:\nexpected:%s\tgot:%s",
[]byte(magicNumber), signedModule[len(signedModule)-len(magicNumber):])
}
if !bytes.HasPrefix(signedModule, content) {
t.Errorf("signed module does not start with original content:\nexpected:%s\tgot:%s",
content, signedModule[len(content):])
}
}
func createFakeKMSClient() (*kms.KeyManagementClient, error) {
ctx := context.Background()
// Setup the fake server.
fakeKeyManagementServer := &fakeKeyManagementServer{}
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, err
}
gsrv := grpc.NewServer()
kmspb.RegisterKeyManagementServiceServer(gsrv, fakeKeyManagementServer)
fakeServerAddr := l.Addr().String()
go func() {
if err := gsrv.Serve(l); err != nil {
panic(err)
}
}()
client, err := kms.NewKeyManagementClient(ctx,
option.WithEndpoint(fakeServerAddr),
option.WithoutAuthentication(),
option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())),
)
if err != nil {
return nil, err
}
return client, nil
}
// generateRSAPrivateKey creates a new RSA private key
func generateRSAPrivateKey() (*rsa.PrivateKey, []byte, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %v", err)
}
// Convert the private key to PEM format.
pemBytes := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
return privateKey, pemBytes, nil
}
// generateECDSAPrivateKey creates a new ecdsa private key
func generateECDSAPrivateKey() ([]byte, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %v", err)
}
derBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal ecdsa private key: %v", err)
}
// Retrieve the PEM format of the private key
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
return pemBytes, nil
}
// generateCertificate creates a new certificate with an associated private key
func generateCertificate(privateKey *rsa.PrivateKey) ([]byte, error) {
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("failed to generate serial number: %v", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Country: []string{"US"},
CommonName: "secure-boot-cert",
},
NotBefore: time.Now(),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"secure-boot.com"},
}
derBytes, err := x509.CreateCertificate(
rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, fmt.Errorf("failed to create certificate: %v", err)
}
return derBytes, nil
}