[obtn] Wrap memutil_dpi for OTBN to be accessible through DPI

This patch wraps DpiMemUtil up in an interface that's convenient for
SV code to call through DPI to backdoor-load ELF code into memory.
Moving the address setup into a separate class also avoids us having
to duplicate it between e.g. otbn_top_sim and the DV code.

Signed-off-by: Rupert Swarbrick <rswarbrick@lowrisc.org>
diff --git a/hw/ip/otbn/dv/memutil/README.md b/hw/ip/otbn/dv/memutil/README.md
new file mode 100644
index 0000000..0ec1c1b
--- /dev/null
+++ b/hw/ip/otbn/dv/memutil/README.md
@@ -0,0 +1,9 @@
+# OTBN memutil wrapper
+
+This code is a wrapper around `VerilatorMemUtil` (defined in
+`hw/dv/verilator/cpp`).
+
+To use it, depend on the core file. If you're using dvsim, you'll also
+need to include `otbn_memutil_sim_opts.hjson` in your simulator
+configuration and add `"{tool}_otbn_memutil_build_opts"` to the
+`en_build_modes` variable.
diff --git a/hw/ip/otbn/dv/memutil/otbn_memutil.cc b/hw/ip/otbn/dv/memutil/otbn_memutil.cc
new file mode 100644
index 0000000..8e05ca7
--- /dev/null
+++ b/hw/ip/otbn/dv/memutil/otbn_memutil.cc
@@ -0,0 +1,181 @@
+// Copyright lowRISC contributors.
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+#include "otbn_memutil.h"
+
+#include <cassert>
+#include <cstring>
+#include <iostream>
+#include <limits>
+#include <stdexcept>
+
+OtbnMemUtil::OtbnMemUtil(const std::string &top_scope) {
+  MemAreaLoc imem_loc = {.base = 0x100000, .size = 4096};
+  std::string imem_scope =
+      top_scope + ".u_imem.u_mem.gen_generic.u_impl_generic";
+  if (!RegisterMemoryArea("imem", imem_scope, 32, &imem_loc)) {
+    throw std::runtime_error("Failed to register IMEM OTBN memory area.");
+  }
+
+  MemAreaLoc dmem_loc = {.base = 0x200000, .size = 4096};
+  std::string dmem_scope =
+      top_scope + ".u_dmem.u_mem.gen_generic.u_impl_generic";
+  if (!RegisterMemoryArea("dmem", dmem_scope, 256, &dmem_loc)) {
+    throw std::runtime_error("Failed to register DMEM OTBN memory area.");
+  }
+}
+
+void OtbnMemUtil::LoadElf(const std::string &elf_path) {
+  LoadElfToMemories(false, elf_path);
+}
+
+const StagedMem::SegMap &OtbnMemUtil::GetSegs(bool is_imem) const {
+  return GetMemoryData(is_imem ? "imem" : "dmem").GetSegs();
+}
+
+extern "C" OtbnMemUtil *OtbnMemUtilMake(const char *top_scope) {
+  try {
+    return new OtbnMemUtil(top_scope);
+  } catch (const std::exception &err) {
+    std::cerr << "Failed to create OtbnMemUtil: " << err.what() << "\n";
+    return nullptr;
+  }
+}
+
+extern "C" void OtbnMemUtilFree(OtbnMemUtil *mem_util) { delete mem_util; }
+
+extern "C" svBit OtbnMemUtilLoadElf(OtbnMemUtil *mem_util,
+                                    const char *elf_path) {
+  assert(mem_util);
+  assert(elf_path);
+  try {
+    mem_util->LoadElf(elf_path);
+    return sv_1;
+  } catch (const std::exception &err) {
+    std::cerr << "Failed to load ELF file from `" << elf_path
+              << "': " << err.what() << "\n";
+    return sv_0;
+  }
+}
+
+extern "C" svBit OtbnMemUtilStageElf(OtbnMemUtil *mem_util,
+                                     const char *elf_path) {
+  assert(mem_util);
+  assert(elf_path);
+  try {
+    mem_util->StageElf(false, elf_path);
+    return sv_1;
+  } catch (const std::exception &err) {
+    std::cerr << "Failed to load ELF file from `" << elf_path
+              << "': " << err.what() << "\n";
+    return sv_0;
+  }
+}
+
+extern "C" int OtbnMemUtilGetSegCount(OtbnMemUtil *mem_util, svBit is_imem) {
+  assert(mem_util);
+  const StagedMem::SegMap &segs = mem_util->GetSegs(is_imem);
+  size_t num_segs = segs.size();
+
+  // Since the segments are disjoint and 32-bit aligned, there are at most 2^30
+  // of them (this, admittedly, would mean an ELF file with a billion segments,
+  // but it's theoretically possible). Fortunately, that number is still
+  // representable with a signed 32-bit integer, so this assertion shouldn't
+  // ever fire.
+  assert(num_segs < std::numeric_limits<int>::max());
+
+  return num_segs;
+}
+
+extern "C" svBit OtbnMemUtilGetSegInfo(OtbnMemUtil *mem_util, svBit is_imem,
+                                       int seg_idx, svBitVecVal *seg_off,
+                                       svBitVecVal *seg_size) {
+  assert(mem_util);
+  assert(seg_off);
+  assert(seg_size);
+
+  const StagedMem::SegMap &segs = mem_util->GetSegs(is_imem);
+
+  // Make sure there is such an index.
+  if ((seg_idx < 0) || ((unsigned)seg_idx > segs.size())) {
+    std::cerr << "Invalid segment index: " << seg_idx << ". "
+              << (is_imem ? 'I' : 'D') << "MEM has " << segs.size()
+              << " segments.\n";
+    return sv_0;
+  }
+
+  // Walk to the desired segment (which we know is safe because we just checked
+  // the index was valid).
+  auto it = std::next(segs.begin(), seg_idx);
+
+  uint32_t seg_addr = it->first.lo;
+
+  // We know that seg_addr must be 32 bit aligned because DpiMemUtil checks its
+  // segments are aligned to the memory word size (which is 32 or 256 bits for
+  // imem/dmem, respectively).
+  assert(seg_addr % 4 == 0);
+
+  uint32_t word_seg_addr = seg_addr / 4;
+
+  size_t size_bytes = it->second.size();
+
+  // We know the size can't be too enormous, because the segment fits in a
+  // 32-bit address space.
+  assert(size_bytes < std::numeric_limits<uint32_t>::max());
+
+  // Divide by 4 to get the number of 32 bit words. Round up: we'll pad out the
+  // data with zeros if there's a ragged edge. (We know this is valid because
+  // any next range is also 32 bit aligned).
+  uint32_t size_words = ((uint64_t)size_bytes + 3) / 4;
+
+  memcpy(seg_off, &word_seg_addr, sizeof(uint32_t));
+  memcpy(seg_size, &size_words, sizeof(uint32_t));
+
+  return sv_1;
+}
+
+extern "C" svBit OtbnMemUtilGetSegData(
+    OtbnMemUtil *mem_util, svBit is_imem, int word_off,
+    /* output bit[31:0] */ svBitVecVal *data_value) {
+  assert(mem_util);
+  assert(data_value);
+
+  if ((word_off < 0) ||
+      ((unsigned)word_off > std::numeric_limits<uint32_t>::max() / 4)) {
+    std::cerr << "Invalid word offset: " << word_off << ".\n";
+    return sv_0;
+  }
+
+  uint32_t byte_addr = (unsigned)word_off * 4;
+
+  const StagedMem::SegMap &segs = mem_util->GetSegs(is_imem);
+
+  auto it = segs.find(byte_addr);
+  if (it == segs.end()) {
+    return sv_0;
+  }
+
+  uint32_t seg_addr = it->first.lo;
+  assert(seg_addr <= byte_addr);
+
+  // The offset within the segment
+  uint32_t seg_off = byte_addr - seg_addr;
+  assert(seg_off < it->second.size());
+
+  // How many bytes are available in the segment, starting at seg_off? We know
+  // this will be positive (because RangeMap::find finds us a range that
+  // includes seg_addr and then DpiMemUtil makes sure that the value at that
+  // range is the right length).
+  size_t avail = it->second.size() - seg_off;
+  size_t to_copy = std::min(avail, (size_t)4);
+
+  // Copy data from the segment into a uint32_t. Zero-initialize it, in case
+  // to_copy < 4.
+  uint32_t data = 0;
+  memcpy(&data, &it->second[seg_off], to_copy);
+
+  // Now copy that uint32_t into data_value and return success.
+  memcpy(data_value, &data, 4);
+  return sv_1;
+}
diff --git a/hw/ip/otbn/dv/memutil/otbn_memutil.core b/hw/ip/otbn/dv/memutil/otbn_memutil.core
new file mode 100644
index 0000000..8a2ce93
--- /dev/null
+++ b/hw/ip/otbn/dv/memutil/otbn_memutil.core
@@ -0,0 +1,22 @@
+CAPI=2:
+# Copyright lowRISC contributors.
+# Licensed under the Apache License, Version 2.0, see LICENSE for details.
+# SPDX-License-Identifier: Apache-2.0
+
+name: "lowrisc:dv:otbn_memutil"
+description: "A wrapper around memutil_verilator for OTBN"
+
+filesets:
+  files_cpp:
+    depend:
+      - lowrisc:dv_verilator:memutil_dpi
+    files:
+      - otbn_memutil.cc
+      - otbn_memutil.h: { is_include_file: true }
+      - otbn_memutil.svh: { file_type: systemVerilogSource, is_include_file: true }
+    file_type: cppSource
+
+targets:
+  default:
+    filesets:
+      - files_cpp
diff --git a/hw/ip/otbn/dv/memutil/otbn_memutil.h b/hw/ip/otbn/dv/memutil/otbn_memutil.h
new file mode 100644
index 0000000..a33717d
--- /dev/null
+++ b/hw/ip/otbn/dv/memutil/otbn_memutil.h
@@ -0,0 +1,60 @@
+// Copyright lowRISC contributors.
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+#pragma once
+
+#include <svdpi.h>
+
+#include "dpi_memutil.h"
+
+class OtbnMemUtil : public DpiMemUtil {
+ public:
+  // Constructor. top_scope is the SV scope that contains IMEM and
+  // DMEM memories as u_imem and u_dmem, respectively.
+  OtbnMemUtil(const std::string &top_scope);
+
+  // Load an ELF file at the given path and backdoor load it into the
+  // attached memories.
+  //
+  // If something goes wrong, throws a std::exception.
+  void LoadElf(const std::string &elf_path);
+
+  // Get access to the segments currently staged for imem/dmem
+  const StagedMem::SegMap &GetSegs(bool is_imem) const;
+};
+
+// DPI-accessible wrappers
+extern "C" {
+OtbnMemUtil *OtbnMemUtilMake(const char *top_scope);
+void OtbnMemUtilFree(OtbnMemUtil *mem_util);
+
+// Loads an ELF file into memory via the backdoor. Returns 1'b1 on success.
+// Prints a message to stderr and returns 1'b0 on failure.
+svBit OtbnMemUtilLoadElf(OtbnMemUtil *mem_util, const char *elf_path);
+
+// Loads an ELF file into the OtbnMemUtil object, but doesn't touch the
+// simulated memory. Returns 1'b1 on success. Prints a message to stderr and
+// returns 1'b0 on failure.
+svBit OtbnMemUtilStageElf(OtbnMemUtil *mem_util, const char *elf_path);
+
+// Returns the number of segments currently staged in imem/dmem.
+int OtbnMemUtilGetSegCount(OtbnMemUtil *mem_util, svBit is_imem);
+
+// Gets offset and size (both in 32-bit words) for a segment currently staged
+// in imem/dmem. Both are returned with output arguments. Returns 1'b1 on
+// success. Prints a message to stderr and returns 1'b0 on failure.
+svBit OtbnMemUtilGetSegInfo(OtbnMemUtil *mem_util, svBit is_imem, int seg_idx,
+                            /* output bit[31:0] */ svBitVecVal *seg_off,
+                            /* output bit[31:0] */ svBitVecVal *seg_size);
+
+// Gets a word of data from segments currently staged in imem/dmem. If there
+// is a word at that address, the function writes its value to the output
+// argument and then returns 1'b1. If there is no word at that address, the
+// output argument is untouched and the function returns 1'b0.
+//
+// If word_off is invalid (negative or enormous), the function writes a
+// message to stderr and returns 1'b0.
+svBit OtbnMemUtilGetSegData(OtbnMemUtil *mem_util, svBit is_imem, int word_off,
+                            /* output bit[31:0] */ svBitVecVal *data_value);
+}
diff --git a/hw/ip/otbn/dv/memutil/otbn_memutil.svh b/hw/ip/otbn/dv/memutil/otbn_memutil.svh
new file mode 100644
index 0000000..f1ef927
--- /dev/null
+++ b/hw/ip/otbn/dv/memutil/otbn_memutil.svh
@@ -0,0 +1,29 @@
+// Copyright lowRISC contributors.
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+`ifndef __OTBN_MEMUTIL_SVH__
+`define __OTBN_MEMUTIL_SVH__
+
+// Imports for the functions defined in otbn_memutil.h. There are documentation comments explaining
+// what the functions do there.
+
+import "DPI-C" function chandle OtbnMemUtilMake(string top_scope);
+
+import "DPI-C" function void OtbnMemUtilFree(chandle mem_util);
+
+import "DPI-C" function bit OtbnMemUtilLoadElf(chandle mem_util, string elf_path);
+
+import "DPI-C" function bit OtbnMemUtilStageElf(chandle mem_util, string elf_path);
+
+import "DPI-C" function int OtbnMemUtilGetSegCount(chandle mem_util, bit is_imem);
+
+import "DPI-C" function bit OtbnMemUtilGetSegInfo(chandle mem_util, bit is_imem, int seg_idx,
+                                                  output bit [31:0] seg_off,
+                                                  output bit [31:0] seg_size);
+
+import "DPI-C" function bit OtbnMemUtilGetSegData(chandle mem_util, bit is_imem, int word_off,
+                                                  output bit[31:0] data_value);
+
+
+`endif // __OTBN_MEMUTIL_SVH__
diff --git a/hw/ip/otbn/dv/memutil/otbn_memutil_sim_opts.hjson b/hw/ip/otbn/dv/memutil/otbn_memutil_sim_opts.hjson
new file mode 100644
index 0000000..a7e7283
--- /dev/null
+++ b/hw/ip/otbn/dv/memutil/otbn_memutil_sim_opts.hjson
@@ -0,0 +1,24 @@
+// Copyright lowRISC contributors.
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+{
+  // Additional build-time options needed to compile C++ sources in
+  // simulators such as VCS and Xcelium for anything that uses
+  // otbn_memutil.
+  memutil_dpi_core: "lowrisc:dv_verilator:memutil_dpi:0"
+  memutil_dpi_src_dir: "{eval_cmd} echo \"{memutil_dpi_core}\" | tr ':' '_'"
+
+  build_modes: [
+    {
+      name: vcs_otbn_memutil_build_opts
+      build_opts: ["-CFLAGS -I{build_dir}/src/{memutil_dpi_src_dir}/cpp",
+                   "-lelf"]
+    }
+
+    {
+      name: xcelium_otbn_memutil_build_opts
+      build_opts: ["-I{build_dir}/src/{memutil_dpi_src_dir}/cpp",
+                   "-lelf"]
+    }
+  ]
+}
diff --git a/hw/ip/otbn/dv/verilator/otbn_top_sim.cc b/hw/ip/otbn/dv/verilator/otbn_top_sim.cc
index d6c3039..7812ac4 100644
--- a/hw/ip/otbn/dv/verilator/otbn_top_sim.cc
+++ b/hw/ip/otbn/dv/verilator/otbn_top_sim.cc
@@ -7,6 +7,7 @@
 #include <iostream>
 #include <svdpi.h>
 
+#include "otbn_memutil.h"
 #include "verilated_toplevel.h"
 #include "verilator_memutil.h"
 #include "verilator_sim_ctrl.h"
@@ -18,20 +19,11 @@
 
 int main(int argc, char **argv) {
   otbn_top_sim top;
-  VerilatorMemUtil memutil;
+  VerilatorMemUtil memutil(new OtbnMemUtil("TOP.otbn_top_sim"));
+
   VerilatorSimCtrl &simctrl = VerilatorSimCtrl::GetInstance();
   simctrl.SetTop(&top, &top.IO_CLK, &top.IO_RST_N,
                  VerilatorSimCtrlFlags::ResetPolarityNegative);
-
-  MemAreaLoc imem_loc = {.base = 0x100000, .size = 4096};
-  MemAreaLoc dmem_loc = {.base = 0x200000, .size = 4096};
-
-  memutil.RegisterMemoryArea(
-      "dmem", "TOP.otbn_top_sim.u_dmem.u_mem.gen_generic.u_impl_generic", 256,
-      &dmem_loc);
-  memutil.RegisterMemoryArea(
-      "imem", "TOP.otbn_top_sim.u_imem.u_mem.gen_generic.u_impl_generic", 32,
-      &imem_loc);
   simctrl.RegisterExtension(&memutil);
 
   bool exit_app = false;
diff --git a/hw/ip/otbn/dv/verilator/otbn_top_sim.core b/hw/ip/otbn/dv/verilator/otbn_top_sim.core
index 7eaf200..db0af92 100644
--- a/hw/ip/otbn/dv/verilator/otbn_top_sim.core
+++ b/hw/ip/otbn/dv/verilator/otbn_top_sim.core
@@ -11,7 +11,7 @@
       - lowrisc:ip:otbn:0.1
   files_verilator:
     depend:
-      - lowrisc:dv_verilator:memutil_verilator
+      - lowrisc:dv:otbn_memutil
       - lowrisc:dv_verilator:simutil_verilator
     files:
       - otbn_top_sim.cc: { file_type: cppSource }
diff --git a/hw/top_earlgrey/top_earlgrey_verilator.cc b/hw/top_earlgrey/top_earlgrey_verilator.cc
index 6536c18..15f3401 100644
--- a/hw/top_earlgrey/top_earlgrey_verilator.cc
+++ b/hw/top_earlgrey/top_earlgrey_verilator.cc
@@ -5,12 +5,12 @@
 #include <iostream>
 
 #include "verilated_toplevel.h"
-#include "verilator_memutil.h"
+#include "verilator_memutil_extension.h"
 #include "verilator_sim_ctrl.h"
 
 int main(int argc, char **argv) {
   top_earlgrey_verilator top;
-  VerilatorMemUtil memutil;
+  VerilatorMemUtilExtension memutil;
   VerilatorSimCtrl &simctrl = VerilatorSimCtrl::GetInstance();
   simctrl.SetTop(&top, &top.clk_i, &top.rst_ni,
                  VerilatorSimCtrlFlags::ResetPolarityNegative);