Add low level interface for defining VM modules in Python. (#9964)

* Add low level interface for defining VM modules in Python.

Inspired by #9947, which shows how to extend IREE with custom C modules,
I just had to do the plumbing to also allow extension from Python. Usual
caveats apply about Python being a terrible thing to implement real
things in, but there is no escaping how easy it is -- in my experience,
people will start by defining simple Python modules and then if we line
that up seamlessly with how you do it in C, it creates a good glide
path to progressively enhance.

This is just the start. Specifically, it does not yet:

* High level sugar: this is a low level interface that is aligned with
  the VM. It is TBD to add some nice-feeling decorators and such.
* Custom types not yet plugged through, but I expect that this could be
  pretty nicely done and interop well with Python.
* There's likely a nice story for integrating with the compiler side,
  but it needs more thought (i.e. custom ops directly from Python, etc)
  to make automatic. Also involves the usual grind of handling the ABI
  type grid, which I don't have time for today.
diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index b00bab3..888149f 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -34,6 +34,8 @@
     "invoke.cc"
     "hal.h"
     "hal.cc"
+    "py_module.h"
+    "py_module.cc"
     "status_utils.cc"
     "status_utils.h"
     "vm.h"
@@ -143,6 +145,13 @@
 
 iree_py_test(
   NAME
+    py_module_test
+  SRCS
+    "tests/py_module_test.py"
+)
+
+iree_py_test(
+  NAME
     system_api_test
   SRCS
     "tests/system_api_test.py"
@@ -162,6 +171,13 @@
     "tests/vm_test.py"
 )
 
+iree_py_test(
+  NAME
+    vm_types_test
+  SRCS
+    "tests/vm_types_test.py"
+)
+
 # TODO: Enable this once the CI bots are updated to install the python3-venv
 # apt package. https://github.com/iree-org/iree/issues/9080
 # iree_py_test(
diff --git a/runtime/bindings/python/binding.h b/runtime/bindings/python/binding.h
index 64240d4..feb1722 100644
--- a/runtime/bindings/python/binding.h
+++ b/runtime/bindings/python/binding.h
@@ -25,6 +25,7 @@
 template <typename Self, typename T>
 class ApiRefCounted {
  public:
+  using RawPtrType = T*;
   ApiRefCounted() : instance_(nullptr) {}
   ApiRefCounted(ApiRefCounted& other) : instance_(other.instance_) { Retain(); }
   ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
diff --git a/runtime/bindings/python/initialize_module.cc b/runtime/bindings/python/initialize_module.cc
index db9d242..09e9bec 100644
--- a/runtime/bindings/python/initialize_module.cc
+++ b/runtime/bindings/python/initialize_module.cc
@@ -7,6 +7,7 @@
 #include "./binding.h"
 #include "./hal.h"
 #include "./invoke.h"
+#include "./py_module.h"
 #include "./status_utils.h"
 #include "./vm.h"
 #include "iree/base/internal/flags.h"
@@ -23,6 +24,7 @@
   m.doc() = "IREE Binding Backend Helpers";
   SetupHalBindings(m);
   SetupInvokeBindings(m);
+  SetupPyModuleBindings(m);
   SetupVmBindings(m);
 
   m.def("parse_flags", [](py::args py_flags) {
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index 660c53f..c7d22c3 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -25,6 +25,7 @@
     HalElementType,
     MemoryAccess,
     MemoryType,
+    PyModuleInterface,
     Shape,
 )
 
diff --git a/runtime/bindings/python/py_module.cc b/runtime/bindings/python/py_module.cc
new file mode 100644
index 0000000..0f3dd83
--- /dev/null
+++ b/runtime/bindings/python/py_module.cc
@@ -0,0 +1,462 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./py_module.h"
+
+#include <string_view>
+
+#include "./vm.h"
+#include "iree/vm/native_module.h"
+
+namespace iree::python {
+
+// Low level class for constructing a native VM module from Python. This
+// class is mutable while the module is being setup and will typically
+// produce a module instance when ready to be used.
+//
+// This class has a complicated life-cycle and can be in one of several
+// states:
+//   UNINITIALZED: Prior to calling Create(). Mutable.
+//   INITIALIZED: After calling Create() and prior to the returned reference
+//     being released. Immutable.
+//   DESTROYED: After the reference from Create() is released. Nothing
+//     more can be done with the instance but it is still live until the
+//     Python reference to it is released.
+class PyModuleInterface {
+ public:
+  PyModuleInterface(std::string module_name, py::object ctor)
+      : module_name_(std::move(module_name)), ctor_(std::move(ctor)) {
+    CheckApiStatus(iree_vm_module_initialize(&interface_, this),
+                   "Failed to initialize vm_module");
+    interface_.destroy = &PyModuleInterface::ModuleDestroy;
+    interface_.name = &PyModuleInterface::ModuleName;
+    interface_.signature = &PyModuleInterface::ModuleSignature;
+    interface_.get_function = &PyModuleInterface::ModuleGetFunction;
+    interface_.lookup_function = &PyModuleInterface::ModuleLookupFunction;
+    interface_.alloc_state = &PyModuleInterface::ModuleAllocState;
+    interface_.free_state = &PyModuleInterface::ModuleFreeState;
+    interface_.resolve_import = &PyModuleInterface::ModuleResolveImport;
+    interface_.notify = &PyModuleInterface::ModuleNotify;
+    interface_.begin_call = &PyModuleInterface::ModuleBeginCall;
+  }
+  PyModuleInterface(const PyModuleInterface &) = delete;
+  ~PyModuleInterface() = default;
+
+  static PyModuleInterface *AsSelf(void *vself) {
+    return static_cast<PyModuleInterface *>(vself);
+  }
+
+  static void ModuleDestroy(void *vself) {
+    auto self = AsSelf(vself);
+    py::gil_scoped_acquire acquire;
+    self->retained_self_ref_ = {};
+  }
+
+  static iree_string_view_t ModuleName(void *vself) {
+    auto self = AsSelf(vself);
+    return {self->module_name_.data(), self->module_name_.size()};
+  }
+
+  static iree_vm_module_signature_t ModuleSignature(void *vself) {
+    auto self = AsSelf(vself);
+    iree_vm_module_signature_t signature = {0};
+    signature.import_function_count = 0;
+    signature.export_function_count = self->exports_.size();
+    signature.internal_function_count = 0;
+    return signature;
+  }
+
+  static iree_status_t ModuleGetFunction(
+      void *vself, iree_vm_function_linkage_t linkage, iree_host_size_t ordinal,
+      iree_vm_function_t *out_function, iree_string_view_t *out_name,
+      iree_vm_function_signature_t *out_signature) {
+    auto self = AsSelf(vself);
+    if (IREE_LIKELY(linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT)) {
+      if (IREE_LIKELY(ordinal < self->export_functions_.size())) {
+        std::unique_ptr<PyFunction> &f = self->export_functions_[ordinal];
+        if (IREE_LIKELY(out_function)) {
+          out_function->linkage = linkage;
+          out_function->module = &self->interface_;
+          out_function->ordinal = ordinal;
+        }
+        if (IREE_LIKELY(out_name)) {
+          *out_name = {f->name.data(), f->name.size()};
+        }
+        if (IREE_LIKELY(out_signature)) {
+          out_signature->calling_convention = {f->cconv.data(),
+                                               f->cconv.size()};
+        }
+        return iree_ok_status();
+      }
+    }
+    return iree_make_status(IREE_STATUS_NOT_FOUND);
+  }
+
+  static iree_status_t ModuleLookupFunction(void *vself,
+                                            iree_vm_function_linkage_t linkage,
+                                            iree_string_view_t name,
+                                            iree_vm_function_t *out_function) {
+    auto self = AsSelf(vself);
+    std::string_view name_cpp(name.data, name.size);
+    if (linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT) {
+      auto found_it = self->export_name_to_ordinals_.find(name_cpp);
+      if (found_it != self->export_name_to_ordinals_.end()) {
+        out_function->linkage = linkage;
+        out_function->module = &self->interface_;
+        out_function->ordinal = found_it->second;
+        return iree_ok_status();
+      }
+    }
+    return iree_make_status(IREE_STATUS_NOT_FOUND, "function %.*s not exported",
+                            (int)name.size, name.data);
+  }
+
+  static iree_status_t ModuleAllocState(
+      void *vself, iree_allocator_t allocator,
+      iree_vm_module_state_t **out_module_state) {
+    auto self = AsSelf(vself);
+    *out_module_state = nullptr;
+    py::gil_scoped_acquire acquire;
+    try {
+      py::object py_state = self->ctor_(self->retained_self_ref_);
+      // Steal the reference and use the raw PyObject* as the state.
+      // This will be released in ModuleFreeState.
+      *out_module_state =
+          reinterpret_cast<iree_vm_module_state_t *>(py_state.release().ptr());
+      return iree_ok_status();
+    } catch (std::exception &e) {
+      return iree_make_status(IREE_STATUS_UNKNOWN,
+                              "Exception in call to PyModule constructor: %s",
+                              e.what());
+    }
+  }
+
+  static void ModuleFreeState(void *vself,
+                              iree_vm_module_state_t *module_state) {
+    py::gil_scoped_acquire acquire;
+    // Release the reference stolen in ModuleAllocState.
+    auto retained_handle =
+        py::handle(reinterpret_cast<PyObject *>(module_state));
+    retained_handle.dec_ref();
+  }
+
+  static iree_status_t ModuleResolveImport(
+      void *vself, iree_vm_module_state_t *module_state,
+      iree_host_size_t ordinal, const iree_vm_function_t *function,
+      const iree_vm_function_signature_t *signature) {
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "Python API does not support imports");
+  }
+
+  static iree_status_t ModuleNotify(void *vself,
+                                    iree_vm_module_state_t *module_state,
+                                    iree_vm_signal_t signal) {
+    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                            "ModuleNotify not implemented");
+  }
+
+  static iree_status_t ModuleBeginCall(void *vself, iree_vm_stack_t *stack,
+                                       iree_vm_function_call_t call) {
+    auto self = AsSelf(vself);
+    if (IREE_UNLIKELY(call.function.ordinal >=
+                      self->export_functions_.size())) {
+      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                              "function ordinal out of bounds: 0 < %u < %zu",
+                              call.function.ordinal,
+                              self->export_functions_.size());
+    }
+
+    auto &f = self->export_functions_[call.function.ordinal];
+    iree_host_size_t frame_size = 0;
+    iree_vm_stack_frame_t *callee_frame = nullptr;
+    IREE_RETURN_IF_ERROR(iree_vm_stack_function_enter(
+        stack, &call.function, IREE_VM_STACK_FRAME_NATIVE, frame_size,
+        /*frame_cleanup_fn=*/nullptr, &callee_frame));
+    auto state_object =
+        py::handle(reinterpret_cast<PyObject *>(callee_frame->module_state));
+
+    try {
+      IREE_RETURN_IF_ERROR(self->Invoke(*f, state_object, stack, call));
+    } catch (std::exception &e) {
+      return iree_make_status(IREE_STATUS_UNKNOWN,
+                              "Exception raised from Python module: %s",
+                              e.what());
+    }
+
+    return iree_ok_status();
+  }
+
+  std::string ToString() {
+    std::string s("<iree.runtime.PyModuleInterface '");
+    s.append(module_name_);
+    s.append("'");
+    if (initialized_) {
+      if (retained_self_ref_) {
+        s.append(" initialized");
+      } else {
+        s.append(" destroyed");
+      }
+    }
+    s.append(">");
+    return s;
+  }
+
+  bool initialized() { return initialized_; }
+
+  bool destroyed() { return initialized_ && !retained_self_ref_; }
+
+  void AssertMutable() {
+    if (initialized_) {
+      throw std::runtime_error("Attempt to mutate a frozen PyModuleInterface");
+    }
+  }
+
+  void ExportFunction(std::string name, std::string cconv,
+                      py::object callable) {
+    // Make sure not already defined.
+    if (export_name_to_ordinals_.count(name)) {
+      std::string msg("PyModule function already defined: ");
+      msg.append(name);
+      throw std::invalid_argument(std::move(msg));
+    }
+
+    // Heap allocate the backing PyFunction so we can reference its pointers.
+    size_t ordinal = exports_.size();
+    auto py_function = std::make_unique<PyFunction>(
+        std::move(name), std::move(cconv), std::move(callable));
+    exports_.push_back({});
+    iree_vm_native_export_descriptor_t &d = exports_.back();
+    d.local_name = {py_function->name.data(), py_function->name.size()};
+    d.calling_convention = {py_function->cconv.data(),
+                            py_function->cconv.size()};
+    d.attr_count = 0;
+    d.attrs = nullptr;
+    std::string &alloced_name = py_function->name;
+    CheckApiStatus(py_function->ParseCconv(), "Unparseable calling convention");
+
+    // Transfer the PyFunction to its vector now that we are done touching it.
+    export_functions_.push_back(std::move(py_function));
+    export_name_to_ordinals_.insert(
+        std::make_pair(std::string_view(alloced_name), ordinal));
+  }
+
+  // Initializes the internal data structures such that GetInterface() will be
+  // valid. After this call, the interface is "live" and this instance will only
+  // be deleted when its refcnt goes to 0, which will call ModuleDestroy and
+  // release our Python side reference to this.
+  void Initialize() {
+    AssertMutable();
+    initialized_ = true;
+    memset(&descriptor_, 0, sizeof(descriptor_));
+    descriptor_.module_name = {module_name_.data(), module_name_.size()};
+    descriptor_.module_attr_count = attrs_.size();
+    descriptor_.module_attrs = attrs_.empty() ? nullptr : attrs_.data();
+    descriptor_.import_count = imports_.size();
+    descriptor_.imports = imports_.empty() ? nullptr : imports_.data();
+    descriptor_.export_count = exports_.size();
+    descriptor_.exports = exports_.empty() ? nullptr : exports_.data();
+    descriptor_.function_count = functions_.size();
+    descriptor_.functions = functions_.empty() ? nullptr : functions_.data();
+    retained_self_ref_ = py::cast(this);
+  }
+
+  // Creates the live Python VmModule reference. This can only be called once.
+  VmModule Create() {
+    Initialize();
+    return VmModule::StealFromRawPtr(&interface_);
+  }
+
+ private:
+  struct PyFunction {
+    std::string name;
+    std::string cconv;
+    py::object callable;
+
+    // Initialized by ParseCconv.
+    iree_string_view_t cconv_arguments;
+    iree_string_view_t cconv_results;
+
+    PyFunction(std::string name, std::string cconv, py::object callable)
+        : name(std::move(name)),
+          cconv(std::move(cconv)),
+          callable(std::move(callable)) {}
+
+    iree_status_t ParseCconv() {
+      iree_vm_function_signature_t signature;
+      memset(&signature, 0, sizeof(signature));
+      signature.calling_convention = {cconv.data(), cconv.size()};
+      IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments(
+          &signature, &cconv_arguments, &cconv_results));
+
+      if (iree_vm_function_call_is_variadic_cconv(cconv_arguments) ||
+          iree_vm_function_call_is_variadic_cconv(cconv_results)) {
+        return iree_make_status(
+            IREE_STATUS_INVALID_ARGUMENT,
+            "PyModules do not yet support variadic arguments/results");
+      }
+
+      return iree_ok_status();
+    }
+  };
+
+  iree_status_t Invoke(PyFunction &f, py::handle state_object,
+                       iree_vm_stack_t *stack, iree_vm_function_call_t call) {
+    py::gil_scoped_acquire acquire;
+    uint8_t *packed_arguments = call.arguments.data;
+    iree_host_size_t packed_arguments_required_size;
+    // TODO: Is this validation needed or do we assume it from up-stack?
+    IREE_RETURN_IF_ERROR(iree_vm_function_call_compute_cconv_fragment_size(
+        f.cconv_arguments, /*segment_size_list=*/nullptr,
+        &packed_arguments_required_size));
+    if (IREE_UNLIKELY(packed_arguments_required_size !=
+                      call.arguments.data_length)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "mismatched packed argument size: actual=%zu, required=%zu",
+          call.arguments.data_length, packed_arguments_required_size);
+    }
+
+    // Unpack arguments.
+    py::list arguments;
+    for (iree_host_size_t i = 0; i < f.cconv_arguments.size; ++i) {
+      switch (f.cconv_arguments.data[i]) {
+        case IREE_VM_CCONV_TYPE_VOID:
+          break;
+        case IREE_VM_CCONV_TYPE_I32:
+          arguments.append(
+              py::cast(*reinterpret_cast<int32_t *>(packed_arguments)));
+          packed_arguments += sizeof(int32_t);
+          break;
+        case IREE_VM_CCONV_TYPE_F32:
+          arguments.append(
+              py::cast(*reinterpret_cast<float *>(packed_arguments)));
+          packed_arguments += sizeof(float);
+          break;
+        case IREE_VM_CCONV_TYPE_I64:
+          arguments.append(
+              py::cast(*reinterpret_cast<int64_t *>(packed_arguments)));
+          packed_arguments += sizeof(int64_t);
+          break;
+        case IREE_VM_CCONV_TYPE_F64:
+          arguments.append(
+              py::cast(*reinterpret_cast<double *>(packed_arguments)));
+          packed_arguments += sizeof(double);
+          break;
+        case IREE_VM_CCONV_TYPE_REF: {
+          iree_vm_ref_t ref =
+              *reinterpret_cast<iree_vm_ref_t *>(packed_arguments);
+          // Since the Python level VmRef can escape, it needs its own ref
+          // count.
+          VmRef py_ref;
+          iree_vm_ref_retain(&ref, &py_ref.ref());
+          arguments.append(py::cast(py_ref, py::return_value_policy::move));
+          packed_arguments += sizeof(iree_vm_ref_t);
+          break;
+        }
+        // TODO: Variadic segments.
+        default:
+          return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                                  "unsupported cconv type %c",
+                                  f.cconv_arguments.data[i]);
+      }
+    }
+
+    auto results = f.callable(state_object, *arguments);
+
+    // Pack results.
+    if (f.cconv_results.size == 0) {
+      return iree_ok_status();
+    }
+    uint8_t *packed_results = call.results.data;
+    bool unary_result = f.cconv_results.size == 1;
+    auto pack_result = [&](py::object &value,
+                           char cconv_type) -> iree_status_t {
+      switch (cconv_type) {
+        case IREE_VM_CCONV_TYPE_VOID:
+          break;
+        case IREE_VM_CCONV_TYPE_I32:
+          *reinterpret_cast<int32_t *>(packed_results) =
+              py::cast<int32_t>(value);
+          packed_results += sizeof(int32_t);
+          break;
+        case IREE_VM_CCONV_TYPE_F32:
+          *reinterpret_cast<float *>(packed_results) = py::cast<float>(value);
+          packed_results += sizeof(float);
+          break;
+        case IREE_VM_CCONV_TYPE_I64:
+          *reinterpret_cast<int64_t *>(packed_results) =
+              py::cast<int64_t>(value);
+          packed_results += sizeof(int64_t);
+          break;
+        case IREE_VM_CCONV_TYPE_F64:
+          *reinterpret_cast<double *>(packed_results) = py::cast<double>(value);
+          packed_results += sizeof(double);
+          break;
+        case IREE_VM_CCONV_TYPE_REF: {
+          iree_vm_ref_t *result_ref =
+              reinterpret_cast<iree_vm_ref_t *>(packed_results);
+          VmRef *py_ref = py::cast<VmRef *>(value);
+          iree_vm_ref_retain(&py_ref->ref(), result_ref);
+          packed_results += sizeof(iree_vm_ref_t);
+          break;
+        }
+        // TODO: Refs (need a generic Python ref wrapper).
+        // TODO: Variadic segments.
+        default:
+          return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+                                  "unsupported cconv type %c", cconv_type);
+      }
+      return iree_ok_status();
+    };
+
+    if (unary_result) {
+      return pack_result(results, f.cconv_results.data[0]);
+    } else {
+      py::sequence results_seq = py::cast<py::sequence>(results);
+      int result_index = 0;
+      for (iree_host_size_t i = 0; i < f.cconv_results.size; ++i) {
+        py::object next_result = results_seq[result_index++];
+        IREE_RETURN_IF_ERROR(pack_result(next_result, f.cconv_results.data[i]));
+      }
+      return iree_ok_status();
+    }
+  }
+
+  // Descriptor state is built up when mutable and then will be populated
+  // on the descriptor when frozen.
+  std::string module_name_;
+  py::object ctor_;
+  std::vector<iree_string_pair_t> attrs_;
+  std::vector<iree_vm_native_import_descriptor_t> imports_;
+  std::vector<iree_vm_native_export_descriptor_t> exports_;
+  std::vector<std::unique_ptr<PyFunction>> export_functions_;
+  std::vector<iree_vm_native_function_ptr_t> functions_;
+
+  // Map of names to ordinals.
+  std::unordered_map<std::string_view, int> export_name_to_ordinals_;
+
+  // Once the builder is frozen, the descriptor will be valid.
+  iree_vm_module_t interface_;
+  iree_vm_native_module_descriptor_t descriptor_;
+
+  // Read-only and descriptor populated when frozen.
+  bool initialized_ = false;
+  py::object retained_self_ref_;
+};
+
+void SetupPyModuleBindings(py::module &m) {
+  py::class_<PyModuleInterface>(m, "PyModuleInterface")
+      .def(py::init<std::string, py::object>(), py::arg("module_name"),
+           py::arg("ctor"))
+      .def("__str__", &PyModuleInterface::ToString)
+      .def_property_readonly("initialized", &PyModuleInterface::initialized)
+      .def_property_readonly("destroyed", &PyModuleInterface::destroyed)
+      .def("create", &PyModuleInterface::Create)
+      .def("export", &PyModuleInterface::ExportFunction, py::arg("name"),
+           py::arg("cconv"), py::arg("callable"));
+}
+
+}  // namespace iree::python
\ No newline at end of file
diff --git a/runtime/bindings/python/py_module.h b/runtime/bindings/python/py_module.h
new file mode 100644
index 0000000..ee04d7e
--- /dev/null
+++ b/runtime/bindings/python/py_module.h
@@ -0,0 +1,20 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_RT_PY_MODULE_H_
+#define IREE_BINDINGS_PYTHON_IREE_RT_PY_MODULE_H_
+
+#include <vector>
+
+#include "./binding.h"
+
+namespace iree::python {
+
+void SetupPyModuleBindings(py::module &m);
+
+}  // namespace iree::python
+
+#endif  // IREE_BINDINGS_PYTHON_IREE_RT_PY_MODULE_H_
\ No newline at end of file
diff --git a/runtime/bindings/python/tests/py_module_test.py b/runtime/bindings/python/tests/py_module_test.py
new file mode 100644
index 0000000..831b510
--- /dev/null
+++ b/runtime/bindings/python/tests/py_module_test.py
@@ -0,0 +1,304 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import gc
+import unittest
+
+import iree.runtime as rt
+
+NONE_CTOR = lambda iface: None
+
+
+class PyModuleInterfaceTest(unittest.TestCase):
+
+  def setUp(self):
+    self._instance = rt.VmInstance()
+
+  def testEmptyModuleLifecycle(self):
+    iface = rt.PyModuleInterface("test1", NONE_CTOR)
+    print(iface)
+    self.assertFalse(iface.initialized)
+    m = iface.create()
+    print(iface)
+    self.assertTrue(iface.initialized)
+    print(m)
+    m = None
+    gc.collect()
+    print(iface)
+    self.assertTrue(iface.destroyed)
+
+  def testEmptyModuleInstance(self):
+    iface = rt.PyModuleInterface("test1", NONE_CTOR)
+    m = iface.create()
+    context = rt.VmContext(self._instance, modules=(m,))
+    self.assertTrue(iface.initialized)
+    print(context)
+
+    # Make sure no circular refs and that everything frees.
+    context = None
+    m = None
+    gc.collect()
+    self.assertTrue(iface.destroyed)
+
+  def testMultiModuleInstance(self):
+    calls = []
+
+    def ctor(iface):
+      calls.append(iface)
+      return None
+
+    iface = rt.PyModuleInterface("test1", ctor)
+    m = iface.create()
+    context1 = rt.VmContext(self._instance, modules=(m,))
+    self.assertTrue(iface.initialized)
+    context2 = rt.VmContext(self._instance, modules=(m,))
+    self.assertTrue(iface.initialized)
+    self.assertEqual(2, len(calls))
+
+    # Make sure no circular refs and that everything frees.
+    calls = None
+    context1 = None
+    m = None
+    context2 = None
+    gc.collect()
+    self.assertTrue(iface.destroyed)
+
+  def testVoidFunctionExport(self):
+    messages = []
+
+    class Methods:
+
+      def __init__(self, iface):
+        self.iface = iface
+        self.counter = 0
+
+      def say_hello(self):
+        messages.append(f"Hello! Your number is {self.counter}")
+        print(messages[-1])
+        self.counter += 1
+
+    iface = rt.PyModuleInterface("test1", Methods)
+    iface.export("say_hello", "0v", Methods.say_hello)
+    m = iface.create()
+    context = rt.VmContext(self._instance, modules=(m,))
+    f = m.lookup_function("say_hello")
+    self.assertIsNotNone(f)
+    args = rt.VmVariantList(0)
+    results = rt.VmVariantList(0)
+
+    # Invoke twice - should produce two messages.
+    context.invoke(f, args, results)
+    context.invoke(f, args, results)
+    self.assertListEqual(messages, [
+        "Hello! Your number is 0",
+        "Hello! Your number is 1",
+    ])
+
+    # Make sure no circular refs and that everything frees.
+    context = None
+    m = None
+    gc.collect()
+    self.assertTrue(iface.destroyed)
+
+  def testPythonException(self):
+    messages = []
+
+    class Methods:
+
+      def __init__(self, iface):
+        pass
+
+      def do_it(self):
+        raise ValueError("This is from Python")
+
+    iface = rt.PyModuleInterface("test1", Methods)
+    iface.export("do_it", "0v", Methods.do_it)
+    m = iface.create()
+    context = rt.VmContext(self._instance, modules=(m,))
+    f = m.lookup_function("do_it")
+    self.assertIsNotNone(f)
+    args = rt.VmVariantList(0)
+    results = rt.VmVariantList(0)
+
+    # We are testing here that the Python level exception is caught and
+    # translated to an IREE status (surfacing as a RuntimeError) vs percolating
+    # through the C call stack.
+    with self.assertRaisesRegex(RuntimeError,
+                                "ValueError: This is from Python"):
+      context.invoke(f, args, results)
+
+    # Make sure no circular refs and that everything frees.
+    context = None
+    m = None
+    gc.collect()
+    self.assertTrue(iface.destroyed)
+
+  def testPrimitiveArguments(self):
+    values = []
+
+    class Methods:
+
+      def __init__(self, iface):
+        pass
+
+      def do_it(self, a, b):
+        values.append((a, b))
+
+    iface = rt.PyModuleInterface("test1", Methods)
+    iface.export("do_it_i32", "0ii", Methods.do_it)
+    iface.export("do_it_i64", "0II", Methods.do_it)
+    iface.export("do_it_f32", "0ff", Methods.do_it)
+    iface.export("do_it_f64", "0FF", Methods.do_it)
+    m = iface.create()
+    context = rt.VmContext(self._instance, modules=(m,))
+
+    args = rt.VmVariantList(2)
+    results = rt.VmVariantList(0)
+    args.push_int(42)
+    args.push_int(43)
+    context.invoke(m.lookup_function("do_it_i32"), args, results)
+    context.invoke(m.lookup_function("do_it_i64"), args, results)
+
+    args = rt.VmVariantList(2)
+    args.push_float(2.0)
+    args.push_float(4.0)
+    # TODO: Python doesn't have 32bit floats, so we are populating f64 args.
+    # These are coming back as zeros, and I expected something to be
+    # doing a conversion? The same is being done with i64 above but is
+    # working there.
+    context.invoke(m.lookup_function("do_it_f32"), args, results)
+    context.invoke(m.lookup_function("do_it_f64"), args, results)
+
+    print(values)
+    self.assertEqual(repr(values),
+                     "[(42, 43), (42, 43), (0.0, 0.0), (2.0, 4.0)]")
+
+    # Make sure no circular refs and that everything frees.
+    context = None
+    m = None
+    gc.collect()
+    self.assertTrue(iface.destroyed)
+
+  def testPrimitiveResults(self):
+    next_results = None
+
+    class Methods:
+
+      def __init__(self, iface):
+        pass
+
+      def do_it(self):
+        return next_results
+
+    iface = rt.PyModuleInterface("test1", Methods)
+    iface.export("do_it_i32", "0v_ii", Methods.do_it)
+    iface.export("do_it_i64", "0v_II", Methods.do_it)
+    iface.export("do_it_f32", "0v_ff", Methods.do_it)
+    iface.export("do_it_f64", "0v_FF", Methods.do_it)
+    iface.export("do_it_unary_i32", "0v_i", Methods.do_it)
+    m = iface.create()
+    context = rt.VmContext(self._instance, modules=(m,))
+
+    args = rt.VmVariantList(0)
+
+    # i32
+    results = rt.VmVariantList(2)
+    next_results = (42, 43)
+    context.invoke(m.lookup_function("do_it_i32"), args, results)
+    self.assertEqual(repr(results), "<VmVariantList(2): [42, 43]>")
+
+    # i64
+    results = rt.VmVariantList(2)
+    next_results = (42, 43)
+    context.invoke(m.lookup_function("do_it_i64"), args, results)
+    self.assertEqual(repr(results), "<VmVariantList(2): [42, 43]>")
+
+    # f32
+    results = rt.VmVariantList(2)
+    next_results = (2.0, 4.0)
+    context.invoke(m.lookup_function("do_it_f32"), args, results)
+    self.assertEqual(repr(results), "<VmVariantList(2): [2.000000, 4.000000]>")
+
+    # f64
+    results = rt.VmVariantList(2)
+    next_results = (2.0, 4.0)
+    context.invoke(m.lookup_function("do_it_f64"), args, results)
+    self.assertEqual(repr(results), "<VmVariantList(2): [2.000000, 4.000000]>")
+
+    # Unary special case.
+    results = rt.VmVariantList(1)
+    next_results = (42)
+    context.invoke(m.lookup_function("do_it_unary_i32"), args, results)
+    self.assertEqual(repr(results), "<VmVariantList(1): [42]>")
+
+    # Make sure no circular refs and that everything frees.
+    context = None
+    m = None
+    gc.collect()
+    self.assertTrue(iface.destroyed)
+
+  def testRefArguments(self):
+    values = []
+
+    class Methods:
+
+      def __init__(self, iface):
+        pass
+
+      def do_it(self, a, b):
+        values.append((a.deref(rt.VmVariantList), b.deref(rt.VmVariantList)))
+
+    iface = rt.PyModuleInterface("test1", Methods)
+    iface.export("do_it", "0rr", Methods.do_it)
+    m = iface.create()
+    context = rt.VmContext(self._instance, modules=(m,))
+
+    # These lists just happen to be reference objects we know how to
+    # create.
+    arg0 = rt.VmVariantList(1)
+    arg0.push_int(42)
+    arg1 = rt.VmVariantList(1)
+    arg1.push_int(84)
+
+    args = rt.VmVariantList(2)
+    args.push_list(arg0)
+    args.push_list(arg1)
+    results = rt.VmVariantList(2)
+    context.invoke(m.lookup_function("do_it"), args, results)
+    print("REF VALUES:", values)
+    self.assertEqual(repr(values),
+                     "[(<VmVariantList(1): [42]>, <VmVariantList(1): [84]>)]")
+
+  def testRefResults(self):
+
+    class Methods:
+
+      def __init__(self, iface):
+        pass
+
+      def do_it(self):
+        # These lists just happen to be reference objects we know how to
+        # create.
+        r0 = rt.VmVariantList(1)
+        r0.push_int(42)
+        r1 = rt.VmVariantList(1)
+        r1.push_int(84)
+        return r0.ref, r1.ref
+
+    iface = rt.PyModuleInterface("test1", Methods)
+    iface.export("do_it", "0v_rr", Methods.do_it)
+    m = iface.create()
+    context = rt.VmContext(self._instance, modules=(m,))
+
+    args = rt.VmVariantList(0)
+    results = rt.VmVariantList(2)
+    context.invoke(m.lookup_function("do_it"), args, results)
+    print("REF RESULTS:", results)
+    self.assertEqual(repr(results), "<VmVariantList(2): [List[42], List[84]]>")
+
+
+if __name__ == "__main__":
+  unittest.main()
diff --git a/runtime/bindings/python/tests/vm_types_test.py b/runtime/bindings/python/tests/vm_types_test.py
new file mode 100644
index 0000000..ada9b5d
--- /dev/null
+++ b/runtime/bindings/python/tests/vm_types_test.py
@@ -0,0 +1,33 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import unittest
+
+import iree.runtime as rt
+
+
+class VmTypesTest(unittest.TestCase):
+
+  def testRefProtocol(self):
+    lst1 = rt.VmVariantList(0)
+    ref = lst1.__iree_vm_ref__
+    ref2 = lst1.ref
+    print(ref)
+    print(ref2)
+    self.assertEqual(ref, ref2)
+    self.assertNotEqual(ref, False)
+    lst2 = rt.VmVariantList.__iree_vm_cast__(ref)
+    print(lst2)
+    lst3 = ref.deref(rt.VmVariantList)
+    print(lst3)
+    self.assertEqual(lst1, lst2)
+    self.assertEqual(lst2, lst3)
+    self.assertNotEqual(lst1, False)
+    self.assertTrue(ref.isinstance(rt.VmVariantList))
+
+
+if __name__ == "__main__":
+  unittest.main()
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 9826530..5ffeede 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -175,6 +175,37 @@
 }
 
 //------------------------------------------------------------------------------
