[sw,otbn] Implement e=3 signature verification for RSA-3072.

Support RSA-3072 public keys with exponent 3.

Signed-off-by: Jade Philipoom <jadep@google.com>
diff --git a/sw/otbn/crypto/rsa_verify_3072.s b/sw/otbn/crypto/rsa_verify_3072.s
index a2d669d..36214ff 100644
--- a/sw/otbn/crypto/rsa_verify_3072.s
+++ b/sw/otbn/crypto/rsa_verify_3072.s
@@ -335,6 +335,100 @@
 
   ret
 
+/**
+ * Variable time 3072-bit modular exponentiation with exponent 3
+ *
+ * Returns: C = modexp(A,3) = mod M
+ *
+ * The squared Montgomery modulus RR and the Montgomery constant m0' have to
+ * be provided at the appropriate locations in dmem.
+ *
+ * Flags: Flags have no meaning beyond the scope of this subroutine.
+ *
+ * The base bignum A is expected in the input buffer, the result C is written
+ * to the output buffer.
+ *
+ * @param[in]  dmem[m0inv] pointer to m0' in dmem
+ * @param[in]  dmem[rr] pointer to RR in dmem
+ * @param[in]  dmem[in_mod] pointer to first limb of modulus M in dmem
+ * @param[in]  dmem[in_buf] pointer to buffer with base bignum
+ * @param[in]  dmem[out_buf] pointer to output buffer
+ *
+ * clobbered registers: x2, x5 to x13, x16 to x21, x29
+                        w2, to w15, w24 to w31
+ * clobbered Flag Groups: FG0, FG1
+ */
+ .globl modexp_var_3072_3
+modexp_var_3072_3:
+  /* Prepare all-zero reg. */
+  bn.xor    w31, w31, w31
+
+  /* Prepare pointers to temp regs. */
+  li         x8, 4
+  li         x9, 3
+  li        x10, 4
+  li        x11, 2
+
+  /* Set pointers to buffers. */
+  la        x24, out_buf
+  la        x16, in_mod
+  la        x23, in_buf
+  la        x26, rr
+  la        x17, m0inv
+
+  /* Convert input to Montgomery domain and store in dmem.
+     dmem[out_buf] <= montmul(dmem[in_buf], dmem[in_RR]) = A*R mod M */
+  addi      x19, x23, 0
+  addi      x20, x26, 0
+  addi      x21, x24, 0
+  jal       x1, montmul
+  loopi     12, 2
+    bn.sid    x8, 0(x21++)
+    addi      x8, x8, 1
+
+  /* Square the outbut buffer.
+     dmem[out_buf]  <= montmul(dmem[out_buf], dmem[out_buf]) = (A^2)*R mod M */
+  addi      x19, x24, 0
+  addi      x20, x24, 0
+  addi      x21, x24, 0
+  jal       x1, montmul
+  loopi     12, 2
+    bn.sid    x8, 0(x21++)
+    addi      x8, x8, 1
+
+  /* Final multiplication and conversion of result from Montgomery domain.
+     dmem[out_buf]  <= montmul(dmem[in_buf], dmem[out_buf]) = (A^3) mod M */
+  addi      x19, x23, 0
+  addi      x20, x24, 0
+  addi      x21, x24, 0
+  jal       x1, montmul
+
+  /* Final conditional subtraction of modulus if mod >= dmem[out_buf]. */
+  bn.add    w31, w31, w31
+  li        x17, 16
+  loopi     12, 4
+    bn.movr   x11, x8++
+    bn.lid    x9, 0(x16++)
+    bn.subb   w2, w2, w3
+    bn.movr   x17++, x11
+  csrrs     x2, 0x7c0, x0
+  /* TODO: currently we subtract the modulus if out_buf == M. This should
+            never happen in an RSA context. We could catch this and raise an
+            alert. */
+  andi      x2, x2, 1
+  li        x8, 4
+  bne       x2, x0, e3_no_sub
+  li        x8, 16
+
+  e3_no_sub:
+
+  /* Store result in output buffer. */
+  addi      x21, x24, 0
+  loopi     12, 2
+    bn.sid    x8, 0(x21++)
+    addi      x8, x8, 1
+
+  ret
 
 /**
  * Variable time 3072-bit modular exponentiation with exponent 65537
@@ -350,8 +444,7 @@
  * Flags: Flags have no meaning beyond the scope of this subroutine.
  *
  * The base bignum A is expected in the input buffer, the result C is written
- * to the output buffer. Note, that the content of the input buffer is
- * modified during execution.
+ * to the output buffer.
  *
  * @param[in]  dmem[m0inv] pointer to m0' in dmem
  * @param[in]  dmem[rr] pointer to RR in dmem
@@ -365,24 +458,24 @@
  */
  .globl modexp_var_3072_f4
 modexp_var_3072_f4:
