[crypto] Simplify DMEM interface for RSA.

Now that we have a .bss directive for OTBN, there's no need to rely on
indirect pointers or specific memory locations in RSA code. This also
will make it easier to use the same underlying libraries for keygen and
signing.

Also includes a minor improvement to the OTBN testutils, so that we
print the error bits if OTBN fails a test because it became locked.

Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/sw/device/lib/testing/otbn_testutils.c b/sw/device/lib/testing/otbn_testutils.c
index 472ba4c..fae936f 100644
--- a/sw/device/lib/testing/otbn_testutils.c
+++ b/sw/device/lib/testing/otbn_testutils.c
@@ -26,16 +26,18 @@
     busy = status != kDifOtbnStatusIdle && status != kDifOtbnStatusLocked;
   }
 
-  // Error out if OTBN is locked.
-  CHECK(status == kDifOtbnStatusIdle, "OTBN is locked.");
-
-  // Get instruction count so that, if error bits are unexpected, we can print
-  // it to help with debugging.
+  // Get instruction count so that we can print them to help with debugging.
   uint32_t instruction_count;
   CHECK_DIF_OK(dif_otbn_get_insn_cnt(otbn, &instruction_count));
 
   dif_otbn_err_bits_t err_bits;
   CHECK_DIF_OK(dif_otbn_get_err_bits(otbn, &err_bits));
