[otbn] change calling convention for modexp

Gets rid of the clunky calling convention with the four
descriptors.

Expects the following layout in the 1st dmem cell now:
[p_out|p_exp|p_in|p_m|p_rr|p_m0d|N|reserved]

2nd dmem cell is reserved for future use
(blinding parameters).

Signed-off-by: Felix Miller <felix.miller@gi-de.com>
diff --git a/sw/otbn/code-snippets/modexp.s b/sw/otbn/code-snippets/modexp.s
index bd08e1b..7e66f88 100644
--- a/sw/otbn/code-snippets/modexp.s
+++ b/sw/otbn/code-snippets/modexp.s
@@ -184,11 +184,6 @@
   /* save pointer to modulus */
   addi      x22, x16, 0
 
-  /* load dmem[0] to w0. This is just used to have a non-zero number
-     available */
-  li        x3, 0
-  bn.lid    x3, 0(x0)
-
   /* zeroize w3 */
   bn.xor    w3, w3, w3
 
@@ -209,6 +204,7 @@
 
   /* compute R-M
      since R = 2^(N*w), this can be computed as R-M = unsigned(0-M) */
+  bn.addi w0, w31, 1
   bn.sub    w3, w31, w0, FG1
   addi      x16, x22, 0
   jal       x1, cond_sub_mod