+// VmRef
+//------------------------------------------------------------------------------
+
+const char* const VmRef::kRefAttr = "__iree_vm_ref__";
+const char* const VmRef::kCastAttr = "__iree_vm_cast__";
+const char* const VmRef::kTypeIdAttr = "__iree_vm_type_id__";
+
+py::object VmRef::Deref(py::object ref_object_class) {
+  return ref_object_class.attr(kCastAttr)(*this);
+}
+
+bool VmRef::IsInstance(py::object ref_object_class) {
+  auto type_id =
+      py::cast<iree_vm_ref_type_t>(ref_object_class.attr(kTypeIdAttr)());
+  return type_id == ref_.type;
+}
+
+std::string VmRef::ToString() {
+  if (!ref_.ptr) {
+    return "<VmRef NULL>";
+  }
+  iree_string_view_t type_name = iree_vm_ref_type_name(ref_.type);
+  std::stringstream ss;
+  ss << "<VmRef ";
+  ss.write(type_name.data, type_name.size);
+  ss << " at " << std::hex << "0x" << reinterpret_cast<uintptr_t>(ref_.ptr)
+     << ">";
+  return ss.str();
+}
+
+//------------------------------------------------------------------------------
 // VmVariantList
 //------------------------------------------------------------------------------
 
@@ -213,7 +244,7 @@
   CheckApiStatus(iree_vm_list_check_deref(ref, &sub_list),
                  "Could not deref list (wrong type?)");
   iree_vm_list_retain(sub_list);
