chaps: Asynchronously init TPM slot

We don't need the TokenInitThread after we have the asynchronous TPM
API. This change could also help us cleanup the thread model in chaps.

BUG=b:205087097
TEST=tast run $DUT hwsec.Chaps*

Change-Id: I0b4686d3f4fcc0ddb776a1ad52d820fd29a778b2
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/3347495
Reviewed-by: Andrey Pronin <apronin@chromium.org>
Reviewed-by: John L Chen <zuan@chromium.org>
Tested-by: Yi Chou <yich@google.com>
Commit-Queue: Yi Chou <yich@google.com>
diff --git a/chaps/slot_manager_impl.cc b/chaps/slot_manager_impl.cc
index 01c25b7..26757e1 100644
--- a/chaps/slot_manager_impl.cc
+++ b/chaps/slot_manager_impl.cc
@@ -146,15 +146,14 @@
   return version + hash_byte;
 }
 
-// Sanity checks authorization data by comparing against a hash stored in the
-// token database.
-// Args:
+// Checks authorization data by comparing against a hash stored in the token
+// database. Args:
 //   auth_data_hash - A hash of the authorization data to be verified.
 //   saved_auth_data_hash - The hash currently stored in the database.
 // Returns:
 //   False if both hash values are valid and they do not match.
-bool SanityCheckAuthData(const string& auth_data_hash,
-                         const string& saved_auth_data_hash) {
+bool CheckAuthDataValid(const string& auth_data_hash,
+                        const string& saved_auth_data_hash) {
   CHECK_EQ(auth_data_hash.length(), 2u);
   if (saved_auth_data_hash.length() != 2 ||
       saved_auth_data_hash[0] != kAuthDataHashVersion)
@@ -205,133 +204,6 @@
                << " indicated that token " << reinitialized_token_path
                << " has been reinitialized.";
 }
-
-// Performs expensive tasks required to initialize a token.
-class TokenInitThread : public base::PlatformThread::Delegate {
- public:
-  // This class will not take ownership of any pointers.
-  TokenInitThread(int slot_id,
-                  FilePath path,
-                  const SecureBlob& auth_data,
-                  TPMUtility* tpm_utility,
-                  ObjectPool* object_pool,
-                  SystemShutdownBlocker* system_shutdown_blocker);
-  TokenInitThread(const TokenInitThread&) = delete;
-  TokenInitThread& operator=(const TokenInitThread&) = delete;
-
-  ~TokenInitThread() override {}
-
-  // PlatformThread::Delegate interface.
-  void ThreadMain() override;
-
- private:
-  bool InitializeKeyHierarchy(SecureBlob* root_key);
-
-  int slot_id_;
-  FilePath path_;
-  SecureBlob auth_data_;
-  TPMUtility* tpm_utility_;
-  ObjectPool* object_pool_;
-  SystemShutdownBlocker* system_shutdown_blocker_;
-};
-
-TokenInitThread::TokenInitThread(int slot_id,
-                                 FilePath path,
-                                 const SecureBlob& auth_data,
-                                 TPMUtility* tpm_utility,
-                                 ObjectPool* object_pool,
-                                 SystemShutdownBlocker* system_shutdown_blocker)
-    : slot_id_(slot_id),
-      path_(path),
-      auth_data_(auth_data),
-      tpm_utility_(tpm_utility),
-      object_pool_(object_pool),
-      system_shutdown_blocker_(system_shutdown_blocker) {}
-
-void TokenInitThread::ThreadMain() {
-  // Create a message loop for this thread.
-  brillo::BaseMessageLoop loop;
-  loop.SetAsCurrent();
-
-  // Block system shutdown while TokenInitThread is running. Unblock shutdown
-  // once TokenInitThread completes or a fallback timeout of
-  // |kTokenInitBlockSystemShutdownFallbackTimeout| has expired.
-  // |system_shutdown_blocker_| can be nullptr in tests.
-  std::unique_ptr<base::ScopedClosureRunner> scoped_closure_runner;
-  if (system_shutdown_blocker_) {
-    auto unblock_closure =
-        base::Bind(&SystemShutdownBlocker::Unblock,
-                   base::Unretained(system_shutdown_blocker_), slot_id_);
-    scoped_closure_runner =
-        std::make_unique<base::ScopedClosureRunner>(unblock_closure);
-    system_shutdown_blocker_->Block(
-        slot_id_, kTokenInitBlockSystemShutdownFallbackTimeout);
-  }
-
-  string auth_data_hash = HashAuthData(auth_data_);
-  string saved_auth_data_hash;
-  string auth_key_blob;
-  string encrypted_root_key;
-  SecureBlob root_key;
-  // Determine whether the key hierarchy has already been initialized based on
-  // whether the relevant blobs exist.
-  if (!object_pool_->GetInternalBlob(kEncryptedAuthKey, &auth_key_blob) ||
-      !object_pool_->GetInternalBlob(kEncryptedRootKey, &encrypted_root_key)) {
-    LOG(INFO) << "Initializing key hierarchy for token at " << path_.value();
-    if (!InitializeKeyHierarchy(&root_key)) {
-      LOG(ERROR) << "Failed to initialize key hierarchy at " << path_.value();
-    }
-  } else {
-    // Don't send the auth data to the TPM if it fails to verify against the
-    // saved hash.
-    object_pool_->GetInternalBlob(kAuthDataHash, &saved_auth_data_hash);
-    if (!SanityCheckAuthData(auth_data_hash, saved_auth_data_hash) ||
-        !tpm_utility_->UnsealData(auth_key_blob, encrypted_root_key,
-                                  Sha1(auth_data_), &root_key)) {
-      LOG(ERROR) << "Failed to unseal for token at " << path_.value()
-                 << ", reinitializing token.";
-      CreateTokenReinitializedFlagFile(path_);
-      if (object_pool_->DeleteAll() != ObjectPool::Result::Success)
-        LOG(WARNING) << "Failed to delete all existing objects.";
-      if (!InitializeKeyHierarchy(&root_key)) {
-        LOG(ERROR) << "Failed to initialize key hierarchy at " << path_.value();
-      }
-    }
-  }
-  if (!object_pool_->SetEncryptionKey(root_key)) {
-    LOG(ERROR) << "SetEncryptionKey failed for token at " << path_.value();
-    return;
-  }
-  if (!root_key.empty()) {
-    if (auth_data_hash != saved_auth_data_hash)
-      object_pool_->SetInternalBlob(kAuthDataHash, auth_data_hash);
-    LOG(INFO) << "Root key is ready for token at " << path_.value();
-  }
-}
-
-bool TokenInitThread::InitializeKeyHierarchy(SecureBlob* root_key) {
-  string root_key_str;
-  if (!tpm_utility_->GenerateRandom(kUserKeySize, &root_key_str)) {
-    LOG(ERROR) << "Failed to generate user encryption key.";
-    return false;
-  }
-  *root_key = SecureBlob(root_key_str.begin(), root_key_str.end());
-  string auth_key_blob;
-  string encrypted_root_key;
-  if (!tpm_utility_->SealData(root_key_str, Sha1(auth_data_), &auth_key_blob,
-                              &encrypted_root_key)) {
-    LOG(ERROR) << "Failed to seal user encryption key.";
-    return false;
-  }
-  if (!object_pool_->SetInternalBlob(kEncryptedAuthKey, auth_key_blob) ||
-      !object_pool_->SetInternalBlob(kEncryptedRootKey, encrypted_root_key)) {
-    LOG(ERROR) << "Failed to write key hierarchy blobs.";
-    return false;
-  }
-  brillo::SecureClearContainer(root_key_str);
-  return true;
-}
-
 }  // namespace
 
 SlotManagerImpl::SlotManagerImpl(ChapsFactory* factory,
@@ -368,11 +240,6 @@
 SlotManagerImpl::~SlotManagerImpl() {
   LOG(INFO) << "SlotManagerImpl is shutting down.";
   for (size_t i = 0; i < slot_list_.size(); ++i) {
-    // Wait for any worker thread to finish.
-    if (slot_list_[i].worker_thread.get()) {
-      LOG(INFO) << "Waiting for worker thread for slot " << i << " to exit.";
-      base::PlatformThread::Join(slot_list_[i].worker_thread_handle);
-    }
     if (tpm_utility_->IsTPMAvailable()) {
       // Unload any keys that have been loaded in the TPM.
       LOG(INFO) << "Unloading keys for slot " << i << ".";
@@ -650,23 +517,11 @@
       factory_->CreateObjectImporter(*slot_id, path, tpm_utility_)));
   CHECK(object_pool.get());
 
-  // Wait for the termination of a previous token.
-  if (slot_list_[*slot_id].worker_thread.get()) {
-    base::PlatformThread::Join(slot_list_[*slot_id].worker_thread_handle);
-    slot_list_[*slot_id].worker_thread.reset();
-    slot_list_[*slot_id].worker_thread_handle = base::PlatformThreadHandle();
-  }
-
   if (tpm_utility_->IsTPMAvailable()) {
-    // Decrypting (or creating) the root key requires the TPM so we'll put
-    // this on a worker thread. This has the effect that queries for public
-    // objects are responsive but queries for private objects will be waiting
-    // for the root key to be ready.
-    slot_list_[*slot_id].worker_thread.reset(
-        new TokenInitThread(*slot_id, path, auth_data, tpm_utility_,
-                            object_pool.get(), system_shutdown_blocker_));
-    base::PlatformThread::Create(0, slot_list_[*slot_id].worker_thread.get(),
-                                 &slot_list_[*slot_id].worker_thread_handle);
+    // Asynchronously Decrypting (or creating) the root key.
+    // This has the effect that queries for public objects are responsive but
+    // queries for private objects will be waiting for the root key to be ready.
+    LoadTPMToken(base::DoNothing(), *slot_id, path, auth_data, object_pool);
   } else {
     // Load a software-only token.
     LOG(WARNING) << "No TPM is available. Loading a software-only token.";
@@ -690,6 +545,169 @@
   return true;
 }
 
+void SlotManagerImpl::LoadTPMToken(base::OnceCallback<void(bool)> callback,
+                                   int slot_id,
+                                   const base::FilePath& path,
+                                   const brillo::SecureBlob& auth_data,
+                                   std::shared_ptr<ObjectPool> object_pool) {
+  if (system_shutdown_blocker_) {
+    base::OnceClosure unblock_closure =
+        base::BindOnce(&SystemShutdownBlocker::Unblock,
+                       base::Unretained(system_shutdown_blocker_), slot_id);
+
+    // Hook the unblock callback into the final callback.
+    callback = base::BindOnce(
+        [](base::OnceClosure unblock, base::OnceCallback<void(bool)> callback,
+           bool result) {
+          std::move(unblock).Run();
+          std::move(callback).Run(result);
+        },
+        std::move(unblock_closure), std::move(callback));
+
+    system_shutdown_blocker_->Block(
+        slot_id, kTokenInitBlockSystemShutdownFallbackTimeout);
+  }
+
+  string auth_data_hash = HashAuthData(auth_data);
+  string saved_auth_data_hash;
+  string auth_key_blob;
+  string encrypted_root_key;
+  // Determine whether the key hierarchy has already been initialized based on
+  // whether the relevant blobs exist.
+  if (!object_pool->GetInternalBlob(kEncryptedAuthKey, &auth_key_blob) ||
+      !object_pool->GetInternalBlob(kEncryptedRootKey, &encrypted_root_key)) {
+    LOG(INFO) << "Initializing key hierarchy for token at " << path.value();
+    InitializeTPMToken(std::move(callback), path, auth_data, object_pool);
+    return;
+  }
+
+  // Don't send the auth data to the TPM if it fails to verify against the
+  // saved hash.
+  object_pool->GetInternalBlob(kAuthDataHash, &saved_auth_data_hash);
+  if (!CheckAuthDataValid(auth_data_hash, saved_auth_data_hash)) {
+    LOG(ERROR) << "Failed to check the auth data is valid for token at "
+               << path.value() << ", reinitializing token.";
+    CreateTokenReinitializedFlagFile(path);
+    if (object_pool->DeleteAll() != ObjectPool::Result::Success)
+      LOG(WARNING) << "Failed to delete all existing objects.";
+
+    InitializeTPMToken(std::move(callback), path, auth_data, object_pool);
+    return;
+  }
+
+  AsyncTPMUtility::UnsealDataCallback unseal_callback = base::BindOnce(
+      &SlotManagerImpl::LoadTPMTokenAfterUnseal, base::Unretained(this),
+      std::move(callback), path, auth_data, object_pool);
+  tpm_utility_->UnsealDataAsync(auth_key_blob, encrypted_root_key,
+                                Sha1(auth_data), std::move(unseal_callback));
+  return;
+}
+
+void SlotManagerImpl::LoadTPMTokenAfterUnseal(
+    base::OnceCallback<void(bool)> callback,
+    const base::FilePath& path,
+    const brillo::SecureBlob& auth_data,
+    std::shared_ptr<ObjectPool> object_pool,
+    bool success,
+    brillo::SecureBlob unsealed_data) {
+  if (!success) {
+    LOG(ERROR) << "Failed to unseal for token at " << path.value()
+               << ", reinitializing token.";
+    CreateTokenReinitializedFlagFile(path);
+    if (object_pool->DeleteAll() != ObjectPool::Result::Success)
+      LOG(WARNING) << "Failed to delete all existing objects.";
+
+    InitializeTPMToken(std::move(callback), path, auth_data, object_pool);
+    return;
+  }
+  LoadTPMTokenFinal(std::move(callback), path, auth_data, object_pool,
+                    unsealed_data);
+}
+
+void SlotManagerImpl::LoadTPMTokenFinal(base::OnceCallback<void(bool)> callback,
+                                        const base::FilePath& path,
+                                        const brillo::SecureBlob& auth_data,
+                                        std::shared_ptr<ObjectPool> object_pool,
+                                        brillo::SecureBlob root_key) {
+  if (!object_pool->SetEncryptionKey(root_key)) {
+    LOG(ERROR) << "SetEncryptionKey failed for token at " << path.value();
+    std::move(callback).Run(false);
+    return;
+  }
+  if (!root_key.empty()) {
+    string auth_data_hash = HashAuthData(auth_data);
+    string saved_auth_data_hash;
+    object_pool->GetInternalBlob(kAuthDataHash, &saved_auth_data_hash);
+    if (auth_data_hash != saved_auth_data_hash) {
+      object_pool->SetInternalBlob(kAuthDataHash, auth_data_hash);
+    }
+    LOG(INFO) << "Root key is ready for token at " << path.value();
+    std::move(callback).Run(true);
+    return;
+  }
+  std::move(callback).Run(false);
+}
+
+void SlotManagerImpl::InitializeTPMToken(
+    base::OnceCallback<void(bool)> callback,
+    const base::FilePath& path,
+    const brillo::SecureBlob& auth_data,
+    std::shared_ptr<ObjectPool> object_pool) {
+  AsyncTPMUtility::GenerateRandomCallback gen_rand_callback =
+      base::BindOnce(&SlotManagerImpl::InitializeTPMTokenAfterGenerateRandom,
+                     base::Unretained(this), std::move(callback), path,
+                     auth_data, object_pool);
+  tpm_utility_->GenerateRandomAsync(kUserKeySize, std::move(gen_rand_callback));
+}
+
+void SlotManagerImpl::InitializeTPMTokenAfterGenerateRandom(
+    base::OnceCallback<void(bool)> callback,
+    const base::FilePath& path,
+    const brillo::SecureBlob& auth_data,
+    std::shared_ptr<ObjectPool> object_pool,
+    bool success,
+    std::string random_data) {
+  if (!success) {
+    LOG(ERROR) << "Failed to generate user encryption key.";
+    std::move(callback).Run(false);
+    return;
+  }
+
+  SecureBlob root_key(random_data.begin(), random_data.end());
+
+  AsyncTPMUtility::SealDataCallback seal_callback = base::BindOnce(
+      &SlotManagerImpl::InitializeTPMTokenAfterSealData, base::Unretained(this),
+      std::move(callback), path, auth_data, object_pool, root_key);
+  tpm_utility_->SealDataAsync(random_data, Sha1(auth_data),
+                              std::move(seal_callback));
+}
+
+void SlotManagerImpl::InitializeTPMTokenAfterSealData(
+    base::OnceCallback<void(bool)> callback,
+    const base::FilePath& path,
+    const brillo::SecureBlob& auth_data,
+    std::shared_ptr<ObjectPool> object_pool,
+    brillo::SecureBlob root_key,
+    bool success,
+    std::string key_blob,
+    std::string encrypted_data) {
+  if (!success) {
+    LOG(ERROR) << "Failed to seal user encryption key.";
+    std::move(callback).Run(false);
+    return;
+  }
+
+  if (!object_pool->SetInternalBlob(kEncryptedAuthKey, key_blob) ||
+      !object_pool->SetInternalBlob(kEncryptedRootKey, encrypted_data)) {
+    LOG(ERROR) << "Failed to write key hierarchy blobs.";
+    std::move(callback).Run(false);
+    return;
+  }
+
+  LoadTPMTokenFinal(std::move(callback), path, auth_data, object_pool,
+                    root_key);
+}
+
 bool SlotManagerImpl::LoadSoftwareToken(const SecureBlob& auth_data,
                                         ObjectPool* object_pool) {
   SecureBlob auth_key_encrypt =
@@ -786,13 +804,6 @@
     return;
   }
 
-  // Wait for initialization to be finished before cleaning up.
-  if (slot_list_[slot_id].worker_thread.get()) {
-    base::PlatformThread::Join(slot_list_[slot_id].worker_thread_handle);
-    slot_list_[slot_id].worker_thread.reset();
-    slot_list_[slot_id].worker_thread_handle = base::PlatformThreadHandle();
-  }
-
   if (tpm_utility_->IsTPMAvailable()) {
     tpm_utility_->UnloadKeysForSlotAsync(slot_id, base::DoNothing());
   }
@@ -831,11 +842,11 @@
   }
   CHECK(object_pool);
   if (tpm_utility_->IsTPMAvailable()) {
-    // Before we attempt the change, sanity check old_auth_data.
+    // Before we attempt the change, check the auth_data is valid.
     string saved_auth_data_hash;
     object_pool->GetInternalBlob(kAuthDataHash, &saved_auth_data_hash);
-    if (!SanityCheckAuthData(HashAuthData(old_auth_data),
-                             saved_auth_data_hash)) {
+    if (!CheckAuthDataValid(HashAuthData(old_auth_data),
+                            saved_auth_data_hash)) {
       LOG(ERROR) << "Old authorization data is not correct.";
       return;
     }
diff --git a/chaps/slot_manager_impl.h b/chaps/slot_manager_impl.h
index 5a5b373..6970922 100644
--- a/chaps/slot_manager_impl.h
+++ b/chaps/slot_manager_impl.h
@@ -19,7 +19,6 @@
 
 #include <base/macros.h>
 #include <base/synchronization/lock.h>
-#include <base/threading/platform_thread.h>
 
 #include "chaps/chaps_factory.h"
 #include "chaps/object_pool.h"
@@ -116,8 +115,6 @@
     // Key: A session identifier.
     // Value: The associated session object.
     std::map<int, std::shared_ptr<Session>> sessions;
-    std::shared_ptr<base::PlatformThread::Delegate> worker_thread;
-    base::PlatformThreadHandle worker_thread_handle;
   };
 
   // Internal token presence check without isolate_credential check.
@@ -171,6 +168,53 @@
                          const std::string& label,
                          int* slot_id);
 
+  // Loads the root key for a TPM token.
+  void LoadTPMToken(base::OnceCallback<void(bool)> callback,
+                    int slot_id,
+                    const base::FilePath& path,
+                    const brillo::SecureBlob& auth_data,
+                    std::shared_ptr<ObjectPool> object_pool);
+
+  // Loads the root key for a TPM token after the TPM unseals data.
+  void LoadTPMTokenAfterUnseal(base::OnceCallback<void(bool)> callback,
+                               const base::FilePath& path,
+                               const brillo::SecureBlob& auth_data,
+                               std::shared_ptr<ObjectPool> object_pool,
+                               bool success,
+                               brillo::SecureBlob unsealed_data);
+
+  // The final operation of loading the root key for a TPM token.
+  void LoadTPMTokenFinal(base::OnceCallback<void(bool)> callback,
+                         const base::FilePath& path,
+                         const brillo::SecureBlob& auth_data,
+                         std::shared_ptr<ObjectPool> object_pool,
+                         brillo::SecureBlob root_key);
+
+  // Initializes a new TPM token.
+  void InitializeTPMToken(base::OnceCallback<void(bool)> callback,
+                          const base::FilePath& path,
+                          const brillo::SecureBlob& auth_data,
+                          std::shared_ptr<ObjectPool> object_pool);
+
+  // Initializes a new TPM token after the TPM generates random.
+  void InitializeTPMTokenAfterGenerateRandom(
+      base::OnceCallback<void(bool)> callback,
+      const base::FilePath& path,
+      const brillo::SecureBlob& auth_data,
+      std::shared_ptr<ObjectPool> object_pool,
+      bool success,
+      std::string random_data);
+
+  // Initializes a new TPM token after the TPM seals data.
+  void InitializeTPMTokenAfterSealData(base::OnceCallback<void(bool)> callback,
+                                       const base::FilePath& path,
+                                       const brillo::SecureBlob& auth_data,
+                                       std::shared_ptr<ObjectPool> object_pool,
+                                       brillo::SecureBlob root_key,
+                                       bool success,
+                                       std::string key_blob,
+                                       std::string encrypted_data);
+
   // Loads the root key for a software-only token.
   bool LoadSoftwareToken(const brillo::SecureBlob& auth_data,
                          ObjectPool* object_pool);