@@ -786,61 +782,6 @@
 
 
 /**
- * Externally callable wrapper for Montgomery modular multiply by one
- *
- * Returns: C = montmul(1,A) = A*R^(-1) mod M
- *
- * Routine for back-conversion from Montgomery domain.
- * This implements the limb-by-limb interleadved Montgomory Modular
- * Multiplication Algorithm, with one operand fixed to 1. This is only a
- * wrapper around the main loop body. For algorithmic implementation details
- * see the mont_loop subroutine.
- *
- * Flags: The states of both FG0 and FG1 depend on intermediate values and are
- *        not usable after return.
- *
- * @param[in]  dmem[0] dptr_M: dmem pointer to first limb of modulus M
- * @param[in]  dmem[4] dptr_m0d: dmem pointer to Montgomery Constant m0'
- * @param[in]  dmem[8] dptr_RR: dmem pointer to first limb of
- *                              squared Montgomery Modulus RR mod M
- * @param[in]  dmem[12] dptr_a: dmem pointer to first limb of operand A
- * @param[in]  dmem[20] dptr_c: dmem pointer to first limb of result C
- * @param[in]  dmem[24] N: Number of limbs per bignum
- * @param[in]  dmem[28] N-1: Number of limbs per bignum minus 1
- * @param[in]  w31: all-zero
- * @param[out] [dmem[dptr_c+N*32-1]:dmem[dptr_c]]: result C
- *
- * clobbered registers: x3, x4, x5, x6 to x13, x16 to x31
- *                      w1 to w3, w24 to w30
- *                      w4 to w[4+N-1]
- * clobbered Flag Groups: FG0, FG1
- */
-mul1:
-  /* prepare pointers to temp regs */
-  li         x8, 4
-  li         x9, 3
-  li        x10, 4
-  li        x11, 2
-
-  /* load pointer to modulus */
-  lw        x16, 0(x0)
-
-  /* load pointer to m0' */
-  lw        x17, 4(x0)
-
-  /* load number of limbs */
-  lw        x30, 24(x0)
-  lw        x31, 28(x0)
-
-  /* call montmul(1,A) algorithm */
-  lw        x19, 12(x0)
-  lw        x21, 20(x0)
-  jal       x1, montmul_mul1
-
-  ret
-
-
-/**
  * Constant-time Montgomery Modular Multiplication
  *
  * Returns: C = montmul(A,B) = A*B*R^(-1) mod M
@@ -1036,36 +977,19 @@
  * Flags: The states of both FG0 and FG1 depend on intermediate values and are
  *        not usable after return.
  *
- * Calling convention:
- * Data is loaded and stored to and from dmem in accordance to four
- *   descriptors that have to be provided in the first 4 dmem cells
- *   (256 bit each).
+ * The base bignum A is expected in the input buffer, the exponent E in the
+ * exp buffer, the result C is written to the output buffer.
+ * Note, that the content of both, the input buffer and the exp buffer is
+ * modified during execution.
  *
- * first descriptor used to convert to montgomery domain:
- *   @param[in]  dmem[0] dptr_M: dmem pointer to first limb of modulus M
- *   @param[in]  dmem[4] dptr_m0d: dmem pointer to Montgomery Constant m0'
- *   @param[in]  dmem[8] dptr_RR: dmem pointer to first limb of
- *                              squared Montgomery Modulus RR mod M
- *   @param[in]  dmem[12] dptr_a_mont: dmem pointer to first limb of base A
- *   @param[in]  dmem[16] dptr_b_mont: dmem pointer to first limb of RR
- *   @param[in]  dmem[20] dptr_c_mont: dmem pointer to first limb of base A
- *   @param[in]  dmem[24] N: Number of limbs per bignum
- *   @param[in]  dmem[28] N-1: Number of limbs per bignum minus 1
- *
- * second descriptor used for squaring:
- *   @param[in]  dmem[44] dptr_a_sqr: dmem pointer to first limb of result C
- *   @param[in]  dmem[48] dptr_b_sqr: dmem pointer to first limb of result C
- *   @param[in]  dmem[52] dptr_c_sqr: dmem pointer to first limb of result C
- *
- * third descriptor used for multiplication:
- *   @param[in]  dmem[76] dptr_a_mul: dmem pointer to first limb of base A
- *   @param[in]  dmem[80] dptr_b_mul: dmem pointer to first limb of result C
- *   @param[in]  dmem[84] dptr_c_mul: dmem pointer to first limb of result C
- *
- * fourth descriptor used for reading the exponent and back-conversion:
- *   @param[in]  dmem[108] dptr_a_ex1: dmem pointer to first limb of base A
- *   @param[in]  dmem[112] dptr_b_ex1: dmem pointer to first limb of exponent E
- *   @param[in]  dmem[116] dptr_c_ex1: dmem pointer to first limb of result C
+ * @param[in]  dmem[2] dptr_rr: pointer to RR in dmem
+ * @param[in]  dmem[4] N: Number of limbs per bignum
+ * @param[in]  dmem[8] dptr_m0d: pointer to m0' in dmem
+ * @param[in]  dmem[12] dptr_rr: pointer to RR in dmem
+ * @param[in]  dmem[16] dptr_m: pointer to first limb of modulus in dmem
+ * @param[in]  dmem[20] dptr_in: pointer to input/base buffer
+ * @param[in]  dmem[20] dptr_exp: pointer to exp buffer
+ * @param[in]  dmem[28] dptr_out: pointer to output/result buffer
  *
  * clobbered registers: x3 to x13, x16 to x31
  *                      w0 to w3, w24 to w30
@@ -1080,18 +1004,19 @@
   li        x11, 2
 
   /* load pointer to modulus */
-  lw        x16, 0(x0)
+  lw        x16, 16(x0)
 
   /* load pointer to m0' */
-  lw        x17, 4(x0)
+  lw        x17, 8(x0)
 
   /* load number of limbs */
-  lw        x30, 24(x0)
-  lw        x31, 28(x0)
+  lw        x30, 4(x0)
+  addi      x31, x30, -1
 
-  /* convert to montgomery domain montmul(A,RR) */
-  lw        x19, 12(x0)
-  lw        x20, 16(x0)
+  /* convert to montgomery domain montmul(A,RR)
+  in = montmul(A,RR) montmul(A,RR) = C*R mod M */
+  lw        x19, 20(x0)
+  lw        x20, 12(x0)
   lw        x21, 20(x0)
   jal       x1, montmul
 
