[otbn] unify montmul routines in modexp

This unfiies the three Montgomery multiplication
subroutine montmul_sqr, montmul_exp and montmul
in a single one.

This is related to #3776.

Signed-off-by: Felix Miller <felix.miller@gi-de.com>
diff --git a/sw/otbn/code-snippets/modexp.s b/sw/otbn/code-snippets/modexp.s
index 7e66f88..b33cf0b 100644
--- a/sw/otbn/code-snippets/modexp.s
+++ b/sw/otbn/code-snippets/modexp.s
@@ -561,76 +561,6 @@
 
 
 /**
- * Constant-time Montgomery Modular Multiplication
- *
- * Returns: C = montmul(A,B) = A*B*R^(-1) mod M
- *
- * This implements the limb-by-limb interleadved Montgomory Modular
- * Multiplication Algorithm. This is only a wrapper around the main loop body.
- * For algorithmic implementation details see the mont_loop subroutine.
- *
- * Flags: The states of both FG0 and FG1 depend on intermediate values and are
- *        not usable after return.
- *
- * @param[in]  x16 dptr_M: dmem pointer to first limb of modulus M
- * @param[in]  x17 dptr_m0d: dmem pointer to Montgomery Constant m0'
- * @param[in]  x19 dptr_a: dmem pointer to first limb of operand A
- * @param[in]  x20 dptr_b: dmem pointer to first limb of operand B
- * @param[in]  x21 dptr_c: dmem pointer to first limb of result C
- * @param[in]  w31 all-zero
- * @param[in]  x30 N: number of limbs
- * @param[in]  x31 N-1: number of limbs minus one
- * @param[in]  x9: pointer to temp reg, must be set to 3
- * @param[in]  x10: pointer to temp reg, must be set to 4
- * @param[in]  x11: pointer to temp reg, must be set to 2
- * @param[out] [dmem[dptr_c+N*32-1]:dmem[dptr_c]]: result C
- *
- * clobbered registers: x3, x4, x5, x6, x8 to x13, x16 to x31
- *                      w1 to w3, w24 to w30
- *                      w4 to w[4+N-1]
- * clobbered Flag Groups: FG0, FG1
- */
-montmul:
-  /* load Montgomery constant: w3 = dmem[x17] = dmem[dptr_m0d] = m0'*/
-  bn.lid    x9, 0(x17)
-
-  /* init regfile bigint buffer with zeros */
-  bn.mov    w2, w31
-  loop      x30, 1
-    bn.movr   x10++, x11
-
-  /* iterate over limbs of operand B */
-  loop      x30, 8
-
-    /* load limb of operand b */
-    bn.lid    x11, 0(x20++)
-
-    /* save some regs */
-    add       x4, x16, x0
-    add       x5, x19, x0
-    add       x6, x20, x0
-
-    /* Main loop body of Montgomory Multiplication algorithm */
-    jal       x1, mont_loop
-
-    /* restore regs */
-    add       x16, x4, x0
-    add       x19, x5, x0
-    add       x20, x6, x0
-
-  /* Store result in dmem starting at dmem[dptr_c] */
-  loop      x30, 2
-    bn.sid    x8, 0(x21++)
-    addi      x8, x8, 1
-
-  /* restore pointer */
-  li        x8, 4
-  li        x10, 4
-
-  ret
-
-
-/**
  * Constant time conditional bigint subtraction
  *
  * Returns C = A-x*B
@@ -790,82 +720,6 @@
  * Multiplication Algorithm. This is only a wrapper around the main loop body.
  * For algorithmic implementation details see the mont_loop subroutine.
  *
- * This variant loads the 3rd descriptor (dmem cell 2) and stores the result
- * in dmem. It is intended to be used as squaring primitive in a
- * square and multiply implementation.
- *
- * Flags: The states of both FG0 and FG1 depend on intermediate values and are
- *        not usable after return.
- *
- * @param[in]  x16 dptr_M: dmem pointer to first limb of modulus M
- * @param[in]  x17 dptr_m0d: dmem pointer to Montgomery Constant m0'
- * @param[in]  x19 dptr_a: dmem pointer to first limb of operand A
- * @param[in]  x20 dptr_b: dmem pointer to first limb of operand B
- * @param[in]  x21 dptr_c: dmem pointer to first limb of result C
- * @param[in]  w31 all-zero
- * @param[in]  x30 N: number of limbs
- * @param[in]  x31 N-1: number of limbs minus one
- * @param[in]  x9: pointer to temp reg, must be set to 3
- * @param[in]  x10: pointer to temp reg, must be set to 4
- * @param[in]  x11: pointer to temp reg, must be set to 2
- *
- * clobbered registers: x5, x6, x7, x8, x10, x12, x13, x16, x17, x19, x20, x21
- *                      w2, w3, w24 to w30, w4 to w[4+N-1]
- * clobbered Flag Groups: FG0, FG1
- */
-montmul_sqr:
-  /* load Montgomery constant: w3 = dmem[x17] = dmem[dptr_m0d] = m0' */
-  bn.lid    x9, 0(x17)
-
-    /* init regfile bigint buffer with zeros */
-  bn.mov    w2, w31
-  loop      x30, 1
-    bn.movr   x10++, x11
-
-  /* iterate over limbs of operand B */
-  loop      x30, 8
-
-    /* load limb of operand b */
-    bn.lid    x11, 0(x20++)
-
-    /* save some regs */
-    addi      x5, x20, 0
-    addi      x6, x16, 0
-    addi      x7, x19, 0
-
-    /* Main loop body of Montgomory Multiplication algorithm */
-    jal       x1, mont_loop
-
-    /* restore regs */
-    addi      x20, x5, 0
-    addi      x16, x6, 0
-    addi      x19, x7, 0
-
-  /* Store result in dmem starting at dmem[dptr_c] */
-  loop      x30, 2
-    bn.sid    x8, 0(x21++)
-    addi      x8, x8, 1
-
-  /* restore pointers */
-  li         x8, 4
-  li        x10, 4
-
-  ret
-
-
-/**
- * Constant-time Montgomery Modular Multiplication
- *
- * Returns: C = montmul(A,B) = A*B*R^(-1) mod M
- *
- * This implements the limb-by-limb interleadved Montgomory Modular
- * Multiplication Algorithm. This is only a wrapper around the main loop body.
- * For algorithmic implementation details see the mont_loop subroutine.
- *
- * This variant loads the 2nd descriptor (dmem cell 1) and stores the result
- * in the regfile. It is intended to be used as multiplication primitive in a
- * square and multiply implementation.
- *
  * Flags: The states of both FG0 and FG1 depend on intermediate values and are
  *        not usable after return.
  *
@@ -885,7 +739,7 @@
  *                      w2, w3, w24 to w30, w4 to w[4+N-1]
  * clobbered Flag Groups: FG0, FG1
  */