+
+  // Error out if OTBN is locked.
+  CHECK(status == kDifOtbnStatusIdle, "OTBN is locked. Error bits: 0x%08x",
+        err_bits);
+
+  // Error out if error bits do not match expectations.
   CHECK(err_bits == expected_err_bits,
         "OTBN error bits: got: 0x%08x, expected: 0x%08x.\nInstruction count: "
         "0x%08x",
diff --git a/sw/otbn/crypto/modexp.s b/sw/otbn/crypto/modexp.s
index 20c17e2..7b50b41 100644
--- a/sw/otbn/crypto/modexp.s
+++ b/sw/otbn/crypto/modexp.s
@@ -174,9 +174,9 @@
 * @param[in]  x18: dptr_RR: dmem pointer to first limb of output buffer for RR
 * @param[in]  x30: N, number of limbs
 * @param[in]  w31: all-zero
-* @param[out] dmem[x18+N*32:x18]: computed RR
+* @param[out] dmem[dptr_RR+N*32:dptr_RR]: computed RR
 *
-* clobbered registers: x3, x8, x10, x11, x16, x18
+* clobbered registers: x3, x8, x10, x11, x22
 *                      w0, w2, w3, w4, w5 to w20 depending on N
 * clobbered flag groups: FG0, FG1
 */
@@ -264,9 +264,13 @@
   /* reset pointer to 1st limb of bigint in regfile */
   li        x8, 5
 
+  /* reset pointer to modulus */
+  addi      x16, x22, 0
+
   /* store computed RR in dmem */
+  addi      x3, x18, 0
   loop      x30, 2
-    bn.sid    x8, 0(x18++)
+    bn.sid    x8, 0(x3++)
     addi      x8, x8, 1
 
   ret
@@ -554,6 +558,9 @@
      subtraction of the modulus from the output buffer. */
   jal       x1, cond_sub_to_reg
 
+  /* restore pointer again */
+  addi      x16, x22, 0
+
   /* restore pointer */
   li        x8, 4
 
@@ -735,7 +742,7 @@
  * @param[in]  x11: pointer to temp reg, must be set to 2
  * @param[out] [w[4+N-1]:w4]: result C
  *
- * clobbered registers: x5, x6, x7, x8, x10, x12, x13, x16, x17, x19, x20, x21
+ * clobbered registers: x5, x6, x7, x8, x10, x12, x13, x16, x17, x19, x20
  *                      w2, w3, w24 to w30, w4 to w[4+N-1]
  * clobbered Flag Groups: FG0, FG1
  */
@@ -836,14 +843,15 @@
  * Note, that the content of both, the input buffer and the exp buffer is
  * modified during execution.
  *
- * @param[in]  dmem[2] dptr_rr: pointer to RR in dmem
- * @param[in]  dmem[4] N: Number of limbs per bignum
- * @param[in]  dmem[8] dptr_m0d: pointer to m0' in dmem
- * @param[in]  dmem[12] dptr_rr: pointer to RR in dmem
- * @param[in]  dmem[16] dptr_m: pointer to first limb of modulus in dmem
- * @param[in]  dmem[20] dptr_in: pointer to input/base buffer
- * @param[in]  dmem[20] dptr_exp: pointer to exp buffer
- * @param[in]  dmem[28] dptr_out: pointer to output/result buffer
+ * @param[in]   x2: dptr_c, dmem pointer to buffer for output C
+ * @param[in]  x14: dptr_a, dmem pointer to first limb of input A
+ * @param[in]  x15: dptr_e, dmem pointer to first limb of exponent E
+ * @param[in]  x16: dptr_M, dmem pointer to first limb of modulus M
+ * @param[in]  x17: dptr_m0d, dmem pointer to first limb of m0'
+ * @param[in]  x18: dptr_RR, dmem pointer to first limb of RR
+ * @param[in]  x30: N, number of limbs per bignum
+ * @param[in]  w31: all-zero
+ * @param[out] dmem[dptr_c:dptr_c+N*32] C, A^E mod M
  *
  * clobbered registers: x3 to x13, x16 to x31
  *                      w0 to w3, w24 to w30
@@ -857,23 +865,16 @@
   li        x10, 4
   li        x11, 2
 
-  /* load pointer to modulus */
-  lw        x16, 16(x0)
-
-  /* load pointer to m0' */
-  lw        x17, 8(x0)
-
-  /* load number of limbs */
-  lw        x30, 4(x0)
+  /* Compute (N-1).
+       x31 <= x30 - 1 = N - 1 */
   addi      x31, x30, -1
 
-  /* convert to montgomery domain montmul(A,RR)
-  in = montmul(A,RR) montmul(A,RR) = C*R mod M */
-  lw        x19, 20(x0)
-  lw        x20, 12(x0)
-  lw        x21, 20(x0)
+  /* Convert input to montgomery domain.
+       dmem[dptr_a] <= montmul(A,RR) = A*R mod M */
+  addi      x19, x14, 0
+  addi      x20, x18, 0
+  addi      x21, x14, 0
   jal       x1, montmul
-  /* Store result in dmem starting at dmem[dptr_c] */
   loop      x30, 2
     bn.sid    x8, 0(x21++)
     addi      x8, x8, 1
@@ -882,11 +883,11 @@
   bn.sub    w2, w2, w2
 
   /* initialize the output buffer with -M */
-  lw        x16, 16(x0)
-  lw        x21, 28(x0)
+  addi      x3, x16, 0
+  addi      x21, x2, 0
   loop      x30, 3
     /* load limb from modulus */
-    bn.lid    x11, 0(x16++)
+    bn.lid    x11, 0(x3++)
 
     /* subtract limb from 0 */
     bn.subb   w2, w31, w2
@@ -894,18 +895,15 @@
     /* store limb in dmem */
     bn.sid    x11, 0(x21++)
 
-  /* reload pointer to modulus */
-  lw        x16, 16(x0)
-
   /* compute bit length of current bigint size */
   slli      x24, x30, 8
 
   /* iterate over all bits of bigint */
   loop      x24, 20
     /* square: out = montmul(out,out)  */
-    lw        x19, 28(x0)
-    lw        x20, 28(x0)
-    lw        x21, 28(x0)
+    addi      x19, x2, 0
+    addi      x20, x2, 0
+    addi      x21, x2, 0
     jal       x1, montmul
     /* Store result in dmem starting at dmem[dptr_c] */
     loop      x30, 2
@@ -913,9 +911,9 @@
       addi      x8, x8, 1
 
     /* multiply: out = montmul(in,out) */
-    lw        x19, 20(x0)
-    lw        x20, 28(x0)
-    lw        x21, 28(x0)
+    addi      x19, x14, 0
+    addi      x20, x2, 0
+    addi      x21, x2, 0
     jal       x1, montmul
 
     /* w2 <= w2 << 1 */
@@ -923,7 +921,7 @@
 
     /* the loop performs a 1-bit left shift of the exponent. Last MSB moves
        to FG0.C, such that it can be used for selection */
-    lw        x20, 24(x0)
+    addi      x20, x15, 0
     loop      x30, 3
       bn.lid    x11, 0(x20)
       /* w2 <= w2 << 1 */
@@ -931,15 +929,15 @@
       bn.sid    x11, 0(x20++)
 
     /* select squared or squared+multiplied result */
-    lw        x21, 28(x0)
+    addi      x21, x2, 0
     jal       x1, sel_sqr_or_sqrmul
 
     nop
 
   /* convert back from montgomery domain */
   /* out = montmul(out,1) = out/R mod M  */
-  lw        x19, 28(x0)
-  lw        x21, 28(x0)
+  addi      x19, x2, 0
+  addi      x21, x2, 0
   jal       x1, montmul_mul1
 
   ret
@@ -965,13 +963,14 @@
  * to the output buffer. Note, that the content of the input buffer is
  * modified during execution.
  *
- * @param[in]  dmem[2] dptr_rr: pointer to RR in dmem
- * @param[in]  dmem[4] N: Number of limbs per bignum
- * @param[in]  dmem[8] dptr_m0d: pointer to m0' in dmem
- * @param[in]  dmem[12] dptr_rr: pointer to RR in dmem
- * @param[in]  dmem[16] dptr_m: pointer to first limb of modulus in dmem
- * @param[in]  dmem[20] dptr_in: pointer to input/base buffer
- * @param[in]  dmem[28] dptr_out: pointer to output/result buffer
+ * @param[in]   x2: dptr_c, dmem pointer to buffer for output C
+ * @param[in]  x14: dptr_a, dmem pointer to first linb of input A
+ * @param[in]  x16: dptr_M, dmem pointer to first limb of modulus M
+ * @param[in]  x17: dptr_m0d, dmem pointer to Mongtgomery constant m0'
+ * @param[in]  x18: dptr_RR, dmem pointer to Montgmery constant RR
+ * @param[in]  x30: N, number of limbs per bignum
+ * @param[in]  w31: all-zero
+ * @param[out] dmem[dptr_c:dptr_c+N*32] C, A^65537 mod M
  *
  * clobbered registers: x3 to x13, x16 to x31
  *                      w0 to w3, w24 to w30
@@ -985,37 +984,34 @@
   li        x10, 4
   li        x11, 2
 
-  /* load pointer to modulus */
-  lw        x16, 16(x0)
-
-  /* load pointer to m0' */
-  lw        x17, 8(x0)
-
-  /* load number of limbs */
-  lw        x30, 4(x0)
+  /* Compute (N-1).
+       x31 <= x30 - 1 = N - 1 */
   addi      x31, x30, -1
 
   /* convert to montgomery domain montmul(A,RR)
   in = montmul(A,RR) montmul(A,RR) = C*R mod M */
-  lw        x19, 20(x0)
-  lw        x20, 12(x0)
-  lw        x21, 20(x0)
+  addi      x19, x14, 0
+  addi      x20, x18, 0
+  addi      x21, x14, 0
   jal       x1, montmul
-  /* Store result in dmem starting at dmem[dptr_c] */
+  /* Store result in dmem starting at dmem[dptr_a] */
   loop      x30, 2
     bn.sid    x8, 0(x21++)
     addi      x8, x8, 1
 
   /* pointer to out buffer */
-  lw        x21, 28(x0)
+  addi      x21, x2, 0
 
   /* zeroize w2 and reset flags */
   bn.sub    w2, w2, w2
 
+  /* pointer to modulus */
+  addi      x3, x16, 0
+
   /* this loop initializes the output buffer with -M */
   loop      x30, 3
     /* load limb from modulus */
-    bn.lid    x11, 0(x16++)
+    bn.lid    x11, 0(x3++)
 
     /* subtract limb from 0 */
     bn.subb   w2, w31, w2
@@ -1023,32 +1019,28 @@
     /* store limb in dmem */
     bn.sid    x11, 0(x21++)
 
-  /* reload pointer to 1st limb of modulus */
-  lw        x16, 16(x0)
-
+  /* TODO: Is this squaring necessary? */
   /* 65537 = 0b10000000000000001
                ^ sqr + mult
     out = montmul(out,out)       */
-  lw        x19, 28(x0)
-  lw        x20, 28(x0)
-  lw        x21, 28(x0)
+  addi      x19, x2, 0
+  addi      x20, x2, 0
   jal       x1, montmul
   /* Store result in dmem starting at dmem[dptr_c] */
+  addi      x21, x2, 0
   loop      x30, 2
     bn.sid    x8, 0(x21++)
     addi      x8, x8, 1
 
   /* out = montmul(in,out)       */
-  lw        x19, 20(x0)
-  lw        x20, 28(x0)
-  lw        x20, 28(x0)
+  addi      x19, x14, 0
+  addi      x20, x2, 0
   jal       x1, montmul
 
   /* store multiplication result in output buffer */
-  lw        x21, 28(x0)
+  addi      x21, x2, 0
   li        x8, 4
   loop      x30, 2
-    /* store selected limb to dmem */
     bn.sid    x8, 0(x21++)
     addi      x8, x8, 1
 
@@ -1056,11 +1048,11 @@
                 ^<< 16 x sqr >>^   */
   loopi      16, 8
     /* square: out = montmul(out, out) */
-    lw        x19, 28(x0)
-    lw        x20, 28(x0)
-    lw        x21, 28(x0)
+    addi      x19, x2, 0
+    addi      x20, x2, 0
     jal       x1, montmul
     /* Store result in dmem starting at dmem[dptr_c] */
+    addi      x21, x2, 0
     loop      x30, 2
       bn.sid    x8, 0(x21++)
       addi      x8, x8, 1
@@ -1069,13 +1061,12 @@
   /* 65537 = 0b10000000000000001
                           mult ^
      out = montmul(in,out)       */
-  lw        x19, 20(x0)
-  lw        x20, 28(x0)
-  lw        x21, 28(x0)
+  addi      x19, x14, 0
+  addi      x20, x2, 0
   jal       x1, montmul
 
   /* store multiplication result in output buffer */
-  lw        x21, 28(x0)
+  addi      x21, x2, 0
   li        x8, 4
   loop      x30, 2
     bn.sid    x8, 0(x21++)
@@ -1083,8 +1074,8 @@
 
   /* convert back from montgomery domain */
   /* out = montmul(out,1) = out/R mod M  */
-  lw        x19, 28(x0)
-  lw        x21, 28(x0)
+  addi      x19, x2, 0
+  addi      x21, x2, 0
   jal       x1, montmul_mul1
 
   ret
@@ -1099,31 +1090,15 @@
  *
  * Needs to be executed once per constant Modulus.
  *
- * @param[in]  dmem[2] dptr_rr: pointer to RR in dmem
- * @param[in]  dmem[4] N: Number of limbs per bignum
- * @param[in]  dmem[8] dptr_m0d: pointer to m0' in dmem
- * @param[in]  dmem[12] dptr_rr: pointer to RR in dmem
- * @param[in]  dmem[16] dptr_m: pointer to first limb of modulus in dmem
+ * @param[in]  x16: dptr_M, dmem pointer to first limb of modulus M
+ * @param[in]  x17: dptr_m0d, dmem pointer to buffer for m0'
+ * @param[in]  x18: dptr_RR, dmem pointer to buffer for RR
+ * @param[in]  x30: N, number of limbs per bignum
+ * @param[in]  w31: all-zero
  * @param[out] [dmem[dptr_m0d+31]:dmem[dptr_m0d]] computed m0'
- * @parma[out] [dmem[dptr_RR+N*32-1]:dmem[dptr_RR]] computed RR
+ * @param[out] [dmem[dptr_RR+N*32-1]:dmem[dptr_RR]] computed RR
  */
 modload:
-
-  /* prepare all-zero reg */
-  bn.xor   w31, w31, w31
-
-  /* load pointer to modulus (dptr_m) */
-  lw       x16, 16(x0)
-
-  /* load pointer to m0' (dptr_m0d) */
-  lw       x17, 8(x0)
-
-  /* load pointer to RR (dptr_rr) */
-  lw       x18, 12(x0)
-
-  /* load number of limbs (N) */
-  lw       x30, 4(x0)
-
   /* load lowest limb of modulus to w28 */
   li       x8, 28
   bn.lid   x8, 0(x16)
diff --git a/sw/otbn/crypto/rsa.s b/sw/otbn/crypto/rsa.s
index c6b7ade..7f54880 100644
--- a/sw/otbn/crypto/rsa.s
+++ b/sw/otbn/crypto/rsa.s
@@ -5,6 +5,18 @@
 .section .text.start
 .globl start
 start:
+  /* Init all-zero register. */
+  bn.xor  w31, w31, w31
+
+  /* Load number of limbs. */
+  la    x2, n_limbs
+  lw    x30, 0(x2)
+
+  /* Load pointers to modulus and Montgomery constant buffers. */
+  la    x16, modulus
+  la    x17, m0d
+  la    x18, RR
+
   /* Read mode, then tail-call either rsa_encrypt or rsa_decrypt */
   la    x2, mode
   lw    x2, 0(x2)
@@ -19,13 +31,21 @@
   unimp
 
 .text
+
 /**
  * RSA encryption
  */
 rsa_encrypt:
-  jal      x1, zero_work_buf
+  /* Compute Montgomery constants. */
   jal      x1, modload
+
+  /* Run exponentiation.
+       dmem[work_buf] = dmem[inout]^65537 mod dmem[modulus] */
+  la       x14, inout
+  la       x2, work_buf
   jal      x1, modexp_65537
+
+  /* dmem[inout] <= dmem[work_buf] */
   jal      x1, cp_work_buf
   ecall
 
@@ -33,9 +53,17 @@
  * RSA decryption
  */
 rsa_decrypt:
-  jal      x1, zero_work_buf
+  /* Compute Montgomery constants. */
   jal      x1, modload
+
+  /* Run exponentiation.
+       dmem[work_buf] = dmem[inout]^dmem[exp] mod dmem[modulus] */
+  la       x14, inout
+  la       x15, exp
+  la       x2, work_buf
   jal      x1, modexp
+
+  /* dmem[inout] <= dmem[work_buf] */
   jal      x1, cp_work_buf
   ecall
 
@@ -55,23 +83,19 @@
 /**
  * Copy the contents of work_buf onto inout
  *
- * clobbered registers: x3, x4, w0
+ * clobbered registers: x2, x3, x4, w0
  */
 cp_work_buf:
-  la  x3, work_buf
-  la  x4, inout
-  /* The buffers are 512 bytes long, which we can load/store with
-     sixteen 256b words. */
-  loopi 16, 2
+  la    x2, n_limbs
+  lw    x30, 0(x2)
+  la    x3, work_buf
+  la    x4, inout
+  loop  x30, 2
     bn.lid x0, 0(x3++)
     bn.sid x0, 0(x4++)
   ret
 
-.data
-/*
-The structure of the 256b below are mandated by the calling convention of the
-RSA library.
-*/
+.bss
 
 /* Mode (1 = encrypt; 2 = decrypt) */
 .globl mode
@@ -83,32 +107,6 @@
 n_limbs:
   .word 0x00000000
 
-/* pointer to m0' (dptr_m0d) */
-dptr_m0d:
-  .word m0d
-
-/* pointer to RR (dptr_rr) */
-dptr_rr:
-  .word RR
-
-/* load pointer to modulus (dptr_m) */
-dptr_m:
-  .word modulus
-
-/* pointer to base bignum buffer (dptr_in) */
-dptr_in:
-  .word inout
-
-/* pointer to exponent buffer (dptr_exp, unused for encrypt) */
-dptr_exp:
-  .word exp
-
-/* pointer to out buffer (dptr_out) */
-dptr_out:
-  .word work_buf
-
-/* (End of fixed-layout section) */
-
 /* Modulus (n) */
 .balign 32
 .globl modulus
@@ -127,17 +125,17 @@
 inout:
   .zero 512
 
+/* Montgomery constant m0'. Filled by `modload`. */
 .balign 32
 m0d:
-  /* filled by modload */
   /* could go in scratchpad if there was space */
   .zero 32
 
 .section .scratchpad
 
+/* Montgomery constant RR. Filled by `modload`. */
 .balign 32
 RR:
-  /* filled by modload */
   .zero 512
 
 /* working data */
diff --git a/sw/otbn/crypto/tests/BUILD b/sw/otbn/crypto/tests/BUILD
index 7e7545d..ad948bc 100644
--- a/sw/otbn/crypto/tests/BUILD
+++ b/sw/otbn/crypto/tests/BUILD
@@ -412,6 +412,7 @@
 
 otbn_sim_test(
     name = "rsa_1024_dec_test",
+    timeout = "long",
     srcs = [
         "rsa_1024_dec_test.s",
     ],
diff --git a/sw/otbn/crypto/tests/rsa_1024_dec_test.s b/sw/otbn/crypto/tests/rsa_1024_dec_test.s
index 4fe9ea6..fa8a5c5 100644
--- a/sw/otbn/crypto/tests/rsa_1024_dec_test.s
+++ b/sw/otbn/crypto/tests/rsa_1024_dec_test.s
@@ -15,12 +15,29 @@
  * w0). See comment at the end of the file for expected values.
  */
  run_rsa_1024_dec:
+  /* Init all-zero register. */
+  bn.xor  w31, w31, w31
+
+  /* Load number of limbs. */
+  li    x30, 4
+
+  /* Load pointers to modulus and Montgomery constant buffers. */
+  la    x16, modulus
+  la    x17, m0inv
+  la    x18, RR
+
+  /* Compute Montgomery constants. */
   jal      x1, modload
+
+  /* Run exponentiation.
+       dmem[plaintext] = dmem[ciphertext]^dmem[exp] mod dmem[modulus] */
+  la       x14, ciphertext
+  la       x15, exp
+  la       x2, plaintext
   jal      x1, modexp
-  /* pointer to out buffer */
-  lw        x21, 28(x0)
 
   /* copy all limbs of result to wide reg file */
+  la       x21, plaintext
   li       x8, 0
   loop     x30, 2
     bn.lid   x8, 0(x21++)
@@ -31,44 +48,9 @@
 
 .data
 
-/*
- * The words below are used by the code above, but the linker can't tell
- * because we reference them by absolute address. Make a global symbol
- * (cfg_data), which will refer to the whole lot and ensure that gc-sections
- * doesn't discard them.
- */
-.globl cfg_data
-cfg_data:
-
-/* reserved */
-.word 0x00000000
-
-/* number of limbs (N) */
-.word 0x00000004
-
-/* pointer to m0' (dptr_m0d) */
-.word 0x00000280
-
-/* pointer to RR (dptr_rr) */
-.word 0x000002c0
-
-/* load pointer to modulus (dptr_m) */
-.word 0x00000080
-
-/* pointer to base bignum buffer (dptr_in) */
-.word 0x000004c0
-
-/* pointer to exponent buffer (dptr_exp) */
-.word 0x000006c0
-
-/* pointer to out buffer (dptr_out) */
-.word 0x000008c0
-
-
 /* Modulus */
-/* skip to 128 */
-.skip 96
-
+.balign 32
+modulus:
 .word 0xc28cf49f
 .word 0xb6e64c3b
 .word 0xa21417f1
@@ -107,9 +89,8 @@
 
 
 /* encrypted message */
-/* skip to 1216 */
-.skip 960
-
+.balign 32
+ciphertext:
 .word 0xe0e14a9b
 .word 0x7ae96741
 .word 0x4a430036
@@ -148,9 +129,8 @@
 
 
 /* private exponent */
-/* skip to 1728 */
-.skip 384
-
+.balign 32
+exp:
 .word 0x93a8cd95
 .word 0x24a2614b
 .word 0xeeb788b3
@@ -186,3 +166,18 @@
 .word 0x511778ce
 .word 0x42209f7b
 .word 0x41b468dc
+
+/* output buffer */
+.balign 32
+plaintext:
+.zero 128
+
+/* buffer for Montgomery constant RR */
+.balign 32
+RR:
+.zero 128
+
+/* buffer for Montgomery constant m0inv */
+.balign 32
+m0inv:
+.zero 32
diff --git a/sw/otbn/crypto/tests/rsa_1024_enc_test.s b/sw/otbn/crypto/tests/rsa_1024_enc_test.s
index db28110..c9419ec 100644
--- a/sw/otbn/crypto/tests/rsa_1024_enc_test.s
+++ b/sw/otbn/crypto/tests/rsa_1024_enc_test.s
@@ -16,12 +16,28 @@
  * w0). See comment at the end of the file for expected values.
  */
 run_rsa_1024_enc:
