verity: add support for salting.

Salting is exposed as an optional salt=<hex> argument. The salt is appended to
hashed blocks if present, padded to 32 bytes with zero bytes.

This code is not yet used by any tools, so it should have no effect.

BUG=chromium-os:12138
TEST=Adhoc, script:12138, unittest, autotest
Build an image (I did this for Kaen) and boot it.
Run the script attached to bug 12138; if it prints 'ok', everything's good.
Check that unit test 'CreateThenVerifyOkSalt' passes.
Run platform_DMVerityBitCorruption and platform_DMVerityCorruption.

Change-Id: I3eeb17d041bcd567c0908b017e9d57a896c11cc4
Signed-off-by: Elly Jones <ellyjones@chromium.org>
Reviewed-on: http://gerrit.chromium.org/gerrit/6708
Reviewed-by: Will Drewry <wad@chromium.org>
diff --git a/dm-bht.c b/dm-bht.c
index 209b08a..539847f 100644
--- a/dm-bht.c
+++ b/dm-bht.c
@@ -110,8 +110,19 @@
 			smp_processor_id());
 		return -EINVAL;
 	}
-	if (crypto_hash_digest(hash_desc, &sg, PAGE_SIZE, digest)) {
-		DMCRIT("crypto_hash_digest failed");
+	if (crypto_hash_update(hash_desc, &sg, PAGE_SIZE)) {
+		DMCRIT("crypto_hash_update failed");
+		return -EINVAL;
+	}
+	if (bht->have_salt) {
+		sg_set_buf(&sg, bht->salt, sizeof(bht->salt));
+		if (crypto_hash_update(hash_desc, &sg, sizeof(bht->salt))) {
+			DMCRIT("crypto_hash_update failed");
+			return -EINVAL;
+		}
+	}
+	if (crypto_hash_final(hash_desc, digest)) {
+		DMCRIT("crypto_hash_final failed");
 		return -EINVAL;
 	}
 
@@ -201,6 +212,8 @@
 	int status = 0;
 	int cpu = 0;
 
+	bht->have_salt = false;
+
 	/* Setup the hash first. Its length determines much of the bht layout */
 	for (cpu = 0; cpu < nr_cpu_ids; ++cpu) {
 		bht->hash_desc[cpu].tfm = crypto_alloc_hash(alg_name, 0, 0);
@@ -938,3 +951,30 @@
 	return 0;
 }
 EXPORT_SYMBOL(dm_bht_root_hexdigest);
+
+/**
+ * dm_bht_set_salt - sets the salt used, in hex
+ * @bht:      pointer to a dm_bht_create()d bht
+ * @hexsalt:  salt string, as hex; will be zero-padded or truncated to
+ *            DM_BHT_SALT_SIZE * 2 hex digits.
+ */
+void dm_bht_set_salt(struct dm_bht *bht, const char *hexsalt)
+{
+	size_t saltlen = min(strlen(hexsalt) / 2, sizeof(bht->salt));
+	bht->have_salt = true;
+	memset(bht->salt, 0, sizeof(bht->salt));
+	dm_bht_hex_to_bin(bht->salt, (const u8 *)hexsalt, saltlen);
+}
+
+/**
+ * dm_bht_salt - returns the salt used, in hex
+ * @bht:      pointer to a dm_bht_create()d bht
+ * @hexsalt:  buffer to put salt into, of length DM_BHT_SALT_SIZE * 2 + 1.
+ */
+int dm_bht_salt(struct dm_bht *bht, char *hexsalt)
+{
+	if (!bht->have_salt)
+		return -EINVAL;
+	dm_bht_bin_to_hex(bht->salt, (u8 *)hexsalt, sizeof(bht->salt));
+	return 0;
+}
diff --git a/dm-bht.h b/dm-bht.h
index 245797a..99a9425 100644
--- a/dm-bht.h
+++ b/dm-bht.h
@@ -17,6 +17,7 @@
  * max to use for now.
  */
 #define DM_BHT_MAX_DIGEST_SIZE 128  /* 1k hashes are unlikely for now */
+#define DM_BHT_SALT_SIZE       32   /* 256 bits of salt is a lot */
 
 /* UNALLOCATED, PENDING, READY, and VERIFIED are valid states. All other
  * values are entry-related return codes.
@@ -84,6 +85,12 @@
 	int depth;  /* Depth of the tree including the root */
 	unsigned int block_count;  /* Number of blocks hashed */
 	char hash_alg[CRYPTO_MAX_ALG_NAME];
+	unsigned char salt[DM_BHT_SALT_SIZE];
+
+	/* This is a temporary hack to ease the transition to salting. It will
+	 * be removed once salting is supported both in kernel and userspace,
+	 * and the salt will default to all zeroes instead. */
+	bool have_salt;
 
 	/* Computed values */
 	unsigned int node_count;  /* Data size (in hashes) for each entry */
@@ -114,6 +121,8 @@
 void dm_bht_set_write_cb(struct dm_bht *bht, dm_bht_callback write_cb);
 int dm_bht_set_root_hexdigest(struct dm_bht *bht, const u8 *hexdigest);
 int dm_bht_root_hexdigest(struct dm_bht *bht, u8 *hexdigest, int available);
+void dm_bht_set_salt(struct dm_bht *bht, const char *hexsalt);
+int dm_bht_salt(struct dm_bht *bht, char *hexsalt);
 
 /* Functions for loading in data from disk for verification */
 bool dm_bht_is_populated(struct dm_bht *bht, unsigned int block);
diff --git a/dm-bht_unittest.cc b/dm-bht_unittest.cc
index bbd7254..59d22d2 100644
--- a/dm-bht_unittest.cc
+++ b/dm-bht_unittest.cc
@@ -100,20 +100,23 @@
  protected:
   // Creates a new dm_bht and sets it in the existing MemoryBht.
   void NewBht(const unsigned int total_blocks,
-              const char *digest_algorithm) {
+              const char *digest_algorithm,
+              const char *salt) {
     bht_.reset(new dm_bht());
-    EXPECT_EQ(0, dm_bht_create(bht_.get(), total_blocks,
-                               digest_algorithm));
+    EXPECT_EQ(0, dm_bht_create(bht_.get(), total_blocks, digest_algorithm));
     if (hash_data_.get() == NULL) {
       sectors_ = dm_bht_sectors(bht_.get());
       hash_data_.reset(new u8[to_bytes(sectors_)]);
     }
     dm_bht_set_write_cb(bht_.get(), MemoryBhtTest::WriteCallback);
     dm_bht_set_read_cb(bht_.get(), MemoryBhtTest::ReadCallback);
+    if (salt)
+      dm_bht_set_salt(bht_.get(), salt);
   }
   void SetupBht(const unsigned int total_blocks,
-                const char *digest_algorithm) {
-    NewBht(total_blocks, digest_algorithm);
+                const char *digest_algorithm,
+                const char *salt) {
+    NewBht(total_blocks, digest_algorithm, salt);
 
     u8 *data = (u8 *)my_memalign(PAGE_SIZE, PAGE_SIZE);
 
@@ -133,7 +136,7 @@
     EXPECT_EQ(0, dm_bht_destroy(bht_.get()));
     // bht is now dead and mbht_ is a prepared hash image
 
-    NewBht(total_blocks, digest_algorithm);
+    NewBht(total_blocks, digest_algorithm, salt);
 
     // Load the tree from the pre-populated hash data
     for (blocks = 0; blocks < total_blocks; blocks += bht_->node_count)
@@ -159,7 +162,7 @@
 
   memset(zero_page, 0, PAGE_SIZE);
 
-  SetupBht(total_blocks, "sha256");
+  SetupBht(total_blocks, "sha256", NULL);
   dm_bht_set_root_hexdigest(bht_.get(),
                             reinterpret_cast<const u8 *>(kRootDigest));
 
@@ -183,7 +186,7 @@
 
   memset(zero_page, 0, PAGE_SIZE);
 
-  SetupBht(total_blocks, "sha256");
+  SetupBht(total_blocks, "sha256", NULL);
   dm_bht_set_root_hexdigest(bht_.get(),
                             reinterpret_cast<const u8 *>(kRootDigest));
 
@@ -207,7 +210,7 @@
 
   memset(zero_page, 0, PAGE_SIZE);
 
-  SetupBht(total_blocks, "sha256");
+  SetupBht(total_blocks, "sha256", NULL);
   dm_bht_set_root_hexdigest(bht_.get(),
                             reinterpret_cast<const u8 *>(kRootDigest));
 
@@ -231,7 +234,7 @@
 
   memset(zero_page, 0, PAGE_SIZE);
 
-  SetupBht(total_blocks, "sha256");
+  SetupBht(total_blocks, "sha256", NULL);
   dm_bht_set_root_hexdigest(bht_.get(),
                             reinterpret_cast<const u8 *>(kRootDigest));
 
@@ -255,7 +258,7 @@
 
   memset(zero_page, 0, PAGE_SIZE);
 
-  SetupBht(total_blocks, "sha256");
+  SetupBht(total_blocks, "sha256", NULL);
   dm_bht_set_root_hexdigest(bht_.get(),
                             reinterpret_cast<const u8 *>(kRootDigest));
 
@@ -279,7 +282,7 @@
 
   memset(zero_page, 0, PAGE_SIZE);
 
-  SetupBht(total_blocks, "sha256");
+  SetupBht(total_blocks, "sha256", NULL);
 
   dm_bht_set_root_hexdigest(bht_.get(),
                             reinterpret_cast<const u8 *>(kRootDigest));
@@ -320,7 +323,7 @@
 
 TEST_F(MemoryBhtTest, CreateThenVerifyBadDataBlock) {
   static const unsigned int total_blocks = 384;
-  SetupBht(total_blocks, "sha256");
+  SetupBht(total_blocks, "sha256", NULL);
   // Set the root hash for a 0-filled image
   static const char kRootDigest[] =
     "45d65d6f9e5a962f4d80b5f1bd7a918152251c27bdad8c5f52b590c129833372";
@@ -342,3 +345,55 @@
   EXPECT_EQ(0, dm_bht_destroy(bht_.get()));
   free(bad_page);
 }
+
+TEST_F(MemoryBhtTest, CreateThenVerifyOkSalt) {
+  static const unsigned int total_blocks = 16384;
+  // Set the root hash for a 0-filled image
+  static const char kRootDigest[] =
+    "8015fea349568f5135ecc833bbc79c9179377207382b53c68d93190b286b1256";
+  static const char salt[] =
+    "01ad1f06255d452d91337bf037953053cc3e452541db4b8ca05811bf3e2b6027";
+  // A page of all zeros
+  u8 *zero_page = (u8 *)my_memalign(PAGE_SIZE, PAGE_SIZE);
+
+  memset(zero_page, 0, PAGE_SIZE);
+
+  SetupBht(total_blocks, "sha256", salt);
+  dm_bht_set_root_hexdigest(bht_.get(),
+                            reinterpret_cast<const u8 *>(kRootDigest));
+
+  for (unsigned int blocks = 0; blocks < total_blocks; ++blocks) {
+    DLOG(INFO) << "verifying block: " << blocks;
+    EXPECT_EQ(0, dm_bht_verify_block(bht_.get(), blocks,
+                                     virt_to_page(zero_page), 0));
+  }
+
+  EXPECT_EQ(0, dm_bht_destroy(bht_.get()));
+  free(zero_page);
+}
+
+TEST_F(MemoryBhtTest, CreateThenVerifyOkLongSalt) {
+  static const unsigned int total_blocks = 16384;
+  // Set the root hash for a 0-filled image
+  static const char kRootDigest[] =
+    "8015fea349568f5135ecc833bbc79c9179377207382b53c68d93190b286b1256";
+  static const char salt[] =
+    "01ad1f06255d452d91337bf037953053cc3e452541db4b8ca05811bf3e2b6027b2188a1d";
+  // A page of all zeros
+  u8 *zero_page = (u8 *)my_memalign(PAGE_SIZE, PAGE_SIZE);
+
+  memset(zero_page, 0, PAGE_SIZE);
+
+  SetupBht(total_blocks, "sha256", salt);
+  dm_bht_set_root_hexdigest(bht_.get(),
+                            reinterpret_cast<const u8 *>(kRootDigest));
+
+  for (unsigned int blocks = 0; blocks < total_blocks; ++blocks) {
+    DLOG(INFO) << "verifying block: " << blocks;
+    EXPECT_EQ(0, dm_bht_verify_block(bht_.get(), blocks,
+                                     virt_to_page(zero_page), 0));
+  }
+
+  EXPECT_EQ(0, dm_bht_destroy(bht_.get()));
+  free(zero_page);
+}
diff --git a/file_hasher.h b/file_hasher.h
index 2ff920b..142be69 100644
--- a/file_hasher.h
+++ b/file_hasher.h
@@ -35,6 +35,12 @@
   // Print a table to stdout which contains a dmsetup compatible format
   virtual void PrintTable(bool colocated);
 
+  virtual void set_salt(const char *salt) {
+    dm_bht_set_salt(&tree_, salt);
+    salt_ = salt;
+  }
+  virtual const char *salt(void) { return salt_; }
+
   static int WriteCallback(void *file,
                            sector_t start,
                            u8 *dst,
@@ -45,6 +51,7 @@
   simple_file::File *destination_;
   unsigned int block_limit_;
   const char *alg_;
+  const char *salt_;
   struct dm_bht tree_;
 };
 
diff --git a/verity_main.cc b/verity_main.cc
index a0dce06..0aa45fc 100644
--- a/verity_main.cc
+++ b/verity_main.cc
@@ -25,6 +25,7 @@
 "  payload_blocks    Size of the image, in blocks (4096 bytes)\n"
 "  hashtree          Path to a hash tree to create or read from\n"
 "  root_hexdigest    Digest of the root node (in hex) for verification\n"
+"  salt              Salt (in hex)\n"
 "\n", name);
 }
 
