Re-implement the python tracing facility. (#6318)

diff --git a/bindings/python/iree/runtime/CMakeLists.txt b/bindings/python/iree/runtime/CMakeLists.txt
index 225d7ad..d501783 100644
--- a/bindings/python/iree/runtime/CMakeLists.txt
+++ b/bindings/python/iree/runtime/CMakeLists.txt
@@ -36,6 +36,7 @@
     "__init__.py"
     "function.py"
     "system_api.py"
+    "tracing.py"
   PYEXT_DEPS
     ::PyExtRt
 )
@@ -67,6 +68,8 @@
   MODULE_PATH iree/runtime
   DEPS
     bindings_python_iree_runtime_PyExtRt
+  ADDL_PACKAGE_FILES
+    ${CMAKE_CURRENT_SOURCE_DIR}/README.md
 )
 
 install(
diff --git a/bindings/python/iree/runtime/README.md b/bindings/python/iree/runtime/README.md
new file mode 100644
index 0000000..cd21ffe
--- /dev/null
+++ b/bindings/python/iree/runtime/README.md
@@ -0,0 +1,23 @@
+# IREE Python Runtime Components
+
+This package provides an API for running compiled IREE binaries and interfacing with the hardware-abstraction-layer.
+
+## Tracing
+
+Execution of calls against binaries can be traced for later replay (i.e. via
+tools like `iree-run-module`). This can be set up either explicitly or
+via environment variables.
+
+To trace via environment variable, set `IREE_SAVE_CALLS` to a directory to dump
+traces into. Each created `SystemContext` will result in one `calls.yaml`
+file (with an index appended to the stem for multiples). Any referenced
+module binaries will be dumped into the same directory and referenced by the
+YAML file.
+
+### Explicit API
+
+```python
+tracer = iree.runtime.Tracer(some_dir)
+config = iree.runtime.Config(driver, tracer)
+...
+```
diff --git a/bindings/python/iree/runtime/__init__.py b/bindings/python/iree/runtime/__init__.py
index c0eb796..589cbf0 100644
--- a/bindings/python/iree/runtime/__init__.py
+++ b/bindings/python/iree/runtime/__init__.py
@@ -17,7 +17,6 @@
 from .binding import BufferUsage, HalBuffer, HalDevice, HalDriver, HalElementType, MemoryAccess, MemoryType, Shape
 # Vm imports
 from .binding import create_hal_module, Linkage, VmVariantList, VmFunction, VmInstance, VmContext, VmModule
-# SystemApi
 from .system_api import *
-# Function
 from .function import *
+from .tracing import *
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index 9f5b308..57537e7 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -5,12 +5,15 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from typing import Optional
+
 import json
 import logging
 
 import numpy as np
 
 from .binding import HalDevice, HalElementType, VmContext, VmFunction, VmVariantList
+from . import tracing
 
 __all__ = [
     "FunctionInvoker",
@@ -64,16 +67,19 @@
       "_arg_descs",
       "_ret_descs",
       "_has_kwargs",
+      "_tracer",
   ]
 
   def __init__(self, vm_context: VmContext, device: HalDevice,
-               vm_function: VmFunction):
+               vm_function: VmFunction,
+               tracer: Optional[tracing.ContextTracer]):
     self._vm_context = vm_context
     # TODO: Needing to know the precise device to allocate on here is bad
     # layering and will need to be fixed in some fashion if/when doing
     # heterogenous dispatch.
     self._device = device
     self._vm_function = vm_function
+    self._tracer = tracer
     self._abi_dict = None
     self._arg_descs = None
     self._ret_descs = None
@@ -85,6 +91,10 @@
     return self._vm_function
 
   def __call__(self, *args, **kwargs):
+    call_trace = None  # type: Optional[tracing.CallTrace]
+    if self._tracer:
+      call_trace = self._tracer.start_call(self._vm_function)
+
     # Initialize the capacity to our total number of args, since we should
     # be below that when doing a flat invocation. May want to be more
     # conservative here when considering nesting.
