[crypto] Add AES-GCM encryption.

Using the GHASH and GCTR building blocks, add the top-level
authenticated encryption algorithm as specified in NIST SP800-38D.

Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/sw/device/lib/crypto/aes_gcm/BUILD b/sw/device/lib/crypto/aes_gcm/BUILD
index 48a1bfa..17c73b4 100644
--- a/sw/device/lib/crypto/aes_gcm/BUILD
+++ b/sw/device/lib/crypto/aes_gcm/BUILD
@@ -9,6 +9,7 @@
     srcs = ["aes_gcm.c"],
     hdrs = ["aes_gcm.h"],
     deps = [
+        "//sw/device/lib/base:hardened",
         "//sw/device/lib/base:memory",
         "//sw/device/lib/crypto/drivers:aes",
     ],
diff --git a/sw/device/lib/crypto/aes_gcm/aes_gcm.c b/sw/device/lib/crypto/aes_gcm/aes_gcm.c
index 1885bb8..6c793c2 100644
--- a/sw/device/lib/crypto/aes_gcm/aes_gcm.c
+++ b/sw/device/lib/crypto/aes_gcm/aes_gcm.c
@@ -7,6 +7,7 @@
 #include <stddef.h>
 #include <stdint.h>
 
+#include "sw/device/lib/base/hardened.h"
 #include "sw/device/lib/base/macros.h"
 #include "sw/device/lib/base/memory.h"
 #include "sw/device/lib/crypto/drivers/aes.h"
@@ -56,8 +57,8 @@
  *
  * @param x First operand block
  * @param y Second operand block
- * @param out Buffer in which to store output; can be the same as one or both
- * operands.
+ * @param[out] out Buffer in which to store output; can be the same as one or
+ * both operands.
  */
 static inline void block_xor(const aes_block_t *x, const aes_block_t *y,
                              aes_block_t *out) {
@@ -141,15 +142,23 @@
 }
 
 /**
- * Get the number of bytes past the last full block of the buffer.
+ * Get the size of the last block for a given input size.
  *
- * Equivalent to `sz % kAesBlockNumBytes`.
+ * Equivalent to `sz % kAesBlockNumBytes`, except that if `sz` is a multiple of
+ * `kAesBlockNumBytes` then this will return `kAesBlockNumBytes` (since the
+ * last block would then be a full block).
  *
- * @param sz Number of bytes to represent
- * @return Offset of end of buffer from last full block
+ * Assumes that `sz` is nonzero.
+ *
+ * @param sz Number of bytes in the buffer, must be nonzero
+ * @return Number of bytes in the last block
  */
