[crypto,test] Add test infrastructure for SPHINCS+.

Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/BUILD b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/BUILD
index 94ccae5..d0e4458 100644
--- a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/BUILD
+++ b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/BUILD
@@ -4,11 +4,45 @@
 
 package(default_visibility = ["//visibility:public"])
 
+load("//rules:autogen.bzl", "autogen_cryptotest_header")
 load(
     "//rules:opentitan_test.bzl",
     "opentitan_functest",
     "verilator_params",
 )
+load("@ot_python_deps//:requirements.bzl", "requirement")
+
+py_binary(
+    name = "sphincsplus_set_testvectors",
+    srcs = ["sphincsplus_set_testvectors.py"],
+    deps = [
+        requirement("hjson"),
+        requirement("mako"),
+    ],
+)
+
+autogen_cryptotest_header(
+    name = "sphincsplus_shake_128s_simple_testvectors_hardcoded_header",
+    hjson = "//sw/device/tests/crypto/testvectors:sphincsplus_shake_128s_simple_testvectors_hardcoded",
+    template = ":sphincsplus_shake_128s_simple_testvectors.h.tpl",
+    tool = ":sphincsplus_set_testvectors",
+)
+
+opentitan_functest(
+    name = "verify_test_hardcoded",
+    srcs = ["verify_test.c"],
+    verilator = verilator_params(
+        timeout = "eternal",
+    ),
+    deps = [
+        ":sphincsplus_shake_128s_simple_testvectors_hardcoded_header",
+        "//sw/device/lib/base:memory",
+        "//sw/device/lib/runtime:ibex",
+        "//sw/device/lib/testing/test_framework:ottf_main",
+        "//sw/device/silicon_creator/lib:test_main",
+        "//sw/device/silicon_creator/lib/sigverify/sphincsplus:verify",
+    ],
+)
 
 opentitan_functest(
     name = "fors_test",
diff --git a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_set_testvectors.py b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_set_testvectors.py
new file mode 100755
index 0000000..f40105b
--- /dev/null
+++ b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_set_testvectors.py
@@ -0,0 +1,71 @@
+#!/usr/bin/env python3
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+import argparse
+import sys
+
+import hjson
+from mako.template import Template
+
+'''
+Read in an HJSON test vector file, convert the test vector to C constants, and
+generate a header file with these test vectors.
+'''
+
+
+def hex_to_hexbytes(x):
+    '''Convert a hex string to a list of bytes as hex strings.'''
+    if x.startswith('0x'):
+        x = x[2:]
+
+    # Double-check that length is even
+    if len(x) % 2 != 0:
+        raise ValueError(f'Cannot convert odd-length hex string (length {len(x)}) to bytes: {x}')
+
+    out = []
+    for i in range(0, len(x), 2):
+        out.append('0x' + x[i:i + 2])
+    return out
+
+
+def main() -> int:
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--hjsonfile', '-j',
+                        metavar='FILE',
+                        required=True,
+                        type=argparse.FileType('r'),
+                        help='Read test vectors from this HJSON file.')
+    parser.add_argument('--template', '-t',
+                        metavar='FILE',
+                        required=True,
+                        type=argparse.FileType('r'),
+                        help='Read header template from this file.')
+    parser.add_argument('--headerfile', '-o',
+                        metavar='FILE',
+                        required=True,
+                        type=argparse.FileType('w'),
+                        help='Write output to this file.')
+
+    args = parser.parse_args()
+
+    # Read test vectors and stringify them
+    with args.hjsonfile as hjsonfile:
+        testvecs = hjson.load(hjsonfile)
+
+    # Convert the values to hexadecimal bytes.
+    for t in testvecs:
+        t['sig_hexbytes'] = hex_to_hexbytes(t['sig_hex'])
+        t['msg_hexbytes'] = hex_to_hexbytes(t['msg_hex'])
+        t['pk_hexbytes'] = hex_to_hexbytes(t['pk_hex'])
+
+    with args.template as template:
+        with args.headerfile as header:
+            header.write(Template(template.read()).render(tests=testvecs))
+
+    return 0
+
+
+if __name__ == '__main__':
+    sys.exit(main())
diff --git a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_shake_128s_simple_testvectors.h.tpl b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_shake_128s_simple_testvectors.h.tpl
new file mode 100644
index 0000000..45ebf1a
--- /dev/null
+++ b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_shake_128s_simple_testvectors.h.tpl
@@ -0,0 +1,69 @@
+// Copyright lowRISC contributors.
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+//
+// AUTOGENERATED. Do not edit this file by hand.
+// See the crypto/tests README for details.
+
+#ifndef OPENTITAN_SW_DEVICE_SILICON_CREATOR_LIB_SIGVERIFY_SPHINCSPLUS_TEST_SPHINCSPLUS_SHAKE_128S_SIMPLE_TESTVECTORS_H_
+#define OPENTITAN_SW_DEVICE_SILICON_CREATOR_LIB_SIGVERIFY_SPHINCSPLUS_TEST_SPHINCSPLUS_SHAKE_128S_SIMPLE_TESTVECTORS_H_
+
+#include "sw/device/silicon_creator/lib/sigverify/sphincsplus/params.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// A test vector for SPHINCS+ signature verification.
+typedef struct spx_verify_test_vector {
+  uint8_t sig[kSpxVerifySigBytes];  // Signature to verify.
+  uint8_t pk[kSpxVerifyPkBytes];    // Public key.
+  size_t msg_len;                   // Length of message.
+  uint8_t *msg;                     // Message.
+} spx_verify_test_vector_t;
+
+static const size_t kSpxVerifyNumTests = ${len(tests)};
+
+// Static message arrays.
+% for idx, t in enumerate(tests):
+  % if t["msg_len"] == 0:
+// msg${idx} is empty.
+  % else:
+static uint8_t msg${idx}[${t["msg_len"]}] = {
+    % for i in range(0, len(t["msg_hexbytes"]), 10):
+   ${', '.join(t["msg_hexbytes"][i:i + 10])},
+    % endfor
+};
+  %endif
+% endfor
+
+static const spx_verify_test_vector_t spx_verify_tests[${len(tests)}] = {
+% for idx, t in enumerate(tests):
+    {
+        .sig =
+            {
+  % for i in range(0, len(t["sig_hexbytes"]), 10):
+                ${', '.join(t["sig_hexbytes"][i:i + 10])},
+  % endfor
+            },
+        .pk =
+            {
+  % for i in range(0, len(t["pk_hexbytes"]), 10):
+                ${', '.join(t["pk_hexbytes"][i:i + 10])},
+  % endfor
+            },
+        .msg_len = ${t["msg_len"]},
+  % if t["msg_len"] == 0:
+        .msg = NULL,
+  % else:
+        .msg = msg${idx},
+  % endif
+    },
+% endfor
+};
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif // OPENTITAN_SW_DEVICE_SILICON_CREATOR_LIB_SIGVERIFY_SPHINCSPLUS_TEST_SPHINCSPLUS_SHAKE_128S_SIMPLE_TESTVECTORS_H_
diff --git a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/verify_test.c b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/verify_test.c
new file mode 100644
index 0000000..d05a6c7
--- /dev/null
+++ b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/verify_test.c
@@ -0,0 +1,143 @@
+// Copyright lowRISC contributors.
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+#include "sw/device/silicon_creator/lib/sigverify/sphincsplus/verify.h"
+
+#include <stdint.h>
+
+#include "sw/device/lib/base/memory.h"
+#include "sw/device/lib/base/status.h"
+#include "sw/device/lib/runtime/ibex.h"
+#include "sw/device/lib/runtime/log.h"
+#include "sw/device/lib/testing/test_framework/check.h"
+#include "sw/device/lib/testing/test_framework/ottf_main.h"
+#include "sw/device/silicon_creator/lib/test_main.h"
+
+// The autogen rule that creates this header creates it in a directory named
+// after the rule, then manipulates the include path in the
+// cc_compilation_context to include that directory, so the compiler will find
+// the version of this file matching the Bazel rule under test.
+#include "sphincsplus_shake_128s_simple_testvectors.h"
+
+// Index of the test vector currently under test
+static uint32_t test_index = 0;
+
+OTTF_DEFINE_TEST_CONFIG();
+
+enum {
+  /**
+   * Number of negative tests to run (manipulating the message and checking
+   * that the signature fails).
+   */
+  kNumNegativeTests = 1,
+};
+
+/**
+ * Start a cycle-count timing profile.
+ */
+static uint64_t profile_start() { return ibex_mcycle_read(); }
+
+/**
+ * End a cycle-count timing profile.
+ *
+ * Call `profile_start()` first.
+ */
+static uint32_t profile_end(uint64_t t_start) {
+  uint64_t t_end = ibex_mcycle_read();
+  uint64_t cycles = t_end - t_start;
+  CHECK(cycles <= UINT32_MAX);
+  return (uint32_t)cycles;
+}
+
+/**
+ * Run the SPHINCS+ verification procedure on the current test.
+ *
+ * @param test Test vector to run.
+ * @param[out] root Output buffer for root node computed from signature.
+ * @param[out] pub_root Output buffer for root node computed from public key.
+ */
+static rom_error_t run_verify(const spx_verify_test_vector_t *test,
+                              uint32_t *root, uint32_t *pub_root) {
+  // Calculate the public-key root to compare against.
+  spx_public_key_root(test->pk, pub_root);
+
+  // Run verification and print the cycle count.
+  uint64_t t_start = profile_start();
+  rom_error_t err =
+      spx_verify(test->sig, test->msg, test->msg_len, test->pk, root);
+  uint32_t cycles = profile_end(t_start);
+  LOG_INFO("Verification took %u cycles.", cycles);
+
+  return err;
+}
+
+/**
+ * Run the current test.
+ *
+ * The verification is expected to succeed.
+ */
+static rom_error_t spx_verify_test(void) {
+  spx_verify_test_vector_t test = spx_verify_tests[test_index];
+
+  uint32_t root[kSpxVerifyRootNumWords];
+  uint32_t pub_root[kSpxVerifyRootNumWords];
+  RETURN_IF_ERROR(run_verify(&test, root, pub_root));
+
+  // Ensure that both roots are the same (verification passed).
+  CHECK_ARRAYS_EQ(root, pub_root, kSpxVerifyRootNumWords);
+  return kErrorOk;
+}
+
+/**
+ * Run the current test with a modified message or signature.
+ *
+ * The verification is expected to fail.
+ */
+static rom_error_t spx_verify_negative_test(void) {
+  spx_verify_test_vector_t test = spx_verify_tests[test_index];
+
+  if (test.msg_len > 0) {
+    // Bitwise-invert the first byte of the message.
+    test.msg[0] = ~test.msg[0];
+  } else {
+    // If the message is empty, change the signature.
+    test.sig[0] = ~test.sig[0];
+  }
+
+  uint32_t root[kSpxVerifyRootNumWords];
+  uint32_t pub_root[kSpxVerifyRootNumWords];
+  RETURN_IF_ERROR(run_verify(&test, root, pub_root));
+
+  // Ensure that the roots are the different (verification failed).
+  CHECK_ARRAYS_NE(root, pub_root, kSpxVerifyRootNumWords);
+  return kErrorOk;
+}
+
+bool test_main() {
+  rom_error_t result = kErrorOk;
+
+  CHECK(kNumNegativeTests <= kSpxVerifyNumTests,
+        "kNumNegativeTests (%d) cannot be larger than the total number of "
+        "tests (%d).",
+        kNumNegativeTests, kSpxVerifyNumTests);
+
+  LOG_INFO("Running %d tests with valid signatures.", kSpxVerifyNumTests);
+
+  for (size_t i = 0; i < kSpxVerifyNumTests; i++) {
+    EXECUTE_TEST(result, spx_verify_test);
+    test_index++;
+    LOG_INFO("Finished test %d of %d.", test_index, kSpxVerifyNumTests);
+  }
+
+  LOG_INFO("Running %d tests with invalid signatures.", kNumNegativeTests);
+
+  test_index = 0;
+  for (size_t i = 0; i < kNumNegativeTests; i++) {
+    EXECUTE_TEST(result, spx_verify_negative_test);
+    test_index++;
+    LOG_INFO("Finished negative test %d of %d.", test_index, kNumNegativeTests);
+  }
+
+  return result == kErrorOk;
+}