[otbn] Rewrite CSR reads and writes in operation docs

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/util/docs/get_impl.py b/hw/ip/otbn/util/docs/get_impl.py
index 7a36bc4..42407b8 100644
--- a/hw/ip/otbn/util/docs/get_impl.py
+++ b/hw/ip/otbn/util/docs/get_impl.py
@@ -148,42 +148,41 @@
 
         return make_aref(regfile_uname, getreg_idx)
 
-    def leave_Call(self,
-                   orig: cst.Call,
-                   updated: cst.Call) -> cst.BaseExpression:
+    @staticmethod
+    def _spot_reg_read(node: cst.Call) -> Optional[cst.BaseExpression]:
         # Detect
         #
         #    state.gprs.get_reg(FOO).read_unsigned()
         #    state.gprs.get_reg(FOO).read_signed()
         #
-        # and replace it with the expressions
+        # and replace with the expressions
         #
         #    GPRs[FOO]
         #    from_2s_complement(GPRs[FOO])
         #
         # respectively.
 
-        # In either case, we expect updated.func to be some long attribute
+        # In either case, we expect node.func to be some long attribute
         # (representing state.gprs.get_reg(FOO).read_X). For unsigned or
         # signed, we can check that it is indeed an Attribute and that
-        # updated.args is empty (neither function takes arguments).
-        if updated.args or not isinstance(updated.func, cst.Attribute):
-            return updated
+        # node.args is empty (neither function takes arguments).
+        if node.args or not isinstance(node.func, cst.Attribute):
+            return None
 
         # Now, check whether we're calling one of the functions we're
         # interested in.
-        if updated.func.attr.value == 'read_signed':
+        if node.func.attr.value == 'read_signed':
             signed = True
-        elif updated.func.attr.value == 'read_unsigned':
+        elif node.func.attr.value == 'read_unsigned':
             signed = False
         else:
-            return updated
+            return None
 
-        # Check that updated.func.value really does represent something of the
+        # Check that node.func.value really does represent something of the
         # form "state.gprs.get_reg(FOO)".
-        ret = ImplTransformer.match_get_reg(updated.func.value)
+        ret = ImplTransformer.match_get_reg(node.func.value)
         if ret is None:
-            return updated
+            return None
 
         if signed:
             # If this is a call to read_signed, we want to wrap the returned
@@ -193,26 +192,65 @@
         else:
             return ret
 
-    def leave_Expr(self,
-                   orig: cst.Expr,
-                   updated: cst.Expr) -> cst.BaseSmallStatement:
-        # This is called when leaving statements that are just expressions. We
-        # use it to spot
+    @staticmethod
+    def _spot_csr_read(node: cst.Call) -> Optional[cst.BaseExpression]:
+        # Detect
+        #
+        #    state.read_csr(FOO)
+        #
+        # and replace it with the expression
+        #
+        #    CSRs[FOO]
+
+        # Check we have exactly one argument
+        if len(node.args) != 1:
+            return None
+
+        # Check this is state.read_csr
+        if not (isinstance(node.func, cst.Attribute) and
+                isinstance(node.func.value, cst.Name) and
+                node.func.value.value == 'state' and
+                node.func.attr.value == 'read_csr'):
+            return None
+
+        return make_aref('CSRs', node.args[0].value)
+
+    def leave_Call(self,
+                   orig: cst.Call,
+                   updated: cst.Call) -> cst.BaseExpression:
+        # Handle:
+        #
+        #    state.gprs.get_reg(FOO).read_unsigned()
+        #    state.gprs.get_reg(FOO).read_signed()
+        #
+        reg_read = ImplTransformer._spot_reg_read(updated)
+        if reg_read is not None:
+            return reg_read
+
+        csr_read = ImplTransformer._spot_csr_read(updated)
+        if csr_read is not None:
+            return csr_read
+
+        return updated
+
+    @staticmethod
+    def _spot_reg_write(node: cst.Expr) -> Optional[NBAssign]:
+        # Spot
         #
         #   state.gprs.get_reg(foo).write_unsigned(bar)
         #   state.gprs.get_reg(foo).write_signed(bar)
         #
-        # and turn it into
+        # and turn them into
         #
         #   GPRs[FOO] = bar
         #   GPRs[FOO] = to_2s_complement(bar)
 
-        if not isinstance(updated.value, cst.Call):
-            return updated
+        if not isinstance(node.value, cst.Call):
+            return None
 
-        call = updated.value
+        call = node.value
         if len(call.args) != 1 or not isinstance(call.func, cst.Attribute):
-            return updated
+            return None
 
         value = call.args[0].value
 
@@ -222,16 +260,56 @@
             rhs = cst.Call(func=cst.Name('to_2s_complement'),
                            args=[cst.Arg(value=value)])
         else:
-            return updated
+            return None
 
         # We expect call.func.value to be match state.gprs.get_reg(foo).
         # Extract the array reference if we can.
         reg_ref = ImplTransformer.match_get_reg(call.func.value)
         if reg_ref is None:
-            return updated
+            return None
 
         return NBAssign.make(reg_ref, rhs)
 
+    @staticmethod
+    def _spot_csr_write(node: cst.Expr) -> Optional[NBAssign]:
+        # Spot
+        #
+        #   state.write_csr(csr, new_val)
+        #
+        # and turn it into
+        #
+        #   CSRs[csr] = new_val
+
+        if not isinstance(node.value, cst.Call):
+            return None
+
+        call = node.value
+        if len(call.args) != 2 or not isinstance(call.func, cst.Attribute):
+            return None
+
+        if not (isinstance(call.func.value, cst.Name) and
+                call.func.value.value == 'state' and
+                call.func.attr.value == 'write_csr'):
+            return None
+
+        idx = call.args[0].value
+        rhs = call.args[1].value
+
+        return NBAssign.make(make_aref('CSRs', idx), rhs)
+
+    def leave_Expr(self,
+                   orig: cst.Expr,
+                   updated: cst.Expr) -> cst.BaseSmallStatement:
+        reg_write = ImplTransformer._spot_reg_write(updated)
+        if reg_write is not None:
+            return reg_write
+
+        csr_write = ImplTransformer._spot_csr_write(updated)
+        if csr_write is not None:
+            return csr_write
+
+        return updated
+
     def leave_Assign(self,
                      orig: cst.Assign,
                      updated: cst.Assign) -> cst.BaseSmallStatement: