[crypto] Add AES-GCM authenticated decryption.

Adds an authenticated decryption operation to complete the AES-GCM
scheme. Decryption is very similar to encryption and uses most of the
same underlying operations.

Also add a note to `hardened_memeq` to explicitly specify that it is
constant-time.

Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/sw/device/lib/base/hardened_memory.h b/sw/device/lib/base/hardened_memory.h
index 9891615..69ab538 100644
--- a/sw/device/lib/base/hardened_memory.h
+++ b/sw/device/lib/base/hardened_memory.h
@@ -69,6 +69,7 @@
  * - It only computes equality, not lexicographic ordering, which would be even
  *   slower.
  * - It returns a `hardened_bool_t`.
+ * - It is constant-time.
  *
  * Input pointers *MUST* be 32-bit aligned, although they do not need to
  * actually point to memory declared as `uint32_t` per the C aliasing rules.
diff --git a/sw/device/lib/crypto/impl/aes_gcm/BUILD b/sw/device/lib/crypto/impl/aes_gcm/BUILD
index 17c73b4..987f3d4 100644
--- a/sw/device/lib/crypto/impl/aes_gcm/BUILD
+++ b/sw/device/lib/crypto/impl/aes_gcm/BUILD
@@ -10,6 +10,7 @@
     hdrs = ["aes_gcm.h"],
     deps = [
         "//sw/device/lib/base:hardened",
+        "//sw/device/lib/base:hardened_memory",
         "//sw/device/lib/base:memory",
         "//sw/device/lib/crypto/drivers:aes",
     ],
diff --git a/sw/device/lib/crypto/impl/aes_gcm/aes_gcm.c b/sw/device/lib/crypto/impl/aes_gcm/aes_gcm.c
index fb2532f..c17a603 100644
--- a/sw/device/lib/crypto/impl/aes_gcm/aes_gcm.c
+++ b/sw/device/lib/crypto/impl/aes_gcm/aes_gcm.c
@@ -8,6 +8,7 @@
 #include <stdint.h>
 
 #include "sw/device/lib/base/hardened.h"
+#include "sw/device/lib/base/hardened_memory.h"
 #include "sw/device/lib/base/macros.h"
 #include "sw/device/lib/base/memory.h"
 #include "sw/device/lib/crypto/drivers/aes.h"
@@ -695,3 +696,57 @@
   return aes_gcm_compute_tag(key, &Htbl, plaintext_len, ciphertext, aad_len,
                              aad, &j0, tag);
 }
+
+aes_error_t aes_gcm_decrypt(const aes_key_t key, const size_t iv_len,
+                            const uint8_t *iv, const size_t ciphertext_len,
+                            const uint8_t *ciphertext, const size_t aad_len,
+                            const uint8_t *aad, const uint8_t *tag,
+                            uint8_t *plaintext, hardened_bool_t *success) {
+  // Check that the input parameter sizes are valid.
+  if (check_buffer_lengths(iv_len, ciphertext_len, aad_len) !=
+      kHardenedBoolTrue) {
+    return kAesInternalError;
+  }
+
+  // Compute the hash subkey H as a product table.
+  aes_gcm_product_table_t Htbl;
+  aes_error_t err = aes_gcm_hash_subkey(key, &Htbl);
+  if (err != kAesOk) {
+    return err;
+  }
+
+  // Compute the counter block (called J0 in the NIST specification).
+  aes_block_t j0;
+  err = aes_gcm_counter(iv_len, iv, &Htbl, &j0);
+  if (err != kAesOk) {
+    return err;
+  }
+
+  // Compute the expected authentication tag T.
+  uint8_t expected_tag[kAesGcmTagNumBytes];
+  err = aes_gcm_compute_tag(key, &Htbl, ciphertext_len, ciphertext, aad_len,
+                            aad, &j0, expected_tag);
+  if (err != kAesOk) {
+    return err;
+  }
+
+  // Copy expected and actual tag to word-size buffers to avoid violating
+  // strict aliasing rules.
+  uint32_t expected_tag_words[kAesGcmTagNumWords];
+  uint32_t tag_words[kAesGcmTagNumWords];
+  memcpy(expected_tag_words, expected_tag, kAesGcmTagNumBytes);
+  memcpy(tag_words, tag, kAesGcmTagNumBytes);
+
+  // Compare the expected tag to the actual tag (in constant time).
+  *success = hardened_memeq(expected_tag_words, tag_words, kAesGcmTagNumWords);
+  if (*success != kHardenedBoolTrue) {
+    // If authentication fails, do not proceed to decryption; simply exit
+    // with success = False. We still use `kAesOk` because there was no
+    // internal error during the authentication check.
+    return kAesOk;
+  }
+
+  // Compute plaintext P = GCTR(K, inc32(J0), ciphertext).
+  block_inc32(&j0);
+  return aes_gcm_gctr(key, &j0, ciphertext_len, ciphertext, plaintext);
+}
diff --git a/sw/device/lib/crypto/impl/aes_gcm/aes_gcm.h b/sw/device/lib/crypto/impl/aes_gcm/aes_gcm.h
index fb07485..4fe2f34 100644
--- a/sw/device/lib/crypto/impl/aes_gcm/aes_gcm.h
+++ b/sw/device/lib/crypto/impl/aes_gcm/aes_gcm.h
@@ -8,6 +8,7 @@
 #include <stddef.h>
 #include <stdint.h>
 
