[sw/crypto] Clarify bounds assumptions for Montgomery R^2.

The Ibex implementation, on closer inspection, doesn't actually require
R / 2 < M, so the comment is rephrased to not state that it does. The
OTBN implementation is greatly simplified by the assumption that R / 2
< M, so it is modified to assume so and justify the assumption by citing
FIPS.

This commit also includes a minor adjustment to the R^2 loop, making it
a loop instead of a loopi so as not to consume the loop stack
unnecessarily.

Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/sw/device/silicon_creator/lib/crypto/rsa_3072/rsa_3072.s b/sw/device/silicon_creator/lib/crypto/rsa_3072/rsa_3072.s
index 71dea5f..abce552 100644
--- a/sw/device/silicon_creator/lib/crypto/rsa_3072/rsa_3072.s
+++ b/sw/device/silicon_creator/lib/crypto/rsa_3072/rsa_3072.s
@@ -563,9 +563,11 @@
  * Software Perspective" (https://eprint.iacr.org/2017/1057). For the purposes
  * of RSA 3072, the parameters w and n from the paper are fixed (w=256, n=12).
  *
- * The algorithm from the paper assumes that 2^wn-1 <= M < 2^wn; we lightly
- * adapt the implementation to accept any positive M by subtracting M from c0
- * until c0 < M.
+ * A note on bounds: the algorithm from the paper states an assumption that
+ * 2^wn-1 <= M < 2^wn (in our case, 2^3071 <= M < 2^3072). We make that
+ * assumption here too, because it agrees with FIPS 186-4 section B.3.1 (page
+ * 53), which states that the prime factors of the RSA modulus must be at least
+ * sqrt(2)*2^(nlen/2-1) (where nlen is the key length, 3072 in this case).
  *
  * The result is stored in dmem[in_rr]. This routine runs in variable time.
  *
@@ -573,7 +575,7 @@
  *
  * @param[in]  dmem[in_mod] pointer to first limb of modulus M in dmem
  *
- * clobbered registers: x2, x3, x7 to x11, x16, w2 to w27, w31
+ * clobbered registers: x9, x10, x16, w4 to w16, w31
  * clobbered Flag Groups: FG0
  */
  .globl precomp_rr
@@ -593,58 +595,26 @@
   /* w16 <= 1 */
   bn.addi   w16, w31, 1
 
-  /* Initialize c0
+  /* Initialize c0.
      c0 = [w4:w15] <= [w4:w16] >> 1 = 2^3701 */
   bn.rshi   w15, w16, w15 >> 1
 
-
-precomp_rr_sub_start:
-
-  /* Repeatedly subtract M until c0 < M.
-       [w16:w27] <= (c0 - M) */
-  jal       x1, subtract_modulus_var
-
-  /* Extract borrow bit from flags register. */
-  csrrs     x2, 0x7c0, x0
-  andi      x2, x2, 1
-
-  /* If borrow is set, then subtraction underflowed, meaning c0 < M; done. */
-  bne       x2, x0, precomp_rr_sub_done
-
-  /* If we got here, then c0 > M; set c0 = c0 - M and repeat.
-       c0 = [w4:w15] <= [w16:w27] = c0 - M */
-  li        x8, 4
-  li        x11, 16
-  loopi     12, 2
-    bn.movr   x8, x11++
-    addi      x8, x8, 1
-
-  /* Jump back to start of subtractions. */
-  beq       x0, x0, precomp_rr_sub_start
-
-precomp_rr_sub_done:
-
-  /* Now, we know that c0 = [w4:w15] = (2^3071) mod M. */
-
   /* One modular doubling to get c1 \equiv 2^3072 mod M.
      c1 = [w4:w15] <= ([w4:w15] * 2) mod M = (2^3072) mod M */
   jal     x1, double_mod_var
 
-  /* Compute (2^3072)^2 mod M by performing 3072 modular doublings.
-     Loop is nested only because #iterations must be < 1024 */
-  loopi     12, 4
-    loopi     256, 2
-      jal     x1, double_mod_var
-      /* Nop because inner loopi can't end on a jump instruction. */
-      nop
-    /* Nop because outer loopi can't end on a loop instruction. */
+  /* Compute (2^3072)^2 mod M by performing 3072 modular doublings. */
+  li      x9, 3072
+  loop     x9, 4
+    jal     x1, double_mod_var
+    /* Nop because loop can't end on a jump instruction. */
     nop
 
   /* Store result in dmem[in_rr]. */
   li        x9, 4
-  la        x26, in_rr
+  la        x10, in_rr
   loopi     12, 2
-    bn.sid    x9, 0(x26++)
+    bn.sid    x9, 0(x10++)
     addi      x9, x9, 1
 
   ret
diff --git a/sw/device/silicon_creator/lib/sigverify_mod_exp_ibex.c b/sw/device/silicon_creator/lib/sigverify_mod_exp_ibex.c
index e7a071e..8a15016 100644
--- a/sw/device/silicon_creator/lib/sigverify_mod_exp_ibex.c
+++ b/sw/device/silicon_creator/lib/sigverify_mod_exp_ibex.c
@@ -76,8 +76,9 @@
 static void calc_r_square(const sigverify_rsa_key_t *key,
                           sigverify_rsa_buffer_t *result) {
   memset(result->data, 0, sizeof(result->data));
-  // Since R/2 < n < R, this subtraction ensures that result = R mod n and
-  // fits in `kSigVerifyRsaNumWords` going into the loop.
+  // This subtraction sets result = -n mod R = R - n, which is equivalent to R
+  // modulo n and ensures that `result` fits in `kSigVerifyRsaNumWords` going
+  // into the loop.
   subtract_modulus(key, result);
 
   // Iteratively shift and reduce `result`.