| package dkms |
| |
| import ( |
| "bytes" |
| "context" |
| "crypto" |
| "crypto/ecdsa" |
| "crypto/elliptic" |
| "crypto/rand" |
| "crypto/rsa" |
| "crypto/sha256" |
| "crypto/x509" |
| "crypto/x509/pkix" |
| "encoding/asn1" |
| "encoding/pem" |
| "fmt" |
| "hash/crc32" |
| "math/big" |
| "net" |
| "os" |
| "path" |
| "reflect" |
| "testing" |
| "time" |
| |
| kms "cloud.google.com/go/kms/apiv1" |
| "cloud.google.com/go/kms/apiv1/kmspb" |
| "github.com/google/go-cmp/cmp" |
| "google.golang.org/api/option" |
| "google.golang.org/grpc" |
| "google.golang.org/grpc/credentials/insecure" |
| "google.golang.org/protobuf/types/known/wrapperspb" |
| ) |
| |
| type fakeKeyManagementServer struct { |
| kmspb.UnimplementedKeyManagementServiceServer |
| } |
| |
| func (f *fakeKeyManagementServer) GetCryptoKeyVersion(ctx context.Context, req *kmspb.GetCryptoKeyVersionRequest) (*kmspb.CryptoKeyVersion, error) { |
| resp := &kmspb.CryptoKeyVersion{ |
| Algorithm: kmspb.CryptoKeyVersion_RSA_DECRYPT_OAEP_2048_SHA1, |
| } |
| return resp, nil |
| } |
| |
| func (f *fakeKeyManagementServer) AsymmetricSign(ctx context.Context, req *kmspb.AsymmetricSignRequest) (*kmspb.AsymmetricSignResponse, error) { |
| t := crc32.MakeTable(crc32.Castagnoli) |
| signature := []byte("asymmetric sign") |
| resp := &kmspb.AsymmetricSignResponse{ |
| Signature: signature, |
| Name: "key", |
| SignatureCrc32C: &wrapperspb.Int64Value{ |
| Value: int64(crc32.Checksum(signature, t)), |
| }, |
| VerifiedDigestCrc32C: true, |
| } |
| return resp, nil |
| } |
| |
| func TestParsePrivateKey(t *testing.T) { |
| ecdsaBytes, err := generateECDSAPrivateKey() |
| if err != nil { |
| t.Fatalf("failed to generate ecdsa private key: %v", err) |
| } |
| |
| _, rsaBytes, err := generateRSAPrivateKey() |
| if err != nil { |
| t.Fatalf("failed to generate rsa private key: %v", err) |
| } |
| testCases := []struct { |
| desc string |
| contents []byte |
| wantErr bool |
| expectedKey crypto.PrivateKey |
| }{ |
| { |
| "Unsupported key", |
| ecdsaBytes, |
| true, |
| nil, |
| }, |
| { |
| "Supported rsa key", |
| rsaBytes, |
| false, |
| &rsa.PrivateKey{}, |
| }, |
| { |
| "Empty bytes", |
| []byte{}, |
| true, |
| nil, |
| }, |
| { |
| "Nonsense bytes", |
| []byte("this is not a key"), |
| true, |
| nil, |
| }, |
| } |
| |
| for _, test := range testCases { |
| t.Run(test.desc, func(t *testing.T) { |
| privateKey, err := parsePrivateKey(test.contents) |
| if gotErr := err != nil; gotErr != test.wantErr { |
| t.Fatalf("got error: %v, want error: %v", err, test.wantErr) |
| } |
| if reflect.TypeOf(privateKey) != reflect.TypeOf(test.expectedKey) { |
| t.Fatalf("private key type did not match; expected: %T, got: %T", test.expectedKey, privateKey) |
| } |
| }) |
| } |
| } |
| |
| func TestGenerateModuleSignatureLocal(t *testing.T) { |
| privateKey, _, err := generateRSAPrivateKey() |
| if err != nil { |
| t.Fatalf("failed to generate rsa key: %v", err) |
| } |
| |
| content := []byte("module") |
| signer := &LocalSigner{key: privateKey} |
| hash := crypto.SHA256 |
| signature, err := signer.Sign(content, hash) |
| if err != nil { |
| t.Fatalf("%v", err) |
| } |
| |
| hashedDigest := sha256.Sum256(content) |
| // Ensure signature is RSA PKCS #1 v1.5 signature |
| publicKey := privateKey.PublicKey |
| if err := rsa.VerifyPKCS1v15(&publicKey, hash, hashedDigest[:], signature); err != nil { |
| t.Fatalf("failed to verify signature: %v", err) |
| } |
| } |
| |
| func TestGenerateModuleSignatureKMS(t *testing.T) { |
| ctx := context.Background() |
| fakeKmsClient, err := createFakeKMSClient() |
| if err != nil { |
| t.Fatalf("failed to create fake kms client: %v", err) |
| } |
| content := []byte("module") |
| |
| signer := &KmsSigner{ctx: ctx, client: fakeKmsClient, keyName: "key"} |
| hash := crypto.SHA256 |
| signature, err := signer.Sign(content, hash) |
| if err != nil { |
| t.Fatalf("%v", err) |
| } |
| |
| expectedSignature := []byte("asymmetric sign") |
| // Ensure signature is RSA PKCS #1 v1.5 signature |
| if diff := cmp.Diff(expectedSignature, signature); diff != "" { |
| t.Errorf("signature did not match expected value \nwant: %v\ngot: %v\ndiff: %v", |
| expectedSignature, signature, diff) |
| } |
| |
| } |
| |
| func TestAppendSignature(t *testing.T) { |
| signature := []byte("signature") |
| contents := []byte("module") |
| |
| signedBytes, err := appendSignature(signature, contents) |
| if err != nil { |
| t.Fatalf("%v", err) |
| } |
| |
| expectedBytes := [...]byte{ |
| // The following line is the bytes of the original module: "module" |
| 0x6D, 0x6F, 0x64, 0x75, 0x6c, 0x65, |
| // The following line is the bytes of the signature: "signature" |
| 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, |
| // The following lines are the bytes of module_signature struct |
| 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, |
| 0x00, 0x00, 0x00, 0x09, |
| // The following lines are the bytes of PKCS7 message: "~Module signature appended~\n" |
| 0x7e, 0x4d, 0x6f, 0x64, 0x75, 0x6c, 0x65, 0x20, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x20, 0x61, |
| 0x70, 0x70, 0x65, 0x6e, 0x64, 0x65, 0x64, 0x7e, 0xa, |
| } |
| |
| if diff := cmp.Diff(expectedBytes[:], signedBytes); diff != "" { |
| t.Errorf("signed module content doesn't match,\nwant: %v\ngot: %v\ndiff: %v", |
| expectedBytes, signedBytes, diff) |
| } |
| } |
| |
| func TestSignModule(t *testing.T) { |
| privateKey, _, err := generateRSAPrivateKey() |
| if err != nil { |
| t.Fatalf("failed to generate rsa key: %v", err) |
| } |
| |
| certBytes, err := generateCertificate(privateKey) |
| if err != nil { |
| t.Fatalf("failed to generate certificate: %v", err) |
| } |
| cert, err := x509.ParseCertificate(certBytes) |
| if err != nil { |
| t.Fatalf("failed to parse certificate: %v", err) |
| } |
| |
| fakeKmsClient, err := createFakeKMSClient() |
| if err != nil { |
| t.Fatalf("failed to create fake kms client: %v", err) |
| } |
| |
| testCases := []struct { |
| desc string |
| content []byte |
| signer ModuleSigner |
| }{ |
| { |
| "Local signing", |
| []byte("Signed locally"), |
| &LocalSigner{ |
| key: privateKey, |
| }, |
| }, |
| { |
| "KMS signing", |
| []byte("Signed by KMS"), |
| &KmsSigner{ |
| ctx: context.Background(), |
| client: fakeKmsClient, |
| keyName: "key", |
| }, |
| }, |
| } |
| |
| for _, test := range testCases { |
| t.Run(test.desc, func(t *testing.T) { |
| digest := digests["sha256"] |
| encryptAlg := &pkix.AlgorithmIdentifier{Algorithm: oidRsaEncryption, Parameters: asn1.NullRawValue} |
| signedModule, err := signModule(test.signer, test.content, digest, cert, encryptAlg) |
| if err != nil { |
| t.Fatalf("%v", err) |
| } |
| |
| if !bytes.HasPrefix(signedModule, test.content) { |
| t.Errorf("signed data is missing module content:\nexpected:%s\tgot:%s", test.content, signedModule[:len(test.content)]) |
| } |
| if !bytes.HasSuffix(signedModule, []byte(magicNumber)) { |
| t.Errorf("signed module does not end with magic number:\nexpected:%s\tgot:%s", |
| []byte(magicNumber), signedModule[len(signedModule)-len(magicNumber):]) |
| } |
| }) |
| } |
| } |
| |
| func TestSignModules(t *testing.T) { |
| tmpDir := t.TempDir() |
| |
| keyPath := path.Join(tmpDir, "rsaKey.pem") |
| key, keyBytes, err := generateRSAPrivateKey() |
| if err != nil { |
| t.Fatalf("failed to generate rsa key: %v", err) |
| } |
| if err = os.WriteFile(keyPath, keyBytes, 0600); err != nil { |
| t.Fatalf("failed to write rsa key to file (%s): %v", keyPath, err) |
| } |
| |
| certPath := path.Join(tmpDir, "cert.der") |
| certBytes, err := generateCertificate(key) |
| if err != nil { |
| t.Fatalf("failed to generate certificate: %v", err) |
| } |
| if err = os.WriteFile(certPath, certBytes, 0600); err != nil { |
| t.Fatalf("failed to write certificate content to file (%s): %v", keyPath, err) |
| } |
| |
| pkg := &Package{ |
| Name: "module", |
| Version: "1.0", |
| KernelVersion: "6.1.100", |
| Arch: "x86_64", |
| BuildId: "18244.151.14", |
| Trees: &Trees{Dkms: tmpDir}, |
| } |
| |
| module := Module{ |
| BuiltName: "module", |
| Package: pkg, |
| } |
| |
| modules := []Module{module} |
| modulePath := module.BuiltPath() |
| |
| err = os.MkdirAll(path.Dir(modulePath), 0755) |
| if err != nil { |
| t.Fatalf("failed to create module directory (%s): %v", modulePath, err) |
| } |
| |
| content := []byte("module") |
| if err := os.WriteFile(modulePath, content, 0600); err != nil { |
| t.Fatalf("failed to write temp module file (%s): %v", modulePath, err) |
| } |
| |
| signer := &LocalSigner{ |
| key: key, |
| } |
| |
| if err := SignModules(modules, certPath, "sha256", signer); err != nil { |
| t.Fatalf("%v", err) |
| } |
| |
| signedModule, err := os.ReadFile(modulePath) |
| if err != nil { |
| t.Fatalf("failed to read from module file after signing: %v", err) |
| } |
| |
| if !bytes.HasSuffix(signedModule, []byte(magicNumber)) { |
| t.Errorf("signed module does not end with magic number:\nexpected:%s\tgot:%s", |
| []byte(magicNumber), signedModule[len(signedModule)-len(magicNumber):]) |
| } |
| |
| if !bytes.HasPrefix(signedModule, content) { |
| t.Errorf("signed module does not start with original content:\nexpected:%s\tgot:%s", |
| content, signedModule[len(content):]) |
| } |
| } |
| |
| func createFakeKMSClient() (*kms.KeyManagementClient, error) { |
| ctx := context.Background() |
| |
| // Setup the fake server. |
| fakeKeyManagementServer := &fakeKeyManagementServer{} |
| l, err := net.Listen("tcp", "localhost:0") |
| if err != nil { |
| return nil, err |
| } |
| gsrv := grpc.NewServer() |
| kmspb.RegisterKeyManagementServiceServer(gsrv, fakeKeyManagementServer) |
| fakeServerAddr := l.Addr().String() |
| go func() { |
| if err := gsrv.Serve(l); err != nil { |
| panic(err) |
| } |
| }() |
| client, err := kms.NewKeyManagementClient(ctx, |
| option.WithEndpoint(fakeServerAddr), |
| option.WithoutAuthentication(), |
| option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), |
| ) |
| if err != nil { |
| return nil, err |
| } |
| return client, nil |
| } |
| |
| // generateRSAPrivateKey creates a new RSA private key |
| func generateRSAPrivateKey() (*rsa.PrivateKey, []byte, error) { |
| privateKey, err := rsa.GenerateKey(rand.Reader, 4096) |
| if err != nil { |
| return nil, nil, fmt.Errorf("failed to generate private key: %v", err) |
| } |
| |
| // Convert the private key to PEM format. |
| pemBytes := pem.EncodeToMemory(&pem.Block{ |
| Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) |
| |
| return privateKey, pemBytes, nil |
| } |
| |
| // generateECDSAPrivateKey creates a new ecdsa private key |
| func generateECDSAPrivateKey() ([]byte, error) { |
| privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
| if err != nil { |
| return nil, fmt.Errorf("failed to generate private key: %v", err) |
| } |
| |
| derBytes, err := x509.MarshalECPrivateKey(privateKey) |
| if err != nil { |
| return nil, fmt.Errorf("failed to marshal ecdsa private key: %v", err) |
| } |
| |
| // Retrieve the PEM format of the private key |
| pemBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: derBytes}) |
| return pemBytes, nil |
| } |
| |
| // generateCertificate creates a new certificate with an associated private key |
| func generateCertificate(privateKey *rsa.PrivateKey) ([]byte, error) { |
| serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) |
| if err != nil { |
| return nil, fmt.Errorf("failed to generate serial number: %v", err) |
| } |
| template := x509.Certificate{ |
| SerialNumber: serialNumber, |
| Subject: pkix.Name{ |
| Country: []string{"US"}, |
| CommonName: "secure-boot-cert", |
| }, |
| NotBefore: time.Now(), |
| |
| KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, |
| ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, |
| DNSNames: []string{"secure-boot.com"}, |
| } |
| |
| derBytes, err := x509.CreateCertificate( |
| rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) |
| if err != nil { |
| return nil, fmt.Errorf("failed to create certificate: %v", err) |
| } |
| |
| return derBytes, nil |
| } |