@@ -1099,8 +1024,8 @@
   bn.sub    w2, w2, w2
 
   /* initialize the output buffer with -M */
-  lw        x16, 0(x0)
-  lw        x21, 116(x0)
+  lw        x16, 16(x0)
+  lw        x21, 28(x0)
   loop      x30, 3
     /* load limb from modulus */
     bn.lid    x11, 0(x16++)
@@ -1112,22 +1037,23 @@
     bn.sid    x11, 0(x21++)
 
   /* reload pointer to modulus */
-  lw        x16, 0(x0)
+  lw        x16, 16(x0)
 
   /* compute bit length of current bigint size */
   slli      x24, x30, 8
 
   /* iterate over all bits of bigint */
-  loop      x24, 22
-    /* square */
-    lw        x19, 44(x0)
-    lw        x20, 48(x0)
-    lw        x21, 52(x0)
+  loop      x24, 17
+    /* square: out = montmul(out,out)  */
+    lw        x19, 28(x0)
+    lw        x20, 28(x0)
+    lw        x21, 28(x0)
     jal       x1, montmul_sqr
 
-    /* multiply */
-    lw        x19, 76(x0)
-    lw        x20, 80(x0)
+    /* multiply: out = montmul(in,out) */
+    lw        x19, 20(x0)
+    lw        x20, 28(x0)
+    lw        x21, 28(x0)
     jal       x1, montmul_mul
 
     /* w2 <= w2 << 1 */
@@ -1135,7 +1061,7 @@
 
     /* the loop performs a 1-bit left shift of the exponent. Last MSB moves
        to FG0.C, such that it can be used for selection */
-    lw        x20, 112(x0)
+    lw        x20, 24(x0)
     loop      x30, 3
       bn.lid    x11, 0(x20)
       /* w2 <= w2 << 1 */
@@ -1143,19 +1069,15 @@
       bn.sid    x11, 0(x20++)
 
     /* select squared or squared+multiplied result */
-    lw        x21, 116(x0)
+    lw        x21, 28(x0)
     jal       x1, sel_sqr_or_sqrmul
 
     nop
 
-  /* load 4th descriptor to w0 */
-  li        x3, 0
-  bn.lid    x3, 96(x0)
-
   /* convert back from montgomery domain */
-  lw        x19, 108(x0)
-  lw        x20, 112(x0)
-  lw        x21, 116(x0)
+  /* out = montmul(out,1) = out/R mod M  */
+  lw        x19, 28(x0)
+  lw        x21, 28(x0)
   jal       x1, montmul_mul1
 
   ret
@@ -1177,37 +1099,17 @@
  * Flags: The states of both FG0 and FG1 depend on intermediate values and are
  *        not usable after return.
  *
- * Calling convention:
- * Data is loaded and stored to and from dmem in accordance to four
- *   descriptors that have to be provided in the first 4 dmem cells
- *   (256 bit each).
+ * The base bignum A is expected in the input buffer, the result C is written
+ * to the output buffer. Note, that the content of the input buffer is
+ * modified during execution.
  *
