[sw/silicon_creator] Add unit tests for sigverify_rsa_key_get

Signed-off-by: Alphan Ulusoy <alphan@google.com>
diff --git a/sw/device/silicon_creator/mask_rom/sigverify_keys_unittest.cc b/sw/device/silicon_creator/mask_rom/sigverify_keys_unittest.cc
index b46810a..62c67f7 100644
--- a/sw/device/silicon_creator/mask_rom/sigverify_keys_unittest.cc
+++ b/sw/device/silicon_creator/mask_rom/sigverify_keys_unittest.cc
@@ -4,15 +4,20 @@
 
 #include "sw/device/silicon_creator/mask_rom/sigverify_keys.h"
 
+#include <cstring>
 #include <unordered_set>
 
-#include "gmock/gmock.h"
 #include "gtest/gtest.h"
+#include "sw/device/lib/base/hardened.h"
+#include "sw/device/lib/testing/mask_rom_test.h"
 #include "sw/device/silicon_creator/lib/drivers/mock_hmac.h"
 #include "sw/device/silicon_creator/lib/drivers/mock_otp.h"
+#include "sw/device/silicon_creator/lib/error.h"
 #include "sw/device/silicon_creator/lib/sigverify.h"
 #include "sw/device/silicon_creator/lib/sigverify_mod_exp.h"
 
+#include "otp_ctrl_regs.h"
+
 namespace sigverify_keys_unittest {
 namespace {
 using ::testing::DoAll;
@@ -20,6 +25,70 @@
 using ::testing::Return;
 using ::testing::SetArgPointee;
 
+class SigverifyRsaKeyGet : public mask_rom_test::MaskRomTest,
+                           public testing::WithParamInterface<size_t> {
+ protected:
+  /**
+   * Sets an expectation for an OTP read for the key at the given index.
+   *
+   * The value that corresponds to `key_index` will be `is_valid` and the values
+   * for all other keys in the corresponding OTP word will be the complement of
+   * `is_valid`.
+   *
+   * @param key_index Index of a key.
+   * @param is_valid Validitiy of the key.
+   */
+  void ExpectOtpRead(size_t key_index, hardened_byte_bool_t is_valid) {
+    const uint32_t read_addr =
+        OTP_CTRL_PARAM_CREATOR_SW_CFG_KEY_IS_VALID_OFFSET +
+        (key_index / kSigverifyNumEntriesPerOtpWord) * sizeof(uint32_t);
+    const size_t entry_index = key_index % kSigverifyNumEntriesPerOtpWord;
+
+    std::array<uint8_t, kSigverifyNumEntriesPerOtpWord> entries;
+    hardened_byte_bool_t others_val = is_valid == kHardenedByteBoolTrue
+                                          ? kHardenedByteBoolFalse
+                                          : kHardenedByteBoolTrue;
+    entries.fill(others_val);
+    entries[entry_index] = is_valid;
+
+    uint32_t read_val;
+    std::memcpy(&read_val, entries.data(), sizeof(read_val));
+    EXPECT_CALL(otp_, read32(read_addr)).WillOnce(Return(read_val));
+  }
+  mask_rom_test::MockOtp otp_;
+};
+
+TEST_P(SigverifyRsaKeyGet, ValidInOtp) {
+  const size_t key_index = GetParam();
+  ExpectOtpRead(key_index, kHardenedByteBoolTrue);
+
+  const sigverify_rsa_key_t *key;
+  EXPECT_EQ(
+      sigverify_rsa_key_get(
+          sigverify_rsa_key_id_get(&kSigVerifyRsaKeys[key_index].n), &key),
+      kErrorOk);
+  EXPECT_EQ(key, &kSigVerifyRsaKeys[key_index]);
+}
+
+TEST_P(SigverifyRsaKeyGet, InvalidInOtp) {
+  const size_t key_index = GetParam();
+  ExpectOtpRead(key_index, kHardenedByteBoolFalse);
+
+  const sigverify_rsa_key_t *key;
+  EXPECT_EQ(
+      sigverify_rsa_key_get(
+          sigverify_rsa_key_id_get(&kSigVerifyRsaKeys[key_index].n), &key),
+      kErrorSigverifyBadKey);
+}
+
+INSTANTIATE_TEST_SUITE_P(AllMaskRomKeys, SigverifyRsaKeyGet,
+                         testing::Range<size_t>(0, kSigVerifyNumRsaKeys));
+
+TEST(SigverifyRsaKeyGet, InvalidId) {
+  const sigverify_rsa_key_t *key;
+  EXPECT_EQ(sigverify_rsa_key_get(0, &key), kErrorSigverifyBadKey);
+}
+
 TEST(Keys, UniqueIds) {
   std::unordered_set<uint32_t> ids;
   for (auto const &key : kSigVerifyRsaKeys) {