@@ -105,8 +115,14 @@
     arg_list = VmVariantList(len(args))
     ret_list = VmVariantList(len(ret_descs) if ret_descs is not None else 1)
     _merge_python_sequence_to_vm(inv, arg_list, args, self._arg_descs)
+    if call_trace:
+      call_trace.add_vm_list(arg_list, "args")
     self._vm_context.invoke(self._vm_function, arg_list, ret_list)
+    if call_trace:
+      call_trace.add_vm_list(ret_list, "results")
     returns = _extract_vm_sequence_to_python(inv, ret_list, ret_descs)
+    if call_trace:
+      call_trace.end_call()
     return_arity = len(returns)
     if return_arity == 1:
       return returns[0]
diff --git a/bindings/python/iree/runtime/setup.py.in b/bindings/python/iree/runtime/setup.py.in
index fc1f02e..4c84cc2 100644
--- a/bindings/python/iree/runtime/setup.py.in
+++ b/bindings/python/iree/runtime/setup.py.in
@@ -11,12 +11,8 @@
 from setuptools import setup, find_namespace_packages, Extension
 import sysconfig
 
-
-README = """# IREE Python Runtime Components
-
-This package provides an API for running compiled IREE binaries and interfacing
-with the hardware-abstraction-layer.
-"""
+with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f:
+  README = f.read()
 
 setup(
     name="iree-runtime@IREE_RELEASE_PACKAGE_SUFFIX@",
@@ -35,7 +31,8 @@
         "Development Status :: 3 - Alpha",
     ],
     python_requires=">=3.6",
-    packages=find_namespace_packages(include=["iree.runtime", "iree.runtime.*"]),
+    packages=find_namespace_packages(
+        include=["iree.runtime", "iree.runtime.*"]),
     ext_modules=[
         Extension(name="iree.runtime.binding", sources=[]),
     ],
diff --git a/bindings/python/iree/runtime/system_api.py b/bindings/python/iree/runtime/system_api.py
index cb7d0b3..3328ac3 100644
--- a/bindings/python/iree/runtime/system_api.py
+++ b/bindings/python/iree/runtime/system_api.py
@@ -28,6 +28,7 @@
     "TARGET_BACKEND_TO_DRIVER",
 ]
 
+import logging
 import os
 import sys
 
@@ -35,6 +36,7 @@
 
 from . import binding as _binding
 from .function import FunctionInvoker
+from . import tracing
 
 import numpy as np
 
@@ -109,14 +111,23 @@
   device: _binding.HalDevice
   vm_instance: _binding.VmInstance
   default_vm_modules: Tuple[_binding.VmModule, ...]
+  tracer: Optional[tracing.Tracer]
 
-  def __init__(self, driver_name: Optional[str] = None):
+  def __init__(self,
+               driver_name: Optional[str] = None,
+               tracer: Optional[tracing.Tracer] = None):
     self.vm_instance = _binding.VmInstance()
     self.driver = _create_default_iree_driver(
         driver_name.split(",") if driver_name is not None else None)
     self.device = self.driver.create_default_device()
     hal_module = _binding.create_hal_module(self.device)
     self.default_vm_modules = (hal_module,)
+    self.tracer = tracer or tracing.get_default_tracer()
+    if self.tracer and self.tracer.enabled:
+      logging.info("IREE runtime tracing calls to path: %s",
+                   self.tracer.trace_path)
+    else:
+      self.tracer = None
 
 
 _global_config = None
@@ -182,9 +193,15 @@
 
   def __init__(self, context: "SystemContext", vm_module: _binding.VmModule):
     self._context = context
+    self._tracer = self._context._config.tracer
     self._vm_module = vm_module
     self._lazy_functions = dict()
 
+    # Let the tracing infra create a traced module.
+    self.traced_module = None
+    if self._tracer:
+      self.traced_module = self._tracer.persist_vm_module(vm_module)
+
   @property
   def name(self):
     return self._vm_module.name