- * first descriptor used to convert to montgomery domain:
- *   @param[in]  dmem[0] dptr_M: dmem pointer to first limb of modulus M
- *   @param[in]  dmem[4] dptr_m0d: dmem pointer to Montgomery Constant m0'
- *   @param[in]  dmem[8] dptr_RR: dmem pointer to first limb of
- *                              squared Montgomery Modulus RR mod M
- *   @param[in]  dmem[12] dptr_a_mont: dmem pointer to first limb of base A
- *   @param[in]  dmem[16] dptr_b_mont: dmem pointer to first limb of RR
- *   @param[in]  dmem[20] dptr_c_mont: dmem pointer to first limb of base A
- *   @param[in]  dmem[24] N: Number of limbs per bignum
- *   @param[in]  dmem[28] N-1: Number of limbs per bignum minus 1
- *
- * second descriptor used for squaring:
- *   @param[in]  dmem[44] dptr_a_sqr: dmem pointer to first limb of result C
- *   @param[in]  dmem[48] dptr_b_sqr: dmem pointer to first limb of result C
- *   @param[in]  dmem[52] dptr_c_sqr: dmem pointer to first limb of result C
- *
- * third descriptor used for multiplication:
- *   @param[in]  dmem[76] dptr_a_mul: dmem pointer to first limb of base A
- *   @param[in]  dmem[80] dptr_b_mul: dmem pointer to first limb of result C
- *   @param[in]  dmem[84] dptr_c_mul: dmem pointer to first limb of result C
- *
- * fourth descriptor used for back-conversion:
- *   @param[in]  dmem[104] dptr_RR: dmem pointer to first limb of
- *                              squared Montgomery Modulus RR mod M
- *   @param[in]  dmem[108] dptr_a_ex1: dmem pointer to first limb of base A
- *   @param[in]  dmem[116] dptr_c_ex1: dmem pointer to first limb of result C
+ * @param[in]  dmem[2] dptr_rr: pointer to RR in dmem
+ * @param[in]  dmem[4] N: Number of limbs per bignum
+ * @param[in]  dmem[8] dptr_m0d: pointer to m0' in dmem
+ * @param[in]  dmem[12] dptr_rr: pointer to RR in dmem
+ * @param[in]  dmem[16] dptr_m: pointer to first limb of modulus in dmem
+ * @param[in]  dmem[20] dptr_in: pointer to input/base buffer
+ * @param[in]  dmem[28] dptr_out: pointer to output/result buffer
  *
  * clobbered registers: x3 to x13, x16 to x31
  *                      w0 to w3, w24 to w30
@@ -1222,24 +1124,24 @@
   li        x11, 2
 
   /* load pointer to modulus */
-  lw        x16, 0(x0)
+  lw        x16, 16(x0)
 
   /* load pointer to m0' */
-  lw        x17, 4(x0)
+  lw        x17, 8(x0)
 
   /* load number of limbs */
-  lw        x30, 24(x0)
-  lw        x31, 28(x0)
+  lw        x30, 4(x0)
+  addi      x31, x30, -1
 
   /* convert to montgomery domain montmul(A,RR)
-  in = montmul(A,RR) = C*R mod M */
-  lw        x19, 12(x0)
-  lw        x20, 16(x0)
+  in = montmul(A,RR) montmul(A,RR) = C*R mod M */
+  lw        x19, 20(x0)
+  lw        x20, 12(x0)
   lw        x21, 20(x0)
   jal       x1, montmul
 
   /* pointer to out buffer */
-  lw        x21, 116(x0)
+  lw        x21, 28(x0)
 
   /* zeroize w2 and reset flags */
   bn.sub    w2, w2, w2
@@ -1255,24 +1157,25 @@
     /* store limb in dmem */
     bn.sid    x11, 0(x21++)
 
-  /* reload pointer to modulus */
-  lw        x16, 32(x0)
+  /* reload pointer to 1st limb of modulus */
+  lw        x16, 16(x0)
 
   /* 65537 = 0b10000000000000001
                ^ sqr + mult
     out = montmul(out,out)       */
-  lw        x19, 44(x0)
-  lw        x20, 48(x0)
-  lw        x21, 52(x0)
+  lw        x19, 28(x0)
+  lw        x20, 28(x0)
+  lw        x21, 28(x0)
   jal       x1, montmul_sqr
 
   /* out = montmul(in,out)       */
-  lw        x19, 76(x0)
-  lw        x20, 80(x0)
+  lw        x19, 20(x0)
+  lw        x20, 28(x0)
+  lw        x20, 28(x0)
   jal       x1, montmul_mul
 
   /* store multiplication result in output buffer */