-  return py::cast(VmVariantList(sub_list));
+  return py::cast(VmVariantList::StealFromRawPtr(sub_list));
 }
 
 py::object VmVariantList::GetVariant(int index) {
@@ -297,7 +328,7 @@
       CheckApiStatus(iree_vm_list_check_deref(v.ref, &sub_list),
                      "Could not deref list (wrong type?)");
       iree_vm_list_retain(sub_list);
-      VmVariantList sub_list_object(sub_list);
+      VmVariantList sub_list_object = VmVariantList::StealFromRawPtr(sub_list);
       for (int i = 0, e = sub_list_object.size(); i < e; ++i) {
         items.append(sub_list_object.GetAsSerializedTraceValue(i));
       }
@@ -496,7 +527,11 @@
       .export_values();
 
   // Mutation and inspection of the variant list is mostly opaque to python.
-  py::class_<VmVariantList>(m, "VmVariantList")
+  auto vm_list = py::class_<VmVariantList>(m, "VmVariantList");
+  VmRef::BindRefProtocol(vm_list, iree_vm_list_type_id, iree_vm_list_retain_ref,
+                         iree_vm_list_check_deref);
+  vm_list
+      // User Methods.
       .def(py::init(&VmVariantList::Create))
       .def_property_readonly("size", &VmVariantList::size)
       .def("__len__", &VmVariantList::size)
@@ -612,6 +647,16 @@
         repr.append(">");
         return repr;
       });