-montmul_mul:
+montmul:
   /* load Montgomery constant: w3 = dmem[x17] = dmem[dptr_m0d] = m0' */
   bn.lid    x9, 0(x17)
 
@@ -969,8 +823,8 @@
  * This implements the square and multiply algorithm, i.e. for each bit of the
  * exponent both the squared only and the squared with multiply results are
  * computed but one result is discarded.
- * Computation is carried out in the Montgomery domain, by using the primitives
- * montmul, montmul_sqr, montmul_mul and montmul_mul1.
+ * Computation is carried out in the Montgomery domain, by using the montmul
+ * primitive.
  * The squared Montgomery modulus RR and the Montgomery constant m0' have to
  * be precomputed and provided at the appropriate locations in dmem.
  *
@@ -1019,6 +873,10 @@
   lw        x20, 12(x0)
   lw        x21, 20(x0)
   jal       x1, montmul
+  /* Store result in dmem starting at dmem[dptr_c] */
+  loop      x30, 2
+    bn.sid    x8, 0(x21++)
+    addi      x8, x8, 1
 
   /* zeroize w2 and reset flags */
   bn.sub    w2, w2, w2
@@ -1043,18 +901,22 @@
   slli      x24, x30, 8
 
   /* iterate over all bits of bigint */
-  loop      x24, 17
+  loop      x24, 20
     /* square: out = montmul(out,out)  */
     lw        x19, 28(x0)
     lw        x20, 28(x0)
     lw        x21, 28(x0)
-    jal       x1, montmul_sqr
+    jal       x1, montmul
+    /* Store result in dmem starting at dmem[dptr_c] */
+    loop      x30, 2
+      bn.sid    x8, 0(x21++)
+      addi      x8, x8, 1
 
     /* multiply: out = montmul(in,out) */
     lw        x19, 20(x0)
     lw        x20, 28(x0)
     lw        x21, 28(x0)
-    jal       x1, montmul_mul
+    jal       x1, montmul
 
     /* w2 <= w2 << 1 */
     bn.add    w2, w2, w2
@@ -1139,6 +1001,10 @@
   lw        x20, 12(x0)
   lw        x21, 20(x0)
   jal       x1, montmul
+  /* Store result in dmem starting at dmem[dptr_c] */
+  loop      x30, 2
+    bn.sid    x8, 0(x21++)
+    addi      x8, x8, 1
 
   /* pointer to out buffer */
   lw        x21, 28(x0)
@@ -1166,13 +1032,17 @@
   lw        x19, 28(x0)
   lw        x20, 28(x0)
   lw        x21, 28(x0)
-  jal       x1, montmul_sqr
+  jal       x1, montmul
+  /* Store result in dmem starting at dmem[dptr_c] */
+  loop      x30, 2
+    bn.sid    x8, 0(x21++)
+    addi      x8, x8, 1
 
   /* out = montmul(in,out)       */
   lw        x19, 20(x0)
   lw        x20, 28(x0)
   lw        x20, 28(x0)
-  jal       x1, montmul_mul
+  jal       x1, montmul
 
   /* store multiplication result in output buffer */
   lw        x21, 28(x0)
@@ -1184,13 +1054,16 @@
 
   /* 65537 = 0b10000000000000001
                 ^<< 16 x sqr >>^   */
-  loopi      16, 5
+  loopi      16, 7
     /* square: out = montmul(out, out) */
     lw        x19, 28(x0)
     lw        x20, 28(x0)
     lw        x21, 28(x0)
-    jal       x1, montmul_sqr
-    nop
+    jal       x1, montmul
+    /* Store result in dmem starting at dmem[dptr_c] */
+    loop      x30, 2
+      bn.sid    x8, 0(x21++)
+      addi      x8, x8, 1
 
   /* 65537 = 0b10000000000000001
                           mult ^
@@ -1198,7 +1071,7 @@
   lw        x19, 20(x0)
   lw        x20, 28(x0)
   lw        x21, 28(x0)
-  jal       x1, montmul_mul
+  jal       x1, montmul
 
   /* store multiplication result in output buffer */
   lw        x21, 28(x0)