-static inline size_t get_block_offset(size_t sz) {
-  return sz & ((1 << kAesBlockLog2NumBytes) - 1);
+static inline size_t get_last_block_num_bytes(size_t sz) {
+  size_t remainder = sz % kAesBlockNumBytes;
+  if (remainder == 0) {
+    return kAesBlockNumBytes;
+  }
+  return remainder;
 }
 
 /**
@@ -163,7 +172,7 @@
  */
 static inline size_t get_nblocks(size_t sz) {
   size_t out = sz >> kAesBlockLog2NumBytes;
-  if (get_block_offset(sz) != 0) {
+  if (get_last_block_num_bytes(sz) != kAesBlockNumBytes) {
     out += 1;
   }
   return out;
@@ -186,6 +195,20 @@
 }
 
 /**
+ * Implements the inc32() function on blocks from NIST SP800-38D.
+ *
+ * Interprets the last (rightmost) 32 bits of the block as a big-endian integer
+ * and increments this value modulo 2^32 in-place.
+ *
+ * @param block AES block, modified in place.
+ */
+static inline void block_inc32(aes_block_t *block) {
+  // Set the last word to the incremented value.
+  block->data[kAesBlockNumWords - 1] =
+      word_inc32(block->data[kAesBlockNumWords - 1]);
+}
+
+/**
  * Multiply an element of the GCM Galois field by the polynomial `x`.
  *
  * This corresponds to a shift right in the bit representation, and then
@@ -194,7 +217,7 @@
  * Runs in constant time.
  *
  * @param p Polynomial to be multiplied
- * @param out Buffer for output
+ * @param[out] out Buffer for output
  */
 static inline void galois_mulx(const aes_block_t *p, aes_block_t *out) {
   // Get the very rightmost bit of the input block (coefficient of x^127).
@@ -215,7 +238,7 @@
  * interpreted as Galois field polynomials.
  *
  * @param H Hash subkey
- * @param product_table Buffer for output
+ * @param[out] product_table Buffer for output
  */
 static void make_product_table(const aes_block_t *H,
                                aes_gcm_product_table_t *tbl) {
@@ -257,7 +280,7 @@
  *
  * @param p Polynomial to multiply
  * @param tbl Precomputed product table for the hash subkey
- * @param result Block in which to store output
+ * @param[out] result Block in which to store output
  */
 static void galois_mul(const aes_block_t *p, const aes_gcm_product_table_t *tbl,
                        aes_block_t *result) {
@@ -335,9 +358,9 @@
   for (size_t i = 0; i < nblocks; ++i) {
     // Construct block i of the input.
     aes_block_t input_block;
-    if ((i == nblocks - 1) && (get_block_offset(input_len) != 0)) {
-      // Last block is not full; pad with zeroes.
-      size_t nbytes = get_block_offset(input_len);
+    if (i == nblocks - 1) {
+      // Last block may be partial; pad with zeroes.
+      size_t nbytes = get_last_block_num_bytes(input_len);
       memset(input_block.data, 0, kAesBlockNumBytes);
       memcpy(input_block.data, input, nbytes);
     } else {
@@ -362,7 +385,7 @@
   make_product_table(hash_subkey, &tbl);
 
   // If the input length is not a multiple of the block size, fail.
-  if (get_block_offset(input_len) != 0) {
+  if (get_last_block_num_bytes(input_len) != kAesBlockNumBytes) {
     return kAesInternalError;
   }
 
@@ -410,12 +433,13 @@
  * @param icb Initial counter block, 128 bits
  * @param len Number of bytes for input and output
  * @param input Pointer to input buffer (may be NULL if `len` is 0)
- * @param output Pointer to output buffer (same size as input, may be the same
- * buffer)
+ * @param[out] output Pointer to output buffer (same size as input, may be the
+ * same buffer)
  */
-aes_error_t aes_gcm_gctr(const aes_key_len_t key_len, uint32_t *key_shares[2],
-                         const aes_block_t *icb, const size_t len,
-                         const uint8_t *input, uint8_t *output) {
+aes_error_t aes_gcm_gctr(const aes_key_len_t key_len,
+                         const uint32_t *key_shares[2], const aes_block_t *icb,
+                         const size_t len, const uint8_t *input,
+                         uint8_t *output) {
   // If the input is empty, the output must be as well. Since the output length
   // is 0, simply return.
   if (len == 0) {
@@ -446,21 +470,21 @@
     // Retrieve the next block of input. All blocks are full-size except for
     // the last block, which may be partial. If the block is partial, the input
     // data will be padded with zeroes.
-    aes_block_t block_in = {.data = {0}};
+    aes_block_t block_in;
     size_t nbytes = kAesBlockNumBytes;
-    if ((i == nblocks - 1) && (get_block_offset(len) != 0)) {
-      // Last block is partial; copy over only the bytes that exist.
-      nbytes = get_block_offset(len);
+    if (i == nblocks - 1) {
+      // Last block is partial; copy the bytes that exist and set the rest of
+      // the block to 0.
+      nbytes = get_last_block_num_bytes(len);
+      memset(block_in.data, 0, kAesBlockNumBytes);
       memcpy(block_in.data, &input[i * kAesBlockNumBytes], nbytes);
     } else {
       // This block is a full block.
       memcpy(block_in.data, &input[i * kAesBlockNumBytes], kAesBlockNumBytes);
     }
 
-    // Allocate a buffer for the cipher output.
-    aes_block_t block_out = {.data = {0}};
-
     // Run the AES-CTR encryption operation on the next block of input.
+    aes_block_t block_out;
     aes_error_t err = aes_single_block(aes_ctr_params, &block_in, &block_out);
     if (err != kAesOk) {
       return err;
@@ -477,3 +501,210 @@
 
   return kAesOk;
 }
+
+/**
+ * Verify that the lengths of AES-GCM parameters are acceptable.
+ *
+ * This routine can be used for both authenticated encryption and authenticated
+ * decryption; the lengths of the plaintext and ciphertext always match, and
+ * `plaintext_len` may represent either.
+ *
+ * @param iv_len IV length in bytes
+ * @param plaintext_len Plaintext/ciphertext length in bytes
+ * @param aad_len Associated data length in bytes
+ */
+static hardened_bool_t check_buffer_lengths(const size_t iv_len,
+                                            const size_t plaintext_len,
+                                            const size_t aad_len) {
+  // Check IV length (must be 96 or 128 bits = 12 or 16 bytes).
+  if (iv_len != 12 && iv_len != 16) {
+    return kHardenedBoolFalse;
+  }
+
+  // Check plaintext/AAD length. Both must be less than 2^32 bytes long. This
+  // is stricter than NIST requires, but SP800-38D also allows implementations
+  // to stipulate lower length limits.
+  if (plaintext_len > UINT32_MAX || aad_len > UINT32_MAX) {
+    return kHardenedBoolFalse;
+  }
+
+  return kHardenedBoolTrue;
+}
+
+/**
+ * Compute the hash subkey for AES-GCM.
+ *
+ * This routine computes the hash subkey H and the product table for H; see
+ * `make_product_table` for representation details.
+ *
+ * If any step in this process fails, the function returns an error and the
+ * output should not be used.
+ *
+ * @param key_len length of key
+ * @param key_shares key, expressed in two shares
+ * @param[out] tbl Destination for the output hash subkey product table
+ * @return OK or error
+ */
+static aes_error_t aes_gcm_hash_subkey(const aes_key_len_t key_len,
+                                       const uint32_t *key_shares[2],
+                                       aes_gcm_product_table_t *tbl) {
+  // Set AES parameters to perform AES-CTR encryption with IV=0.
+  aes_params_t aes_ctr_params = {
+      .encrypt = true,
+      .mode = kAesCipherModeCtr,
+      .key_len = key_len,
+      .key = {key_shares[0], key_shares[1]},
+      .iv = {0},
+  };
+
+  // Compute the initial hash subkey H = AES_K(0). Note that to get this
+  // result from AES_CTR, we set both the IV and plaintext to zero; this way,
+  // AES-CTR's final XOR with the plaintext does nothing.
+  aes_block_t zero;
+  memset(zero.data, 0, kAesBlockNumBytes);
+  aes_block_t hash_subkey;
+  aes_error_t err = aes_single_block(aes_ctr_params, &zero, &hash_subkey);
+  if (err != kAesOk) {
+    return err;
+  }
+
+  // Compute the product table for H.
+  make_product_table(&hash_subkey, tbl);
+
+  return kAesOk;
+}
+
+/**
+ * Compute the counter block based on the given IV and hash subkey.
+ *
+ * This block is called J0 in the NIST documentation, and is the same for both
+ * encryption and decryption.
+ *
+ * @param iv_len IV length in bytes
+ * @param iv IV value
+ * @param tbl Product table for the hash subkey H
+ * @param[out] j0 Destination for the output counter block
+ * @return OK or error
+ */
+static aes_error_t aes_gcm_counter(const size_t iv_len, const uint8_t *iv,
+                                   aes_gcm_product_table_t *tbl,
+                                   aes_block_t *j0) {
+  if (iv_len == 12) {
+    // If the IV is 96 bits, then J0 = (IV || {0}^31 || 1).
+    memcpy(j0->data, iv, iv_len);
+    // Set the last word to 1 (as a big-endian integer).
+    j0->data[kAesBlockNumWords - 1] = reverse_bytes(1);
+  } else if (iv_len == 16) {
+    // If the IV is 128 bits, then J0 = GHASH(H, IV || {0}^120 || 0x80), where
+    // {0}^120 means 120 zero bits (15 0x00 bytes).
+    memset(j0->data, 0, kAesBlockNumBytes);
+    aes_gcm_ghash_update(tbl, iv_len, iv, j0);
+    uint8_t buffer[kAesBlockNumBytes];
+    memset(buffer, 0, kAesBlockNumBytes);
+    buffer[kAesBlockNumBytes - 1] = 0x80;
+    aes_gcm_ghash_update(tbl, kAesBlockNumBytes, buffer, j0);
+  } else {
+    // Should not happen; invalid IV length.
+    return kAesInternalError;
+  }
+
+  return kAesOk;
+}
+
+/**
+ * Compute the AES-GCM authentication tag.
+ *
+ * @param key_len Length of the AES key
+ * @param key_shares AES key, split into two shares
+ * @param tbl Product table for the hash subkey H
+ * @param ciphertext_len Length of the ciphertext in bytes
+ * @param ciphertext Ciphertext value
+ * @param aad_len Length of the associated data in bytes
+ * @param aad Associated data value
+ * @param j0 Counter block (J0 in the NIST specification)
+ * @param[out] tag Buffer for output tag (128 bits)
+ */
+static aes_error_t aes_gcm_compute_tag(const aes_key_len_t key_len,
+                                       const uint32_t *key_shares[2],
+                                       const aes_gcm_product_table_t *tbl,
+                                       const size_t ciphertext_len,
+                                       const uint8_t *ciphertext,
+                                       const size_t aad_len, const uint8_t *aad,
+                                       const aes_block_t *j0, uint8_t *tag) {
+  // Compute S = GHASH(H, expand(A) || expand(C) || len64(A) || len64(C))
+  // where:
+  //   * A is the aad, C is the ciphertext
+  //   * expand(x) pads x to a multiple of 128 bits by adding zeroes to the
+  //     right-hand side
+  //   * len64(x) is the length of x in bits expressed as a
+  //     big-endian 64-bit integer.
+
+  // Compute GHASH(H, expand(A) || expand(C)).
+  aes_block_t s;
+  memset(s.data, 0, kAesBlockNumBytes);
+  aes_gcm_ghash_update(tbl, aad_len, aad, &s);
+  aes_gcm_ghash_update(tbl, ciphertext_len, ciphertext, &s);
+
+  // Compute len64(A) and len64(C) by computing the length in *bits* (shift by
+  // 3) and then converting to big-endian.
+  uint64_t last_block[2] = {
+      __builtin_bswap64(((uint64_t)aad_len) * 8),
+      __builtin_bswap64(((uint64_t)ciphertext_len) * 8),
+  };
+
+  // Use memcpy() to avoid violating strict aliasing when converting to bytes.
+  uint8_t last_block_bytes[sizeof(last_block)];
+  memcpy(last_block_bytes, last_block, sizeof(last_block));
+
+  // Finish computing S by appending (len64(A) || len64(C)).
+  aes_gcm_ghash_update(tbl, kAesBlockNumBytes, last_block_bytes, &s);
+
+  // Compute the tag T = GCTR(K, J0, S).
+  uint8_t s_data_bytes[sizeof(s.data)];
+  memcpy(s_data_bytes, s.data, sizeof(s.data));
+  return aes_gcm_gctr(key_len, key_shares, j0, kAesBlockNumBytes, s_data_bytes,
+                      tag);
+}
+
+aes_error_t aes_gcm_encrypt(const aes_key_len_t key_len,
+                            const uint32_t *key_shares[2], const size_t iv_len,
+                            const uint8_t *iv, const size_t plaintext_len,
+                            const uint8_t *plaintext, const size_t aad_len,
+                            const uint8_t *aad, uint8_t *ciphertext,
+                            uint8_t *tag) {
+  // Check that the input parameter sizes are valid.
+  if (check_buffer_lengths(iv_len, plaintext_len, aad_len) !=
+      kHardenedBoolTrue) {
+    return kAesInternalError;
+  }
+
+  // Compute the hash subkey H as a product table.
+  aes_gcm_product_table_t Htbl;
+  aes_error_t err = aes_gcm_hash_subkey(key_len, key_shares, &Htbl);
+  if (err != kAesOk) {
+    return err;
+  }
+
+  // Compute the counter block (called J0 in the NIST specification).
+  aes_block_t j0;
+  err = aes_gcm_counter(iv_len, iv, &Htbl, &j0);
+  if (err != kAesOk) {
+    return err;
+  }
+
+  // Compute inc32(J0).
+  aes_block_t j0_inc;
+  memcpy(j0_inc.data, j0.data, kAesBlockNumBytes);
+  block_inc32(&j0_inc);
+
+  // Compute ciphertext C = GCTR(K, inc32(J0), plaintext).
+  err = aes_gcm_gctr(key_len, key_shares, &j0_inc, plaintext_len, plaintext,
+                     ciphertext);
+  if (err != kAesOk) {
+    return err;
+  }
+
+  // Compute the authentication tag T.
+  return aes_gcm_compute_tag(key_len, key_shares, &Htbl, plaintext_len,
+                             ciphertext, aad_len, aad, &j0, tag);
+}
diff --git a/sw/device/lib/crypto/aes_gcm/aes_gcm.h b/sw/device/lib/crypto/aes_gcm/aes_gcm.h
index e2470db..fafe542 100644
--- a/sw/device/lib/crypto/aes_gcm/aes_gcm.h
+++ b/sw/device/lib/crypto/aes_gcm/aes_gcm.h
@@ -16,6 +16,35 @@
 #endif  // __cplusplus
 
 /**
+ * AES-GCM authenticated encryption as defined in NIST SP800-38D, algorithm 4.
+ *
+ * The key is represented as two shares which are XORed together to get the
+ * real key value. The IV must be either 96 or 128 bits.
+ *
+ * The byte-lengths of the plaintext and AAD must each be < 2^32. This is a
+ * tighter constraint than the length limits in section 5.2.1.1.
+ *
+ * This implementation does not support short tags.
+ *
+ * @param key_len length of key
+ * @param key_shares key, expressed in two shares
+ * @param iv_len length of IV in bytes
+ * @param iv IV value (may be NULL if iv_len is 0)
+ * @param plaintext_len length of plaintext in bytes
+ * @param plaintext plaintext value (may be NULL if plaintext_len is 0)
+ * @param aad_len length of AAD in bytes
+ * @param aad AAD value (may be NULL if aad_len is 0)
+ * @param ciphertext Output buffer for ciphertext (same length as plaintext)
+ * @param[out] tag Output buffer for tag (128 bits)
+ */
+OT_WARN_UNUSED_RESULT
+aes_error_t aes_gcm_encrypt(const aes_key_len_t key_len,
+                            const uint32_t *key_shares[2], const size_t iv_len,
+                            const uint8_t *iv, const size_t plaintext_len,
+                            const uint8_t *plaintext, const size_t aad_len,
+                            const uint8_t *aad, uint8_t *ciphertext,
+                            uint8_t *tag);
+/**
  * GHASH operation as defined in NIST SP800-38D, algorithm 2.
  *
  * The input size must be a multiple of the block size.
@@ -28,7 +57,7 @@
  * @param hash_subkey The hash subkey (a 128-bit cipher block).
  * @param input_len Number of bytes in the input.
  * @param input Pointer to input buffer.
- * @param output Block in which to store the output.
+ * @param[out] output Block in which to store the output.
  */
 OT_WARN_UNUSED_RESULT
 aes_error_t aes_gcm_ghash(const aes_block_t *hash_subkey,
diff --git a/sw/device/lib/crypto/drivers/aes.h b/sw/device/lib/crypto/drivers/aes.h
index f1ca26b..152afc0 100644
--- a/sw/device/lib/crypto/drivers/aes.h
+++ b/sw/device/lib/crypto/drivers/aes.h
@@ -95,7 +95,7 @@
    * the two pointers should be pointed to an array of zero-valued words
    * of sufficient length.
    */
-  uint32_t *key[2];
+  const uint32_t *key[2];
 
   /**
    * The IV to use with the CBC or CTR modes.
diff --git a/sw/device/lib/crypto/drivers/aes_test.c b/sw/device/lib/crypto/drivers/aes_test.c
index b61d08d..66df04a 100644
--- a/sw/device/lib/crypto/drivers/aes_test.c
+++ b/sw/device/lib/crypto/drivers/aes_test.c
@@ -54,12 +54,16 @@
   // hardware; in general, the key should be generated by either generating
   // two shares and setting key = a ^ b, or generating a mask and setting
   // a = key ^ mask, b = mask.
-  uint32_t share0[8] = {0};
-  uint32_t share1[8] = {0};
-  for (int i = 0; i < ARRAYSIZE(share0); ++i) {
-    share0[i] = ~kSecretKey[i];
-    share1[i] = UINT32_MAX;
-  }
+  const uint32_t share0[8] = {~kSecretKey[0],
+                              ~kSecretKey[1],
+                              ~kSecretKey[2],
+                              ~kSecretKey[3],
+                              0,
+                              0,
+                              0,
+                              0};
+  const uint32_t share1[8] = {UINT32_MAX, UINT32_MAX, UINT32_MAX, UINT32_MAX,
+                              0,          0,          0,          0};
 
   LOG_INFO("Configuring the AES hardware.");
   aes_params_t params = {