-  lw        x21, 84(x0)
+  lw        x21, 28(x0)
   li        x8, 4
   loop      x30, 2
     /* store selected limb to dmem */
@@ -1283,33 +1186,31 @@
                 ^<< 16 x sqr >>^   */
   loopi      16, 5
     /* square: out = montmul(out, out) */
-    lw        x19, 44(x0)
-    lw        x20, 48(x0)
-    lw        x21, 52(x0)
+    lw        x19, 28(x0)
+    lw        x20, 28(x0)
+    lw        x21, 28(x0)
     jal       x1, montmul_sqr
     nop
 
   /* 65537 = 0b10000000000000001
                           mult ^
      out = montmul(in,out)       */
-  lw        x19, 76(x0)
-  lw        x20, 80(x0)
+  lw        x19, 20(x0)
+  lw        x20, 28(x0)
+  lw        x21, 28(x0)
   jal       x1, montmul_mul
 
   /* store multiplication result in output buffer */
-  lw        x21, 84(x0)
+  lw        x21, 28(x0)
   li        x8, 4
   loop      x30, 2
     bn.sid    x8, 0(x21++)
     addi      x8, x8, 1
 
-  /* pointer to out buffer */
-  lw        x19, 108(x0)
-  /* pointer to out buffer */
-  lw        x21, 116(x0)
-
   /* convert back from montgomery domain */
   /* out = montmul(out,1) = out/R mod M  */
+  lw        x19, 28(x0)
+  lw        x21, 28(x0)
   jal       x1, montmul_mul1
 
   ret
@@ -1324,10 +1225,11 @@
  *
  * Needs to be executed once per constant Modulus.
  *
- * @param[in]  dmem[0] dptr_m: pointer to first limb of modulus in dmem
- * @param[in]  dmem[1] dptr_m0d: pointer to m0' in dmem
  * @param[in]  dmem[2] dptr_rr: pointer to RR in dmem
- * @param[in]  dmem[6] N: Number of limbs per bignum
+ * @param[in]  dmem[4] N: Number of limbs per bignum
+ * @param[in]  dmem[8] dptr_m0d: pointer to m0' in dmem
+ * @param[in]  dmem[12] dptr_rr: pointer to RR in dmem
+ * @param[in]  dmem[16] dptr_m: pointer to first limb of modulus in dmem
  * @param[out] [dmem[dptr_m0d+31]:dmem[dptr_m0d]] computed m0'
  * @parma[out] [dmem[dptr_RR+N*32-1]:dmem[dptr_RR]] computed RR
  */
@@ -1337,16 +1239,16 @@
   bn.xor   w31, w31, w31
 
   /* load pointer to modulus (dptr_m) */
-  lw       x16, 0(x0)
+  lw       x16, 16(x0)
 
   /* load pointer to m0' (dptr_m0d) */
-  lw       x17, 4(x0)
+  lw       x17, 8(x0)
 
   /* load pointer to RR (dptr_rr) */
-  lw       x18, 8(x0)
+  lw       x18, 12(x0)
 
   /* load number of limbs (N) */
-  lw       x30, 24(x0)
+  lw       x30, 4(x0)
 
   /* load lowest limb of modulus to w28 */
   li       x8, 28
diff --git a/sw/otbn/code-snippets/rsa_1024_dec_test.s b/sw/otbn/code-snippets/rsa_1024_dec_test.s
index 5a7c2ac..6e125c6 100644
--- a/sw/otbn/code-snippets/rsa_1024_dec_test.s
+++ b/sw/otbn/code-snippets/rsa_1024_dec_test.s
@@ -18,7 +18,7 @@
   jal      x1, modload
   jal      x1, modexp
   /* pointer to out buffer */
-  lw        x21, 116(x0)
+  lw        x21, 28(x0)
 
   /* copy all limbs of result to wide reg file */
   li       x8, 0
@@ -31,52 +31,35 @@
 
 .data
 