+
+  py::class_<VmRef>(m, "VmRef")
+      .def("isinstance", &VmRef::IsInstance)
+      .def("deref", &VmRef::Deref)
+      .def("__repr__", &VmRef::ToString)
+      .def("__eq__",
+           [](VmRef& self, VmRef& other) {
+             return self.ref().ptr == other.ref().ptr;
+           })
+      .def("__eq__", [](VmRef& self, py::object& other) { return false; });
 }
 
 }  // namespace python
diff --git a/runtime/bindings/python/vm.h b/runtime/bindings/python/vm.h
index e3632dd..135e98b 100644
--- a/runtime/bindings/python/vm.h
+++ b/runtime/bindings/python/vm.h
@@ -14,6 +14,7 @@
 #include "iree/base/api.h"
 #include "iree/vm/api.h"
 #include "iree/vm/bytecode_module.h"
+#include "iree/vm/ref.h"
 
 namespace iree {
 namespace python {
@@ -37,6 +38,12 @@
 };
 
 template <>
+struct ApiPtrAdapter<iree_vm_list_t> {
+  static void Retain(iree_vm_list_t* b) { iree_vm_list_retain(b); }
+  static void Release(iree_vm_list_t* b) { iree_vm_list_release(b); }
+};
+
+template <>
 struct ApiPtrAdapter<iree_vm_module_t> {
   static void Retain(iree_vm_module_t* b) { iree_vm_module_retain(b); }
   static void Release(iree_vm_module_t* b) { iree_vm_module_release(b); }
@@ -50,44 +57,33 @@
   }
 };
 