@@ -212,7 +229,8 @@
     # layering and will need to be fixed in some fashion if/when doing
     # heterogenous dispatch.
     return FunctionInvoker(self._context.vm_context,
-                           self._context.config.device, vm_function)
+                           self._context.config.device, vm_function,
+                           self._context._tracer)
 
   def __repr__(self):
     return f"<BoundModule {repr(self._vm_module)}>"
@@ -254,6 +272,13 @@
           (m.name, BoundModule(self, m)) for m in init_vm_modules
       ])
 
+    self._tracer = None  # type: Optional[tracing.ContextTracer]
+    if self._config.tracer:
+      self._tracer = tracing.ContextTracer(
+          self._config.tracer,
+          is_dynamic=self._is_dynamic,
+          modules=[bm.traced_module for bm in self._bound_modules.values()])
+
   @property
   def vm_context(self) -> _binding.VmContext:
     return self._vm_context
@@ -279,7 +304,10 @@
     for m in vm_modules:
       if m.name in self._bound_modules:
         raise ValueError(f"Attempt to register duplicate VmModule: '{m.name}'")
-      self._bound_modules[m.name] = BoundModule(self, m)
+      bound_module = BoundModule(self, m)
+      self._bound_modules[m.name] = bound_module
+      if self._tracer:
+        self._tracer.add_module(bound_module.traced_module)
     self._vm_context.register_modules(vm_modules)
 
   def add_vm_module(self, vm_module):
diff --git a/bindings/python/iree/runtime/system_api_test.py b/bindings/python/iree/runtime/system_api_test.py
index 61b6297..eeeab80 100644
--- a/bindings/python/iree/runtime/system_api_test.py
+++ b/bindings/python/iree/runtime/system_api_test.py
@@ -7,7 +7,9 @@
 
 # pylint: disable=unused-variable
 
+import os
 import re
+import tempfile
 
 from absl import logging
 from absl.testing import absltest
@@ -83,19 +85,34 @@
     results = f(arg0, arg1)
     np.testing.assert_allclose(results, [4., 10., 18., 28.])
 
-  # TODO: Re-implement tracing in a more sustainable fashion.
-  # def test_serialize_values(self):
-  #   ctx = iree.runtime.SystemContext()
-  #   self.assertTrue(ctx.is_dynamic)
-  #   ctx.add_vm_module(create_simple_mul_module())
-  #   self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
-  #   f = ctx.modules.arithmetic["simple_mul"]
-  #   arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
-  #   arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
-  #   results = f(arg0, arg1)
-  #   inputs, outputs = f.get_serialized_values()
-  #   self.assertEqual(inputs, ("4xf32=1 2 3 4", "4xf32=4 5 6 7"))
-  #   self.assertEqual(outputs, ("4xf32=4 10 18 28",))
+  def test_tracing_explicit(self):
+    with tempfile.TemporaryDirectory() as temp_dir:
+      tracer = iree.runtime.Tracer(temp_dir)
+      config = iree.runtime.Config("dylib", tracer=tracer)
+      self.verify_tracing(config, temp_dir)
+
+  def test_tracing_from_environment(self):
+    original = os.environ.get(iree.runtime.TRACE_PATH_ENV_KEY)
+    try:
+      with tempfile.TemporaryDirectory() as temp_dir:
+        os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = temp_dir
+        config = iree.runtime.Config("dylib")
+        self.verify_tracing(config, temp_dir)
+    finally:
+      if original:
+        os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = original
+
+  def verify_tracing(self, config, temp_dir):
+    logging.info("Tracing test to: %s", temp_dir)
+    ctx = iree.runtime.SystemContext(config=config)
+    ctx.add_vm_module(create_simple_mul_module())
+    f = ctx.modules.arithmetic["simple_mul"]
+    arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+    arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+    results = f(arg0, arg1)
+    self.assertTrue(os.path.exists(os.path.join(temp_dir, "arithmetic.vmfb")))
+    self.assertTrue(os.path.exists(os.path.join(temp_dir, "calls.yaml")))
+    # TODO: Once replay is possible, verify that.
 
   def test_load_vm_module(self):
     arithmetic = iree.runtime.load_vm_module(create_simple_mul_module())