-/* descriptor 1: a=in, b=RR, c=in
-   convert to Montgomery */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000004c0
-.word 0x000002c0
-.word 0x000004c0
-.word 0x00000004
-.word 0x00000003
+/* reserved */
+.word 0x00000000
 
-/* descriptor 2: a=out, b=out, c=out
-   square */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000008e0
-.word 0x000008e0
-.word 0x000008e0
+/* number of limbs (N) */
 .word 0x00000004
-.word 0x00000003
 
-/* descriptor 3: a=in, b=out, c=out
-   multiply */
-.word 0x00000080
+/* pointer to m0' (dptr_m0d) */
 .word 0x00000280
+
+/* pointer to RR (dptr_rr) */
 .word 0x000002c0
+
+/* load pointer to modulus (dptr_m) */
+.word 0x00000080
+
+/* pointer to base bignum buffer (dptr_in) */
 .word 0x000004c0
-.word 0x000008e0
-.word 0x000008e0
-.word 0x00000004
-.word 0x00000003
 
-/* descriptor 4: a=in, b=exp, c=out
-   shift exponent and convert back */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000008e0
+/* pointer to exponent buffer (dptr_exp) */
 .word 0x000006c0
-.word 0x000008e0
-.word 0x00000004
-.word 0x00000003
+
+/* pointer to out buffer (dptr_out) */
+.word 0x000008c0
 
 
-/* modulus */
+/* Modulus */
+/* skip to 128 */
+.skip 96
+
 .word 0xc28cf49f
 .word 0xb6e64c3b
 .word 0xa21417f1
diff --git a/sw/otbn/code-snippets/rsa_1024_enc_test.s b/sw/otbn/code-snippets/rsa_1024_enc_test.s
index d7c8d7a..d985c06 100644
--- a/sw/otbn/code-snippets/rsa_1024_enc_test.s
+++ b/sw/otbn/code-snippets/rsa_1024_enc_test.s
@@ -19,7 +19,7 @@
   jal      x1, modload
   jal      x1, modexp_65537
   /* pointer to out buffer */
-  lw        x21, 116(x0)
+  lw        x21, 28(x0)
 
   /* copy all limbs of result to wide reg file */
   li       x8, 0
@@ -32,52 +32,35 @@
 
 .data
 
-/* descriptor 1: a=in, b=RR, c=in
-   convert to Montgomery */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000004c0
-.word 0x000002c0
-.word 0x000004c0
-.word 0x00000004
-.word 0x00000003
+/* reserved */
+.word 0x00000000
 
-/* descriptor 2: a=out, b=out, c=out
-   square */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000008e0
-.word 0x000008e0
-.word 0x000008e0
+/* number of limbs (N) */
 .word 0x00000004
-.word 0x00000003
 
-/* descriptor 3: a=in, b=out, c=out
-   multiply */
-.word 0x00000080
+/* pointer to m0' (dptr_m0d) */
 .word 0x00000280
+
+/* pointer to RR (dptr_rr) */
 .word 0x000002c0
+
+/* load pointer to modulus (dptr_m) */
+.word 0x00000080
+
+/* pointer to base bignum buffer (dptr_in) */
 .word 0x000004c0
-.word 0x000008e0
-.word 0x000008e0
-.word 0x00000004
-.word 0x00000003
 
-/* descriptor 4: a=in, b=exp, c=out
-   shift exponent and convert back */
-.word 0x00000080
-.word 0x00000280
-.word 0x000002c0
-.word 0x000008e0
+/* pointer to exponent buffer (dptr_exp, unused for encrypt) */
 .word 0x000006c0
-.word 0x000008e0
-.word 0x00000004
-.word 0x00000003
+
+/* pointer to out buffer (dptr_out) */
+.word 0x000008c0
 
 
 /* Modulus */
+/* skip to 128 */
+.skip 96
+
 .word 0xc28cf49f
 .word 0xb6e64c3b
 .word 0xa21417f1