+  /* Init all-zero register. */
+  bn.xor  w31, w31, w31
+
+  /* Load number of limbs. */
+  li    x30, 4
+
+  /* Load pointers to modulus and Montgomery constant buffers. */
+  la    x16, modulus
+  la    x17, m0inv
+  la    x18, RR
+
+  /* Compute Montgomery constants. */
   jal      x1, modload
+
+  /* Run exponentiation.
+       dmem[plaintext] = dmem[ciphertext]^dmem[exp] mod dmem[modulus] */
+  la       x14, plaintext
+  la       x2, ciphertext
   jal      x1, modexp_65537
-  /* pointer to out buffer */
-  lw        x21, 28(x0)
 
   /* copy all limbs of result to wide reg file */
+  la       x21, ciphertext
   li       x8, 0
   loop     x30, 2
     bn.lid   x8, 0(x21++)
@@ -29,47 +45,11 @@
 
   ecall
 
-
 .data
 
-/*
- * The words below are used by the code above, but the linker can't tell
- * because we reference them by absolute address. Make a global symbol
- * (cfg_data), which will refer to the whole lot and ensure that gc-sections
- * doesn't discard them.
- */
-.globl cfg_data
-cfg_data:
-
-/* reserved */
-.word 0x00000000
-
-/* number of limbs (N) */
-.word 0x00000004
-
-/* pointer to m0' (dptr_m0d) */
-.word 0x00000280
-
-/* pointer to RR (dptr_rr) */
-.word 0x000002c0
-
-/* load pointer to modulus (dptr_m) */
-.word 0x00000080
-
-/* pointer to base bignum buffer (dptr_in) */
-.word 0x000004c0
-
-/* pointer to exponent buffer (dptr_exp, unused for encrypt) */
-.word 0x000006c0
-
-/* pointer to out buffer (dptr_out) */
-.word 0x000008c0
-
-
 /* Modulus */
-/* skip to 128 */
-.skip 96
-
+.balign 32
+modulus:
 .word 0xc28cf49f
 .word 0xb6e64c3b
 .word 0xa21417f1
@@ -108,9 +88,8 @@
 
 
 /* Message */
-/* skip to 1216 */
-.skip 960
-
+.balign 32
+plaintext:
 .word 0x206d653f
 .word 0x20666f72
 .word 0x74686973
@@ -146,3 +125,18 @@
 .word 0x00000000
 .word 0x00000000
 .word 0x00000000
+
+/* output buffer */
+.balign 32
+ciphertext:
+.zero 128
+
+/* buffer for Montgomery constant RR */
+.balign 32
+RR:
+.zero 128
+
+/* buffer for Montgomery constant m0inv */
+.balign 32
+m0inv:
+.zero 32