vboot2: use enum hash algorithm

This changes the internals of vboot2 to use the enumerated type for
hash algorithm.  The conversion from crypto algorithm is done only
when unpacking the key (and ok, in checking the rsa padding, but that
goes away in the next change).  This is preparation for the vboot2
data types, which separate signature and hash algorithms into their
own fields.

There is no external change in the calling API to vboot, and no change
to the external data structures.

BUG=chromium:423882
BRANCH=none
TEST=VBOOT2=1 make runtests

Change-Id: I9c6de08d742dab941beb806fbd2bfc1e11c01e2c
Signed-off-by: Randall Spangler <rspangler@chromium.org>
Reviewed-on: https://chromium-review.googlesource.com/225208
Reviewed-by: Daisuke Nojiri <dnojiri@chromium.org>
Reviewed-by: Bill Richardson <wfrichar@chromium.org>
diff --git a/firmware/2lib/2api.c b/firmware/2lib/2api.c
index 8948093..1f128a4 100644
--- a/firmware/2lib/2api.c
+++ b/firmware/2lib/2api.c
@@ -190,7 +190,7 @@
 	if (size)
 		*size = pre->body_signature.data_size;
 
-	return vb2_digest_init(dc, key.algorithm);
+	return vb2_digest_init(dc, key.hash_alg);
 }
 
 int vb2api_extend_hash(struct vb2_context *ctx,
@@ -222,7 +222,7 @@
 	struct vb2_workbuf wb;
 
 	uint8_t *digest;
-	uint32_t digest_size = vb2_digest_size(dc->algorithm);
+	uint32_t digest_size = vb2_digest_size(dc->hash_alg);
 
 	struct vb2_fw_preamble *pre;
 	struct vb2_public_key key;
diff --git a/firmware/2lib/2common.c b/firmware/2lib/2common.c
index 21c42a3..0da3a61 100644
--- a/firmware/2lib/2common.c
+++ b/firmware/2lib/2common.c
@@ -192,10 +192,18 @@
 	if (rv)
 		return rv;
 
+	/* Check key algorithm */
 	if (packed_key->algorithm >= VB2_ALG_COUNT) {
 		VB2_DEBUG("Invalid algorithm.\n");
 		return VB2_ERROR_UNPACK_KEY_ALGORITHM;
 	}
+	key->algorithm = packed_key->algorithm;
+
+	key->hash_alg = vb2_crypto_to_hash(packed_key->algorithm);
+	if (key->hash_alg == VB2_HASH_INVALID) {
+		VB2_DEBUG("Unsupported hash algorithm.\n");
+		return VB2_ERROR_UNPACK_KEY_HASH_ALGORITHM;
+	}
 
 	expected_key_size = vb2_packed_key_size(packed_key->algorithm);
 	if (!expected_key_size || expected_key_size != packed_key->key_size) {
@@ -220,8 +228,6 @@
 	key->n = buf32 + 2;
 	key->rr = buf32 + 2 + key->arrsize;
 
-	key->algorithm = packed_key->algorithm;
-
 	return VB2_SUCCESS;
 }
 
@@ -264,7 +270,7 @@
 	}
 
 	/* Digest goes at start of work buffer */
-	digest_size = vb2_digest_size(key->algorithm);
+	digest_size = vb2_digest_size(key->hash_alg);
 	if (!digest_size)
 		return VB2_ERROR_VDATA_DIGEST_SIZE;
 
@@ -277,7 +283,7 @@
 	if (!dc)
 		return VB2_ERROR_VDATA_WORKBUF_HASHING;
 
-	rv = vb2_digest_init(dc, key->algorithm);
+	rv = vb2_digest_init(dc, key->hash_alg);
 	if (rv)
 		return rv;
 
