[otbn] change calling convention for modexp
Gets rid of the clunky calling convention with the four
descriptors.
Expects the following layout in the 1st dmem cell now:
[p_out|p_exp|p_in|p_m|p_rr|p_m0d|N|reserved]
2nd dmem cell is reserved for future use
(blinding parameters).
Signed-off-by: Felix Miller <felix.miller@gi-de.com>
diff --git a/sw/otbn/code-snippets/modexp.s b/sw/otbn/code-snippets/modexp.s
index bd08e1b..7e66f88 100644
--- a/sw/otbn/code-snippets/modexp.s
+++ b/sw/otbn/code-snippets/modexp.s
@@ -184,11 +184,6 @@
/* save pointer to modulus */
addi x22, x16, 0
- /* load dmem[0] to w0. This is just used to have a non-zero number
- available */
- li x3, 0
- bn.lid x3, 0(x0)
-
/* zeroize w3 */
bn.xor w3, w3, w3
@@ -209,6 +204,7 @@
/* compute R-M
since R = 2^(N*w), this can be computed as R-M = unsigned(0-M) */
+ bn.addi w0, w31, 1
bn.sub w3, w31, w0, FG1
addi x16, x22, 0
jal x1, cond_sub_mod
@@ -786,61 +782,6 @@
/**
- * Externally callable wrapper for Montgomery modular multiply by one
- *
- * Returns: C = montmul(1,A) = A*R^(-1) mod M
- *
- * Routine for back-conversion from Montgomery domain.
- * This implements the limb-by-limb interleadved Montgomory Modular
- * Multiplication Algorithm, with one operand fixed to 1. This is only a
- * wrapper around the main loop body. For algorithmic implementation details
- * see the mont_loop subroutine.
- *
- * Flags: The states of both FG0 and FG1 depend on intermediate values and are
- * not usable after return.
- *
- * @param[in] dmem[0] dptr_M: dmem pointer to first limb of modulus M
- * @param[in] dmem[4] dptr_m0d: dmem pointer to Montgomery Constant m0'
- * @param[in] dmem[8] dptr_RR: dmem pointer to first limb of
- * squared Montgomery Modulus RR mod M
- * @param[in] dmem[12] dptr_a: dmem pointer to first limb of operand A
- * @param[in] dmem[20] dptr_c: dmem pointer to first limb of result C
- * @param[in] dmem[24] N: Number of limbs per bignum
- * @param[in] dmem[28] N-1: Number of limbs per bignum minus 1
- * @param[in] w31: all-zero
- * @param[out] [dmem[dptr_c+N*32-1]:dmem[dptr_c]]: result C
- *
- * clobbered registers: x3, x4, x5, x6 to x13, x16 to x31
- * w1 to w3, w24 to w30
- * w4 to w[4+N-1]
- * clobbered Flag Groups: FG0, FG1
- */
-mul1:
- /* prepare pointers to temp regs */
- li x8, 4
- li x9, 3
- li x10, 4
- li x11, 2
-
- /* load pointer to modulus */
- lw x16, 0(x0)
-
- /* load pointer to m0' */
- lw x17, 4(x0)
-
- /* load number of limbs */
- lw x30, 24(x0)
- lw x31, 28(x0)
-
- /* call montmul(1,A) algorithm */
- lw x19, 12(x0)
- lw x21, 20(x0)
- jal x1, montmul_mul1
-
- ret
-
-
-/**
* Constant-time Montgomery Modular Multiplication
*
* Returns: C = montmul(A,B) = A*B*R^(-1) mod M
@@ -1036,36 +977,19 @@
* Flags: The states of both FG0 and FG1 depend on intermediate values and are
* not usable after return.
*
- * Calling convention:
- * Data is loaded and stored to and from dmem in accordance to four
- * descriptors that have to be provided in the first 4 dmem cells
- * (256 bit each).
+ * The base bignum A is expected in the input buffer, the exponent E in the
+ * exp buffer, the result C is written to the output buffer.
+ * Note, that the content of both, the input buffer and the exp buffer is
+ * modified during execution.
*
- * first descriptor used to convert to montgomery domain:
- * @param[in] dmem[0] dptr_M: dmem pointer to first limb of modulus M
- * @param[in] dmem[4] dptr_m0d: dmem pointer to Montgomery Constant m0'
- * @param[in] dmem[8] dptr_RR: dmem pointer to first limb of
- * squared Montgomery Modulus RR mod M
- * @param[in] dmem[12] dptr_a_mont: dmem pointer to first limb of base A
- * @param[in] dmem[16] dptr_b_mont: dmem pointer to first limb of RR
- * @param[in] dmem[20] dptr_c_mont: dmem pointer to first limb of base A
- * @param[in] dmem[24] N: Number of limbs per bignum
- * @param[in] dmem[28] N-1: Number of limbs per bignum minus 1
- *
- * second descriptor used for squaring:
- * @param[in] dmem[44] dptr_a_sqr: dmem pointer to first limb of result C
- * @param[in] dmem[48] dptr_b_sqr: dmem pointer to first limb of result C
- * @param[in] dmem[52] dptr_c_sqr: dmem pointer to first limb of result C
- *
- * third descriptor used for multiplication:
- * @param[in] dmem[76] dptr_a_mul: dmem pointer to first limb of base A
- * @param[in] dmem[80] dptr_b_mul: dmem pointer to first limb of result C
- * @param[in] dmem[84] dptr_c_mul: dmem pointer to first limb of result C
- *
- * fourth descriptor used for reading the exponent and back-conversion:
- * @param[in] dmem[108] dptr_a_ex1: dmem pointer to first limb of base A
- * @param[in] dmem[112] dptr_b_ex1: dmem pointer to first limb of exponent E
- * @param[in] dmem[116] dptr_c_ex1: dmem pointer to first limb of result C
+ * @param[in] dmem[2] dptr_rr: pointer to RR in dmem
+ * @param[in] dmem[4] N: Number of limbs per bignum
+ * @param[in] dmem[8] dptr_m0d: pointer to m0' in dmem
+ * @param[in] dmem[12] dptr_rr: pointer to RR in dmem
+ * @param[in] dmem[16] dptr_m: pointer to first limb of modulus in dmem
+ * @param[in] dmem[20] dptr_in: pointer to input/base buffer
+ * @param[in] dmem[20] dptr_exp: pointer to exp buffer
+ * @param[in] dmem[28] dptr_out: pointer to output/result buffer
*
* clobbered registers: x3 to x13, x16 to x31
* w0 to w3, w24 to w30
@@ -1080,18 +1004,19 @@
li x11, 2
/* load pointer to modulus */
- lw x16, 0(x0)
+ lw x16, 16(x0)
/* load pointer to m0' */
- lw x17, 4(x0)
+ lw x17, 8(x0)
/* load number of limbs */
- lw x30, 24(x0)
- lw x31, 28(x0)
+ lw x30, 4(x0)
+ addi x31, x30, -1
- /* convert to montgomery domain montmul(A,RR) */
- lw x19, 12(x0)
- lw x20, 16(x0)
+ /* convert to montgomery domain montmul(A,RR)
+ in = montmul(A,RR) montmul(A,RR) = C*R mod M */
+ lw x19, 20(x0)
+ lw x20, 12(x0)
lw x21, 20(x0)
jal x1, montmul
@@ -1099,8 +1024,8 @@
bn.sub w2, w2, w2
/* initialize the output buffer with -M */
- lw x16, 0(x0)
- lw x21, 116(x0)
+ lw x16, 16(x0)
+ lw x21, 28(x0)
loop x30, 3
/* load limb from modulus */
bn.lid x11, 0(x16++)
@@ -1112,22 +1037,23 @@
bn.sid x11, 0(x21++)
/* reload pointer to modulus */
- lw x16, 0(x0)
+ lw x16, 16(x0)
/* compute bit length of current bigint size */
slli x24, x30, 8
/* iterate over all bits of bigint */
- loop x24, 22
- /* square */
- lw x19, 44(x0)
- lw x20, 48(x0)
- lw x21, 52(x0)
+ loop x24, 17
+ /* square: out = montmul(out,out) */
+ lw x19, 28(x0)
+ lw x20, 28(x0)
+ lw x21, 28(x0)
jal x1, montmul_sqr
- /* multiply */
- lw x19, 76(x0)
- lw x20, 80(x0)
+ /* multiply: out = montmul(in,out) */
+ lw x19, 20(x0)
+ lw x20, 28(x0)
+ lw x21, 28(x0)
jal x1, montmul_mul
/* w2 <= w2 << 1 */
@@ -1135,7 +1061,7 @@
/* the loop performs a 1-bit left shift of the exponent. Last MSB moves
to FG0.C, such that it can be used for selection */
- lw x20, 112(x0)
+ lw x20, 24(x0)
loop x30, 3
bn.lid x11, 0(x20)
/* w2 <= w2 << 1 */
@@ -1143,19 +1069,15 @@
bn.sid x11, 0(x20++)
/* select squared or squared+multiplied result */
- lw x21, 116(x0)
+ lw x21, 28(x0)
jal x1, sel_sqr_or_sqrmul
nop
- /* load 4th descriptor to w0 */
- li x3, 0
- bn.lid x3, 96(x0)
-
/* convert back from montgomery domain */
- lw x19, 108(x0)
- lw x20, 112(x0)
- lw x21, 116(x0)
+ /* out = montmul(out,1) = out/R mod M */
+ lw x19, 28(x0)
+ lw x21, 28(x0)
jal x1, montmul_mul1
ret
@@ -1177,37 +1099,17 @@
* Flags: The states of both FG0 and FG1 depend on intermediate values and are
* not usable after return.
*
- * Calling convention:
- * Data is loaded and stored to and from dmem in accordance to four
- * descriptors that have to be provided in the first 4 dmem cells
- * (256 bit each).
+ * The base bignum A is expected in the input buffer, the result C is written
+ * to the output buffer. Note, that the content of the input buffer is
+ * modified during execution.
*
- * first descriptor used to convert to montgomery domain:
- * @param[in] dmem[0] dptr_M: dmem pointer to first limb of modulus M
- * @param[in] dmem[4] dptr_m0d: dmem pointer to Montgomery Constant m0'
- * @param[in] dmem[8] dptr_RR: dmem pointer to first limb of
- * squared Montgomery Modulus RR mod M
- * @param[in] dmem[12] dptr_a_mont: dmem pointer to first limb of base A
- * @param[in] dmem[16] dptr_b_mont: dmem pointer to first limb of RR
- * @param[in] dmem[20] dptr_c_mont: dmem pointer to first limb of base A
- * @param[in] dmem[24] N: Number of limbs per bignum
- * @param[in] dmem[28] N-1: Number of limbs per bignum minus 1
- *
- * second descriptor used for squaring:
- * @param[in] dmem[44] dptr_a_sqr: dmem pointer to first limb of result C
- * @param[in] dmem[48] dptr_b_sqr: dmem pointer to first limb of result C
- * @param[in] dmem[52] dptr_c_sqr: dmem pointer to first limb of result C
- *
- * third descriptor used for multiplication:
- * @param[in] dmem[76] dptr_a_mul: dmem pointer to first limb of base A
- * @param[in] dmem[80] dptr_b_mul: dmem pointer to first limb of result C
- * @param[in] dmem[84] dptr_c_mul: dmem pointer to first limb of result C
- *
- * fourth descriptor used for back-conversion:
- * @param[in] dmem[104] dptr_RR: dmem pointer to first limb of
- * squared Montgomery Modulus RR mod M
- * @param[in] dmem[108] dptr_a_ex1: dmem pointer to first limb of base A
- * @param[in] dmem[116] dptr_c_ex1: dmem pointer to first limb of result C
+ * @param[in] dmem[2] dptr_rr: pointer to RR in dmem
+ * @param[in] dmem[4] N: Number of limbs per bignum
+ * @param[in] dmem[8] dptr_m0d: pointer to m0' in dmem
+ * @param[in] dmem[12] dptr_rr: pointer to RR in dmem
+ * @param[in] dmem[16] dptr_m: pointer to first limb of modulus in dmem
+ * @param[in] dmem[20] dptr_in: pointer to input/base buffer
+ * @param[in] dmem[28] dptr_out: pointer to output/result buffer
*
* clobbered registers: x3 to x13, x16 to x31
* w0 to w3, w24 to w30
@@ -1222,24 +1124,24 @@
li x11, 2
/* load pointer to modulus */
- lw x16, 0(x0)
+ lw x16, 16(x0)
/* load pointer to m0' */
- lw x17, 4(x0)
+ lw x17, 8(x0)
/* load number of limbs */
- lw x30, 24(x0)
- lw x31, 28(x0)
+ lw x30, 4(x0)
+ addi x31, x30, -1
/* convert to montgomery domain montmul(A,RR)
- in = montmul(A,RR) = C*R mod M */
- lw x19, 12(x0)
- lw x20, 16(x0)
+ in = montmul(A,RR) montmul(A,RR) = C*R mod M */
+ lw x19, 20(x0)
+ lw x20, 12(x0)
lw x21, 20(x0)
jal x1, montmul
/* pointer to out buffer */
- lw x21, 116(x0)
+ lw x21, 28(x0)
/* zeroize w2 and reset flags */
bn.sub w2, w2, w2
@@ -1255,24 +1157,25 @@
/* store limb in dmem */
bn.sid x11, 0(x21++)
- /* reload pointer to modulus */
- lw x16, 32(x0)
+ /* reload pointer to 1st limb of modulus */
+ lw x16, 16(x0)
/* 65537 = 0b10000000000000001
^ sqr + mult
out = montmul(out,out) */
- lw x19, 44(x0)
- lw x20, 48(x0)
- lw x21, 52(x0)
+ lw x19, 28(x0)
+ lw x20, 28(x0)
+ lw x21, 28(x0)
jal x1, montmul_sqr
/* out = montmul(in,out) */
- lw x19, 76(x0)
- lw x20, 80(x0)
+ lw x19, 20(x0)
+ lw x20, 28(x0)
+ lw x20, 28(x0)
jal x1, montmul_mul
/* store multiplication result in output buffer */
- lw x21, 84(x0)
+ lw x21, 28(x0)
li x8, 4
loop x30, 2
/* store selected limb to dmem */
@@ -1283,33 +1186,31 @@
^<< 16 x sqr >>^ */
loopi 16, 5
/* square: out = montmul(out, out) */
- lw x19, 44(x0)
- lw x20, 48(x0)
- lw x21, 52(x0)
+ lw x19, 28(x0)
+ lw x20, 28(x0)
+ lw x21, 28(x0)
jal x1, montmul_sqr
nop
/* 65537 = 0b10000000000000001
mult ^
out = montmul(in,out) */
- lw x19, 76(x0)
- lw x20, 80(x0)
+ lw x19, 20(x0)
+ lw x20, 28(x0)
+ lw x21, 28(x0)
jal x1, montmul_mul
/* store multiplication result in output buffer */
- lw x21, 84(x0)
+ lw x21, 28(x0)
li x8, 4
loop x30, 2
bn.sid x8, 0(x21++)
addi x8, x8, 1
- /* pointer to out buffer */
- lw x19, 108(x0)
- /* pointer to out buffer */
- lw x21, 116(x0)
-
/* convert back from montgomery domain */
/* out = montmul(out,1) = out/R mod M */
+ lw x19, 28(x0)
+ lw x21, 28(x0)
jal x1, montmul_mul1
ret
@@ -1324,10 +1225,11 @@
*
* Needs to be executed once per constant Modulus.
*
- * @param[in] dmem[0] dptr_m: pointer to first limb of modulus in dmem
- * @param[in] dmem[1] dptr_m0d: pointer to m0' in dmem
* @param[in] dmem[2] dptr_rr: pointer to RR in dmem
- * @param[in] dmem[6] N: Number of limbs per bignum
+ * @param[in] dmem[4] N: Number of limbs per bignum
+ * @param[in] dmem[8] dptr_m0d: pointer to m0' in dmem
+ * @param[in] dmem[12] dptr_rr: pointer to RR in dmem
+ * @param[in] dmem[16] dptr_m: pointer to first limb of modulus in dmem
* @param[out] [dmem[dptr_m0d+31]:dmem[dptr_m0d]] computed m0'
* @parma[out] [dmem[dptr_RR+N*32-1]:dmem[dptr_RR]] computed RR
*/
@@ -1337,16 +1239,16 @@
bn.xor w31, w31, w31
/* load pointer to modulus (dptr_m) */
- lw x16, 0(x0)
+ lw x16, 16(x0)
/* load pointer to m0' (dptr_m0d) */
- lw x17, 4(x0)
+ lw x17, 8(x0)
/* load pointer to RR (dptr_rr) */
- lw x18, 8(x0)
+ lw x18, 12(x0)
/* load number of limbs (N) */
- lw x30, 24(x0)
+ lw x30, 4(x0)
/* load lowest limb of modulus to w28 */
li x8, 28
diff --git a/sw/otbn/code-snippets/rsa_1024_dec_test.s b/sw/otbn/code-snippets/rsa_1024_dec_test.s
index 5a7c2ac..6e125c6 100644
--- a/sw/otbn/code-snippets/rsa_1024_dec_test.s
+++ b/sw/otbn/code-snippets/rsa_1024_dec_test.s
@@ -18,7 +18,7 @@
jal x1, modload
jal x1, modexp
/* pointer to out buffer */
- lw x21, 116(x0)
+ lw x21, 28(x0)
/* copy all limbs of result to wide reg file */
li x8, 0
@@ -31,52 +31,35 @@
.data
-/* descriptor 1: a=in, b=RR, c=in
- convert to Montgomery */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000004c0
-.word 0x000002c0
-.word 0x000004c0
-.word 0x00000004
-.word 0x00000003
+/* reserved */
+.word 0x00000000
-/* descriptor 2: a=out, b=out, c=out
- square */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000008e0
-.word 0x000008e0
-.word 0x000008e0
+/* number of limbs (N) */
.word 0x00000004
-.word 0x00000003
-/* descriptor 3: a=in, b=out, c=out
- multiply */
-.word 0x00000080
+/* pointer to m0' (dptr_m0d) */
.word 0x00000280
+
+/* pointer to RR (dptr_rr) */
.word 0x000002c0
+
+/* load pointer to modulus (dptr_m) */
+.word 0x00000080
+
+/* pointer to base bignum buffer (dptr_in) */
.word 0x000004c0
-.word 0x000008e0
-.word 0x000008e0
-.word 0x00000004
-.word 0x00000003
-/* descriptor 4: a=in, b=exp, c=out
- shift exponent and convert back */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000008e0
+/* pointer to exponent buffer (dptr_exp) */
.word 0x000006c0
-.word 0x000008e0
-.word 0x00000004
-.word 0x00000003
+
+/* pointer to out buffer (dptr_out) */
+.word 0x000008c0
-/* modulus */
+/* Modulus */
+/* skip to 128 */
+.skip 96
+
.word 0xc28cf49f
.word 0xb6e64c3b
.word 0xa21417f1
diff --git a/sw/otbn/code-snippets/rsa_1024_enc_test.s b/sw/otbn/code-snippets/rsa_1024_enc_test.s
index d7c8d7a..d985c06 100644
--- a/sw/otbn/code-snippets/rsa_1024_enc_test.s
+++ b/sw/otbn/code-snippets/rsa_1024_enc_test.s
@@ -19,7 +19,7 @@
jal x1, modload
jal x1, modexp_65537
/* pointer to out buffer */
- lw x21, 116(x0)
+ lw x21, 28(x0)
/* copy all limbs of result to wide reg file */
li x8, 0
@@ -32,52 +32,35 @@
.data
-/* descriptor 1: a=in, b=RR, c=in
- convert to Montgomery */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000004c0
-.word 0x000002c0
-.word 0x000004c0
-.word 0x00000004
-.word 0x00000003
+/* reserved */
+.word 0x00000000
-/* descriptor 2: a=out, b=out, c=out
- square */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000008e0
-.word 0x000008e0
-.word 0x000008e0
+/* number of limbs (N) */
.word 0x00000004
-.word 0x00000003
-/* descriptor 3: a=in, b=out, c=out
- multiply */
-.word 0x00000080
+/* pointer to m0' (dptr_m0d) */
.word 0x00000280
+
+/* pointer to RR (dptr_rr) */
.word 0x000002c0
+
+/* load pointer to modulus (dptr_m) */
+.word 0x00000080
+
+/* pointer to base bignum buffer (dptr_in) */
.word 0x000004c0
-.word 0x000008e0
-.word 0x000008e0
-.word 0x00000004
-.word 0x00000003
-/* descriptor 4: a=in, b=exp, c=out
- shift exponent and convert back */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000008e0
+/* pointer to exponent buffer (dptr_exp, unused for encrypt) */
.word 0x000006c0
-.word 0x000008e0
-.word 0x00000004
-.word 0x00000003
+
+/* pointer to out buffer (dptr_out) */
+.word 0x000008c0
/* Modulus */
+/* skip to 128 */
+.skip 96
+
.word 0xc28cf49f
.word 0xb6e64c3b
.word 0xa21417f1