blob: ddef158c6f8172b2877246ea211b85f4687f4514 [file] [log] [blame]
# Copyright lowRISC contributors.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
from random import getrandbits
from typing import List, Optional, Tuple, cast
from attrdict import AttrDict # type: ignore
from riscvmodel.model import (Model, State, # type: ignore
Environment, TerminateException)
from riscvmodel.isa import Instruction # type: ignore
from riscvmodel.types import (RegisterFile, Register, # type: ignore
SingleRegister, Trace, BitflagRegister)
from .variant import RV32IXotbn
class TraceCallStackPush(Trace): # type: ignore
def __init__(self, value: int):
self.value = value
def __str__(self) -> str:
return "RAS push {:08x}".format(self.value)
class TraceCallStackPop(Trace): # type: ignore
def __init__(self, value: int):
self.value = value
def __str__(self) -> str:
return "RAS pop {:08x}".format(self.value)
class TraceLoopStart(Trace): # type: ignore
def __init__(self, iterations: int, bodysize: int):
self.iterations = iterations
self.bodysize = bodysize
def __str__(self) -> str:
return "Start LOOP, {} iterations, bodysize: {}".format(
self.iterations, self.bodysize)
class TraceLoopIteration(Trace): # type: ignore
def __init__(self, iteration: int, total: int):
self.iteration = iteration = total
def __str__(self) -> str:
return "LOOP iteration {}/{}".format(self.iteration,
class OTBNIntRegisterFile(RegisterFile): # type: ignore
def __init__(self) -> None:
super().__init__(num=32, bits=32, immutable={0: 0})
# The call stack for x1 and its pending updates
self.callstack = [] # type: List[int]
self.have_read_callstack = False
self.callstack_push_val = None # type: Optional[int]
def __setitem__(self, key: int, value: int) -> None:
# Special handling for the callstack in x1
if key == 1:
self.callstack_push_val = value
# Otherwise, use the base class implementation
super().__setitem__(key, value)
def __getitem__(self, key: int) -> int:
# Special handling for the callstack in x1
if key == 1:
self.have_read_callstack = True
return cast(int, super().__getitem__(key))
def post_insn(self) -> None:
'''Update the x1 call stack after an instruction executes
This needs to run after execution (which sets up callstack_push_val and
have_read_callstack) but before we print the instruction in
State.issue, because any changes to x1 need to be reflected there.
cs_changed = False
if self.have_read_callstack:
if self.callstack:
cs_changed = True
if self.callstack_push_val is not None:
cs_changed = True
# Update self.regs[1] so that it always points at the top of the stack.
# If the stack is empty, set it to zero (we need to decide what happens
# in this case: see issue #3239)
if cs_changed:
cs_val = 0
if self.callstack:
cs_val = self.callstack[0]
super().__setitem__(1, cs_val)
self.have_read_callstack = False
self.callstack_push_val = None
class LoopLevel:
'''An object representing a level in the current loop stack
start_addr is the first instruction inside the loop (the instruction
following the loop instruction). insn_count is the number of instructions
in the loop (and must be positive). restarts is one less than the number of
iterations, and must be positive.
def __init__(self, start_addr: int, insn_count: int, restarts: int):
assert 0 <= start_addr
assert 0 < insn_count
assert 0 < restarts
self.loop_count = 1 + restarts
self.restarts_left = restarts
self.start_addr = start_addr
self.match_addr = start_addr + 4 * insn_count
class LoopStack:
'''An object representing the loop stack
An entry on the loop stack represents a possible back edge: the
restarts_left counter tracks the number of these back edges. The entry is
removed when the counter gets to zero.
def __init__(self) -> None:
self.stack = [] # type: List[LoopLevel]
self.trace = [] # type: List[Trace]
def start_loop(self,
next_addr: int,
insn_count: int,
loop_count: int) -> Optional[int]:
'''Start a loop.
Adds the loop to the stack and returns the next PC if it's not
straight-line. If the loop count is one, this acts as a NOP (and
doesn't change the stack). If the loop count is zero, this doesn't
change the stack but the next PC will be the match address.
assert 0 <= next_addr
assert 0 < insn_count
assert 0 <= loop_count
self.trace.append(TraceLoopStart(loop_count, insn_count))
if loop_count == 0:
return next_addr + 4 * insn_count
if loop_count > 1:
self.stack.append(LoopLevel(next_addr, insn_count, loop_count - 1))
return None
def step(self, cur_pc: int) -> int:
'''Calculate the next PC and update loop stack'''
next_pc = cur_pc + 4
if self.stack:
top = self.stack[-1]
if next_pc == top.match_addr:
assert top.restarts_left > 0
top.restarts_left -= 1
if not top.restarts_left:
# 1-based iteration number
idx = top.loop_count - top.restarts_left
self.trace.append(TraceLoopIteration(idx, top.loop_count))
return top.start_addr
return next_pc
def changes(self) -> List[Trace]:
return self.trace
def commit(self) -> None:
self.trace = []
class FlagGroups:
def __init__(self) -> None:
self.groups = {
0: BitflagRegister(["C", "L", "M", "Z"], prefix = "FG0."),
1: BitflagRegister(["C", "L", "M", "Z"], prefix = "FG1.")
def __getitem__(self, key: int) -> BitflagRegister:
assert 0 <= key <= 1
return self.groups[key]
def __setitem__(self, key: int, value: int) -> None:
assert 0 <= key <= 1
def changes(self) -> List[Trace]:
return cast(List[Trace],
self.groups[0].changes() + self.groups[1].changes())
def commit(self) -> None:
class OTBNState(State): # type: ignore
def __init__(self) -> None:
# Hack: this matches the superclass constructor, but you need it to
# explain to mypy what self.pc is (because mypy can't peek into
# riscvmodel without throwing up lots of errors)
self.pc = Register(32)
self.intreg = OTBNIntRegisterFile()
self.wreg = RegisterFile(num=32, bits=256, immutable={}, prefix="w")
self.single_regs = {
'acc': SingleRegister(256, "ACC"),
'mod': SingleRegister(256, "MOD")
self.flags = FlagGroups()
self.loop_stack = LoopStack()
def csr_read(self, index: int) -> int:
if index == 0x7C0:
return int(self.wreg)
elif 0x7D0 <= index <= 0x7D7:
bit_shift = 32 * (index - 0x7D0)
mask32 = (1 << 32) - 1
return (int(self.mod) >> bit_shift) & mask32
elif index == 0xFC0:
return getrandbits(32)
return cast(int, super().csr_read(self, index))
def wcsr_read(self, index: int) -> int:
assert 0 <= index <= 2
if index == 0:
return int(self.mod)
elif index == 1:
return getrandbits(256)
assert index == 2
return int(self.single_regs['acc'])
def wcsr_write(self, index: int, value: int) -> None:
if index == 0:
self.mod = value
def loop_start(self, iterations: int, bodysize: int) -> None:
next_pc = int(self.pc) + 4
skip_pc = self.loop_stack.start_loop(next_pc, bodysize, iterations)
if skip_pc is not None:
def loop_step(self) -> None:
def changes(self) -> List[Trace]:
c = cast(List[Trace], super().changes())
c += self.loop_stack.changes()
c += self.wreg.changes()
c += self.flags.changes()
for name, reg in sorted(self.single_regs.items()):
c += reg.changes()
return c
def commit(self) -> None:
for reg in self.single_regs.values():
class OTBNEnvironment(Environment): # type: ignore
def call(self, state: OTBNState) -> None:
raise TerminateException(0)
class OTBNModel(Model): # type: ignore
def __init__(self, verbose: bool):
self.state = OTBNState()
def get_wr_quarterword(self, wridx: int, qwsel: int) -> int:
assert 0 <= wridx <= 31
assert 0 <= qwsel <= 3
mask = (1 << 64) - 1
return (int(self.state.wreg[wridx]) >> (qwsel * 64)) & mask
def set_wr_halfword(self, wridx: int, value: int, hwsel: int) -> None:
assert 0 <= wridx <= 31
assert (value >> 128) == 0
assert 0 <= hwsel <= 1
mask = ((1 << 128) - 1) << (0 if hwsel else 128)
curr = int(self.state.wreg[wridx]) & mask
valpos = value << 128 if hwsel else value
self.state.wreg[wridx].set(curr | valpos)
def load_wlen_word_from_memory(self, addr: int) -> int:
assert 0 <= addr
word = 0
for byte_off in range(0, 32, 4):
bit_off = byte_off * 8
word += cast(int, self.state.memory.lw(addr + byte_off)) << bit_off
return word
def store_wlen_word_to_memory(self, addr: int, word: int) -> None:
assert 0 <= addr
assert 0 <= word
assert (word >> 256) == 0
mask32 = (1 << 32) - 1
for byte_off in range(0, 32, 4):
bit_off = byte_off * 8
self.state.memory.sw(addr + byte_off, (word >> bit_off) & mask32)
def add_with_carry(a: int, b: int, carry_in: int) -> Tuple[int, int]:
result = a + b + carry_in
flags_out = AttrDict({"C": (result >> 256) & 1,
"L": result & 1,
"M": (result >> 255) & 1,
"Z": 1 if result == 0 else 0})
return (result & ((1 << 256) - 1), flags_out)
def issue(self, insn: Instruction) -> List[Trace]:
'''An overridden version of riscvmodel's Model.issue
We have to override this to allow the loop stack to jump in between
instruction execution and calculating the trace of changes.
self.state.pc += 4
trace = self.state.changes()
if self.verbose is not False:
", ".join([str(t) for t in trace])))
return trace