blob: a59773f973efaec2eb4309b6e1a5faacd51a0bf8 [file] [log] [blame]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "missive/encryption/encryption_module.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <base/bind.h>
#include <base/containers/flat_map.h>
#include <base/hash/hash.h>
#include <base/rand_util.h>
#include <base/strings/strcat.h>
#include <base/synchronization/waitable_event.h>
#include <base/feature_list.h>
#include <base/test/task_environment.h>
#include <base/time/time.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "missive/encryption/decryption.h"
#include "missive/encryption/encryption.h"
#include "missive/encryption/encryption_module_interface.h"
#include "missive/encryption/primitives.h"
#include "missive/encryption/scoped_encryption_feature.h"
#include "missive/encryption/testing_primitives.h"
#include "missive/proto/record.pb.h"
#include "missive/util/status.h"
#include "missive/util/status_macros.h"
#include "missive/util/statusor.h"
#include "missive/util/test_support_callbacks.h"
using ::testing::Eq;
using ::testing::Ne;
namespace reporting {
namespace {
class EncryptionModuleTest : public ::testing::Test {
protected:
EncryptionModuleTest() = default;
void SetUp() override {
encryption_module_ = EncryptionModule::Create();
auto decryptor_result = test::Decryptor::Create();
ASSERT_OK(decryptor_result.status()) << decryptor_result.status();
decryptor_ = std::move(decryptor_result.ValueOrDie());
}
StatusOr<EncryptedRecord> EncryptSync(base::StringPiece data) {
test::TestEvent<StatusOr<EncryptedRecord>> encrypt_record;
encryption_module_->EncryptRecord(data, encrypt_record.cb());
return encrypt_record.result();
}
StatusOr<std::string> DecryptSync(
std::pair<std::string /*shared_secret*/, std::string /*encrypted_data*/>
encrypted) {
test::TestEvent<StatusOr<test::Decryptor::Handle*>> open_decrypt;
decryptor_->OpenRecord(encrypted.first, open_decrypt.cb());
auto open_decrypt_result = open_decrypt.result();
RETURN_IF_ERROR(open_decrypt_result.status());
test::Decryptor::Handle* const dec_handle =
open_decrypt_result.ValueOrDie();
test::TestEvent<Status> add_decrypt;
dec_handle->AddToRecord(encrypted.second, add_decrypt.cb());
RETURN_IF_ERROR(add_decrypt.result());
std::string decrypted_string;
test::TestEvent<Status> close_decrypt;
dec_handle->CloseRecord(base::BindOnce(
[](std::string* decrypted_string,
base::OnceCallback<void(Status)> close_cb,
StatusOr<base::StringPiece> result) {
if (!result.ok()) {
std::move(close_cb).Run(result.status());
return;
}
*decrypted_string = std::string(result.ValueOrDie());
std::move(close_cb).Run(Status::StatusOK());
},
base::Unretained(&decrypted_string), close_decrypt.cb()));
RETURN_IF_ERROR(close_decrypt.result());
return decrypted_string;
}
StatusOr<std::string> DecryptMatchingSecret(
Encryptor::PublicKeyId public_key_id, base::StringPiece encrypted_key) {
// Retrieve private key that matches public key hash.
test::TestEvent<StatusOr<std::string>> retrieve_private_key;
decryptor_->RetrieveMatchingPrivateKey(public_key_id,
retrieve_private_key.cb());
ASSIGN_OR_RETURN(std::string private_key, retrieve_private_key.result());
// Decrypt symmetric key with that private key and peer public key.
ASSIGN_OR_RETURN(std::string shared_secret,
decryptor_->DecryptSecret(private_key, encrypted_key));
return shared_secret;
}
Status AddNewKeyPair() {
// Generate new pair of private key and public value.
uint8_t out_public_value[kKeySize];
uint8_t out_private_key[kKeySize];
test::GenerateEncryptionKeyPair(out_private_key, out_public_value);
test::TestEvent<StatusOr<Encryptor::PublicKeyId>> record_keys;
decryptor_->RecordKeyPair(
std::string(reinterpret_cast<const char*>(out_private_key), kKeySize),
std::string(reinterpret_cast<const char*>(out_public_value), kKeySize),
record_keys.cb());
ASSIGN_OR_RETURN(Encryptor::PublicKeyId new_public_key_id,
record_keys.result());
test::TestEvent<Status> set_public_key;
encryption_module_->UpdateAsymmetricKey(
std::string(reinterpret_cast<const char*>(out_public_value), kKeySize),
new_public_key_id, set_public_key.cb());
RETURN_IF_ERROR(set_public_key.result());
return Status::StatusOK();
}
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
scoped_refptr<EncryptionModuleInterface> encryption_module_;
scoped_refptr<test::Decryptor> decryptor_;
private:
test::ScopedEncryptionFeature encryption_feature_{/*enable=*/true};
};
TEST_F(EncryptionModuleTest, EncryptAndDecrypt) {
constexpr char kTestString[] = "ABCDEF";
// Register new pair of private key and public value.
ASSERT_OK(AddNewKeyPair());
// Encrypt the test string using the last public value.
const auto encrypted_result = EncryptSync(kTestString);
ASSERT_OK(encrypted_result.status()) << encrypted_result.status();
// Decrypt shared secret with private asymmetric key.
auto decrypt_secret_result = DecryptMatchingSecret(
encrypted_result.ValueOrDie().encryption_info().public_key_id(),
encrypted_result.ValueOrDie().encryption_info().encryption_key());
ASSERT_OK(decrypt_secret_result.status()) << decrypt_secret_result.status();
// Decrypt back.
const auto decrypted_result = DecryptSync(
std::make_pair(decrypt_secret_result.ValueOrDie(),
encrypted_result.ValueOrDie().encrypted_wrapped_record()));
ASSERT_OK(decrypted_result.status()) << decrypted_result.status();
EXPECT_THAT(decrypted_result.ValueOrDie(), ::testing::StrEq(kTestString));
}
TEST_F(EncryptionModuleTest, EncryptionDisabled) {
constexpr char kTestString[] = "ABCDEF";
// Disable encryption.
test::ScopedEncryptionFeature encryption_feature_{/*enable=*/false};
// Encrypt the test string.
const auto encrypted_result = EncryptSync(kTestString);
ASSERT_OK(encrypted_result.status());
// Expect the result to be identical to the original record,
// and have no encryption_info.
EXPECT_EQ(encrypted_result.ValueOrDie().encrypted_wrapped_record(),
kTestString);
EXPECT_FALSE(encrypted_result.ValueOrDie().has_encryption_info());
}
TEST_F(EncryptionModuleTest, PublicKeyUpdate) {
constexpr char kTestString[] = "ABCDEF";
// No key yet, attempt to encrypt the test string.
ASSERT_FALSE(encryption_module_->has_encryption_key());
ASSERT_TRUE(encryption_module_->need_encryption_key());
auto encrypted_result = EncryptSync(kTestString);
EXPECT_EQ(encrypted_result.status().error_code(), error::NOT_FOUND);
// Register new pair of private key and public value.
ASSERT_OK(AddNewKeyPair());
ASSERT_TRUE(encryption_module_->has_encryption_key());
ASSERT_FALSE(encryption_module_->need_encryption_key());
// Encrypt the test string using the last public value.
encrypted_result = EncryptSync(kTestString);
ASSERT_OK(encrypted_result.status()) << encrypted_result.status();
// Simulate short wait. Key is still available and not needed.
task_environment_.FastForwardBy(base::TimeDelta::FromHours(8));
ASSERT_TRUE(encryption_module_->has_encryption_key());
ASSERT_FALSE(encryption_module_->need_encryption_key());
encrypted_result = EncryptSync(kTestString);
ASSERT_OK(encrypted_result.status()) << encrypted_result.status();
// Simulate long wait. Key is still available, but is needed now.
task_environment_.FastForwardBy(base::TimeDelta::FromDays(1));
ASSERT_TRUE(encryption_module_->has_encryption_key());
ASSERT_TRUE(encryption_module_->need_encryption_key());
encrypted_result = EncryptSync(kTestString);
ASSERT_OK(encrypted_result.status()) << encrypted_result.status();
// Register one more pair of private key and public value.
ASSERT_OK(AddNewKeyPair());
ASSERT_TRUE(encryption_module_->has_encryption_key());
ASSERT_FALSE(encryption_module_->need_encryption_key());
encrypted_result = EncryptSync(kTestString);
ASSERT_OK(encrypted_result.status()) << encrypted_result.status();
}
TEST_F(EncryptionModuleTest, EncryptAndDecryptMultiple) {
constexpr const char* kTestStrings[] = {"Rec1", "Rec22", "Rec333",
"Rec4444", "Rec55555", "Rec666666"};
// Encrypted records.
std::vector<EncryptedRecord> encrypted_records;
// 1. Register first key pair.
ASSERT_OK(AddNewKeyPair());
// 2. Encrypt 3 test strings.
for (const char* test_string :
{kTestStrings[0], kTestStrings[1], kTestStrings[2]}) {
const auto encrypted_result = EncryptSync(test_string);
ASSERT_OK(encrypted_result.status()) << encrypted_result.status();
encrypted_records.emplace_back(encrypted_result.ValueOrDie());
}
// 3. Register second key pair.
ASSERT_OK(AddNewKeyPair());
// 4. Encrypt 2 test strings.
for (const char* test_string : {kTestStrings[3], kTestStrings[4]}) {
const auto encrypted_result = EncryptSync(test_string);
ASSERT_OK(encrypted_result.status()) << encrypted_result.status();
encrypted_records.emplace_back(encrypted_result.ValueOrDie());
}
// 3. Register third key pair.
ASSERT_OK(AddNewKeyPair());
// 4. Encrypt one more test strings.
for (const char* test_string : {kTestStrings[5]}) {
const auto encrypted_result = EncryptSync(test_string);
ASSERT_OK(encrypted_result.status()) << encrypted_result.status();
encrypted_records.emplace_back(encrypted_result.ValueOrDie());
}
// For every encrypted record:
for (size_t i = 0; i < encrypted_records.size(); ++i) {
// Decrypt encrypted_key with private asymmetric key.
auto decrypt_secret_result = DecryptMatchingSecret(
encrypted_records[i].encryption_info().public_key_id(),
encrypted_records[i].encryption_info().encryption_key());
ASSERT_OK(decrypt_secret_result.status()) << decrypt_secret_result.status();
// Decrypt back.
const auto decrypted_result = DecryptSync(
std::make_pair(decrypt_secret_result.ValueOrDie(),
encrypted_records[i].encrypted_wrapped_record()));
ASSERT_OK(decrypted_result.status()) << decrypted_result.status();
// Verify match.
EXPECT_THAT(decrypted_result.ValueOrDie(),
::testing::StrEq(kTestStrings[i]));
}
}
TEST_F(EncryptionModuleTest, EncryptAndDecryptMultipleParallel) {
// Context of single encryption. Self-destructs upon completion or failure.
class SingleEncryptionContext {
public:
SingleEncryptionContext(
base::StringPiece test_string,
base::StringPiece public_key,
Encryptor::PublicKeyId public_key_id,
scoped_refptr<EncryptionModuleInterface> encryption_module,
base::OnceCallback<void(StatusOr<EncryptedRecord>)> response)
: test_string_(test_string),
public_key_(public_key),
public_key_id_(public_key_id),
encryption_module_(encryption_module),
response_(std::move(response)) {}
SingleEncryptionContext(const SingleEncryptionContext& other) = delete;
SingleEncryptionContext& operator=(const SingleEncryptionContext& other) =
delete;
~SingleEncryptionContext() {
DCHECK(!response_) << "Self-destruct without prior response";
}
void Start() {
base::ThreadPool::PostTask(
FROM_HERE, base::BindOnce(&SingleEncryptionContext::SetPublicKey,
base::Unretained(this)));
}
private:
void Respond(StatusOr<EncryptedRecord> result) {
std::move(response_).Run(result);
delete this;
}
void SetPublicKey() {
encryption_module_->UpdateAsymmetricKey(
public_key_, public_key_id_,
base::BindOnce(
[](SingleEncryptionContext* self, Status status) {
if (!status.ok()) {
self->Respond(status);
return;
}
base::ThreadPool::PostTask(
FROM_HERE,
base::BindOnce(&SingleEncryptionContext::EncryptRecord,
base::Unretained(self)));
},
base::Unretained(this)));
}
void EncryptRecord() {
encryption_module_->EncryptRecord(
test_string_,
base::BindOnce(
[](SingleEncryptionContext* self,
StatusOr<EncryptedRecord> encryption_result) {
base::ThreadPool::PostTask(
FROM_HERE,
base::BindOnce(&SingleEncryptionContext::Respond,
base::Unretained(self), encryption_result));
},
base::Unretained(this)));
}
private:
const std::string test_string_;
const std::string public_key_;
const Encryptor::PublicKeyId public_key_id_;
const scoped_refptr<EncryptionModuleInterface> encryption_module_;
base::OnceCallback<void(StatusOr<EncryptedRecord>)> response_;
};
// Context of single decryption. Self-destructs upon completion or failure.
class SingleDecryptionContext {
public:
SingleDecryptionContext(
const EncryptedRecord& encrypted_record,
scoped_refptr<test::Decryptor> decryptor,
base::OnceCallback<void(StatusOr<base::StringPiece>)> response)
: encrypted_record_(encrypted_record),
decryptor_(decryptor),
response_(std::move(response)) {}
SingleDecryptionContext(const SingleDecryptionContext& other) = delete;
SingleDecryptionContext& operator=(const SingleDecryptionContext& other) =
delete;
~SingleDecryptionContext() {
DCHECK(!response_) << "Self-destruct without prior response";
}
void Start() {
base::ThreadPool::PostTask(
FROM_HERE,
base::BindOnce(&SingleDecryptionContext::RetrieveMatchingPrivateKey,
base::Unretained(this)));
}
private:
void Respond(StatusOr<base::StringPiece> result) {
std::move(response_).Run(result);
delete this;
}
void RetrieveMatchingPrivateKey() {
// Retrieve private key that matches public key hash.
decryptor_->RetrieveMatchingPrivateKey(
encrypted_record_.encryption_info().public_key_id(),
base::BindOnce(
[](SingleDecryptionContext* self,
StatusOr<std::string> private_key_result) {
if (!private_key_result.ok()) {
self->Respond(private_key_result.status());
return;
}
base::ThreadPool::PostTask(
FROM_HERE,
base::BindOnce(
&SingleDecryptionContext::DecryptSharedSecret,
base::Unretained(self),
private_key_result.ValueOrDie()));
},
base::Unretained(this)));
}
void DecryptSharedSecret(base::StringPiece private_key) {
// Decrypt shared secret from private key and peer public key.
auto shared_secret_result = decryptor_->DecryptSecret(
private_key, encrypted_record_.encryption_info().encryption_key());
if (!shared_secret_result.ok()) {
Respond(shared_secret_result.status());
return;
}
base::ThreadPool::PostTask(
FROM_HERE, base::BindOnce(&SingleDecryptionContext::OpenRecord,
base::Unretained(this),
shared_secret_result.ValueOrDie()));
}
void OpenRecord(base::StringPiece shared_secret) {
decryptor_->OpenRecord(
shared_secret,
base::BindOnce(
[](SingleDecryptionContext* self,
StatusOr<test::Decryptor::Handle*> handle_result) {
if (!handle_result.ok()) {
self->Respond(handle_result.status());
return;
}
base::ThreadPool::PostTask(
FROM_HERE,
base::BindOnce(
&SingleDecryptionContext::AddToRecord,
base::Unretained(self),
base::Unretained(handle_result.ValueOrDie())));
},
base::Unretained(this)));
}
void AddToRecord(test::Decryptor::Handle* handle) {
handle->AddToRecord(
encrypted_record_.encrypted_wrapped_record(),
base::BindOnce(
[](SingleDecryptionContext* self, test::Decryptor::Handle* handle,
Status status) {
if (!status.ok()) {
self->Respond(status);
return;
}
base::ThreadPool::PostTask(
FROM_HERE,
base::BindOnce(&SingleDecryptionContext::CloseRecord,
base::Unretained(self),
base::Unretained(handle)));
},
base::Unretained(this), base::Unretained(handle)));
}
void CloseRecord(test::Decryptor::Handle* handle) {
handle->CloseRecord(base::BindOnce(
[](SingleDecryptionContext* self,
StatusOr<base::StringPiece> decryption_result) {
self->Respond(decryption_result);
},
base::Unretained(this)));
}
private:
const EncryptedRecord encrypted_record_;
const scoped_refptr<test::Decryptor> decryptor_;
base::OnceCallback<void(StatusOr<base::StringPiece>)> response_;
};
constexpr std::array<const char*, 6> kTestStrings = {
"Rec1", "Rec22", "Rec333", "Rec4444", "Rec55555", "Rec666666"};
// Public and private key pairs in this test are reversed strings.
std::vector<std::string> private_key_strings;
std::vector<std::string> public_value_strings;
std::vector<Encryptor::PublicKeyId> public_value_ids;
for (size_t i = 0; i < 3; ++i) {
// Generate new pair of private key and public value.
uint8_t out_public_value[kKeySize];
uint8_t out_private_key[kKeySize];
test::GenerateEncryptionKeyPair(out_private_key, out_public_value);
private_key_strings.emplace_back(
reinterpret_cast<const char*>(out_private_key), kKeySize);
public_value_strings.emplace_back(
reinterpret_cast<const char*>(out_public_value), kKeySize);
}
// Register all key pairs for decryption.
std::vector<test::TestEvent<StatusOr<Encryptor::PublicKeyId>>> record_results(
public_value_strings.size());
for (size_t i = 0; i < public_value_strings.size(); ++i) {
base::ThreadPool::PostTask(
FROM_HERE,
base::BindOnce(
[](base::StringPiece private_key_string,
base::StringPiece public_key_string,
scoped_refptr<test::Decryptor> decryptor,
base::OnceCallback<void(StatusOr<Encryptor::PublicKeyId>)>
done_cb) {
decryptor->RecordKeyPair(private_key_string, public_key_string,
std::move(done_cb));
},
private_key_strings[i], public_value_strings[i], decryptor_,
record_results[i].cb()));
}
// Verify registration success.
for (auto& record_result : record_results) {
const auto result = record_result.result();
ASSERT_OK(result.status()) << result.status();
public_value_ids.push_back(result.ValueOrDie());
}
// Encrypt all records in parallel.
std::vector<test::TestEvent<StatusOr<EncryptedRecord>>> results(
kTestStrings.size());
for (size_t i = 0; i < kTestStrings.size(); ++i) {
// Choose random key pair.
size_t i_key_pair = base::RandInt(0, public_value_strings.size() - 1);
(new SingleEncryptionContext(
kTestStrings[i], public_value_strings[i_key_pair],
public_value_ids[i_key_pair], encryption_module_, results[i].cb()))
->Start();
}
// Decrypt all records in parallel.
std::vector<test::TestEvent<StatusOr<std::string>>> decryption_results(
kTestStrings.size());
for (size_t i = 0; i < results.size(); ++i) {
// Verify encryption success.
const auto result = results[i].result();
ASSERT_OK(result.status()) << result.status();
// Decrypt and compare encrypted_record.
(new SingleDecryptionContext(
result.ValueOrDie(), decryptor_,
base::BindOnce(
[](base::OnceCallback<void(StatusOr<std::string>)>
decryption_result,
StatusOr<base::StringPiece> result) {
if (!result.ok()) {
std::move(decryption_result).Run(result.status());
return;
}
std::move(decryption_result)
.Run(std::string(result.ValueOrDie()));
},
decryption_results[i].cb())))
->Start();
}
// Verify decryption results.
for (size_t i = 0; i < decryption_results.size(); ++i) {
const auto decryption_result = decryption_results[i].result();
ASSERT_OK(decryption_result.status()) << decryption_result.status();
// Verify data match.
EXPECT_THAT(decryption_result.ValueOrDie(),
::testing::StrEq(kTestStrings[i]));
}
}
} // namespace
} // namespace reporting