[crypto,test] Add test infrastructure for SPHINCS+.
Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/BUILD b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/BUILD
index 94ccae5..d0e4458 100644
--- a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/BUILD
+++ b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/BUILD
@@ -4,11 +4,45 @@
package(default_visibility = ["//visibility:public"])
+load("//rules:autogen.bzl", "autogen_cryptotest_header")
load(
"//rules:opentitan_test.bzl",
"opentitan_functest",
"verilator_params",
)
+load("@ot_python_deps//:requirements.bzl", "requirement")
+
+py_binary(
+ name = "sphincsplus_set_testvectors",
+ srcs = ["sphincsplus_set_testvectors.py"],
+ deps = [
+ requirement("hjson"),
+ requirement("mako"),
+ ],
+)
+
+autogen_cryptotest_header(
+ name = "sphincsplus_shake_128s_simple_testvectors_hardcoded_header",
+ hjson = "//sw/device/tests/crypto/testvectors:sphincsplus_shake_128s_simple_testvectors_hardcoded",
+ template = ":sphincsplus_shake_128s_simple_testvectors.h.tpl",
+ tool = ":sphincsplus_set_testvectors",
+)
+
+opentitan_functest(
+ name = "verify_test_hardcoded",
+ srcs = ["verify_test.c"],
+ verilator = verilator_params(
+ timeout = "eternal",
+ ),
+ deps = [
+ ":sphincsplus_shake_128s_simple_testvectors_hardcoded_header",
+ "//sw/device/lib/base:memory",
+ "//sw/device/lib/runtime:ibex",
+ "//sw/device/lib/testing/test_framework:ottf_main",
+ "//sw/device/silicon_creator/lib:test_main",
+ "//sw/device/silicon_creator/lib/sigverify/sphincsplus:verify",
+ ],
+)
opentitan_functest(
name = "fors_test",
diff --git a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_set_testvectors.py b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_set_testvectors.py
new file mode 100755
index 0000000..f40105b
--- /dev/null
+++ b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_set_testvectors.py
@@ -0,0 +1,71 @@
+#!/usr/bin/env python3
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+import argparse
+import sys
+
+import hjson
+from mako.template import Template
+
+'''
+Read in an HJSON test vector file, convert the test vector to C constants, and
+generate a header file with these test vectors.
+'''
+
+
+def hex_to_hexbytes(x):
+ '''Convert a hex string to a list of bytes as hex strings.'''
+ if x.startswith('0x'):
+ x = x[2:]
+
+ # Double-check that length is even
+ if len(x) % 2 != 0:
+ raise ValueError(f'Cannot convert odd-length hex string (length {len(x)}) to bytes: {x}')
+
+ out = []
+ for i in range(0, len(x), 2):
+ out.append('0x' + x[i:i + 2])
+ return out
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--hjsonfile', '-j',
+ metavar='FILE',
+ required=True,
+ type=argparse.FileType('r'),
+ help='Read test vectors from this HJSON file.')
+ parser.add_argument('--template', '-t',
+ metavar='FILE',
+ required=True,
+ type=argparse.FileType('r'),
+ help='Read header template from this file.')
+ parser.add_argument('--headerfile', '-o',
+ metavar='FILE',
+ required=True,
+ type=argparse.FileType('w'),
+ help='Write output to this file.')
+
+ args = parser.parse_args()
+
+ # Read test vectors and stringify them
+ with args.hjsonfile as hjsonfile:
+ testvecs = hjson.load(hjsonfile)
+
+ # Convert the values to hexadecimal bytes.
+ for t in testvecs:
+ t['sig_hexbytes'] = hex_to_hexbytes(t['sig_hex'])
+ t['msg_hexbytes'] = hex_to_hexbytes(t['msg_hex'])
+ t['pk_hexbytes'] = hex_to_hexbytes(t['pk_hex'])
+
+ with args.template as template:
+ with args.headerfile as header:
+ header.write(Template(template.read()).render(tests=testvecs))
+
+ return 0
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_shake_128s_simple_testvectors.h.tpl b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_shake_128s_simple_testvectors.h.tpl
new file mode 100644
index 0000000..45ebf1a
--- /dev/null
+++ b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/sphincsplus_shake_128s_simple_testvectors.h.tpl
@@ -0,0 +1,69 @@
+// Copyright lowRISC contributors.
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+//
+// AUTOGENERATED. Do not edit this file by hand.
+// See the crypto/tests README for details.
+
+#ifndef OPENTITAN_SW_DEVICE_SILICON_CREATOR_LIB_SIGVERIFY_SPHINCSPLUS_TEST_SPHINCSPLUS_SHAKE_128S_SIMPLE_TESTVECTORS_H_
+#define OPENTITAN_SW_DEVICE_SILICON_CREATOR_LIB_SIGVERIFY_SPHINCSPLUS_TEST_SPHINCSPLUS_SHAKE_128S_SIMPLE_TESTVECTORS_H_
+
+#include "sw/device/silicon_creator/lib/sigverify/sphincsplus/params.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// A test vector for SPHINCS+ signature verification.
+typedef struct spx_verify_test_vector {
+ uint8_t sig[kSpxVerifySigBytes]; // Signature to verify.
+ uint8_t pk[kSpxVerifyPkBytes]; // Public key.
+ size_t msg_len; // Length of message.
+ uint8_t *msg; // Message.
+} spx_verify_test_vector_t;
+
+static const size_t kSpxVerifyNumTests = ${len(tests)};
+
+// Static message arrays.
+% for idx, t in enumerate(tests):
+ % if t["msg_len"] == 0:
+// msg${idx} is empty.
+ % else:
+static uint8_t msg${idx}[${t["msg_len"]}] = {
+ % for i in range(0, len(t["msg_hexbytes"]), 10):
+ ${', '.join(t["msg_hexbytes"][i:i + 10])},
+ % endfor
+};
+ %endif
+% endfor
+
+static const spx_verify_test_vector_t spx_verify_tests[${len(tests)}] = {
+% for idx, t in enumerate(tests):
+ {
+ .sig =
+ {
+ % for i in range(0, len(t["sig_hexbytes"]), 10):
+ ${', '.join(t["sig_hexbytes"][i:i + 10])},
+ % endfor
+ },
+ .pk =
+ {
+ % for i in range(0, len(t["pk_hexbytes"]), 10):
+ ${', '.join(t["pk_hexbytes"][i:i + 10])},
+ % endfor
+ },
+ .msg_len = ${t["msg_len"]},
+ % if t["msg_len"] == 0:
+ .msg = NULL,
+ % else:
+ .msg = msg${idx},
+ % endif
+ },
+% endfor
+};
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // OPENTITAN_SW_DEVICE_SILICON_CREATOR_LIB_SIGVERIFY_SPHINCSPLUS_TEST_SPHINCSPLUS_SHAKE_128S_SIMPLE_TESTVECTORS_H_
diff --git a/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/verify_test.c b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/verify_test.c
new file mode 100644
index 0000000..d05a6c7
--- /dev/null
+++ b/sw/device/silicon_creator/lib/sigverify/sphincsplus/test/verify_test.c
@@ -0,0 +1,143 @@
+// Copyright lowRISC contributors.
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+#include "sw/device/silicon_creator/lib/sigverify/sphincsplus/verify.h"
+
+#include <stdint.h>
+
+#include "sw/device/lib/base/memory.h"
+#include "sw/device/lib/base/status.h"
+#include "sw/device/lib/runtime/ibex.h"
+#include "sw/device/lib/runtime/log.h"
+#include "sw/device/lib/testing/test_framework/check.h"
+#include "sw/device/lib/testing/test_framework/ottf_main.h"
+#include "sw/device/silicon_creator/lib/test_main.h"
+
+// The autogen rule that creates this header creates it in a directory named
+// after the rule, then manipulates the include path in the
+// cc_compilation_context to include that directory, so the compiler will find
+// the version of this file matching the Bazel rule under test.
+#include "sphincsplus_shake_128s_simple_testvectors.h"
+
+// Index of the test vector currently under test
+static uint32_t test_index = 0;
+
+OTTF_DEFINE_TEST_CONFIG();
+
+enum {
+ /**
+ * Number of negative tests to run (manipulating the message and checking
+ * that the signature fails).
+ */
+ kNumNegativeTests = 1,
+};
+
+/**
+ * Start a cycle-count timing profile.
+ */
+static uint64_t profile_start() { return ibex_mcycle_read(); }
+
+/**
+ * End a cycle-count timing profile.
+ *
+ * Call `profile_start()` first.
+ */
+static uint32_t profile_end(uint64_t t_start) {
+ uint64_t t_end = ibex_mcycle_read();
+ uint64_t cycles = t_end - t_start;
+ CHECK(cycles <= UINT32_MAX);
+ return (uint32_t)cycles;
+}
+
+/**
+ * Run the SPHINCS+ verification procedure on the current test.
+ *
+ * @param test Test vector to run.
+ * @param[out] root Output buffer for root node computed from signature.
+ * @param[out] pub_root Output buffer for root node computed from public key.
+ */
+static rom_error_t run_verify(const spx_verify_test_vector_t *test,
+ uint32_t *root, uint32_t *pub_root) {
+ // Calculate the public-key root to compare against.
+ spx_public_key_root(test->pk, pub_root);
+
+ // Run verification and print the cycle count.
+ uint64_t t_start = profile_start();
+ rom_error_t err =
+ spx_verify(test->sig, test->msg, test->msg_len, test->pk, root);
+ uint32_t cycles = profile_end(t_start);
+ LOG_INFO("Verification took %u cycles.", cycles);
+
+ return err;
+}
+
+/**
+ * Run the current test.
+ *
+ * The verification is expected to succeed.
+ */
+static rom_error_t spx_verify_test(void) {
+ spx_verify_test_vector_t test = spx_verify_tests[test_index];
+
+ uint32_t root[kSpxVerifyRootNumWords];
+ uint32_t pub_root[kSpxVerifyRootNumWords];
+ RETURN_IF_ERROR(run_verify(&test, root, pub_root));
+
+ // Ensure that both roots are the same (verification passed).
+ CHECK_ARRAYS_EQ(root, pub_root, kSpxVerifyRootNumWords);
+ return kErrorOk;
+}
+
+/**
+ * Run the current test with a modified message or signature.
+ *
+ * The verification is expected to fail.
+ */
+static rom_error_t spx_verify_negative_test(void) {
+ spx_verify_test_vector_t test = spx_verify_tests[test_index];
+
+ if (test.msg_len > 0) {
+ // Bitwise-invert the first byte of the message.
+ test.msg[0] = ~test.msg[0];
+ } else {
+ // If the message is empty, change the signature.
+ test.sig[0] = ~test.sig[0];
+ }
+
+ uint32_t root[kSpxVerifyRootNumWords];
+ uint32_t pub_root[kSpxVerifyRootNumWords];
+ RETURN_IF_ERROR(run_verify(&test, root, pub_root));
+
+ // Ensure that the roots are the different (verification failed).
+ CHECK_ARRAYS_NE(root, pub_root, kSpxVerifyRootNumWords);
+ return kErrorOk;
+}
+
+bool test_main() {
+ rom_error_t result = kErrorOk;
+
+ CHECK(kNumNegativeTests <= kSpxVerifyNumTests,
+ "kNumNegativeTests (%d) cannot be larger than the total number of "
+ "tests (%d).",
+ kNumNegativeTests, kSpxVerifyNumTests);
+
+ LOG_INFO("Running %d tests with valid signatures.", kSpxVerifyNumTests);
+
+ for (size_t i = 0; i < kSpxVerifyNumTests; i++) {
+ EXECUTE_TEST(result, spx_verify_test);
+ test_index++;
+ LOG_INFO("Finished test %d of %d.", test_index, kSpxVerifyNumTests);
+ }
+
+ LOG_INFO("Running %d tests with invalid signatures.", kNumNegativeTests);
+
+ test_index = 0;
+ for (size_t i = 0; i < kNumNegativeTests; i++) {
+ EXECUTE_TEST(result, spx_verify_negative_test);
+ test_index++;
+ LOG_INFO("Finished negative test %d of %d.", test_index, kNumNegativeTests);
+ }
+
+ return result == kErrorOk;
+}