-  /* prepare all-zero reg */
+  /* Prepare all-zero reg. */
   bn.xor    w31, w31, w31
 
-  /* prepare pointers to temp regs */
+  /* Prepare pointers to temp regs. */
   li         x8, 4
   li         x9, 3
   li        x10, 4
   li        x11, 2
 
-  /* set pointers to buffers */
+  /* Set pointers to buffers. */
   la        x24, out_buf
   la        x16, in_mod
   la        x23, in_buf
   la        x26, rr
   la        x17, m0inv
 
-  /* convert input to Montgomery domain and store in dmem
-     dmem[out_buf] = montmul(dmem[in_buf], dmem[in_RR]) = A*R mod M */
+  /* Convert input to Montgomery domain and store in dmem.
+     dmem[out_buf] <= montmul(dmem[in_buf], dmem[in_RR]) = A*R mod M */
   addi      x19, x23, 0
   addi      x20, x26, 0
   addi      x21, x24, 0
@@ -394,7 +487,7 @@
   /* 16 consecutive Montgomery squares on the outbut buffer, i.e. after loop:
      dmem[out_buf] <= dmem[out_buf]^65536*R mod M */
   loopi     16, 8
-    /* dmem[out_buf]  = montmul(dmem[out_buf], dmem[out_buf]) */
+    /* dmem[out_buf]  <= montmul(dmem[out_buf], dmem[out_buf]) */
     addi      x19, x24, 0
     addi      x20, x24, 0
     addi      x21, x24, 0
@@ -404,7 +497,7 @@
       addi      x8, x8, 1
     nop
 
-  /* final multiplication and conversion of result from Montgomery domain
+  /* Final multiplication and conversion of result from Montgomery domain.
      out_buf  <= montmul(*x28, *x20) = montmul(dmem[in_buf], dmem[out_buf]) */
   addi      x19, x23, 0
   addi      x20, x24, 0
@@ -425,12 +518,12 @@
             alert. */
   andi      x2, x2, 1
   li        x8, 4
-  bne       x2, x0, no_sub
+  bne       x2, x0, f4_no_sub
   li        x8, 16
 
-  no_sub:
+  f4_no_sub:
 
-  /* store result in output buffer */
+  /* Store result in output buffer. */
   addi      x21, x24, 0
   loopi     12, 2
     bn.sid    x8, 0(x21++)
diff --git a/sw/otbn/crypto/run_rsa_verify_3072.s b/sw/otbn/crypto/run_rsa_verify_3072.s
index f172e73..d112520 100644
--- a/sw/otbn/crypto/run_rsa_verify_3072.s
+++ b/sw/otbn/crypto/run_rsa_verify_3072.s
@@ -75,8 +75,9 @@
   ecall
 
 modexp_3:
-  /* e=3 exponentiation is unimplemented */
-  unimp
+  /* run modular exponentiation */
+  jal      x1, modexp_var_3072_3
+  ecall
 
 .data
 
diff --git a/sw/otbn/crypto/run_rsa_verify_3072_rr_modexp.s b/sw/otbn/crypto/run_rsa_verify_3072_rr_modexp.s
index 4aa1bca..7fb3b3e 100644
--- a/sw/otbn/crypto/run_rsa_verify_3072_rr_modexp.s
+++ b/sw/otbn/crypto/run_rsa_verify_3072_rr_modexp.s
@@ -44,8 +44,10 @@
   ecall
 
 modexp_3:
-  /* e=3 exponentiation is unimplemented */
-  unimp
+  /* run modular exponentiation */
+  jal      x1, modexp_var_3072_3
+
+  ecall
 
 .data