memory: model of a memory system This includes classes for main memory and cache elements. Change-Id: I11896177e5503b9225c5f54ff9d57ae79005e934
diff --git a/memory_system.py b/memory_system.py new file mode 100644 index 0000000..8c531a6 --- /dev/null +++ b/memory_system.py
@@ -0,0 +1,546 @@ +"""MemorySystem module.""" + +import collections +import math +import re +import sys +from typing import Optional, Sequence, Union + +from counter import Counter +import interfaces + +class LRUSet: + """LRU set.""" + + def __init__(self, size: int) -> None: + self.the_set = collections.deque(maxlen=size) + self.dirty = set() + + def try_access(self, tag: int, set_dirty: bool) -> bool: + try: + self.the_set.remove(tag) + self.the_set.append(tag) + if set_dirty: + self.dirty.add(tag) + return True + except ValueError: # remove failed, hence tag not in set + return False + + def evict(self) -> Optional[int]: + if len(self.the_set) == self.the_set.maxlen: + tag = self.the_set.popleft() + try: + self.dirty.remove(tag) + return tag + except KeyError: + pass + return None + + def insert(self, tag: int, dirty: bool) -> None: + if self.the_set: + self.dirty.discard(self.the_set[0]) + self.the_set.append(tag) + if dirty: + self.dirty.add(tag) + + def take(self, tag: int) -> bool: + if self.the_set[-1] == tag: + # This is an optimisation (take is always called after try_access) + self.the_set.pop() + else: + self.the_set.remove(tag) + try: + self.dirty.remove(tag) + return True + except KeyError: + return False + + +class DirectMapMem: + """Direct Map.""" + + def __init__(self, line_size_log2: int, + size_log2: int) -> None: + + self.line_size_log2 = line_size_log2 + self.index_size_log2 = size_log2 - self.line_size_log2 + self.tags = [None] * (2**self.index_size_log2) + self.dirty = set() + + def index(self, addr: int) -> int: + mask = (1 << self.index_size_log2) - 1 + return (addr >> self.line_size_log2) & mask + + def tag(self, addr: int) -> int: + return addr >> (self.line_size_log2 + self.index_size_log2) + + def line_addr(self, addr: int) -> int: + return (addr >> self.line_size_log2) << self.line_size_log2 + + def try_access(self, addr: int, set_dirty: bool) -> bool: + if self.tags[self.index(addr)] == self.tag(addr): + if set_dirty: + self.dirty.add(self.index(addr)) + return True + + return False + + def evict_for(self, addr: int) -> Optional[int]: + i = self.index(addr) + if self.tags[i] is not None: + tag = self.tags[i] + self.tags[i] = None + try: + self.dirty.remove(tag) + return ( + (tag << self.index_size_log2) | i) << self.line_size_log2 + except KeyError: + pass + return None + + def insert(self, addr: int, dirty: bool) -> None: + self.tags[self.index(addr)] = self.tag(addr) + if dirty: + self.dirty.add(self.index(addr)) + else: + self.dirty.discard(self.index(addr)) + + def take(self, addr: int) -> bool: + self.tags[self.index(addr)] = None + try: + self.dirty.remove(self.tag(addr)) + return True + except KeyError: + return False + + +class SetAssocMem: + """Set associative.""" + + def __init__(self, desc, line_size_log2: int, size_log2: int) -> None: + self.line_size_log2 = line_size_log2 + + # TODO(sflur): handle size that is not power of 2? + set_size_log2 = int(math.log2(desc["set_size"])) + self.index_size_log2 = size_log2 - self.line_size_log2 - set_size_log2 + + if desc["replacement"] == "LRU": + self.tags = [ + LRUSet(desc["set_size"]) for _ in range(2**self.index_size_log2) + ] + else: + assert False + + def index(self, addr: int) -> int: + mask = (1 << self.index_size_log2) - 1 + return (addr >> self.line_size_log2) & mask + + def tag(self, addr: int) -> int: + return addr >> (self.line_size_log2 + self.index_size_log2) + + def line_addr(self, addr: int) -> int: + return (addr >> self.line_size_log2) << self.line_size_log2 + + def try_access(self, addr: int, set_dirty: bool) -> bool: + return self.tags[self.index(addr)].try_access(self.tag(addr), set_dirty) + + def insert(self, addr: int, dirty: bool) -> None: + self.tags[self.index(addr)].insert(self.tag(addr), dirty) + + def take(self, addr: int) -> bool: + return self.tags[self.index(addr)].take(self.tag(addr)) + + def evict_for(self, addr: int) -> Optional[int]: + i = self.index(addr) + tag = self.tags[i].evict() + return (((tag << self.index_size_log2) | i) << self.line_size_log2 + if tag is not None else None) + + +def load_mem(desc, line_size_log2: int, size_log2: int): + if desc["type"] == "set_assoc": + return SetAssocMem(desc, line_size_log2, size_log2) + + if desc["type"] == "direct_map": + return DirectMapMem(line_size_log2, size_log2) + + # TODO(sflur): report error + assert False + + +BYTE_UNITS = {"b": 0, "kb": 10, "mb": 20, "gb": 30, "tb": 40} + + +def parse_bytes_to_log2(x: Union[int, str]) -> int: + u = 0 + if isinstance(x, str): + m = re.match(r"^(\d+)\s*(.*)", x) + assert m + x = int(m.group(1)) + if m.group(2): + u = BYTE_UNITS[m.group(2).lower()] + + assert x > 0 + return int(math.log2(x)) + u + + +class CacheFront: + """Cache that supports load/store.""" + + def __init__(self, desc, parent) -> None: + self.parent = parent + + # TODO(sflur): report an error if not divisible by 8? + line_size_log2 = int( + math.log2(desc["line_size"] // 8)) + size_log2 = parse_bytes_to_log2(desc["size"]) + self.mem = load_mem(desc["placement"], line_size_log2, size_log2) + + self.write_policy = desc["write_policy"] + self.latencies = desc["latencies"] + + self.front_reqs = collections.deque() + self.front_replys = collections.defaultdict(collections.deque) + self.state = None + + def issue_load(self, uid, addr) -> None: + self.front_reqs.append(("read", uid, addr)) + + def issue_store(self, uid, addr) -> None: + self.front_reqs.append(("write", uid, addr)) + + def take_load_replys(self, uid) -> Sequence[int]: + res = collections.deque() + replys = self.front_replys[uid] + for _ in range(len(replys)): + if replys[0][0] == "read": + res.append(replys.popleft()[2]) + else: + replys.rotate() + + if not replys: + del self.front_replys[uid] + + return res + + def take_store_replys(self, uid) -> Sequence[int]: + res = collections.deque() + replys = self.front_replys[uid] + for _ in range(len(replys)): + if replys[0][0] == "write": + res.append(replys.popleft()[2]) + else: + replys.rotate() + + if not replys: + del self.front_replys[uid] + + return res + + def tick(self) -> None: + if self.state: + if self.state[0] == "stall": + _, delay, res = self.state + if delay > 0: + self.state = ("stall", delay - 1, res) + else: + self.front_replys[res[1]].append(res) + self.state = None + + elif self.state[0] == "miss": + _, req = self.state + cmd, _, addr = req + write_back_addr = self.mem.evict_for(addr) + if write_back_addr is not None: + self.parent.front_reqs.append( + ("write", self, write_back_addr)) + self.parent.front_reqs.append((f"fetch_{cmd}", self, addr)) + self.state = ("stall-parent", req) + + elif self.state[0] == "write-through": + _, req = self.state + _, _, addr = req + self.parent.front_reqs.append(("write", self, addr)) + self.state = ("stall-parent", req) + + def tock(self) -> None: + if not self.state and self.front_reqs: + req = self.front_reqs.popleft() + if req[0] in ["read", "write"]: + cmd, _, addr = req + if self.mem.try_access( + addr, cmd == "write" and + self.write_policy == "write_back"): + if cmd == "write" and self.write_policy == "write_through": + self.state = ("write-through", req) + else: + self.state = ("stall", self.latencies[cmd] - 1, req) + else: + self.state = ("miss", req) + else: + assert False + + if (self.state and self.state[0] == "stall-parent" and + self.parent.front_replys[self]): + reply = self.parent.front_replys[self].popleft() + if reply[0] == "write": + _, req = self.state + if self.write_policy == "write_through": + self.state = ("stall", self.latencies[req[0]] - 1, req) + # else: it's a write-back, we still need to wait for the fetch, + # hence `self.state` stays the same. + elif reply[0] in ["fetch_read", "fetch_write"]: + _, req = self.state + cmd, _, addr = req + self.mem.insert( + addr, cmd == "write" and self.write_policy == "write_back") + + if cmd == "write" and self.write_policy == "write_through": + self.state = ("write-through", req) + else: + self.state = ("stall", self.latencies[cmd] - 1, req) + else: + assert False + + +class Cache: + """Cache that is part of a hierarchy (not front).""" + + def __init__(self, desc, parent) -> None: + self.parent = parent + self.children = [] + + # TODO(sflur): report an error if not divisible by 8? + line_size_log2 = int(math.log2(desc["line_size"] // 8)) + size_log2 = parse_bytes_to_log2(desc["size"]) + self.mem = load_mem(desc["placement"], line_size_log2, size_log2) + + self.inclusion = desc["inclusion"] + self.write_policy = desc["write_policy"] + self.latencies = desc["latencies"] + + self.front_reqs = collections.deque() + self.front_replys = collections.defaultdict(collections.deque) + self.state = None + + def tick(self) -> None: + if self.state: + if self.state[0] == "stall": + _, delay, res = self.state + if delay > 0: + self.state = ("stall", delay - 1, res) + else: + self.front_replys[res[1]].append(res) + self.state = None + + elif self.state[0] == "miss": + _, req = self.state + if req[0] in ["fetch_read", "fetch_write"]: + cmd, _, addr = req + if self.inclusion == "inclusive": + write_back_addr = self.mem.evict_for(addr) + if write_back_addr is not None: + self.parent.front_reqs.append( + ("write", self, write_back_addr)) + self.parent.front_reqs.append((cmd, self, addr)) + self.state = ("stall-parent", req) + + elif req[0] == "write": + _, _, addr = req + write_back_addr = self.mem.evict_for(addr) + if write_back_addr is not None: + self.parent.front_reqs.append( + ("write", self, write_back_addr)) + self.parent.front_reqs.append(("fetch_write", self, addr)) + self.state = ("stall-parent", req) + + elif self.state[0] == "write-through": + _, req = self.state + _, _, addr = req + self.parent.front_reqs.append(("write", self, addr)) + self.state = ("stall-parent", req) + + def tock(self) -> None: + if not self.state and self.front_reqs: + req = self.front_reqs.popleft() + if req[0] in ["fetch_read", "fetch_write"]: + cmd, _, addr = req + if self.mem.try_access(addr, False): + if self.inclusion == "exclusive": + _dirty = self.mem.take(addr) + # TODO(sflur): pass the dirty bit + self.state = ("stall", self.latencies[cmd] - 1, req) + else: + self.state = ("miss", req) + + elif req[0] == "write": + cmd, _, addr = req + if self.mem.try_access(addr, self.write_policy == "write_back"): + if self.write_policy == "write_back": + self.state = ("stall", self.latencies[cmd] - 1, req) + if self.write_policy == "write_through": + self.state = ("write-through", req) + else: + self.state = ("miss", req) + + else: + assert False + + if (self.state and self.state[0] == "stall-parent" and + self.parent.front_replys[self]): + reply = self.parent.front_replys[self].popleft() + if reply[0] == "write": + _, req = self.state + if self.write_policy == "write_through": + self.state = ("stall", self.latencies[req[0]] - 1, req) + # else: it's a write-back, we still need to wait for the fetch, + # hence `self.state` stays the same. + elif reply[0] in ["fetch_read", "fetch_write"]: + _, req = self.state + cmd, _, addr = req + self.mem.insert( + addr, cmd == "write" and self.write_policy == "write_back") + + if cmd == "write" and self.write_policy == "write_through": + self.state = ("write-through", req) + else: + self.state = ("stall", self.latencies[cmd] - 1, req) + else: + assert False + + +class MainMemory: + """Main memory.""" + + def __init__(self, desc) -> None: + self.children = [] + + self.latencies = desc["latencies"] + + self.front_reqs = collections.deque() + self.front_replys = collections.defaultdict(collections.deque) + self.state = None + + def issue_load(self, uid, addr) -> None: + self.front_reqs.append(("read", uid, addr)) + + def issue_store(self, uid, addr) -> None: + self.front_reqs.append(("write", uid, addr)) + + def take_load_replys(self, uid) -> Sequence[int]: + res = collections.deque() + replys = self.front_replys[uid] + for _ in range(len(replys)): + if replys[0][0] == "read": + res.append(replys.popleft()[2]) + else: + replys.rotate() + + if not replys: + del self.front_replys[uid] + + return res + + def take_store_replys(self, uid) -> Sequence[int]: + res = collections.deque() + replys = self.front_replys[uid] + for _ in range(len(replys)): + if replys[0][0] == "write": + res.append(replys.popleft()[2]) + else: + replys.rotate() + + if not replys: + del self.front_replys[uid] + + return res + + def tick(self) -> None: + if self.state: + if self.state[0] == "stall": + _, delay, res = self.state + if delay > 0: + self.state = ("stall", delay - 1, res) + else: + self.front_replys[res[1]].append(res) + self.state = None + + def tock(self) -> None: + if self.state is None and self.front_reqs: + req = self.front_reqs.popleft() + if req[0] in ["read", "write", "fetch_read", "fetch_write"]: + self.state = ("stall", self.latencies[req[0]] - 1, req) + else: + assert False + + +class MemorySystem(interfaces.Module): + """Memory system model.""" + + def __init__(self, desc) -> None: + super().__init__("MS") + + self.elements = {"main": MainMemory(desc)} + + if "levels" in desc: + for uid, d in desc["levels"].items(): + self.load_element(uid, d, self.elements["main"]) + + def load_element(self, uid, desc, parent) -> None: + front = "levels" not in desc + if desc["type"] == "unified": + e = CacheFront(desc, parent) if front else Cache(desc, parent) + elif desc["type"] == "dcache": + e = CacheFront(desc, parent) if front else Cache(desc, parent) + elif desc["type"] == "icache": + e = CacheFront(desc, parent) if front else Cache(desc, parent) + else: + self.logger.error("unknown cache type: %s", desc['type']) + sys.exit(1) + + parent.children.append(e) + + self.elements[uid] = e + + if "levels" in desc: + for u, d in desc["levels"].items(): + self.load_element(u, d, e) + + # Implements interfaces.Module + # pylint: disable-next=useless-parent-delegation + def reset(self, cntr: Counter) -> None: + super().reset(cntr) + # TODO(sflur): implement proper reset + + # Implements interfaces.Module + def tick(self, cntr: Counter) -> None: + super().tick(cntr) + + for e in self.elements.values(): + e.tick() + + # Implements interfaces.Module + def tock(self, cntr: Counter) -> None: + super().tock(cntr) + + for e in self.elements.values(): + e.tock() + + # Implements interfaces.Module + def pending(self) -> int: + # TODO(sflur): maybe return the number of outstanding accesses? + return 0 + + # Implements interfaces.Module + def print_state_detailed(self, file) -> None: + # TODO(sflur): what would be useful to print here? + pass + + # Implements interfaces.Module + def get_state_three_valued_header(self) -> Sequence[str]: + return [] + + # Implements interfaces.Module + def get_state_three_valued(self, vals: Sequence[str]) -> Sequence[str]: + # TODO(sflur): what would be useful to print here? + return []