[otbn] Rewrite CSR reads and writes in operation docs
Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/util/docs/get_impl.py b/hw/ip/otbn/util/docs/get_impl.py
index 7a36bc4..42407b8 100644
--- a/hw/ip/otbn/util/docs/get_impl.py
+++ b/hw/ip/otbn/util/docs/get_impl.py
@@ -148,42 +148,41 @@
return make_aref(regfile_uname, getreg_idx)
- def leave_Call(self,
- orig: cst.Call,
- updated: cst.Call) -> cst.BaseExpression:
+ @staticmethod
+ def _spot_reg_read(node: cst.Call) -> Optional[cst.BaseExpression]:
# Detect
#
# state.gprs.get_reg(FOO).read_unsigned()
# state.gprs.get_reg(FOO).read_signed()
#
- # and replace it with the expressions
+ # and replace with the expressions
#
# GPRs[FOO]
# from_2s_complement(GPRs[FOO])
#
# respectively.
- # In either case, we expect updated.func to be some long attribute
+ # In either case, we expect node.func to be some long attribute
# (representing state.gprs.get_reg(FOO).read_X). For unsigned or
# signed, we can check that it is indeed an Attribute and that
- # updated.args is empty (neither function takes arguments).
- if updated.args or not isinstance(updated.func, cst.Attribute):
- return updated
+ # node.args is empty (neither function takes arguments).
+ if node.args or not isinstance(node.func, cst.Attribute):
+ return None
# Now, check whether we're calling one of the functions we're
# interested in.
- if updated.func.attr.value == 'read_signed':
+ if node.func.attr.value == 'read_signed':
signed = True
- elif updated.func.attr.value == 'read_unsigned':
+ elif node.func.attr.value == 'read_unsigned':
signed = False
else:
- return updated
+ return None
- # Check that updated.func.value really does represent something of the
+ # Check that node.func.value really does represent something of the
# form "state.gprs.get_reg(FOO)".
- ret = ImplTransformer.match_get_reg(updated.func.value)
+ ret = ImplTransformer.match_get_reg(node.func.value)
if ret is None:
- return updated
+ return None
if signed:
# If this is a call to read_signed, we want to wrap the returned
@@ -193,26 +192,65 @@
else:
return ret
- def leave_Expr(self,
- orig: cst.Expr,
- updated: cst.Expr) -> cst.BaseSmallStatement:
- # This is called when leaving statements that are just expressions. We
- # use it to spot
+ @staticmethod
+ def _spot_csr_read(node: cst.Call) -> Optional[cst.BaseExpression]:
+ # Detect
+ #
+ # state.read_csr(FOO)
+ #
+ # and replace it with the expression
+ #
+ # CSRs[FOO]
+
+ # Check we have exactly one argument
+ if len(node.args) != 1:
+ return None
+
+ # Check this is state.read_csr
+ if not (isinstance(node.func, cst.Attribute) and
+ isinstance(node.func.value, cst.Name) and
+ node.func.value.value == 'state' and
+ node.func.attr.value == 'read_csr'):
+ return None
+
+ return make_aref('CSRs', node.args[0].value)
+
+ def leave_Call(self,
+ orig: cst.Call,
+ updated: cst.Call) -> cst.BaseExpression:
+ # Handle:
+ #
+ # state.gprs.get_reg(FOO).read_unsigned()
+ # state.gprs.get_reg(FOO).read_signed()
+ #
+ reg_read = ImplTransformer._spot_reg_read(updated)
+ if reg_read is not None:
+ return reg_read
+
+ csr_read = ImplTransformer._spot_csr_read(updated)
+ if csr_read is not None:
+ return csr_read
+
+ return updated
+
+ @staticmethod
+ def _spot_reg_write(node: cst.Expr) -> Optional[NBAssign]:
+ # Spot
#
# state.gprs.get_reg(foo).write_unsigned(bar)
# state.gprs.get_reg(foo).write_signed(bar)
#
- # and turn it into
+ # and turn them into
#
# GPRs[FOO] = bar
# GPRs[FOO] = to_2s_complement(bar)
- if not isinstance(updated.value, cst.Call):
- return updated
+ if not isinstance(node.value, cst.Call):
+ return None
- call = updated.value
+ call = node.value
if len(call.args) != 1 or not isinstance(call.func, cst.Attribute):
- return updated
+ return None
value = call.args[0].value
@@ -222,16 +260,56 @@
rhs = cst.Call(func=cst.Name('to_2s_complement'),
args=[cst.Arg(value=value)])
else:
- return updated
+ return None
# We expect call.func.value to be match state.gprs.get_reg(foo).
# Extract the array reference if we can.
reg_ref = ImplTransformer.match_get_reg(call.func.value)
if reg_ref is None:
- return updated
+ return None
return NBAssign.make(reg_ref, rhs)
+ @staticmethod
+ def _spot_csr_write(node: cst.Expr) -> Optional[NBAssign]:
+ # Spot
+ #
+ # state.write_csr(csr, new_val)
+ #
+ # and turn it into
+ #
+ # CSRs[csr] = new_val
+
+ if not isinstance(node.value, cst.Call):
+ return None
+
+ call = node.value
+ if len(call.args) != 2 or not isinstance(call.func, cst.Attribute):
+ return None
+
+ if not (isinstance(call.func.value, cst.Name) and
+ call.func.value.value == 'state' and
+ call.func.attr.value == 'write_csr'):
+ return None
+
+ idx = call.args[0].value
+ rhs = call.args[1].value
+
+ return NBAssign.make(make_aref('CSRs', idx), rhs)
+
+ def leave_Expr(self,
+ orig: cst.Expr,
+ updated: cst.Expr) -> cst.BaseSmallStatement:
+ reg_write = ImplTransformer._spot_reg_write(updated)
+ if reg_write is not None:
+ return reg_write
+
+ csr_write = ImplTransformer._spot_csr_write(updated)
+ if csr_write is not None:
+ return csr_write
+
+ return updated
+
def leave_Assign(self,
orig: cst.Assign,
updated: cst.Assign) -> cst.BaseSmallStatement: