[crypto] Add an OTBN implementation of constant-time GCD.
This is a component of RSA key generation.
Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/sw/otbn/crypto/BUILD b/sw/otbn/crypto/BUILD
index 3794bb6..42f4634 100644
--- a/sw/otbn/crypto/BUILD
+++ b/sw/otbn/crypto/BUILD
@@ -35,6 +35,13 @@
)
otbn_library(
+ name = "gcd",
+ srcs = [
+ "gcd.s",
+ ],
+)
+
+otbn_library(
name = "modexp",
srcs = [
"modexp.s",
diff --git a/sw/otbn/crypto/gcd.s b/sw/otbn/crypto/gcd.s
new file mode 100644
index 0000000..b9c43f4
--- /dev/null
+++ b/sw/otbn/crypto/gcd.s
@@ -0,0 +1,310 @@
+/* Copyright lowRISC contributors. */
+/* Licensed under the Apache License, Version 2.0, see LICENSE for details. */
+/* SPDX-License-Identifier: Apache-2.0 */
+
+/**
+ * Constant-time conditional subtraction.
+ *
+ * Returns a' = if FG1.L then (a - b) else a.
+ *
+ * This is a specialized helper function for the GCD computation. Modifies the
+ * input a in-place.
+ *
+ * Flags have no meaning beyond the scope of this subroutine.
+ *
+ * @param[in] x9: n, number of 256-bit limbs for inputs a and b
+ * @param[in] x3: dptr_a, pointer to input a in DMEM
+ * @param[in] x4: dptr_b, pointer to input b in DMEM
+ * @param[in] x23: 23, constant
+ * @param[in] x24: 24, constant
+ * @param[in] w31: all-zero
+ * @param[in] FG1.L: selection flag
+ * @param[out] dmem[dptr_a:dptr_a+n*32]: a', result
+ *
+ * clobbered registers: x3, x4, x5, w23, w24
+ * clobbered flag groups: FG0
+ */
+gcd_cond_sub:
+ /* Clear flags. */
+ bn.add w31, w31, w31
+
+ /* Loop through each limb. */
+ loop x9, 5
+ /* w23 <= dmem[x3] = a[i] */
+ bn.lid x23, 0(x3)
+ /* w24 <= dmem[x4] = b[i] */
+ bn.lid x24, 0(x4++)
+ /* w24 <= w23 - w24 = (a - b)[i] */
+ bn.subb w24, w23, w24
+ /* w24 <= FG1.L ? w24 : w23 = c[i] */
+ bn.sel w24, w24, w23, FG1.L
+ /* dmem[x3] <= w24 = c[i] */
+ bn.sid x24, 0(x3++)
+
+ ret
+
+/**
+ * Shifts input 1 bit to the right if it is even.
+ *
+ * Returns a' = if a[0] then a else a >> 1.
+ *
+ * This is a specialized helper function for the GCD computation. The input is
+ * modified in-place. This routine runs in constant time.
+ *
+ * Flags have no meaning beyond the scope of this subroutine.
+ *
+ * @param[in] x3: dptr_a, pointer to input a in DMEM
+ * @param[in] x9: n, number of 256-bit limbs for input a
+ * @param[in] x23: 23, constant
+ * @param[in] x24: 24, constant
+ * @param[in] w31: all-zero
+ * @param[out] dmem[dptr_a:dptr_a+n*32]: a', result
+ *
+ * clobbered registers: x3, x4, x22, x23, x24, w23, w24
+ * clobbered flag groups: FG0
+ */
+gcd_cond_rshift1:
+ /* x22 <= x9 - 1 = n - 1 */
+ addi x22, x0, 1
+ sub x22, x9, x22
+
+ /* Get a pointer to the second limb of the input.
+ x4 <= x3 + 32 = dptr_a + 32 */
+ addi x4, x3, 32
+
+ /* w23 <= dmem[x3] = a[255:0] */
+ bn.lid x23, 0(x3)
+
+ /* FG0.L <= a[0] */
+ bn.add w23, w23, w31
+
+ /* If the number of limbs is 1, skip the loop. This check is required because
+ a loop with n-1=0 iterations will cause a LOOP error. */
+ beq x22, x0, _gcd_cond_rshift1_loop_end
+
+ /* Loop through all limbs except the last. */
+ loop x22, 5
+ /* w23 <= dmem[x3] = a[i] */
+ bn.lid x23, 0(x3)
+ /* w24 <= dmem[x4] = a[i+1] */
+ bn.lid x24, 0(x4++)
+ /* w24 <= (a >> 1)[i] */
+ bn.rshi w24, w24, w23 >> 1
+ /* w24 <= FG0.L ? w23 : w24 = a'[i] */
+ bn.sel w24, w23, w24, FG0.L
+ /* dmem[x3] <= w24 = a'[i] */
+ bn.sid x24, 0(x3++)
+
+ _gcd_cond_rshift1_loop_end:
+
+ /* Last limb is special because there's no next limb; we use 0. */
+ bn.lid x23, 0(x3)
+ bn.rshi w24, w31, w23 >> 1
+ bn.sel w24, w23, w24, FG0.L
+ bn.sid x24, 0(x3++)
+
+ ret
+
+/**
+ * Shifts input 1 bit to the left if the counter is nonzero.
+ *
+ * Returns a' = if ctr > 0 then a << 1 else a.
+ *
+ * This is a specialized helper function for the GCD computation. The input a
+ * is modified in-place. This routine runs in constant time.
+ *
+ * Flags have no meaning beyond the scope of this subroutine.
+ *
+ * @param[in] x3: dptr_a, pointer to input a in DMEM
+ * @param[in] x9: n, number of 256-bit limbs for input a
+ * @param[in] x23: 23, constant
+ * @param[in] x24: 24, constant
+ * @param[in] x25: 25, constant
+ * @param[in] w20: ctr
+ * @param[in] w31: all-zero
+ * @param[out] dmem[dptr_a:dptr_a+n*32]: a', result
+ *
+ * clobbered registers: x3, x22, w20, w23, w24, w25
+ * clobbered flag groups: FG0
+ */
+gcd_cond_lshift1:
+ /* Check if counter is zero.
+ FG0.Z <= ctr != 0 */
+ bn.addi w20, w20, 0
+
+ /* w23 <= 0 */
+ bn.mov w23, w31
+
+ /* Loop through remaining limbs.
+
+ Loop invariants (i=0 to i=n-1):
+ w23 = i == 0 ? 0 : a[i-1]
+ x3 = dptr_a + (i * 32)
+ */
+ loop x9, 5
+ /* w24 <= dmem[x3] = a[i] */
+ bn.lid x24, 0(x3)
+ /* w25 <= (a << 1)[i] */
+ bn.rshi w25, w24, w23 >> 255
+ /* w25 <= FG0.Z ? w24 : w25 = a'[i] */
+ bn.sel w25, w24, w25, FG0.Z
+ /* dmem[x3] <= w25 = a'[i] */
+ bn.sid x25, 0(x3++)
+ /* w23 <= w24 = a[i] */
+ bn.mov w23, w24
+
+ ret
+
+/**
+ * Compute the greatest common denominator of two large numbers.
+ *
+ * Returns g = gcd(x, y).
+ *
+ * This implementation is based on an implementation from BoringSSL, which is
+ * in turn a constant-time version of HAC Algorithm 14.54 (binary GCD). The
+ * full implementation is here:
+ * https://boringssl.googlesource.com/boringssl/+/1b2b7b2e70ce5ff50df917ee7745403d824155c5/crypto/fipsmodule/bn/gcd_extra.c#49
+ *
+ * In pseudocode, this algorithm is:
+ * shift = 0
+ * num_iters = nbits(x) + nbits(y)
+ * for i=0..num_iters:
+ * if x[0] & y[0]:
+ * x = (x >= y) ? x-y : x
+ * y = (x >= y) ? y : y-x
+ * if !(x[0] | y[0]):
+ * shift += 1
+ * x = x[0] ? x : x >> 1
+ * y = y[0] ? y : y >> 1
+ * y |= x
+ * return y << shift
+ *
+ * The final y |= x is only needed if y is zero; the algorithm guarantees that
+ * x is zero at the end of the computation otherwise. This implementation skips
+ * the final or, so y must not be zero.
+ *
+ * This routine overwrites its inputs. At the end, the buffer for x will be 0
+ * and the buffer for y will hold the result.
+ *
+ * The number of limbs is in principle only limited by DMEM space.
+ *
+ * Flags have no meaning beyond the scope of this subroutine.
+ *
+ * @param[in] x9: n, number of 256-bit limbs for x, y, and g
+ * @param[in] x10: dptr_x, pointer to input x in DMEM
+ * @param[in] x11: dptr_y, pointer to input y in DMEM
+ * @param[in] w31: all-zero
+ * @param[out] dmem[dptr_y:dptr_y+n*32]: g, result
+ *
+ * clobbered registers: x3, x4, x21 to x25, w20 to w25
+ * clobbered flag groups: FG0, FG1
+ */
+.globl gcd
+gcd:
+ /* Initialize the shift to 0.
+ w20 <= 0 = shift */
+ bn.mov w20, w31
+
+ /* Compute the number of iterations. The inputs x and y have n * 256 bits
+ each, so the sum of their number of bits is n * 2 * 256 = n << 9.
+ x21 <= x9 << 9 = num_iters */
+ slli x21, x9, 9
+
+ /* Set up constants for loop.
+ x23 <= 23
+ x24 <= 24
+ x25 <= 25 */
+ li x23, 23
+ li x24, 24
+ li x25, 25
+
+ /* Main loop. */
+ loop x21, 30
+ /* Load the least significant limbs of x and y.
+ w23 <= dmem[x10] = x[255:0]
+ w24 <= dmem[x11] = y[255:0] */
+ bn.lid x23, 0(x10)
+ bn.lid x24, 0(x11)
+
+ /* w25 <= x[255:0] & y[255:0] */
+ bn.and w25, w23, w24
+
+ /* Clear flags. */
+ bn.add w31, w31, w31
+
+ /* Compare x and y.
+ FG0.C <= y < x = !(x >= y) */
+ addi x3, x10, 0
+ addi x4, x11, 0
+ loop x9, 3
+ bn.lid x23, 0(x3++)
+ bn.lid x24, 0(x4++)
+ bn.cmpb w23, w24
+
+ /* Capture final borrow in a wide register.
+ w22 <= FG0.C = !(x >= y) */
+ bn.addc w22, w31, w31
+
+ /* Update FG1.L so that it is 1 if x should be updated.
+ FG1.L <= (w25 & ~w22)[0] = (x[0] & y[0]) && (x >= y) */
+ bn.not w23, w22
+ bn.and w23, w23, w25, FG1
+
+ /* dmem[dptr_x] = cond_sub(dptr_x, dptr_y) = if FG1.L then x - y else x */
+ addi x3, x10, 0
+ addi x4, x11, 0
+ jal x1, gcd_cond_sub
+
+ /* Update FG1.L so that it is 1 if y should be updated.
+ FG1.L <= (w25[0] & w22)[0] = (x[0] & y[0]) && !(x >= y) */
+ bn.and w23, w22, w25, FG1
+
+ /* dmem[dptr_y] = cond_sub(dptr_y, dptr_x) = if FG1.L then y - x else y */
+ addi x3, x11, 0
+ addi x4, x10, 0
+ jal x1, gcd_cond_sub
+
+ /* Reload the least significant limbs of x and y.
+ w23 <= dmem[x10] = x[255:0]
+ w24 <= dmem[x11] = y[255:0] */
+ bn.lid x23, 0(x10)
+ bn.lid x24, 0(x11)
+
+ /* FG1.L = (x[0] | y[0]) */
+ bn.or w25, w23, w24, FG1
+
+ /* Update shift.
+ w20 <= FG1.L ? w20 : w20 + 1
+ = if (x[0] | y[0]) then shift else shift + 1 */
+ bn.addi w25, w20, 1
+ bn.sel w20, w20, w25, FG1.L
+
+ /* Shift x to the right by 1 if it is even.
+ dmem[dptr_x] <= if x[0] then x else x >> 1 */
+ addi x3, x10, 0
+ jal x1, gcd_cond_rshift1
+
+ /* Shift y to the right by 1 if it is even.
+ dmem[dptr_y] <= if y[0] then y else y >> 1 */
+ addi x3, x11, 0
+ jal x1, gcd_cond_rshift1
+ nop
+
+ /* End of loop. At this point we are guaranteed x = 0. */
+
+ /* Compute the maximum value of the shift. This is the number of bits in each
+ input, since the gcd cannot be greater than either of its operands.
+ x21 <= x9 << 8 = n * 256 */
+ slli x21, x9, 8
+
+ /* Shift y left to obtain the final result.
+ dmem[dptr_y] <= y << w20 = y << shift */
+ loop x21, 4
+ /* dmem[dptr_y] <= if w20 != 0 then dmem[dptr_y] << 1 else dmem[dptr_y] */
+ addi x3, x11, 0
+ jal x1, gcd_cond_lshift1
+ /* w20 <= max(0, w20 - 1) */
+ bn.subi w23, w20, 1
+ bn.sel w20, w20, w23, FG0.C
+
+ ret