diff --git a/bindings/python/iree/runtime/tracing.py b/bindings/python/iree/runtime/tracing.py
new file mode 100644
index 0000000..08b9891
--- /dev/null
+++ b/bindings/python/iree/runtime/tracing.py
@@ -0,0 +1,168 @@
+"""Tracing support."""
+
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from genericpath import exists
+from typing import Dict, List, Optional, Sequence
+
+import logging
+import os
+import sys
+
+from . import binding as _binding
+
+try:
+  import yaml
+except ModuleNotFoundError:
+  _has_yaml = False
+else:
+  _has_yaml = True
+
+__all__ = [
+    "get_default_tracer",
+    "Tracer",
+    "TRACE_PATH_ENV_KEY",
+]
+
+TRACE_PATH_ENV_KEY = "IREE_SAVE_CALLS"
+
+
+class Tracer:
+  """Object for tracing calls made into the runtime."""
+
+  def __init__(self, trace_path: str):
+    if not _has_yaml:
+      self.enabled = False
+      logging.warning("PyYAML not installed: tracing will be disabled")
+      return
+    self.enabled = True
+    self.trace_path = trace_path
+    os.makedirs(trace_path, exist_ok=True)
+    self._name_count = dict()  # type: Dict[str, int]
+
+  def persist_vm_module(self, vm_module: _binding.VmModule) -> "TracedModule":
+    # Depending on how the module was created, there are different bits
+    # of information available to reconstruct.
+    name = vm_module.name
+    flatbuffer_blob = vm_module.stashed_flatbuffer_blob
+    if flatbuffer_blob:
+      save_path = os.path.join(self.trace_path,
+                               self.get_unique_name(f"{name}.vmfb"))
+      logging.info("Saving traced vmfb to %s", save_path)
+      with open(save_path, "wb") as f:
+        f.write(flatbuffer_blob)
+      return TracedModule(self, vm_module, save_path)
+
+    # No persistent form, but likely they are built-in modules.
+    return TracedModule(self, vm_module)
+
+  def get_unique_name(self, local_name: str) -> str:
+    if local_name not in self._name_count:
+      self._name_count[local_name] = 1
+      return local_name
+    stem, ext = os.path.splitext(local_name)
+    index = self._name_count[local_name]
+    self._name_count[local_name] += 1
+    unique_name = f"{stem}__{index}{ext}"
+    return unique_name
+
+
+class TracedModule:
+  """Wraps a VmModule with additional information for tracing."""
+
+  def __init__(self,
+               parent: Tracer,
+               vm_module: _binding.VmModule,
+               vmfb_path: Optional[str] = None):
+    self._parent = parent
+    self._vm_module = vm_module
+    self._vmfb_path = vmfb_path
+
+  def serialize(self):
+    module_record = {"name": self._vm_module.name}
+    if self._vmfb_path:
+      module_record["type"] = "bytecode"
+      module_record["vmfb_path"] = os.path.relpath(self._vmfb_path,
+                                                   self._parent.trace_path)
+    else:
+      module_record["type"] = "native"
+
+    return module_record
+
+
+class ContextTracer:
+  """Traces invocations against a context."""
+
+  def __init__(self, parent: Tracer, is_dynamic: bool,
+               modules: Sequence[TracedModule]):
+    self._parent = parent
+    self._modules = list(modules)  # type: List[TracedModule]
+    self._frame_count = 0
+    self._file_path = os.path.join(parent.trace_path,
+                                   parent.get_unique_name("calls.yaml"))
+    if os.path.exists(self._file_path):
+      # Truncate the file.
+      with open(self._file_path, "wt"):
+        pass
+    else:
+      os.makedirs(os.path.dirname(parent.trace_path), exist_ok=True)
+    logging.info("Tracing context events to: %s", self._file_path)
+    self.emit_frame({
+        "type": "init_context",
+        "is_dynamic": is_dynamic,
+        "modules": [m.serialize() for m in self._modules],
+    })
+
+  def add_module(self, module: TracedModule):
+    self._modules.append(module)
+    self.emit_frame({
+        "type": "add_module",
+        "module": module.serialize(),
+    })
+
+  def start_call(self, function: _binding.VmFunction):
+    logging.info("Tracing call to %s.%s", function.module_name, function.name)
+
+    # Start assembling the call record.
+    record = {
+        "type": "call",
+        "module_name": function.module_name,
+        "function_name": function.name,
+    }
+    return CallTrace(self, record)
+
+  def emit_frame(self, frame: dict):
+    self._frame_count += 1
+    with open(self._file_path, "at") as f:
+      if self._frame_count != 1:
+        f.write("---\n")
+      contents = yaml.dump(frame, sort_keys=False)
+      f.write(contents)
+
+
+class CallTrace:
+
+  def __init__(self, parent: ContextTracer, record: dict):
+    self._parent = parent
+    self._record = record
+
+  def add_vm_list(self, vm_list: _binding.VmVariantList, key: str):
+    mapped = []
+    for i in range(len(vm_list)):
+      mapped.append(vm_list.get_serialized_trace_value(i))
+    self._record[key] = mapped
+
+  def end_call(self):
+    self._parent.emit_frame(self._record)
+
+
+def get_default_tracer() -> Optional[Tracer]:
+  """Gets a default call tracer based on environment variables."""
+  default_path = os.getenv(TRACE_PATH_ENV_KEY)
+  if not default_path:
+    return None
+  return Tracer(default_path)
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index bfcba19..1379669 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -120,7 +120,8 @@
 // VmModule
 //------------------------------------------------------------------------------
 