+template <>
+struct ApiPtrAdapter<iree_vm_ref_t> {
+  static void Retain(iree_vm_ref_t* b) {
+    iree_vm_ref_t out_ref;
+    std::memset(&out_ref, 0, sizeof(out_ref));
+    iree_vm_ref_retain(b, &out_ref);
+  }
+  static void Release(iree_vm_ref_t* b) { iree_vm_ref_release(b); }
+};
+
 //------------------------------------------------------------------------------
 // VmVariantList
+// TODO: Rename to VmList
 //------------------------------------------------------------------------------
 
-class VmVariantList {
+class VmVariantList : public ApiRefCounted<VmVariantList, iree_vm_list_t> {
  public:
-  VmVariantList() : list_(nullptr) {}
-  ~VmVariantList() {
-    if (list_) {
-      iree_vm_list_release(list_);
-    }
-  }
-
-  VmVariantList(VmVariantList&& other) {
-    list_ = other.list_;
-    other.list_ = nullptr;
-  }
-
-  VmVariantList& operator=(const VmVariantList&) = delete;
-  VmVariantList(const VmVariantList&) = delete;
-
   static VmVariantList Create(iree_host_size_t capacity) {
     iree_vm_list_t* list;
     CheckApiStatus(iree_vm_list_create(/*element_type=*/nullptr, capacity,
                                        iree_allocator_system(), &list),
                    "Error allocating variant list");
-    return VmVariantList(list);
+    return VmVariantList::StealFromRawPtr(list);
   }
 
-  iree_host_size_t size() const { return iree_vm_list_size(list_); }
+  iree_host_size_t size() const { return iree_vm_list_size(raw_ptr()); }
 
-  iree_vm_list_t* raw_ptr() { return list_; }
-  const iree_vm_list_t* raw_ptr() const { return list_; }
-  iree_vm_list_t* steal_raw_ptr() {
-    iree_vm_list_t* stolen = list_;
-    list_ = nullptr;
-    return stolen;
-  }
   void AppendNullRef() {
     iree_vm_ref_t null_ref = {0};
     CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &null_ref),
@@ -103,14 +99,10 @@
   py::object GetAsBufferView(int index);
   py::object GetVariant(int index);
   py::object GetAsSerializedTraceValue(int index);
-
- private:
-  VmVariantList(iree_vm_list_t* list) : list_(list) {}
-  iree_vm_list_t* list_;
 };
 
 //------------------------------------------------------------------------------
-// ApiRefCounted types
+// VmInstance
 //------------------------------------------------------------------------------
 
 class VmInstance : public ApiRefCounted<VmInstance, iree_vm_instance_t> {
@@ -118,6 +110,10 @@
   static VmInstance Create();
 };
 
+//------------------------------------------------------------------------------
+// VmModule
+//------------------------------------------------------------------------------
+
 class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
  public:
   static VmModule FromFlatbufferBlob(py::object flatbuffer_blob_object);
@@ -137,6 +133,10 @@
   py::object stashed_flatbuffer_blob = py::none();
 };
 