+#include "sw/device/lib/base/hardened.h"
 #include "sw/device/lib/base/macros.h"
 #include "sw/device/lib/crypto/drivers/aes.h"
 
@@ -15,6 +16,15 @@
 extern "C" {
 #endif  // __cplusplus
 
+enum {
+  /* Full tag size in bits. */
+  kAesGcmTagNumBits = 128,
+  /* Full tag size in bytes. */
+  kAesGcmTagNumBytes = kAesGcmTagNumBits / 8,
+  /* Full tag size in words. */
+  kAesGcmTagNumWords = kAesGcmTagNumBytes / sizeof(uint32_t),
+};
+
 /**
  * AES-GCM authenticated encryption as defined in NIST SP800-38D, algorithm 4.
  *
@@ -35,6 +45,7 @@
  * @param aad AAD value (may be NULL if aad_len is 0)
  * @param ciphertext Output buffer for ciphertext (same length as plaintext)
  * @param[out] tag Output buffer for tag (128 bits)
+ * @return Error status; OK if no errors
  */
 OT_WARN_UNUSED_RESULT
 aes_error_t aes_gcm_encrypt(const aes_key_t key, const size_t iv_len,
@@ -56,12 +67,52 @@
  * @param input_len Number of bytes in the input.
  * @param input Pointer to input buffer.
  * @param[out] output Block in which to store the output.
+ * @return Error status; OK if no errors
  */
 OT_WARN_UNUSED_RESULT
 aes_error_t aes_gcm_ghash(const aes_block_t *hash_subkey,
                           const size_t input_len, const uint8_t *input,
                           aes_block_t *output);
 
+/**
+ * AES-GCM authenticated decryption as defined in NIST SP800-38D, algorithm 5.
+ *
+ * The key is represented as two shares which are XORed together to get the
+ * real key value. The IV must be either 96 or 128 bits.
+ *
+ * The byte-lengths of the plaintext and AAD must each be < 2^32. This is a
+ * tighter constraint than the length limits in section 5.2.1.1.
+ *
+ * If authentication fails, this function will return `kHardenedBoolFalse` for
+ * the `success` output parameter, and the plaintext should be ignored. Note
+ * the distinction between the `success` output parameter and the return value
+ * (type `aes_error_t`): the return value indicates whether there was an
+ * internal error while processing the function, and `success` indicates
+ * whether the authentication check passed. If the return value is anything
+ * other than OK, all output from this function should be discarded, including
+ * `success`.
+ *
+ * This implementation does not support short tags.
+ *
+ * @param key AES key
+ * @param iv_len length of IV in bytes
+ * @param iv IV value (may be NULL if iv_len is 0)
+ * @param ciphertext_len length of ciphertext in bytes
+ * @param ciphertext plaintext value (may be NULL if ciphertext_len is 0)
+ * @param aad_len length of AAD in bytes
+ * @param aad AAD value (may be NULL if aad_len is 0)
+ * @param tag Authentication tag (128 bits)
+ * @param plaintext[out] Output buffer for plaintext (same length as ciphertext)
+ * @param success[out] True if authentication was successful, otherwise false
+ * @return Error status; OK if no errors
+ */
+OT_WARN_UNUSED_RESULT
+aes_error_t aes_gcm_decrypt(const aes_key_t key, const size_t iv_len,
+                            const uint8_t *iv, const size_t ciphertext_len,
+                            const uint8_t *ciphertext, const size_t aad_len,
+                            const uint8_t *aad, const uint8_t *tag,
+                            uint8_t *plaintext, hardened_bool_t *success);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/sw/device/tests/crypto/aes_gcm_functest.c b/sw/device/tests/crypto/aes_gcm_functest.c
index 4be706c..fffefa3 100644
--- a/sw/device/tests/crypto/aes_gcm_functest.c
+++ b/sw/device/tests/crypto/aes_gcm_functest.c
@@ -2,6 +2,7 @@
 // Licensed under the Apache License, Version 2.0, see LICENSE for details.
 // SPDX-License-Identifier: Apache-2.0
 
+#include "sw/device/lib/base/hardened.h"
 #include "sw/device/lib/base/macros.h"
 #include "sw/device/lib/crypto/drivers/aes.h"
 #include "sw/device/lib/crypto/impl/aes_gcm/aes_gcm.h"
@@ -25,10 +26,10 @@
   // of 4, then the most significant bytes of the last word are ignored.
   size_t aad_len;
   uint8_t *aad;
-  // Expected authentication tag.
-  uint8_t expected_tag[16];
-  // Expected ciphertext (same length as plaintext).
-  uint8_t *expected_ciphertext;
+  // Authentication tag.
+  uint8_t tag[16];
+  // Ciphertext (same length as plaintext).
+  uint8_t *ciphertext;
 } aes_gcm_test_t;
 
 /**
@@ -91,11 +92,11 @@
         .plaintext = NULL,
         .aad_len = 0,
         .aad = NULL,
-        .expected_tag =
+        .tag =
             {// Tag = b7aa223a6c75a0976633ce79d9fddf06
              0xb7, 0xaa, 0x22, 0x3a, 0x6c, 0x75, 0xa0, 0x97, 0x66, 0x33, 0xce,
              0x79, 0xd9, 0xfd, 0xdf, 0x06},
-        .expected_ciphertext = NULL,
+        .ciphertext = NULL,
     },
 
     // Empty input, empty aad, 128-bit IV, 128-bit key
@@ -110,11 +111,11 @@
         .plaintext = NULL,
         .aad_len = 0,
         .aad = NULL,
-        .expected_tag =
+        .tag =
             {// Tag = 4c59f0d420d9eb8669c40ad23b5419ba
              0x4c, 0x59, 0xf0, 0xd4, 0x20, 0xd9, 0xeb, 0x86, 0x69, 0xc4, 0x0a,
              0xd2, 0x3b, 0x54, 0x19, 0xba},
-        .expected_ciphertext = NULL,
+        .ciphertext = NULL,
     },
 
     // 128-bit IV, 256-bit key, real message and aad
@@ -129,11 +130,11 @@
         .plaintext = kPlaintext,
         .aad_len = kAadLen,
         .aad = kAad,
-        .expected_tag =
+        .tag =
             {// Tag = 324895b3d2f656e4fa2f8ce056137061
              0x32, 0x48, 0x95, 0xb3, 0xd2, 0xf6, 0x56, 0xe4, 0xfa, 0x2f, 0x8c,
              0xe0, 0x56, 0x13, 0x70, 0x61},
-        .expected_ciphertext = kCiphertext256,
+        .ciphertext = kCiphertext256,
     },
 };
 
@@ -178,12 +179,44 @@
     LOG_INFO("aes_gcm_encrypt() took %u cycles", cycles);
     CHECK(err == kAesOk, "AES-GCM encryption returned an error: %08x", err);
 
-    CHECK_ARRAYS_EQ(actual_tag, test.expected_tag, sizeof(test.expected_tag));
+    CHECK_ARRAYS_EQ(actual_tag, test.tag, sizeof(test.tag));
     if (test.plaintext_len > 0) {
-      CHECK_ARRAYS_EQ(actual_ciphertext, test.expected_ciphertext,
-                      test.plaintext_len);
+      CHECK_ARRAYS_EQ(actual_ciphertext, test.ciphertext, test.plaintext_len);
     }
 
+    // Call AES-GCM decrypt with the correct tag.
+    uint8_t actual_plaintext[test.plaintext_len];
+    hardened_bool_t success;
+    start = ibex_mcycle_read();
+    err = aes_gcm_decrypt(test_key, test.iv_len, test.iv, test.plaintext_len,
+                          test.ciphertext, test.aad_len, test.aad, test.tag,
+                          actual_plaintext, &success);
+    end = ibex_mcycle_read();
+    cycles = end - start;
+    LOG_INFO("aes_gcm_decrypt() took %u cycles", cycles);
+    CHECK(err == kAesOk, "AES-GCM decryption returned an error: %08x", err);
+    CHECK(success == kHardenedBoolTrue,
+          "AES-GCM decryption failed on valid input");
+
+    if (test.plaintext_len > 0) {
+      CHECK_ARRAYS_EQ(actual_plaintext, test.plaintext, test.plaintext_len);
+    }
+
+    // Call AES-GCM decrypt with an incorrect tag.
+    uint8_t bad_tag[16];
+    memcpy(bad_tag, test.tag, 16);
+    bad_tag[15]++;
+    start = ibex_mcycle_read();
+    err = aes_gcm_decrypt(test_key, test.iv_len, test.iv, test.plaintext_len,
+                          test.ciphertext, test.aad_len, test.aad, bad_tag,
+                          actual_plaintext, &success);
+    end = ibex_mcycle_read();
+    cycles = end - start;
+    LOG_INFO("aes_gcm_decrypt() took %u cycles", cycles);
+    CHECK(err == kAesOk, "AES-GCM decryption returned an error: %08x", err);
+    CHECK(success == kHardenedBoolFalse,
+          "AES-GCM decryption passed an invalid tag");
+
     LOG_INFO("Finished AES-GCM test %d.", i + 1);
   }