diff --git a/firmware/2lib/2rsa.c b/firmware/2lib/2rsa.c
index 1df9115..15951c5 100644
--- a/firmware/2lib/2rsa.c
+++ b/firmware/2lib/2rsa.c
@@ -225,33 +225,25 @@
 int vb2_check_padding(uint8_t *sig, int algorithm)
 {
 	/* Determine padding to use depending on the signature type */
+	uint32_t hash_alg = vb2_crypto_to_hash(algorithm);
 	uint32_t pad_size = vb2_rsa_sig_size(algorithm) -
-		vb2_digest_size(algorithm);
+		vb2_digest_size(hash_alg);
 	const uint8_t *tail;
 	uint32_t tail_size;
 	int result = 0;
 
 	int i;
 
-	switch (algorithm) {
-	case VB2_ALG_RSA1024_SHA1:
-	case VB2_ALG_RSA2048_SHA1:
-	case VB2_ALG_RSA4096_SHA1:
-	case VB2_ALG_RSA8192_SHA1:
+	switch (hash_alg) {
+	case VB2_HASH_SHA1:
 		tail = sha1_tail;
 		tail_size = sizeof(sha1_tail);
 		break;
-	case VB2_ALG_RSA1024_SHA256:
-	case VB2_ALG_RSA2048_SHA256:
-	case VB2_ALG_RSA4096_SHA256:
-	case VB2_ALG_RSA8192_SHA256:
+	case VB2_HASH_SHA256:
 		tail = sha256_tail;
 		tail_size = sizeof(sha256_tail);
 		break;
-	case VB2_ALG_RSA1024_SHA512:
-	case VB2_ALG_RSA2048_SHA512:
-	case VB2_ALG_RSA4096_SHA512:
-	case VB2_ALG_RSA8192_SHA512:
+	case VB2_HASH_SHA512:
 		tail = sha512_tail;
 		tail_size = sizeof(sha512_tail);
 		break;
@@ -321,7 +313,7 @@
 	 * we don't return before this check if the padding check failed.)
 	 */
 	pad_size = vb2_rsa_sig_size(key->algorithm) -
-		vb2_digest_size(key->algorithm);
+		vb2_digest_size(key->hash_alg);
 
 	if (vb2_safe_memcmp(sig + pad_size, digest, key_bytes - pad_size)) {
 		VB2_DEBUG("Digest check failed!\n");
diff --git a/firmware/2lib/2sha_utility.c b/firmware/2lib/2sha_utility.c
index 52492ab..47581ea 100644
--- a/firmware/2lib/2sha_utility.c
+++ b/firmware/2lib/2sha_utility.c
@@ -28,7 +28,7 @@
 #define CTH_SHA512 VB2_HASH_INVALID
 #endif
 
-static const uint8_t crypto_to_hash[VB2_ALG_COUNT] = {
+static const uint8_t crypto_to_hash[] = {
 	CTH_SHA1,
 	CTH_SHA256,
 	CTH_SHA512,
@@ -52,17 +52,17 @@
  * the crypto algorithm or its corresponding hash algorithm is invalid or not
  * supported.
  */
-enum vb2_hash_algorithm vb2_hash_algorithm(uint32_t algorithm)
+enum vb2_hash_algorithm vb2_crypto_to_hash(uint32_t algorithm)
 {
-	if (algorithm < VB2_ALG_COUNT)
+	if (algorithm < ARRAY_SIZE(crypto_to_hash))
 		return crypto_to_hash[algorithm];
 	else
 		return VB2_HASH_INVALID;
 }
 
-int vb2_digest_size(uint32_t algorithm)
+int vb2_digest_size(enum vb2_hash_algorithm hash_alg)
 {
-	switch (vb2_hash_algorithm(algorithm)) {
+	switch (hash_alg) {
 #if VB2_SUPPORT_SHA1
 	case VB2_HASH_SHA1:
 		return VB2_SHA1_DIGEST_SIZE;
@@ -80,11 +80,12 @@
 	}
 }
 
-int vb2_digest_init(struct vb2_digest_context *dc, uint32_t algorithm)
+int vb2_digest_init(struct vb2_digest_context *dc,
+		    enum vb2_hash_algorithm hash_alg)
 {
-	dc->algorithm = algorithm;
+	dc->hash_alg = hash_alg;
 
-	switch (vb2_hash_algorithm(dc->algorithm)) {
+	switch (dc->hash_alg) {
 #if VB2_SUPPORT_SHA1
 	case VB2_HASH_SHA1:
 		vb2_sha1_init(&dc->sha1);
@@ -109,7 +110,7 @@
 		      const uint8_t *buf,
 		      uint32_t size)
 {
-	switch (vb2_hash_algorithm(dc->algorithm)) {
+	switch (dc->hash_alg) {
 #if VB2_SUPPORT_SHA1
 	case VB2_HASH_SHA1:
 		vb2_sha1_update(&dc->sha1, buf, size);
@@ -134,10 +135,10 @@
 			uint8_t *digest,
 			uint32_t digest_size)
 {
-	if (digest_size < vb2_digest_size(dc->algorithm))
+	if (digest_size < vb2_digest_size(dc->hash_alg))
 		return VB2_ERROR_SHA_FINALIZE_DIGEST_SIZE;
 
-	switch (vb2_hash_algorithm(dc->algorithm)) {
+	switch (dc->hash_alg) {
 #if VB2_SUPPORT_SHA1
 	case VB2_HASH_SHA1:
 		vb2_sha1_finalize(&dc->sha1, digest);
diff --git a/firmware/2lib/include/2common.h b/firmware/2lib/include/2common.h
index c1b9861..dcf799a 100644
--- a/firmware/2lib/include/2common.h
+++ b/firmware/2lib/include/2common.h
@@ -23,6 +23,11 @@
 #define VB2_MAX(A, B) ((A) > (B) ? (A) : (B))
 #endif
 
+/* Return the number of elements in an array */
+#ifndef ARRAY_SIZE
+#define ARRAY_SIZE(array) (sizeof(array)/sizeof(array[0]))
+#endif
+
 /*
  * Debug output. printf() for tests. otherwise, it's platform-dependent.
  */
diff --git a/firmware/2lib/include/2return_codes.h b/firmware/2lib/include/2return_codes.h
index b530bcd..4070f00 100644
--- a/firmware/2lib/include/2return_codes.h
+++ b/firmware/2lib/include/2return_codes.h
@@ -161,6 +161,9 @@
 	 */
 	VB2_ERROR_VDATA_DIGEST_SIZE,
 
+	/* Unsupported hash algorithm in vb2_unpack_key() */
+	VB2_ERROR_UNPACK_KEY_HASH_ALGORITHM,
+
         /**********************************************************************
 	 * Keyblock verification errors (all in vb2_verify_keyblock())
 	 */
diff --git a/firmware/2lib/include/2rsa.h b/firmware/2lib/include/2rsa.h
index 47225ca..5409ce3 100644
--- a/firmware/2lib/include/2rsa.h
+++ b/firmware/2lib/include/2rsa.h
@@ -7,6 +7,7 @@
 #define VBOOT_REFERENCE_2RSA_H_
 
 #include "2crypto.h"
+#include "2struct.h"
 
 struct vb2_workbuf;
 
@@ -17,6 +18,7 @@
 	const uint32_t *n;   /* Modulus as little endian array */
 	const uint32_t *rr;  /* R^2 as little endian array */
 	uint32_t algorithm;  /* Algorithm to use when verifying with the key */
+	enum vb2_hash_algorithm hash_alg;  /* Hash algorithm */
 };
 
 /**
diff --git a/firmware/2lib/include/2sha.h b/firmware/2lib/include/2sha.h
index 675fc66..5879236 100644
--- a/firmware/2lib/include/2sha.h
+++ b/firmware/2lib/include/2sha.h
@@ -7,6 +7,7 @@
 #define VBOOT_REFERENCE_2SHA_H_
 
 #include "2crypto.h"
+#include "2struct.h"
 
 /* Hash algorithms may be disabled individually to save code space */
 
@@ -75,8 +76,8 @@
 #endif
 	};
 
-	/* Current hash algorithm (enum vb2_crypto_algorithm) */
-	uint32_t algorithm;
+	/* Current hash algorithm */
+	enum vb2_hash_algorithm hash_alg;
 };
 
 /**
@@ -124,24 +125,25 @@
  * the crypto algorithm or its corresponding hash algorithm is invalid or not
  * supported.
  */
-enum vb2_hash_algorithm vb2_hash_algorithm(uint32_t algorithm);
+enum vb2_hash_algorithm vb2_crypto_to_hash(uint32_t algorithm);
 
 /**
- * Return the size of the digest for a key algorithm.
+ * Return the size of the digest for a hash algorithm.
  *
- * @param algorithm	Key algorithm (enum vb2_crypto_algorithm)
+ * @param hash_alg	Hash algorithm
  * @return The size of the digest, or 0 if error.
  */
-int vb2_digest_size(uint32_t algorithm);
+int vb2_digest_size(enum vb2_hash_algorithm hash_alg);
 
 /**
  * Initialize a digest context for doing block-style digesting.
  *
  * @param dc		Digest context
- * @param algorithm	Key algorithm (enum vb2_crypto_algorithm)
+ * @param hash_alg	Hash algorithm
  * @return VB2_SUCCESS, or non-zero on error.
  */
-int vb2_digest_init(struct vb2_digest_context *dc, uint32_t algorithm);
+int vb2_digest_init(struct vb2_digest_context *dc,
+		    enum vb2_hash_algorithm hash_alg);
 
 /**
  * Extend a digest's hash with another block of data.
diff --git a/tests/vb2_api_tests.c b/tests/vb2_api_tests.c
index ebd228e..49b59b2 100644
--- a/tests/vb2_api_tests.c
+++ b/tests/vb2_api_tests.c
@@ -26,6 +26,7 @@
 const char mock_body[320] = "Mock body";
 const int mock_body_size = sizeof(mock_body);
 const int mock_algorithm = VB2_ALG_RSA2048_SHA256;
+const int mock_hash_alg = VB2_HASH_SHA256;
 const int mock_sig_size = 64;
 
 /* Mocked function data */
@@ -140,16 +141,18 @@
 		return VB2_ERROR_UNPACK_KEY_SIZE;
 
 	key->algorithm = k->algorithm;
+	key->hash_alg = vb2_crypto_to_hash(k->algorithm);
 
 	return VB2_SUCCESS;
 }
 
-int vb2_digest_init(struct vb2_digest_context *dc, uint32_t algorithm)
+int vb2_digest_init(struct vb2_digest_context *dc,
+		    enum vb2_hash_algorithm hash_alg)
 {
-	if (algorithm != mock_algorithm)
+	if (hash_alg != mock_hash_alg)
 		return VB2_ERROR_SHA_INIT_ALGORITHM;
 
-	dc->algorithm = algorithm;
+	dc->hash_alg = hash_alg;
 
 	return VB2_SUCCESS;
 }
@@ -158,7 +161,7 @@
 		      const uint8_t *buf,
 		      uint32_t size)
 {
-	if (dc->algorithm != mock_algorithm)
+	if (dc->hash_alg != mock_hash_alg)
 		return VB2_ERROR_SHA_EXTEND_ALGORITHM;
 
 	return VB2_SUCCESS;
@@ -370,7 +373,7 @@
 	reset_common_data(FOR_EXTEND_HASH);
 	dc = (struct vb2_digest_context *)
 		(cc.workbuf + sd->workbuf_hash_offset);
-	dc->algorithm++;
+	dc->hash_alg = mock_hash_alg + 1;
 	TEST_EQ(vb2api_extend_hash(&cc, mock_body, mock_body_size),
 		VB2_ERROR_SHA_EXTEND_ALGORITHM, "hash extend fail");
 }
diff --git a/tests/vb2_common2_tests.c b/tests/vb2_common2_tests.c
index 6c3a3e4..2afa9f9 100644
--- a/tests/vb2_common2_tests.c
+++ b/tests/vb2_common2_tests.c
@@ -36,6 +36,9 @@
 	TEST_SUCC(vb2_unpack_key(&rsa, buf, size), "vb2_unpack_key() ok");
 
 	TEST_EQ(rsa.algorithm, key2->algorithm, "vb2_unpack_key() algorithm");
+	TEST_EQ(rsa.hash_alg, vb2_crypto_to_hash(key2->algorithm),
+		"vb2_unpack_key() hash_alg");
+
 
 	PublicKeyCopy(key, orig_key);
 	key2->algorithm = VB2_ALG_COUNT;
diff --git a/tests/vb2_rsa_padding_tests.c b/tests/vb2_rsa_padding_tests.c
index e85b54f..3dd3316 100644
--- a/tests/vb2_rsa_padding_tests.c
+++ b/tests/vb2_rsa_padding_tests.c
@@ -34,6 +34,7 @@
 	k2->n = key->n;
 	k2->rr = key->rr;
 	k2->algorithm = key->algorithm;
+	k2->hash_alg = vb2_crypto_to_hash(key->algorithm);
 }
 
 /**
diff --git a/tests/vb2_sha_tests.c b/tests/vb2_sha_tests.c
index c60bbd1..501f90e 100644
--- a/tests/vb2_sha_tests.c
+++ b/tests/vb2_sha_tests.c
@@ -14,15 +14,15 @@
 #include "test_common.h"
 
 static int vb2_digest(const uint8_t *buf,
-	       uint32_t size,
-	       uint32_t algorithm,
-	       uint8_t *digest,
-	       uint32_t digest_size)
+		      uint32_t size,
+		      enum vb2_hash_algorithm hash_alg,
+		      uint8_t *digest,
+		      uint32_t digest_size)
 {
 	struct vb2_digest_context dc;
 	int rv;
 
-	rv = vb2_digest_init(&dc, algorithm);
+	rv = vb2_digest_init(&dc, hash_alg);
 	if (rv)
 		return rv;
 
@@ -46,15 +46,14 @@
 	for (i = 0; i < 3; i++) {
 		TEST_SUCC(vb2_digest(test_inputs[i],
 				     strlen((char *)test_inputs[i]),
-				     VB2_ALG_RSA1024_SHA1, digest,
-				     sizeof(digest)),
+				     VB2_HASH_SHA1, digest, sizeof(digest)),
 			  "vb2_digest() SHA1");
 		TEST_EQ(memcmp(digest, sha1_results[i], sizeof(digest)),
 			0, "SHA1 digest");
 	}
 
 	TEST_EQ(vb2_digest(test_inputs[0], strlen((char *)test_inputs[0]),
-			    VB2_ALG_RSA1024_SHA1, digest, sizeof(digest) - 1),
+			   VB2_HASH_SHA1, digest, sizeof(digest) - 1),
 		VB2_ERROR_SHA_FINALIZE_DIGEST_SIZE, "vb2_digest() too small");
 }
 
@@ -71,15 +70,14 @@
 	for (i = 0; i < 3; i++) {
 		TEST_SUCC(vb2_digest(test_inputs[i],
 				     strlen((char *)test_inputs[i]),
-				     VB2_ALG_RSA1024_SHA256, digest,
-				     sizeof(digest)),
+				     VB2_HASH_SHA256, digest, sizeof(digest)),
 			  "vb2_digest() SHA256");
 		TEST_EQ(memcmp(digest, sha256_results[i], sizeof(digest)),
 			0, "SHA-256 digest");
 	}
 
 	TEST_EQ(vb2_digest(test_inputs[0], strlen((char *)test_inputs[0]),
-			   VB2_ALG_RSA1024_SHA256, digest, sizeof(digest) - 1),
+			   VB2_HASH_SHA256, digest, sizeof(digest) - 1),
 		VB2_ERROR_SHA_FINALIZE_DIGEST_SIZE, "vb2_digest() too small");
 }
 
@@ -96,7 +94,7 @@
 	for (i = 0; i < 3; i++) {
 		TEST_SUCC(vb2_digest(test_inputs[i],
 				     strlen((char *)test_inputs[i]),
-				     VB2_ALG_RSA1024_SHA512, digest,
+				     VB2_HASH_SHA512, digest,
 				     sizeof(digest)),
 			  "vb2_digest() SHA512");
 		TEST_EQ(memcmp(digest, sha512_results[i], sizeof(digest)),
@@ -104,7 +102,7 @@
 	}
 
 	TEST_EQ(vb2_digest(test_inputs[0], strlen((char *)test_inputs[0]),
-			   VB2_ALG_RSA1024_SHA512, digest, sizeof(digest) - 1),
+			   VB2_HASH_SHA512, digest, sizeof(digest) - 1),
 		VB2_ERROR_SHA_FINALIZE_DIGEST_SIZE, "vb2_digest() too small");
 }
 
@@ -113,16 +111,29 @@
 	uint8_t digest[VB2_SHA512_DIGEST_SIZE];
 	struct vb2_digest_context dc;
 
-	TEST_EQ(vb2_digest_size(VB2_ALG_COUNT), 0, "digest size invalid alg");
+	/* Crypto algorithm to hash algorithm mapping */
+	TEST_EQ(vb2_crypto_to_hash(VB2_ALG_RSA1024_SHA1), VB2_HASH_SHA1,
+		"Crypto map to SHA1");
+	TEST_EQ(vb2_crypto_to_hash(VB2_ALG_RSA2048_SHA256), VB2_HASH_SHA256,
+		"Crypto map to SHA256");
+	TEST_EQ(vb2_crypto_to_hash(VB2_ALG_RSA4096_SHA256), VB2_HASH_SHA256,
+		"Crypto map to SHA256 2");
+	TEST_EQ(vb2_crypto_to_hash(VB2_ALG_RSA8192_SHA512), VB2_HASH_SHA512,
+		"Crypto map to SHA512");
+	TEST_EQ(vb2_crypto_to_hash(VB2_ALG_COUNT), VB2_HASH_INVALID,
+		"Crypto map to invalid");
+
+	TEST_EQ(vb2_digest_size(VB2_HASH_INVALID), 0,
+		"digest size invalid alg");
 
 	TEST_EQ(vb2_digest((uint8_t *)oneblock_msg, strlen(oneblock_msg),
-			   VB2_ALG_COUNT, digest, sizeof(digest)),
+			   VB2_HASH_INVALID, digest, sizeof(digest)),
 		VB2_ERROR_SHA_INIT_ALGORITHM,
 		"vb2_digest() invalid alg");
 
 	/* Test bad algorithm inside extend and finalize */
-	vb2_digest_init(&dc, VB2_ALG_RSA1024_SHA1);
-	dc.algorithm = VB2_ALG_COUNT;
+	vb2_digest_init(&dc, VB2_HASH_SHA256);
+	dc.hash_alg = VB2_HASH_INVALID;
 	TEST_EQ(vb2_digest_extend(&dc, digest, sizeof(digest)),
 		VB2_ERROR_SHA_EXTEND_ALGORITHM,
 		"vb2_digest_extend() invalid alg");