@@ -38,7 +39,8 @@
 static int verity_create(const char *alg,
                          const char *image_path,
                          unsigned int image_blocks,
-                         const char *hash_path);
+                         const char *hash_path,
+                         const char *salt);
 
 void splitarg(char *arg, char **key, char **val) {
   char *sp = NULL;
@@ -51,6 +53,7 @@
   const char *alg = NULL;
   const char *payload = NULL;
   const char *hashtree = NULL;
+  const char *salt = NULL;
   unsigned int payload_blocks = 0;
   int i;
   char *key, *val;
@@ -71,6 +74,8 @@
     else if (!strcmp(key, "mode"))
       // Silently drop the mode for now...
       ;
+    else if (!strcmp(key, "salt"))
+      salt = val;
     else {
       fprintf(stderr, "bogus key: '%s'\n", key);
       print_usage(argv[0]);
@@ -88,7 +93,7 @@
   }
 
   if (mode == VERITY_CREATE) {
-    return verity_create(alg, payload, payload_blocks, hashtree);
+    return verity_create(alg, payload, payload_blocks, hashtree, salt);
   } else {
     LOG(FATAL) << "Verification not done yet";
   }
@@ -98,7 +103,8 @@
 static int verity_create(const char *alg,
                          const char *image_path,
                          unsigned int image_blocks,
-                         const char *hash_path) {
+                         const char *hash_path,
+                         const char *salt) {
   // Configure files
   simple_file::Env env;
 
@@ -118,6 +124,8 @@
                                    image_blocks,
                                    alg))
     << "Failed to initialize hasher";
+  if (salt)
+    hasher.set_salt(salt);
   LOG_IF(FATAL, !hasher.Hash());
   LOG_IF(FATAL, !hasher.Store());
   hasher.PrintTable(true);