fetch and sched: models for fetch and dispatch queues

Also includes the TBM model of a queue, and a class for reading traces.

Change-Id: I9dbcb5fe94611935e1f0b61d3f2da38d522c71a5
diff --git a/buffered_queue.py b/buffered_queue.py
new file mode 100644
index 0000000..fe3e592
--- /dev/null
+++ b/buffered_queue.py
@@ -0,0 +1,96 @@
+"""Queue Module."""
+
+import collections
+import itertools
+from typing import Optional, Sequence, TypeVar
+
+import interfaces
+
+
+# Declare type variable
+T = TypeVar('T')
+
+
+# pylint: disable-next=abstract-method
+# __len__ and __iter__ are implemented by deque
+class BufferedQueue(collections.deque[T], interfaces.ConsumableQueue[T]):
+    """Queue model.
+
+    For the owner this is a buffered queue; for the consumer this is
+    ConsumableQueue. Elements that should not be visible from outside (e.g.
+    during the computation of next state) can be added to the queue using
+    `buffer(e)`. Call `flush` to make them visible.
+    """
+
+    def __init__(self, size: Optional[int]) -> None:
+        """Construct a queue.
+
+        Args:
+          size: the size of the queue. `None` (or -1) for infinite queue.
+        """
+        super().__init__([])
+        self._size = size if size != -1 else None
+        self._buff = collections.deque()
+
+    def is_buffer_full(self) -> bool:
+        if self._size is not None:
+            return len(self) + len(self._buff) >= self._size
+
+        return False
+
+    def buffer(self, item) -> None:
+        self._buff.append(item)
+
+    def flush(self) -> None:
+        if self._size is None or len(self) + len(self._buff) <= self._size:
+            self.extend(self._buff)
+            self._buff.clear()
+        else:
+            for _ in range(self._size - len(self)):
+                self.append(self._buff.popleft())
+
+    def chain(self):
+        return itertools.chain(self, self._buff)
+
+    def pp_three_valued(self, vals: Sequence[str]) -> str:
+        if self.is_buffer_full():
+            # Full
+            return vals[2]
+
+        if any(self.chain()):
+            # Partial
+            return vals[1]
+
+        # Empty
+        return vals[0]
+
+    # Implements interfaces.ConsumableQueue
+    @property
+    def size(self) -> Optional[int]:
+        return self._size
+
+    # Implements interfaces.ConsumableQueue
+    def full(self) -> bool:
+        return self._size is not None and len(self) >= self._size
+
+    # Implements interfaces.ConsumableQueue
+    def dequeue(self) -> Optional[T]:
+        return self.popleft()
+
+    # Implements interfaces.ConsumableQueue
+    def peek(self) -> Optional[T]:
+        return self[0] if self else None
+
+    # Implements interfaces.ConsumableQueue
+    # The following is useless, but without it pylint will issue an error
+    # (E0110) for every instantiation of BufferedQueue.
+    # pylint: disable-next=useless-parent-delegation
+    def __len__(self) -> int:
+        return super().__len__()
+
+    # Implements interfaces.ConsumableQueue
+    # The following is useless, but without it pylint will issue an error
+    # (E0110) for every instantiation of BufferedQueue.
+    # pylint: disable-next=useless-parent-delegation
+    def __iter__(self):
+        return super().__iter__()
diff --git a/fetch_unit.py b/fetch_unit.py
new file mode 100644
index 0000000..383e5ae
--- /dev/null
+++ b/fetch_unit.py
@@ -0,0 +1,216 @@
+"""Fetch Unit module."""
+
+from typing import Any, Dict, Optional, Sequence
+
+from buffered_queue import BufferedQueue
+import counter
+from counter import Counter
+from functional_trace import FunctionalTrace
+import interfaces
+
+class NextFetch:
+    """Hold the sate of next-addr fetching.
+
+    `addr` is the memory location from which the next batch of instructions
+    should be fetched from. This can be None if there are no more instructions
+    in the trace, or when the next instruction (after a branch) is not the
+    normal +4 bytes successor.
+    """
+
+    def __init__(self) -> None:
+        self._addr = None
+        self._stall = False
+
+    @property
+    def addr(self) -> Optional[int]:
+        return self._addr
+
+    @addr.setter
+    def addr(self, val: int) -> None:
+        self._addr = val
+        self._stall = False
+
+    @property
+    def stall(self) -> bool:
+        return self._stall
+
+    @stall.setter
+    def stall(self, val: bool) -> None:
+        self._addr = None
+        self._stall = val
+
+
+class FetchUnit(interfaces.FetchUnit):
+    def __init__(self, config: Dict[str, Any], trace: FunctionalTrace):
+        super().__init__("FE")
+
+        self._trace = trace
+        self._branch_prediction = config["branch_prediction"]
+        self._fetch_rate = config["fetch_rate"]
+
+
+        ## Current state
+        # The queue from which `SchedUnit` reads.
+        self._queue = BufferedQueue(config.get("fetch_queue_size"))
+
+        # The next address to fetch a batch from, or indicate a stall (waiting
+        # for branch target to be computed).
+        self._next_fetch_addr = NextFetch()
+
+        ## Next state
+        self._next_fetch_stall = None
+
+    # Implements interfaces.FetchUnit
+    @property
+    def queue(self) -> interfaces.ConsumableQueue:
+        return self._queue
+
+    # Implements interfaces.FetchUnit
+    def eof(self) -> bool:
+        return self._trace.eof()
+
+    # Implements interfaces.FetchUnit
+    def pending(self) -> int:
+        return len(self._queue)
+
+    # Implements interfaces.FetchUnit
+    def reset(self, cntr: Counter) -> None:
+        super().reset(cntr)
+        # TODO(sflur): implement proper reset
+        cntr.stalls[self.name] = 0
+        cntr.utilizations[self.name] = counter.Utilization(self.queue.size)
+
+    # Implements interfaces.FetchUnit
+    def tick(self, cntr: Counter) -> None:
+        super().tick(cntr)
+
+        if self._trace.eof():
+            self.log("can't fetch new instructions:"
+                     " no more instructions in trace.")
+            return
+
+        if (self._queue.size is not None and
+                len(self._queue) + self._fetch_rate >
+                self._queue.size):
+            self.log("can't fetch new instructions:"
+                     " not enough room in the fetch queue.")
+            cntr.stalls[self.name] += 1
+            return
+
+        # TODO(sflur): make `inst_size` configurable.
+        inst_size = 4  # bytes
+
+        if self._next_fetch_addr.addr is not None:
+            if self._trace.next_addr() != self._next_fetch_addr.addr:
+                if self._branch_prediction == "none":
+                    self.log(
+                        "generating memory accesses for"
+                        f" {self._next_fetch_addr.addr} (but next trace"
+                        f" instruction is at {self._trace.next_addr()})")
+
+                    # TODO(sflur): generate memory accesses for the whole batch.
+
+                    self._next_fetch_addr.stall = True
+                    return
+
+                assert self._branch_prediction == "perfect", (
+                        # pylint: disable-next=consider-using-f-string
+                        "Error: Unknown branch prediction option %s" %
+                        self._branch_prediction)
+
+        elif self._next_fetch_addr.stall:
+            self.log("stalling")
+            cntr.stalls[self.name] += 1
+            return
+
+        # The first address of the current batch. After a branch this might not
+        # be properly aligned. We should still generate memory accesses for the
+        # missing lower bytes!
+        fetch_addr = self._trace.next_addr()
+        # TODO(sflur): generate memory accesses for the whole batch.
+
+        # TODO(sflur): handle compressed instructions, and misaligned
+        # instructions?
+
+        # Set the address for the next batch, and force it to be aligned.
+        next_addr = fetch_addr + (inst_size * self._fetch_rate)
+        next_addr -= next_addr % (inst_size * self._fetch_rate)
+        self._next_fetch_addr.addr = next_addr
+
+        # Buffer the current batch of instructions.
+        for fetch_addr in range(fetch_addr, next_addr, inst_size):
+            if fetch_addr != self._trace.next_addr():
+                # This instruction was not executed in the functional trace,
+                # hence it's not in the trace. But, a uarch would fetch this
+                # instruction from memory, and it would occupy a place in the
+                # queue, so we simulate that (with a None).
+                self._queue.buffer(None)
+                continue
+
+            inst = self._trace.dequeue()
+            if inst is None:
+                self.log("no more instructions in trace")
+                break
+
+            self.log(inst.mnemonic + " from mem/trace")
+            self._queue.buffer(inst)
+
+            if (not inst.is_branch and
+                    inst.addr + inst_size != self._trace.next_addr()):
+                # This could happen when an exception is taken
+                # TODO(sflur): what do we need to do to handle an exception?
+                self.log("next fetch is an exception handler?")
+                self._next_fetch_addr.addr = self._trace.next_addr()
+
+        # We count all the instructions a uarch would actually fetch.
+        cntr.utilizations[self.name].count += self._fetch_rate
+
+    # Implements interfaces.FetchUnit
+    def tock(self, cntr: Counter) -> None:
+        super().tock(cntr)
+
+        self._queue.flush()
+
+        if self._next_fetch_stall is not None:
+            self._next_fetch_addr.stall = self._next_fetch_stall
+            self._next_fetch_stall = None
+
+        cntr.utilizations[self.name].occupied += len(self._queue)
+
+    # Implements interfaces.FetchUnit
+    def branch_resolved(self) -> None:
+        """Inform the FU that branch target is now avilable."""
+        self.log("branch resolved")
+
+        assert self._branch_prediction == "none"
+
+        # The branch target might have already been placed in the the fetch
+        # queue, so we only clean Nones (fake instructions) from the queue.
+        while self._queue.buff and self._queue.buff[0] is None:
+            self._queue.buff.popleft()
+        while self._queue and self._queue[0] is None:
+            self._queue.popleft()
+
+        if self.phase == interfaces.CyclePhase.TICK:
+            self._next_fetch_stall = False
+        else:
+            assert self.phase == interfaces.CyclePhase.TOCK
+            self._next_fetch_addr.stall = False
+
+    # Implements interfaces.FetchUnit
+    def print_state_detailed(self, file) -> None:
+        if self._queue:
+            queue_str = ", ".join(str(i) if i else "X"
+                                  # pylint: disable-next=bad-reversed-sequence
+                                  for i in reversed(self._queue))
+        else:
+            queue_str = "-"
+        print(f"[{self.name}] {queue_str}", file=file)
+
+    # Implements interfaces.FetchUnit
+    def get_state_three_valued_header(self) -> Sequence[str]:
+        return [self.name]
+
+    # Implements interfaces.FetchUnit
+    def get_state_three_valued(self,vals: Sequence[str]) ->  Sequence[str]:
+        return [self._queue.pp_three_valued(vals)]
diff --git a/functional_trace.py b/functional_trace.py
new file mode 100644
index 0000000..3c1da0f
--- /dev/null
+++ b/functional_trace.py
@@ -0,0 +1,120 @@
+"""Trace module."""
+
+from typing import Any, IO, Optional
+
+# Generated by `flatc`.
+import FBInstruction.Instructions as FBInstrs
+from instruction import Instruction
+import tbm_options
+from utilities import FileFormat
+
+class FunctionalTrace:
+    """Class representing a functional simulator trace."""
+
+    def __init__(self, input_file: IO[Any], input_format: FileFormat,
+                 instructions_range: str) -> None:
+        self.input_file = input_file
+
+        self.input_format = input_format
+        if input_format == FileFormat.JSON:
+            self.read_instructions = self.read_json_instructions
+        else:
+            assert input_format == FileFormat.FLATBUFFERS
+            self.read_instructions = self.read_fb_instructions
+
+        start, _, end = instructions_range.partition(":")
+        self.end = int(end) if end else None
+
+        self.instr_count = 0
+        self.instrs = []
+        self.read_instructions()
+        self.skip(int(start))
+
+    @classmethod
+    def from_json(cls, input_file: IO[Any], instructions_range: str):
+        return cls(input_file, FileFormat.JSON, instructions_range)
+
+    @classmethod
+    def from_fb(cls, input_file: IO[Any], instructions_range: str):
+        return cls(input_file, FileFormat.FLATBUFFERS, instructions_range)
+
+    def next_addr(self) -> Optional[int]:
+        return self.instrs[-1].addr if self.instrs else None
+
+    def dequeue(self) -> Optional[Instruction]:
+        if self.eof():
+            return None
+
+        return_val = self.instrs.pop()
+        if not self.instrs:
+            self.read_instructions()
+        self.instr_count += 1
+
+        return return_val
+
+    def skip(self, n: int) -> None:
+        if self.eof() or n <= 0:
+            return
+
+        if len(self.instrs) >= n:
+            self.instr_count += n
+            self.instrs = self.instrs[:-n]
+            if not self.instrs:
+                self.read_instructions()
+            return
+
+        n -= len(self.instrs)
+        self.instr_count += len(self.instrs)
+        self.instrs.clear()
+
+        if self.input_format == FileFormat.JSON:
+            for _ in range(n):
+                self.input_file.readline()
+            self.instr_count += n
+            self.read_instructions()
+            return
+
+        assert self.input_format == FileFormat.FLATBUFFERS
+
+        while len(self.instrs) < n:
+            n -= len(self.instrs)
+            self.instr_count += len(self.instrs)
+            self.instrs.clear()
+            # TODO(sflur): self.read_fb_instructions() constructs the
+            # instructions but we don't use most of them.
+            self.read_fb_instructions()
+            if not self.instrs:
+                return
+
+        if n > 0:
+            self.instr_count += n
+            self.instrs = self.instrs[:-n]
+
+    def eof(self) -> bool:
+        return (not self.instrs or
+                (self.end is not None and self.instr_count >= self.end))
+
+    def read_json_instructions(self) -> None:
+        instrs = []
+        for _ in range(tbm_options.args.json_trace_buffer_size):
+            line = self.input_file.readline()
+            if line:
+                instrs.append(Instruction.from_json(line))
+            else:
+                break
+        instrs.reverse()
+        self.instrs = instrs
+
+    def read_fb_instructions(self) -> None:
+        length = self.input_file.read(4)
+        if not length:
+            return
+
+        length = int.from_bytes(length, byteorder="little")
+
+        buf = bytearray(self.input_file.read(length))
+        instrs = FBInstrs.Instructions.GetRootAsInstructions(buf, 0)
+        self.instrs = [
+            Instruction.from_fb(instrs.Instructions(i))
+            for i in reversed(range(instrs.InstructionsLength()))
+        ]
diff --git a/sched_unit.py b/sched_unit.py
new file mode 100644
index 0000000..11998db
--- /dev/null
+++ b/sched_unit.py
@@ -0,0 +1,191 @@
+from typing import Any, Dict, Iterable, Sequence
+
+from buffered_queue import BufferedQueue
+import counter
+from counter import Counter
+from instruction import Instruction
+import interfaces
+
+
+# TODO(b/261690182): rename the SchedUnit
+class SchedUnit(interfaces.SchedUnit):
+    """Issue unit model."""
+
+    def __init__(self, config: Dict[str, Any]):
+        super().__init__("SC")
+
+        self._decode_rate = config.get("decode_rate")
+        self._branch_prediction = config["branch_prediction"]
+
+        self._fetch_unit = None
+        self._exec_unit = None
+
+        ## Current states
+        self._queues = {}
+        self._branch_stalling = False
+
+        ## Next state
+        self._next_branch_stalling = None
+
+    def add_queue(self, uid: str, desc) -> None:
+        self._queues[uid] = BufferedQueue(desc.get("size"))
+
+    def connect(self, fetch_unit: interfaces.FetchUnit,
+                exec_unit: interfaces.ExecUnit) -> None:
+        self._fetch_unit = fetch_unit
+        self._exec_unit = exec_unit
+
+    # Implements interfaces.SchedUnit
+    @property
+    def queues(self) -> Iterable[BufferedQueue[Instruction]]:
+        return self._queues.values()
+
+    # Implements interfaces.SchedUnit
+    def pending(self) -> int:
+        return sum(len(q) for q in self._queues.values())
+
+    # Implements interfaces.SchedUnit
+    def reset(self, cntr: Counter) -> None:
+        super().reset(cntr)
+        # TODO(sflur): implement proper reset
+        cntr.stalls[self.name] = 0
+        for uid, q in self._queues.items():
+            cntr.utilizations[uid] = counter.Utilization(q.size)
+
+    # Implements interfaces.SchedUnit
+    def tick(self, cntr: Counter) -> None:
+        super().tick(cntr)
+
+        if self._branch_stalling:
+            self.log("queuing stalled: unresolved branch")
+            return
+
+        for _ in range(self._decode_rate if self._decode_rate
+                       else len(self._fetch_unit.queue)):
+            if not self._fetch_unit.queue:
+                # Fetch queue is empty
+                break
+
+            fetched_instr = self._fetch_unit.queue.peek()
+            if not fetched_instr:
+                # A None instruction in the fetch queue stands for instruction
+                # that the functional simulator did not execute (or fetch), so
+                # we don't know what instruction that was, or how it behaved.
+                # In a real uarch this instruction will take some resources
+                # until the uarch figures out it should be evicted.
+                # TODO(sflur): count these instructions and apply some
+                # proportional penalty to the performance TBM reports?
+                self._fetch_unit.queue.dequeue()
+                continue
+
+            # Check if we need to flush pending instructions.
+            if fetched_instr.is_flush and (self.pending() or
+                                           self._exec_unit.pending()):
+                # TODO(sflur): Currently flush instructions wait in the fetch
+                # queue, is that the right place to wait in?
+                cntr.stalls[self.name] += 1
+                self.log(f"queueing stalled: flush in effect: {fetched_instr}")
+                break
+
+            if fetched_instr.is_nop:
+                self.log(f"retired NOP instruction: {fetched_instr}")
+                self._fetch_unit.queue.dequeue()
+                cntr.retired_instruction_count += 1
+                continue
+
+            qid = self._exec_unit.get_issue_queue_id(fetched_instr)
+
+            # Check if the queue is available.
+            if self._queues[qid].is_buffer_full():
+                cntr.stalls[self.name] += 1
+                self.log(f"queueing stalled: '{qid}' is full")
+                break
+
+            # TODO(sflur): instead of check_conflicts, we could add the
+            # instructions to the scoreboard at this point.
+            if not self.check_conflicts(fetched_instr, qid):
+                # TODO(sflur): the blocking instruction is still in the fetch
+                # queue, maybe move it somewhere else?
+                cntr.stalls[self.name] += 1
+                self.log("queueing stalled: conflict with queued instruction")
+                break
+
+            # It is safe to queue the instruction.
+            self._queues[qid].buffer(fetched_instr)
+            self._fetch_unit.queue.dequeue()
+            cntr.utilizations[qid].count += 1
+            self.log(f"instruction '{fetched_instr}' queued")
+
+            if fetched_instr.is_branch:
+                cntr.branch_count += 1
+
+                if self._branch_prediction == "none":
+                    self._branch_stalling = True
+                    break
+
+    # Implements interfaces.SchedUnit
+    def tock(self, cntr: Counter) -> None:
+        super().tock(cntr)
+
+        for q in self._queues.values():
+            q.flush()
+
+        if self._next_branch_stalling is not None:
+            self._branch_stalling = self._next_branch_stalling
+            self._next_branch_stalling = None
+
+        for uid, q in self._queues.items():
+            cntr.utilizations[uid].occupied += len(q)
+
+    def check_conflicts(self, new_instr: Instruction, qid: str) -> bool:
+        """Check if `instr` conflicts with other instructions.
+
+        Check whether it is safe to reorder `instr` wrt the instructions
+        already in other queues.
+        There is no need to check conflicts with instructions that are already
+        in execution pipes, as that is handled by the scoreboard.
+
+        Args:
+          new_instr: fetched instructions.
+          qid: the dispatch queue the instruction will be placed in.
+        Returns:
+          True if there are no conflicts, False otherwise.
+        """
+
+        for name, q in self._queues.items():
+            if name == qid:
+                # skip the queue new_instr is going to, as it's an in-order
+                # queue.
+                continue
+
+            for instr in q.chain():
+                if new_instr.conflicts_with(instr):
+                    return False
+
+        return True
+
+    # Implements interfaces.SchedUnit
+    def branch_resolved(self) -> None:
+        if self.phase == interfaces.CyclePhase.TICK:
+            self._next_branch_stalling = False
+        else:
+            assert self.phase == interfaces.CyclePhase.TOCK
+            self._branch_stalling = False
+
+    # Implements interfaces.SchedUnit
+    def print_state_detailed(self, file) -> None:
+        for uid, dq in self._queues.items():
+            if dq:
+                queue_str = ", ".join(str(i) for i in reversed(dq))
+            else:
+                queue_str = "-"
+            print(f"[qu-{uid}] {queue_str}", file=file)
+
+
+    # Implements interfaces.SchedUnit
+    def get_state_three_valued_header(self) -> Sequence[str]:
+        return self._queues.keys()
+
+    # Implements interfaces.SchedUnit
+    def get_state_three_valued(self,vals: Sequence[str]) ->  Sequence[str]:
+        return [q.pp_three_valued(vals) for q in self._queues.values()]