[otbn] unify montmul routines in modexp
This unfiies the three Montgomery multiplication
subroutine montmul_sqr, montmul_exp and montmul
in a single one.
This is related to #3776.
Signed-off-by: Felix Miller <felix.miller@gi-de.com>
diff --git a/sw/otbn/code-snippets/modexp.s b/sw/otbn/code-snippets/modexp.s
index 7e66f88..b33cf0b 100644
--- a/sw/otbn/code-snippets/modexp.s
+++ b/sw/otbn/code-snippets/modexp.s
@@ -561,76 +561,6 @@
/**
- * Constant-time Montgomery Modular Multiplication
- *
- * Returns: C = montmul(A,B) = A*B*R^(-1) mod M
- *
- * This implements the limb-by-limb interleadved Montgomory Modular
- * Multiplication Algorithm. This is only a wrapper around the main loop body.
- * For algorithmic implementation details see the mont_loop subroutine.
- *
- * Flags: The states of both FG0 and FG1 depend on intermediate values and are
- * not usable after return.
- *
- * @param[in] x16 dptr_M: dmem pointer to first limb of modulus M
- * @param[in] x17 dptr_m0d: dmem pointer to Montgomery Constant m0'
- * @param[in] x19 dptr_a: dmem pointer to first limb of operand A
- * @param[in] x20 dptr_b: dmem pointer to first limb of operand B
- * @param[in] x21 dptr_c: dmem pointer to first limb of result C
- * @param[in] w31 all-zero
- * @param[in] x30 N: number of limbs
- * @param[in] x31 N-1: number of limbs minus one
- * @param[in] x9: pointer to temp reg, must be set to 3
- * @param[in] x10: pointer to temp reg, must be set to 4
- * @param[in] x11: pointer to temp reg, must be set to 2
- * @param[out] [dmem[dptr_c+N*32-1]:dmem[dptr_c]]: result C
- *
- * clobbered registers: x3, x4, x5, x6, x8 to x13, x16 to x31
- * w1 to w3, w24 to w30
- * w4 to w[4+N-1]
- * clobbered Flag Groups: FG0, FG1
- */
-montmul:
- /* load Montgomery constant: w3 = dmem[x17] = dmem[dptr_m0d] = m0'*/
- bn.lid x9, 0(x17)
-
- /* init regfile bigint buffer with zeros */
- bn.mov w2, w31
- loop x30, 1
- bn.movr x10++, x11
-
- /* iterate over limbs of operand B */
- loop x30, 8
-
- /* load limb of operand b */
- bn.lid x11, 0(x20++)
-
- /* save some regs */
- add x4, x16, x0
- add x5, x19, x0
- add x6, x20, x0
-
- /* Main loop body of Montgomory Multiplication algorithm */
- jal x1, mont_loop
-
- /* restore regs */
- add x16, x4, x0
- add x19, x5, x0
- add x20, x6, x0
-
- /* Store result in dmem starting at dmem[dptr_c] */
- loop x30, 2
- bn.sid x8, 0(x21++)
- addi x8, x8, 1
-
- /* restore pointer */
- li x8, 4
- li x10, 4
-
- ret
-
-
-/**
* Constant time conditional bigint subtraction
*
* Returns C = A-x*B
@@ -790,82 +720,6 @@
* Multiplication Algorithm. This is only a wrapper around the main loop body.
* For algorithmic implementation details see the mont_loop subroutine.
*
- * This variant loads the 3rd descriptor (dmem cell 2) and stores the result
- * in dmem. It is intended to be used as squaring primitive in a
- * square and multiply implementation.
- *
- * Flags: The states of both FG0 and FG1 depend on intermediate values and are
- * not usable after return.
- *
- * @param[in] x16 dptr_M: dmem pointer to first limb of modulus M
- * @param[in] x17 dptr_m0d: dmem pointer to Montgomery Constant m0'
- * @param[in] x19 dptr_a: dmem pointer to first limb of operand A
- * @param[in] x20 dptr_b: dmem pointer to first limb of operand B
- * @param[in] x21 dptr_c: dmem pointer to first limb of result C
- * @param[in] w31 all-zero
- * @param[in] x30 N: number of limbs
- * @param[in] x31 N-1: number of limbs minus one
- * @param[in] x9: pointer to temp reg, must be set to 3
- * @param[in] x10: pointer to temp reg, must be set to 4
- * @param[in] x11: pointer to temp reg, must be set to 2
- *
- * clobbered registers: x5, x6, x7, x8, x10, x12, x13, x16, x17, x19, x20, x21
- * w2, w3, w24 to w30, w4 to w[4+N-1]
- * clobbered Flag Groups: FG0, FG1
- */
-montmul_sqr:
- /* load Montgomery constant: w3 = dmem[x17] = dmem[dptr_m0d] = m0' */
- bn.lid x9, 0(x17)
-
- /* init regfile bigint buffer with zeros */
- bn.mov w2, w31
- loop x30, 1
- bn.movr x10++, x11
-
- /* iterate over limbs of operand B */
- loop x30, 8
-
- /* load limb of operand b */
- bn.lid x11, 0(x20++)
-
- /* save some regs */
- addi x5, x20, 0
- addi x6, x16, 0
- addi x7, x19, 0
-
- /* Main loop body of Montgomory Multiplication algorithm */
- jal x1, mont_loop
-
- /* restore regs */
- addi x20, x5, 0
- addi x16, x6, 0
- addi x19, x7, 0
-
- /* Store result in dmem starting at dmem[dptr_c] */
- loop x30, 2
- bn.sid x8, 0(x21++)
- addi x8, x8, 1
-
- /* restore pointers */
- li x8, 4
- li x10, 4
-
- ret
-
-
-/**
- * Constant-time Montgomery Modular Multiplication
- *
- * Returns: C = montmul(A,B) = A*B*R^(-1) mod M
- *
- * This implements the limb-by-limb interleadved Montgomory Modular
- * Multiplication Algorithm. This is only a wrapper around the main loop body.
- * For algorithmic implementation details see the mont_loop subroutine.
- *
- * This variant loads the 2nd descriptor (dmem cell 1) and stores the result
- * in the regfile. It is intended to be used as multiplication primitive in a
- * square and multiply implementation.
- *
* Flags: The states of both FG0 and FG1 depend on intermediate values and are
* not usable after return.
*
@@ -885,7 +739,7 @@
* w2, w3, w24 to w30, w4 to w[4+N-1]
* clobbered Flag Groups: FG0, FG1
*/
-montmul_mul:
+montmul:
/* load Montgomery constant: w3 = dmem[x17] = dmem[dptr_m0d] = m0' */
bn.lid x9, 0(x17)
@@ -969,8 +823,8 @@
* This implements the square and multiply algorithm, i.e. for each bit of the
* exponent both the squared only and the squared with multiply results are
* computed but one result is discarded.
- * Computation is carried out in the Montgomery domain, by using the primitives
- * montmul, montmul_sqr, montmul_mul and montmul_mul1.
+ * Computation is carried out in the Montgomery domain, by using the montmul
+ * primitive.
* The squared Montgomery modulus RR and the Montgomery constant m0' have to
* be precomputed and provided at the appropriate locations in dmem.
*
@@ -1019,6 +873,10 @@
lw x20, 12(x0)
lw x21, 20(x0)
jal x1, montmul
+ /* Store result in dmem starting at dmem[dptr_c] */
+ loop x30, 2
+ bn.sid x8, 0(x21++)
+ addi x8, x8, 1
/* zeroize w2 and reset flags */
bn.sub w2, w2, w2
@@ -1043,18 +901,22 @@
slli x24, x30, 8
/* iterate over all bits of bigint */
- loop x24, 17
+ loop x24, 20
/* square: out = montmul(out,out) */
lw x19, 28(x0)
lw x20, 28(x0)
lw x21, 28(x0)
- jal x1, montmul_sqr
+ jal x1, montmul
+ /* Store result in dmem starting at dmem[dptr_c] */
+ loop x30, 2
+ bn.sid x8, 0(x21++)
+ addi x8, x8, 1
/* multiply: out = montmul(in,out) */
lw x19, 20(x0)
lw x20, 28(x0)
lw x21, 28(x0)
- jal x1, montmul_mul
+ jal x1, montmul
/* w2 <= w2 << 1 */
bn.add w2, w2, w2
@@ -1139,6 +1001,10 @@
lw x20, 12(x0)
lw x21, 20(x0)
jal x1, montmul
+ /* Store result in dmem starting at dmem[dptr_c] */
+ loop x30, 2
+ bn.sid x8, 0(x21++)
+ addi x8, x8, 1
/* pointer to out buffer */
lw x21, 28(x0)
@@ -1166,13 +1032,17 @@
lw x19, 28(x0)
lw x20, 28(x0)
lw x21, 28(x0)
- jal x1, montmul_sqr
+ jal x1, montmul
+ /* Store result in dmem starting at dmem[dptr_c] */
+ loop x30, 2
+ bn.sid x8, 0(x21++)
+ addi x8, x8, 1
/* out = montmul(in,out) */
lw x19, 20(x0)
lw x20, 28(x0)
lw x20, 28(x0)
- jal x1, montmul_mul
+ jal x1, montmul
/* store multiplication result in output buffer */
lw x21, 28(x0)
@@ -1184,13 +1054,16 @@
/* 65537 = 0b10000000000000001
^<< 16 x sqr >>^ */
- loopi 16, 5
+ loopi 16, 7
/* square: out = montmul(out, out) */
lw x19, 28(x0)
lw x20, 28(x0)
lw x21, 28(x0)
- jal x1, montmul_sqr
- nop
+ jal x1, montmul
+ /* Store result in dmem starting at dmem[dptr_c] */
+ loop x30, 2
+ bn.sid x8, 0(x21++)
+ addi x8, x8, 1
/* 65537 = 0b10000000000000001
mult ^
@@ -1198,7 +1071,7 @@
lw x19, 20(x0)
lw x20, 28(x0)
lw x21, 28(x0)
- jal x1, montmul_mul
+ jal x1, montmul
/* store multiplication result in output buffer */
lw x21, 28(x0)