+//------------------------------------------------------------------------------
+// VmContext
+//------------------------------------------------------------------------------
+
 class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> {
  public:
   // Creates a context, optionally with modules, which will make the context
@@ -160,6 +160,90 @@
 class VmInvocation : public ApiRefCounted<VmInvocation, iree_vm_invocation_t> {
 };
 
+//------------------------------------------------------------------------------
+// VmRef (represents a pointer to an arbitrary reference object).
+//------------------------------------------------------------------------------
+
+class VmRef {
+ public:
+  //----------------------------------------------------------------------------
+  // Binds the reference protocol to a VmRefObject bound class.
+  // This defines three attributes:
+  //   __iree_vm_type_id__()
+  //        Gets the type id from the object.
+  //   [readonly property] __iree_vm_ref__ :
+  //        Gets a VmRef from the object.
+  //   __iree_vm_cast__(ref) :
+  //        Dereferences the VmRef to the concrete type.
+  //
+  // In addition, a user attribute of "ref" will be added that is an alias of
+  // __iree_vm_ref__.
+  //
+  // An __eq__ method is added which returns true if the python objects refer
+  // to the same vm object.
+  //
+  // The BindRefProtocol() helper is used on a py::class_ defined for a
+  // reference object. It takes some of the C helper functions that are defined
+  // for each type and is generic.
+  //----------------------------------------------------------------------------
+  static const char* const kTypeIdAttr;
+  static const char* const kRefAttr;
+  static const char* const kCastAttr;
+
+  template <typename PyClass, typename TypeIdFunctor, typename RetainRefFunctor,
+            typename CheckDerefFunctor>
+  static void BindRefProtocol(PyClass& cls, TypeIdFunctor type_id,
+                              RetainRefFunctor retain_ref,
+                              CheckDerefFunctor check_deref) {
+    using WrapperType = typename PyClass::type;
+    using RawPtrType = typename WrapperType::RawPtrType;
+    auto ref_lambda = [=](WrapperType& self) {
+      return VmRef::Steal(retain_ref(self.raw_ptr()));
+    };
+    cls.def_static(VmRef::kTypeIdAttr, [=]() { return type_id(); });
+    cls.def_property_readonly(VmRef::kRefAttr, ref_lambda);
+    cls.def_property_readonly("ref", ref_lambda);
+    cls.def_static(VmRef::kCastAttr, [=](VmRef& ref) {
+      RawPtrType casted;
+      CheckApiStatus(check_deref(ref.ref(), &casted), "Incompatible type");
+      return WrapperType::StealFromRawPtr(casted);
+    });
+    cls.def("__eq__", [](WrapperType& self, WrapperType& other) {
+      return self.raw_ptr() == other.raw_ptr();
+    });
+    cls.def("__eq__",
+            [](WrapperType& self, py::object& other) { return false; });
+  }
+
+  // Initializes a null ref.
+  VmRef() { std::memset(&ref_, 0, sizeof(ref_)); }
+  VmRef(VmRef&& other) : ref_(other.ref_) {
+    std::memset(&other.ref_, 0, sizeof(other.ref_));
+  }
+  VmRef(const VmRef&) = delete;
+  VmRef& operator=(const VmRef&) = delete;
+  ~VmRef() {
+    if (ref_.ptr) {
+      iree_vm_ref_release(&ref_);
+    }
+  }
+
+  // Creates a VmRef from an owned ref, taking the reference count.
+  static VmRef Steal(iree_vm_ref_t ref) { return VmRef(ref); }
+
+  iree_vm_ref_t& ref() { return ref_; }
+
+  py::object Deref(py::object ref_object_class);
+  bool IsInstance(py::object ref_object_class);
+
+  std::string ToString();
+
+ private:
+  // Initializes with an owned ref.
+  VmRef(iree_vm_ref_t ref) : ref_(ref) {}
+  iree_vm_ref_t ref_;
+};
+
 void SetupVmBindings(pybind11::module m);
 
 }  // namespace python