-VmModule VmModule::FromFlatbufferBlob(py::buffer flatbuffer_blob) {
+VmModule VmModule::FromFlatbufferBlob(py::object flatbuffer_blob_object) {
+  auto flatbuffer_blob = py::cast<py::buffer>(flatbuffer_blob_object);
   auto buffer_info = flatbuffer_blob.request();
   iree_vm_module_t* module;
 
@@ -143,7 +144,9 @@
   }
 
   CheckApiStatus(status, "Error creating vm module from flatbuffer");
-  return VmModule::CreateRetained(module);
+  auto py_module = VmModule::CreateRetained(module);
+  py_module.stashed_flatbuffer_blob = flatbuffer_blob_object;
+  return py_module;
 }
 
 std::optional<iree_vm_function_t> VmModule::LookupFunction(
@@ -294,6 +297,109 @@
   throw RaiseValueError("Unsupported VM to Python Type Conversion");
 }
 
+py::object VmVariantList::GetAsSerializedTraceValue(int index) {
+  iree_vm_variant_t v = iree_vm_variant_empty();
+  CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
+                 "Could not access list element");
+  if (iree_vm_type_def_is_value(&v.type)) {
+    // Convert a value type.
+    py::dict record;
+    switch (v.type.value_type) {
+      case IREE_VM_VALUE_TYPE_I8:
+        record["i8"] = py::cast(v.i8);
+        break;
+      case IREE_VM_VALUE_TYPE_I16:
+        record["i16"] = py::cast(v.i16);
+        break;
+      case IREE_VM_VALUE_TYPE_I32:
+        record["i32"] = py::cast(v.i32);
+        break;
+      case IREE_VM_VALUE_TYPE_I64:
+        record["i64"] = py::cast(v.i64);
+        break;
+      case IREE_VM_VALUE_TYPE_F32:
+        record["f32"] = py::cast(v.f32);
+        break;
+      case IREE_VM_VALUE_TYPE_F64:
+        record["f64"] = py::cast(v.f64);
+        break;
+      default:
+        throw RaiseValueError("Unsupported VM value type conversion");
+    }
+    record["type"] = py::cast("value");
+    return std::move(record);
+  } else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) {
+    py::dict record;
+    record["type"] = "null";
+    return std::move(record);
+  } else if (iree_vm_type_def_is_ref(&v.type)) {
+    // Convert reference type.
+    if (iree_vm_list_isa(v.ref)) {
+      py::dict record;
+      record["type"] = "list";
+      py::list items;
+      iree_vm_list_t* sub_list = NULL;
+      CheckApiStatus(iree_vm_list_check_deref(v.ref, &sub_list),
+                     "Could not deref list (wrong type?)");
+      iree_vm_list_retain(sub_list);
+      VmVariantList sub_list_object(sub_list);
+      for (int i = 0, e = sub_list_object.size(); i < e; ++i) {
+        items.append(sub_list_object.GetAsSerializedTraceValue(i));
+      }
+      record["items"] = std::move(items);
+      return std::move(record);
+    } else if (iree_hal_buffer_view_isa(v.ref)) {
+      py::dict record;
+      record["type"] = "buffer_view";
+      iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(v.ref);
+      if (!buffer_view) {
+        throw RaiseValueError(
+            "Could not deref result buffer view (wrong type?)");
+      }
+      iree_hal_buffer_t* raw_buffer = iree_hal_buffer_view_buffer(buffer_view);
+      if (!raw_buffer) {
+        throw RaiseValueError("Could not deref result buffer (wrong type?)");
+      }
+
+      // Extract dims from the buffer view.
+      size_t rank = 0;
+      std::vector<int32_t> dims(6);
+      iree_status_t status = iree_hal_buffer_view_shape(
+          buffer_view, dims.capacity(), dims.data(), &rank);
+      if (iree_status_is_out_of_range(status)) {
+        dims.resize(rank);
+        status = iree_hal_buffer_view_shape(buffer_view, dims.capacity(),
+                                            dims.data(), &rank);
+      }
+      CheckApiStatus(status, "Error extracting shape");
+      dims.resize(rank);
+      record["shape"] = py::cast(std::move(dims));
+
+      // Element type.
+      iree_hal_element_type_t element_type =
+          iree_hal_buffer_view_element_type(buffer_view);
+      // TODO: Would be nice to output as hex.
+      record["element_type"] = element_type;
+
+      // Map memory.
+      iree_device_size_t byte_length = iree_hal_buffer_byte_length(raw_buffer);
+      iree_hal_buffer_mapping_t mapped_memory;
+      CheckApiStatus(iree_hal_buffer_map_range(
+                         raw_buffer, IREE_HAL_MEMORY_ACCESS_READ,
+                         0 /* element_offset */, byte_length, &mapped_memory),
+                     "Could not map memory");
+      record["contents"] =
+          py::bytes(reinterpret_cast<const char*>(mapped_memory.contents.data),
+                    mapped_memory.contents.data_length);
+      iree_hal_buffer_unmap_range(&mapped_memory);
+
+      return std::move(record);
+    }
+  }
+
+  throw RaiseValueError("Unsupported VM to Python Type Conversion");
+}
+
 py::object VmVariantList::GetAsNdarray(int index) {
   iree_vm_variant_t v = iree_vm_variant_empty();
   CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
@@ -492,6 +598,8 @@
       .def("get_as_ndarray", &VmVariantList::GetAsNdarray)
       .def("get_as_list", &VmVariantList::GetAsList)
       .def("get_variant", &VmVariantList::GetVariant)
+      .def("get_serialized_trace_value",
+           &VmVariantList::GetAsSerializedTraceValue)
       .def("push_float", &VmVariantList::PushFloat)
       .def("push_int", &VmVariantList::PushInt)
       .def("push_list", &VmVariantList::PushList)
@@ -501,6 +609,18 @@
   py::class_<iree_vm_function_t>(m, "VmFunction")
       .def_readonly("linkage", &iree_vm_function_t::linkage)
       .def_readonly("ordinal", &iree_vm_function_t::ordinal)
+      .def_property_readonly("name",
+                             [](iree_vm_function_t& self) {
+                               iree_string_view_t name =
+                                   iree_vm_function_name(&self);
+                               return py::str(name.data, name.size);
+                             })
+      .def_property_readonly("module_name",
+                             [](iree_vm_function_t& self) {
+                               iree_string_view_t name =
+                                   iree_vm_module_name(self.module);
+                               return py::str(name.data, name.size);
+                             })
       .def_property_readonly("reflection",
                              [](iree_vm_function_t& self) {
                                return GetFunctionReflectionDict(self);
@@ -534,6 +654,9 @@
       .def_property_readonly("name", &VmModule::name)
       .def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
            py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT)
+      .def_property_readonly(
+          "stashed_flatbuffer_blob",
+          [](VmModule& self) { return self.get_stashed_flatbuffer_blob(); })
       .def("__repr__", [](VmModule& self) {
         std::string repr("<VmModule ");
         iree_string_view_t name = iree_vm_module_name(self.raw_ptr());
diff --git a/bindings/python/iree/runtime/vm.h b/bindings/python/iree/runtime/vm.h
index 33da644..caf69db 100644
--- a/bindings/python/iree/runtime/vm.h
+++ b/bindings/python/iree/runtime/vm.h
@@ -99,6 +99,7 @@
   py::object GetAsList(int index);
   py::object GetAsNdarray(int index);
   py::object GetVariant(int index);
+  py::object GetAsSerializedTraceValue(int index);
 
  private:
   VmVariantList(iree_vm_list_t* list) : list_(list) {}
@@ -116,7 +117,7 @@
 
 class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
  public:
-  static VmModule FromFlatbufferBlob(py::buffer flatbuffer_blob);
+  static VmModule FromFlatbufferBlob(py::object flatbuffer_blob_object);
 
   std::optional<iree_vm_function_t> LookupFunction(
       const std::string& name, iree_vm_function_linkage_t linkage);
@@ -125,6 +126,12 @@
     auto name_sv = iree_vm_module_name(raw_ptr());
     return std::string(name_sv.data, name_sv.size);
   }
+
+  py::object get_stashed_flatbuffer_blob() { return stashed_flatbuffer_blob; }
+
+ private:
+  // If the module was created from a flatbuffer blob, we stash it here.
+  py::object stashed_flatbuffer_blob = py::none();
 };
 
 class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> {
diff --git a/bindings/python/iree/runtime/vm_test.py b/bindings/python/iree/runtime/vm_test.py
index 81672c5..e29b2e1 100644
--- a/bindings/python/iree/runtime/vm_test.py
+++ b/bindings/python/iree/runtime/vm_test.py
@@ -140,38 +140,32 @@
     logging.info("context: %s", context)
 
   def test_add_scalar_new_abi(self):
-    # TODO: Enable with new ABI.
-    return
     m = create_add_scalar_module()
     instance = iree.runtime.VmInstance()
     context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
     f = m.lookup_function("add_scalar")
-    finv = iree.runtime.FunctionInvoker(context, self.device, f)
+    finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
     result = finv(5, 6)
     logging.info("result: %s", result)
     self.assertEqual(result, 11)
 
   def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
-    # TODO: Enable with new ABI.
-    return
     m = create_simple_dynamic_abs_module()
     instance = iree.runtime.VmInstance()
     context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
     f = m.lookup_function("simple_mul")
-    finv = iree.runtime.FunctionInvoker(context, self.device, f)
+    finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
     arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
     result = finv(arg0)
     logging.info("result: %s", result)
     np.testing.assert_allclose(result, [[1., 2.], [3., 4.]])
 
   def test_synchronous_invoke_function_new_abi(self):
-    # TODO: Enable with new ABI.
-    return
     m = create_simple_static_mul_module()
     instance = iree.runtime.VmInstance()
     context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
     f = m.lookup_function("simple_mul")
-    finv = iree.runtime.FunctionInvoker(context, self.device, f)
+    finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
     arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
     arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
     result = finv(arg0, arg1)