Modernize and relocate iree/runtime Python package to runtime/bindings/python. (#8912)

* Modernize and relocate iree/runtime Python package to iree/runtime/python.

Non-functional changes:

* Moves `bindings/python/iree/runtime` -> `iree/runtime/python/iree/runtime`.
* Fixes dash vs undescore inconsistency in compile install path. Now both use underscores (python_packages/iree_compiler and python_packages/iree_runtime).
* Moves build directory for iree/compiler/python to iree/compiler/python (was outputting to bindings/python). Updates locations that were hard-coded.

Functional changes:

* Removes the old build-dir only setup.py in favor of an iree/runtime/setup.py that works from either the source or build dir.
* Reworks the releases to use the new setup.py as-is vs scripting the build manually.
* iree.runtime.version is now generated in the same way as iree.compiler.version.
* Users can now run iree/runtime/setup.py with pip themselves to generate a wheel (i.e. `pip wheel iree/runtime`).

It is now possible to integrate python package testing into the presubmit and have build jobs that generated Python installable binaries for subsequent steps.

The only file left in bindings/python is build_requirements.txt. It is referred to by some docs and CI jobs so leaving as-is for the moment (will find it a new home in a followup).
diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt
new file mode 100644
index 0000000..8a3ce7a
--- /dev/null
+++ b/runtime/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+if(IREE_BUILD_PYTHON_BINDINGS)
+  # Copy Python packaging files to the build dir so that we can install from
+  # there.
+  configure_file(pyproject.toml pyproject.toml COPYONLY)
+  configure_file(setup.py setup.py @ONLY)
+
+  add_subdirectory(bindings/python/iree/runtime)
+endif()
diff --git a/runtime/README.md b/runtime/README.md
new file mode 100644
index 0000000..e395753
--- /dev/null
+++ b/runtime/README.md
@@ -0,0 +1,43 @@
+# IREE runtime
+
+Note that this directory is in a transitional state. The C code still lives
+in directories under `iree/` and will be relocated here in the future.
+
+## Language Bindings
+
+### Python
+
+The included `setup.py` file can be used to build Python binaries or directly
+install the IREE runtime API. Do note that the runtime is quite heavy and
+unless you are developing it and on a significant machine, you will want to
+use released binaries.
+
+There are two ways to build/install Python packages:
+
+* Directly from the source tree (this is how official releases are done).
+* From the build directory while developing.
+
+It is recommended to use your favorite method for managing
+[virtual environments](https://docs.python.org/3/library/venv.html) instead
+of modifying the system installation.
+
+Only relatively recent versions of `pip` are supported. Always use the latest
+via `pip install --upgrade pip`.
+
+You can build either from the source or build tree (assumes that CMake has
+been configured and the project built). The latter is typically used by
+project developers who are already setup for development and want to
+incrementally generate Python packages without rebuilding.
+
+To build a wheel that can be installed on the same Python version and OS:
+
+```
+python -m pip wheel runtime/
+```
+
+To directly install:
+
+```
+python -m pip install runtime/
+```
+
diff --git a/runtime/bindings/python/iree/runtime/CMakeLists.txt b/runtime/bindings/python/iree/runtime/CMakeLists.txt
new file mode 100644
index 0000000..e93f41f
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/CMakeLists.txt
@@ -0,0 +1,185 @@
+# Copyright 2020 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+set(NUMPY_DEPS "")
+set(PYBIND_COPTS "-fexceptions")
+set(PYBIND_EXTENSION_COPTS "-fvisibility=hidden")
+
+set(_python_extra_srcs)
+set(_extra_install_tool_targets)
+set(_tracy_enabled OFF)
+
+if(TARGET IREETracyCaptureServer)
+  message(STATUS "Bundline Tracy CLI tools with Python API")
+  set(_tracy_enabled ON)
+  list(APPEND _python_extra_srcs "scripts/iree-tracy-capture")
+  list(APPEND _extra_install_tool_targets "IREETracyCaptureServer")
+endif()
+
+################################################################################
+# Package
+################################################################################
+
+iree_pyext_module(
+  NAME
+    PyExtRt
+  MODULE_NAME binding
+  SRCS
+    "binding.h"
+    "initialize_module.cc"
+    "invoke.h"
+    "invoke.cc"
+    "hal.h"
+    "hal.cc"
+    "status_utils.cc"
+    "status_utils.h"
+    "vm.h"
+    "vm.cc"
+  UNIX_LINKER_SCRIPT
+    "unix_version.lds"
+  DEFINES
+    # Pybind code seems to be incompatible with C++ allocation tracing
+    # hooks so disable it.
+    IREE_TRACING_HOOK_CPP_NEW_DELETE=0
+  DEPS
+    iree::base
+    iree::base::cc
+    iree::base::internal::flags
+    iree::base::tracing
+    iree::hal
+    iree::hal::drivers
+    iree::modules::hal
+    iree::vm
+    iree::vm::bytecode_module
+)
+
+iree_py_library(
+  NAME
+    runtime
+  SRCS
+    "__init__.py"
+    "array_interop.py"
+    "flags.py"
+    "function.py"
+    "system_api.py"
+    "tracing.py"
+    "scripts/iree_benchmark_trace/__main__.py"
+    "scripts/iree_run_trace/__main__.py"
+    "scripts/iree_run_module/__main__.py"
+    ${_python_extra_srcs}
+  PYEXT_DEPS
+    ::PyExtRt
+)
+
+iree_symlink_tool(
+  TARGET runtime
+  FROM_TOOL_TARGET iree_tools_iree-benchmark-trace
+  TO_EXE_NAME iree-benchmark-trace
+)
+
+iree_symlink_tool(
+  TARGET runtime
+  FROM_TOOL_TARGET iree_tools_iree-run-trace
+  TO_EXE_NAME iree-run-trace
+)
+
+iree_symlink_tool(
+  TARGET runtime
+  FROM_TOOL_TARGET iree_tools_iree-run-module
+  TO_EXE_NAME iree-run-module
+)
+
+if(_tracy_enabled)
+  iree_symlink_tool(
+    TARGET runtime
+    FROM_TOOL_TARGET IREETracyCaptureServer
+    TO_EXE_NAME iree-tracy-capture
+  )
+endif()
+
+################################################################################
+# Tests
+################################################################################
+
+iree_py_test(
+  NAME
+    array_interop_test
+  SRCS
+    "array_interop_test.py"
+)
+
+iree_py_test(
+  NAME
+    flags_test
+  SRCS
+    "flags_test.py"
+)
+
+iree_py_test(
+  NAME
+    function_test
+  SRCS
+    "function_test.py"
+)
+
+iree_py_test(
+  NAME
+    hal_test
+  SRCS
+    "hal_test.py"
+)
+
+iree_py_test(
+  NAME
+    system_api_test
+  SRCS
+    "system_api_test.py"
+)
+
+iree_py_test(
+  NAME
+    vm_test
+  SRCS
+    "vm_test.py"
+)
+
+################################################################################
+# Install
+################################################################################
+
+iree_py_install_package(
+  COMPONENT IreePythonPackage-runtime
+  PACKAGE_NAME iree_runtime
+  MODULE_PATH iree/runtime
+  DEPS
+    # TODO: Update CMake target/path mangling rules to make this syntactically
+    # rooted on "iree" in some way.
+    runtime_bindings_python_iree_runtime_PyExtRt
+    iree_tools_iree-benchmark-trace
+    iree_tools_iree-run-module
+    iree_tools_iree-run-trace
+    ${_extra_install_tool_targets}
+  ADDL_PACKAGE_FILES
+    ${CMAKE_CURRENT_SOURCE_DIR}/README.md
+)
+
+install(
+  # TODO: Update CMake target/path mangling rules to make this syntactically
+  # rooted on "iree" in some way.
+  TARGETS runtime_bindings_python_iree_runtime_PyExtRt
+  COMPONENT ${PY_INSTALL_COMPONENT}
+  DESTINATION "${PY_INSTALL_MODULE_DIR}"
+)
+
+install(
+  TARGETS
+    iree_tools_iree-benchmark-trace
+    iree_tools_iree-run-module
+    iree_tools_iree-run-trace
+    ${_extra_install_tool_targets}
+  DESTINATION "${PY_INSTALL_MODULE_DIR}"
+  COMPONENT "${PY_INSTALL_COMPONENT}"
+)
diff --git a/runtime/bindings/python/iree/runtime/README.md b/runtime/bindings/python/iree/runtime/README.md
new file mode 100644
index 0000000..cd21ffe
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/README.md
@@ -0,0 +1,23 @@
+# IREE Python Runtime Components
+
+This package provides an API for running compiled IREE binaries and interfacing with the hardware-abstraction-layer.
+
+## Tracing
+
+Execution of calls against binaries can be traced for later replay (i.e. via
+tools like `iree-run-module`). This can be set up either explicitly or
+via environment variables.
+
+To trace via environment variable, set `IREE_SAVE_CALLS` to a directory to dump
+traces into. Each created `SystemContext` will result in one `calls.yaml`
+file (with an index appended to the stem for multiples). Any referenced
+module binaries will be dumped into the same directory and referenced by the
+YAML file.
+
+### Explicit API
+
+```python
+tracer = iree.runtime.Tracer(some_dir)
+config = iree.runtime.Config(driver, tracer)
+...
+```
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
new file mode 100644
index 0000000..e0d8643
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -0,0 +1,47 @@
+"""Module init for the python bindings."""
+
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# pylint: disable=g-multiple-import
+# pylint: disable=g-bad-import-order
+# pylint: disable=wildcard-import
+
+from . import binding
+
+# Pull some of the native symbols into the public API.
+# Hal imports
+from .binding import (
+    BufferCompatibility,
+    BufferUsage,
+    HalAllocator,
+    HalBuffer,
+    HalBufferView,
+    HalDevice,
+    HalDriver,
+    HalElementType,
+    MemoryAccess,
+    MemoryType,
+    Shape,
+)
+
+# Vm imports
+from .binding import (
+    create_hal_module,
+    Linkage,
+    VmVariantList,
+    VmFunction,
+    VmInstance,
+    VmContext,
+    VmModule,
+)
+
+from .array_interop import *
+from .system_api import *
+from .function import *
+from .tracing import *
+
+from . import flags
diff --git a/runtime/bindings/python/iree/runtime/array_interop.py b/runtime/bindings/python/iree/runtime/array_interop.py
new file mode 100644
index 0000000..d761274
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/array_interop.py
@@ -0,0 +1,253 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""BufferView and Python Array Protocol interop."""
+
+from typing import Optional, Tuple
+import logging
+import numpy as np
+import numpy.lib.mixins
+
+from .binding import (
+    BufferUsage,
+    HalBufferView,
+    HalDevice,
+    HalElementType,
+    MappedMemory,
+    MemoryType,
+)
+
+__all__ = [
+    "asdevicearray",
+    "DeviceArray",
+]
+
+_DEVICE_HANDLED_FUNCTIONS = {}
+
+
+def _device_implements(np_function):
+  """Decorator that registers a base class implementation."""
+
+  def decorator(func):
+    _DEVICE_HANDLED_FUNCTIONS[np_function] = func
+    return func
+
+  return decorator
+
+
+class DeviceArray(numpy.lib.mixins.NDArrayOperatorsMixin):
+  """An IREE device array.
+
+  Device arrays can be in one of two states:
+    1. Host accessible: The array will be backed by host accessible memory
+       and can have the usual things done with it that one expects to be
+       able to do with an ndarray.
+    2. Device resident: The array is just a handle to a device resident
+       Buffer (and BufferView wrapper). Metadata about the array are accessible
+       (shape and dtype) but anything that touches the data cannot be accessed
+       in this state.
+
+  How a device array comes into existence controls how it can transition
+  between these states:
+    * A user can create a DeviceArray explicitly with a device allocator.
+      Such an array will not be implicitly convertible to host accessible,
+      although accessors exist to do so.
+    * When created by the platform with a synchronization policy, then
+      implicit transfer back to the host will trigger appropriate waits and
+      be performed automatically (this is the common case for function return
+      values if not otherwise configured, as an example).
+  """
+
+  def __init__(self,
+               device: HalDevice,
+               buffer_view: HalBufferView,
+               implicit_host_transfer: bool = False,
+               override_dtype=None):
+    self._device = device
+    self._buffer_view = buffer_view
+    self._implicit_host_transfer = implicit_host_transfer
+    self._override_dtype = override_dtype
+
+    # If the array is host accessible, these will be non-None.
+    self._mapped_memory: Optional[MappedMemory] = None
+    self._host_array: Optional[np.ndarray] = None
+
+  def __array__(self, dtype=None):
+    self._transfer_to_host(True)
+    if dtype is None:
+      return self._host_array
+    else:
+      return self._host_array.__array__(dtype)  # pytype: disable=attribute-error
+
+  def __array_function__(self, func, types, args, kwargs):
+    if func in _DEVICE_HANDLED_FUNCTIONS:
+      return _DEVICE_HANDLED_FUNCTIONS[func](*args, **kwargs)
+
+    # Anything else forces a transfer to host and then delegates to the
+    # host array.
+    host_array = self.to_host()
+    return host_array.__array_function__(func, types, args, kwargs)  # pytype: disable=attribute-error
+
+  def __repr__(self):
+    return f"<IREE DeviceArray: shape={np.shape(self)}, dtype={self.dtype}>"
+
+  @property
+  def is_host_accessible(self):
+    """Whether this array is currently host accessible."""
+    return self._host_array is not None
+
+  def to_host(self) -> np.ndarray:
+    self._transfer_to_host(False)
+    return self._host_array
+
+  def _transfer_to_host(self, implicit):
+    if self._host_array is not None:
+      return
+    if implicit and not self._implicit_host_transfer:
+      raise ValueError(
+          "DeviceArray cannot be implicitly transferred to the host: "
+          "if necessary, do an explicit transfer via .to_host()")
+    self._mapped_memory, self._host_array = self._map_to_host()
+
+  def _map_to_host(self) -> Tuple[MappedMemory, np.ndarray]:
+    # TODO: When synchronization is enabled, need to block here.
+    raw_dtype = self._get_raw_dtype()
+    mapped_memory = self._buffer_view.map()
+    host_array = mapped_memory.asarray(self._buffer_view.shape, raw_dtype)
+    # Detect if we need to force an explicit conversion. This happens when
+    # we were requested to pretend that the array is in a specific dtype,
+    # even if that is not representable on the device. You guessed it:
+    # this is to support bools.
+    if self._override_dtype is not None and self._override_dtype != raw_dtype:
+      host_array = host_array.astype(self._override_dtype)
+    return mapped_memory, host_array
+
+  def _get_raw_dtype(self):
+    return HalElementType.map_to_dtype(self._buffer_view.element_type)
+
+  @property
+  def dtype(self):
+    if self._override_dtype:
+      return self._override_dtype
+    return self._get_raw_dtype()
+
+  @property
+  def shape(self):
+    return np.shape(self)
+
+  def astype(self, dtype, casting="unsafe", copy=True):
+    if self.dtype == dtype and not copy:
+      return self
+    host_ary = self.to_host()
+    return host_ary.astype(dtype, casting=casting, copy=copy)
+
+  def reshape(self, *args):
+    # TODO(scotttodd): add a native impl with a new buffer_view of the same data
+    # TODO(scotttodd): return DeviceArray instead of host ndarray?
+    host_ary = self.to_host()
+    return host_ary.reshape(*args)
+
+  def __iter__(self):
+    host_ary = self.to_host()
+    return host_ary.__iter__()
+
+  def __getitem__(self, index):
+    host_ary = self.to_host()
+    return host_ary.__getitem__(index)
+
+  def __reduce__(self):
+    # Since this is used for making deep copies and pickling, we map
+    # separately from any interactive state. We just reduce to the actual
+    # host ndarray, which supports the necessary serialization protocols.
+    _, host_array = self._map_to_host()
+    return _restore_reduced_array, (host_array,)
+
+
+def _restore_reduced_array(ary):
+  return ary
+
+
+# Function implementations with custom behavior.
+@_device_implements(np.shape)
+def _(arr: DeviceArray):
+  return arr._buffer_view.shape
+
+
+@_device_implements(np.reshape)
+def _(arr: DeviceArray, *args):
+  return arr.reshape(*args)
+
+
+def asdevicearray(device: HalDevice,
+                  a,
+                  dtype=None,
+                  *,
+                  implicit_host_transfer: bool = False,
+                  memory_type=MemoryType.DEVICE_LOCAL,
+                  allowed_usage=(BufferUsage.DISPATCH | BufferUsage.TRANSFER |
+                                 BufferUsage.MAPPING),
+                  element_type: Optional[HalElementType] = None) -> DeviceArray:
+  """Helper to create a DeviceArray from an arbitrary array like.
+
+  This is similar in purpose and usage to np.asarray, except that it takes
+  a device as the first argument. This may not be the best mechanism for
+  getting a DeviceArray, depending on your use case, but it is reliable
+  and simple. This function may make a defensive copy or cause implicit
+  transfers to satisfy the request. If this is important to you, then a lower
+  level API is likely more appropriate.
+
+  Note that additional flags `memory_type`, `allowed_usage` and `element_type`
+  are only hints if creating a new DeviceArray. If `a` is already a DeviceArray,
+  they are ignored.
+  """
+  if isinstance(a, DeviceArray):
+    if dtype is None:
+      return a
+    # Need to do a conversion, which we currently do not support on the
+    # device, so transfer back to the host.
+    logging.warn(
+        "Implicit dtype conversion of a DeviceArray forces a host transfer")
+  # First get an ndarray.
+  a = np.asarray(a, dtype=dtype)
+  element_type = map_dtype_to_element_type(a.dtype)
+  if element_type is None:
+    raise ValueError(f"Could not map dtype {a.dtype} to IREE element type")
+  buffer_view = device.allocator.allocate_buffer_copy(
+      memory_type=memory_type,
+      allowed_usage=allowed_usage,
+      buffer=a,
+      element_type=element_type)
+  return DeviceArray(device,
+                     buffer_view,
+                     implicit_host_transfer=implicit_host_transfer,
+                     override_dtype=a.dtype)
+
+
+# NOTE: Numpy dtypes are not hashable and exist in a hierarchy that should
+# be queried via isinstance checks. This should be done as a fallback but
+# this is a linear list for quick access to the most common. There may also
+# be a better way to do this.
+_DTYPE_TO_HAL_ELEMENT_TYPE = (
+    (np.float32, HalElementType.FLOAT_32),
+    (np.float64, HalElementType.FLOAT_64),
+    (np.float16, HalElementType.FLOAT_16),
+    (np.int32, HalElementType.SINT_32),
+    (np.int64, HalElementType.SINT_64),
+    (np.int16, HalElementType.SINT_16),
+    (np.int8, HalElementType.SINT_8),
+    (np.uint32, HalElementType.UINT_32),
+    (np.uint64, HalElementType.UINT_64),
+    (np.uint16, HalElementType.UINT_16),
+    (np.uint8, HalElementType.UINT_8),
+    (np.bool_, HalElementType.BOOL_8),
+)
+
+
+def map_dtype_to_element_type(dtype) -> Optional[HalElementType]:
+  for match_dtype, element_type in _DTYPE_TO_HAL_ELEMENT_TYPE:
+    if match_dtype == dtype:
+      return element_type
+  else:
+    return None
diff --git a/runtime/bindings/python/iree/runtime/array_interop_test.py b/runtime/bindings/python/iree/runtime/array_interop_test.py
new file mode 100644
index 0000000..032f296
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/array_interop_test.py
@@ -0,0 +1,149 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import copy
+import numpy as np
+import unittest
+
+import iree.runtime
+
+
+class DeviceHalTest(unittest.TestCase):
+
+  def setUp(self):
+    super().setUp()
+    self.driver = iree.runtime.HalDriver.create("vmvx")
+    self.device = self.driver.create_default_device()
+    self.allocator = self.device.allocator
+
+  def testMetadataAttributes(self):
+    init_ary = np.zeros([3, 4], dtype=np.int32) + 2
+    ary = iree.runtime.asdevicearray(self.device, init_ary)
+    self.assertEqual([3, 4], ary.shape)
+    self.assertEqual(np.int32, ary.dtype)
+
+  def testExplicitHostTransfer(self):
+    init_ary = np.zeros([3, 4], dtype=np.int32) + 2
+    ary = iree.runtime.asdevicearray(self.device, init_ary)
+    self.assertEqual(repr(ary), "<IREE DeviceArray: shape=[3, 4], dtype=int32>")
+    self.assertFalse(ary.is_host_accessible)
+
+    # Explicit transfer.
+    cp = ary.to_host()
+    np.testing.assert_array_equal(cp, init_ary)
+    self.assertTrue(ary.is_host_accessible)
+
+  def testOverrideDtype(self):
+    init_ary = np.zeros([3, 4], dtype=np.int32) + 2
+    buffer_view = self.allocator.allocate_buffer_copy(
+        memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+        allowed_usage=iree.runtime.BufferUsage.CONSTANT,
+        buffer=init_ary,
+        element_type=iree.runtime.HalElementType.SINT_32)
+
+    ary = iree.runtime.DeviceArray(self.device,
+                                   buffer_view,
+                                   override_dtype=np.float32)
+
+    # Explicit transfer.
+    cp = ary.to_host()
+    self.assertEqual(cp.dtype, np.float32)
+    np.testing.assert_array_equal(cp, init_ary.astype(np.float32))
+    self.assertTrue(ary.is_host_accessible)
+
+  def testIllegalImplicitHostTransfer(self):
+    init_ary = np.zeros([3, 4], dtype=np.int32) + 2
+    ary = iree.runtime.asdevicearray(self.device, init_ary)
+    # Implicit transfer.
+    with self.assertRaises(ValueError):
+      _ = np.asarray(ary)
+
+  def testImplicitHostArithmetic(self):
+    init_ary = np.zeros([3, 4], dtype=np.int32) + 2
+    ary = iree.runtime.asdevicearray(self.device,
+                                     init_ary,
+                                     implicit_host_transfer=True)
+    sum = ary + init_ary
+    np.testing.assert_array_equal(sum, init_ary + 2)
+    self.assertTrue(ary.is_host_accessible)
+
+  def testArrayFunctions(self):
+    init_ary = np.zeros([3, 4], dtype=np.float32) + 2
+    ary = iree.runtime.asdevicearray(self.device,
+                                     init_ary,
+                                     implicit_host_transfer=True)
+    f = np.isfinite(ary)
+    self.assertTrue(f.all())
+
+  def testIteration(self):
+    init_ary = np.array([0, 1, 2, 3, 4, 5])
+    ary = iree.runtime.asdevicearray(self.device,
+                                     init_ary,
+                                     implicit_host_transfer=True)
+
+    for index, value in enumerate(ary):
+      self.assertEqual(index, value)
+
+  def testSubscriptable(self):
+    init_ary = np.array([0, 1, 2, 3, 4, 5])
+    ary = iree.runtime.asdevicearray(self.device,
+                                     init_ary,
+                                     implicit_host_transfer=True)
+
+    for index in range(0, 6):
+      value = ary[index]
+      self.assertEqual(index, value)
+
+  def testReshape(self):
+    init_ary = np.zeros([3, 4], dtype=np.float32) + 2
+    ary = iree.runtime.asdevicearray(self.device,
+                                     init_ary,
+                                     implicit_host_transfer=True)
+    reshaped = ary.reshape((4, 3))
+    self.assertEqual((4, 3), reshaped.shape)
+
+    np_reshaped = np.reshape(ary, (2, 2, 3))
+    self.assertEqual((2, 2, 3), np_reshaped.shape)
+
+  def testDeepcopy(self):
+    init_ary = np.zeros([3, 4], dtype=np.float32) + 2
+    orig_ary = iree.runtime.asdevicearray(self.device,
+                                          init_ary,
+                                          implicit_host_transfer=True)
+    copy_ary = copy.deepcopy(orig_ary)
+    self.assertIsNot(orig_ary, copy_ary)
+    np.testing.assert_array_equal(orig_ary, copy_ary)
+
+  def testAsType(self):
+    init_ary = np.zeros([3, 4], dtype=np.int32) + 2
+    orig_ary = iree.runtime.asdevicearray(self.device,
+                                          init_ary,
+                                          implicit_host_transfer=True)
+    # Same dtype, no copy.
+    i32_nocopy = orig_ary.astype(np.int32, copy=False)
+    self.assertIs(orig_ary, i32_nocopy)
+
+    # Same dtype, copy.
+    i32_nocopy = orig_ary.astype(np.int32)
+    self.assertIsNot(orig_ary, i32_nocopy)
+    np.testing.assert_array_equal(orig_ary, i32_nocopy)
+
+    # Different dtype, copy.
+    f32_copy = orig_ary.astype(np.float32)
+    self.assertIsNot(orig_ary, f32_copy)
+    self.assertEqual(f32_copy.dtype, np.float32)
+    np.testing.assert_array_equal(orig_ary.astype(np.float32), f32_copy)
+
+  def testBool(self):
+    init_ary = np.zeros([3, 4], dtype=np.bool_)
+    init_ary[1] = True  # Set some non-zero value.
+    ary = iree.runtime.asdevicearray(self.device, init_ary)
+    self.assertEqual(repr(ary), "<IREE DeviceArray: shape=[3, 4], dtype=bool>")
+    np.testing.assert_array_equal(ary.to_host(), init_ary)
+
+
+if __name__ == "__main__":
+  unittest.main()
diff --git a/runtime/bindings/python/iree/runtime/binding.h b/runtime/bindings/python/iree/runtime/binding.h
new file mode 100644
index 0000000..64240d4
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/binding.h
@@ -0,0 +1,99 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_BINDING_H_
+#define IREE_BINDINGS_PYTHON_IREE_BINDING_H_
+
+#include <optional>
+#include <vector>
+
+#include "iree/base/api.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+
+namespace iree {
+namespace python {
+
+namespace py = pybind11;
+
+template <typename T>
+struct ApiPtrAdapter {};
+
+template <typename Self, typename T>
+class ApiRefCounted {
+ public:
+  ApiRefCounted() : instance_(nullptr) {}
+  ApiRefCounted(ApiRefCounted& other) : instance_(other.instance_) { Retain(); }
+  ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
+    other.instance_ = nullptr;
+  }
+  ApiRefCounted& operator=(ApiRefCounted&& other) {
+    instance_ = other.instance_;
+    other.instance_ = nullptr;
+    return *this;
+  }
+  void operator=(const ApiRefCounted&) = delete;
+
+  ~ApiRefCounted() { Release(); }
+
+  // Steals the reference to the object referenced by the given raw pointer and
+  // returns a wrapper (transfers ownership).
+  static Self StealFromRawPtr(T* retained_inst) {
+    auto self = Self();
+    self.instance_ = retained_inst;
+    return self;
+  }
+
+  // Retains the object referenced by the given raw pointer and returns
+  // a wrapper.
+  static Self BorrowFromRawPtr(T* non_retained_inst) {
+    auto self = Self();
+    self.instance_ = non_retained_inst;
+    if (non_retained_inst) {
+      ApiPtrAdapter<T>::Retain(non_retained_inst);
+    }
+    return self;
+  }
+
+  // Whether it is nullptr.
+  operator bool() const { return instance_; }
+
+  T* steal_raw_ptr() {
+    T* ret = instance_;
+    instance_ = nullptr;
+    return ret;
+  }
+
+  T* raw_ptr() {
+    if (!instance_) {
+      throw std::invalid_argument("API object is null");
+    }
+    return instance_;
+  }
+
+  const T* raw_ptr() const {
+    return const_cast<ApiRefCounted*>(this)->raw_ptr();
+  }
+
+  void Retain() {
+    if (instance_) {
+      ApiPtrAdapter<T>::Retain(instance_);
+    }
+  }
+  void Release() {
+    if (instance_) {
+      ApiPtrAdapter<T>::Release(instance_);
+    }
+  }
+
+ private:
+  T* instance_;
+};
+
+}  // namespace python
+}  // namespace iree
+
+#endif  // IREE_BINDINGS_PYTHON_IREE_BINDING_H_
diff --git a/runtime/bindings/python/iree/runtime/export.def b/runtime/bindings/python/iree/runtime/export.def
new file mode 100644
index 0000000..24e1f7b
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/export.def
@@ -0,0 +1,9 @@
+;; Copyright 2019 The IREE Authors
+;;
+;; Licensed under the Apache License v2.0 with LLVM Exceptions.
+;; See https://llvm.org/LICENSE.txt for license information.
+;; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+LIBRARY BINDING
+EXPORTS
+  PyInit_binding @1
diff --git a/runtime/bindings/python/iree/runtime/flags.py b/runtime/bindings/python/iree/runtime/flags.py
new file mode 100644
index 0000000..a7b1020
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/flags.py
@@ -0,0 +1,12 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .binding import parse_flags
+
+# When enabled, performs additional function input validation checks. In the
+# event of errors, this will yield nicer error messages but comes with a
+# runtime cost.
+FUNCTION_INPUT_VALIDATION = True
diff --git a/runtime/bindings/python/iree/runtime/flags_test.py b/runtime/bindings/python/iree/runtime/flags_test.py
new file mode 100644
index 0000000..886176a
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/flags_test.py
@@ -0,0 +1,24 @@
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from iree import runtime as rt
+import numpy as np
+import unittest
+
+
+class FlagsTest(unittest.TestCase):
+
+  def testParse(self):
+    # We always have the logging verbose level available so use it.
+    rt.flags.parse_flags("--iree_v=1")
+
+  def testParseError(self):
+    with self.assertRaisesRegex(ValueError, "flag 'barbar' not recognized"):
+      rt.flags.parse_flags("--barbar")
+
+
+if __name__ == "__main__":
+  unittest.main()
diff --git a/runtime/bindings/python/iree/runtime/function.py b/runtime/bindings/python/iree/runtime/function.py
new file mode 100644
index 0000000..40eb89c
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/function.py
@@ -0,0 +1,390 @@
+# Lint as: python3
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Dict, Optional
+
+import json
+import logging
+
+import numpy as np
+
+from .binding import (
+    _invoke_statics,
+    ArgumentPacker,
+    BufferUsage,
+    HalBufferView,
+    HalDevice,
+    InvokeContext,
+    MemoryType,
+    VmContext,
+    VmFunction,
+    VmVariantList,
+)
+
+from . import tracing
+from .array_interop import (
+    map_dtype_to_element_type,
+    DeviceArray,
+)
+from .flags import (
+    FUNCTION_INPUT_VALIDATION,)
+
+__all__ = [
+    "FunctionInvoker",
+]
+
+
+class Invocation:
+  __slots__ = [
+      "current_arg",
+      "current_desc",
+      "current_return_list",
+      "current_return_index",
+      "device",
+  ]
+
+  def __init__(self, device: HalDevice):
+    self.device = device
+    # Captured during arg/ret processing to emit better error messages.
+    self.current_arg = None
+    self.current_desc = None
+    self.current_return_list = None
+    self.current_return_index = 0
+
+  def summarize_arg_error(self) -> str:
+    if self.current_arg is None:
+      return ""
+    if isinstance(self.current_arg, np.ndarray):
+      current_arg_repr = (
+          f"ndarray({self.current_arg.shape}, {self.current_arg.dtype})")
+    else:
+      current_arg_repr = repr(self.current_arg)
+    return f"{repr(current_arg_repr)} with description {self.current_desc}"
+
+  def summarize_return_error(self) -> str:
+    if self.current_return_list is None:
+      return ""
+    try:
+      vm_repr = f"{self.current_return_index}@{self.current_return_list}"
+    except:
+      vm_repr = "<error printing list item>"
+    return f"{vm_repr} with description {self.current_desc}"
+
+
+class FunctionInvoker:
+  """Wraps a VmFunction, enabling invocations against it."""
+  __slots__ = [
+      "_vm_context",
+      "_device",
+      "_vm_function",
+      "_abi_dict",
+      "_arg_descs",
+      "_arg_packer",
+      "_ret_descs",
+      "_has_inlined_results",
+      "_tracer",
+  ]
+
+  def __init__(self, vm_context: VmContext, device: HalDevice,
+               vm_function: VmFunction,
+               tracer: Optional[tracing.ContextTracer]):
+    self._vm_context = vm_context
+    # TODO: Needing to know the precise device to allocate on here is bad
+    # layering and will need to be fixed in some fashion if/when doing
+    # heterogenous dispatch.
+    self._device = device
+    self._vm_function = vm_function
+    self._tracer = tracer
+    self._abi_dict = None
+    self._arg_descs = None
+    self._ret_descs = None
+    self._has_inlined_results = False
+    self._parse_abi_dict(vm_function)
+    self._arg_packer = ArgumentPacker(_invoke_statics, self._arg_descs)
+
+  @property
+  def vm_function(self) -> VmFunction:
+    return self._vm_function
+
+  def __call__(self, *args, **kwargs):
+    invoke_context = InvokeContext(self._device)
+    arg_list = self._arg_packer.pack(invoke_context, args, kwargs)
+
+    call_trace = None  # type: Optional[tracing.CallTrace]
+    if self._tracer:
+      call_trace = self._tracer.start_call(self._vm_function)
+    try:
+      # Initialize the capacity to our total number of args, since we should
+      # be below that when doing a flat invocation. May want to be more
+      # conservative here when considering nesting.
+      inv = Invocation(self._device)
+      ret_descs = self._ret_descs
+
+      ret_list = VmVariantList(len(ret_descs) if ret_descs is not None else 1)
+      if call_trace:
+        call_trace.add_vm_list(arg_list, "args")
+      self._invoke(arg_list, ret_list)
+      if call_trace:
+        call_trace.add_vm_list(ret_list, "results")
+
+      # Un-inline the results to align with reflection, as needed.
+      reflection_aligned_ret_list = ret_list
+      if self._has_inlined_results:
+        reflection_aligned_ret_list = VmVariantList(1)
+        reflection_aligned_ret_list.push_list(ret_list)
+      returns = _extract_vm_sequence_to_python(inv, reflection_aligned_ret_list,
+                                               ret_descs)
+      return_arity = len(returns)
+      if return_arity == 1:
+        return returns[0]
+      elif return_arity == 0:
+        return None
+      else:
+        return tuple(returns)
+    finally:
+      if call_trace:
+        call_trace.end_call()
+
+  # Break out invoke so it shows up in profiles.
+  def _invoke(self, arg_list, ret_list):
+    self._vm_context.invoke(self._vm_function, arg_list, ret_list)
+
+  def _parse_abi_dict(self, vm_function: VmFunction):
+    reflection = vm_function.reflection
+    abi_json = reflection.get("iree.abi")
+    if abi_json is None:
+      # It is valid to have no reflection data, and rely on pure dynamic
+      # dispatch.
+      logging.debug(
+          "Function lacks reflection data. Interop will be limited: %r",
+          vm_function)
+      return
+    try:
+      self._abi_dict = json.loads(abi_json)
+    except json.JSONDecodeError as e:
+      raise RuntimeError(
+          f"Reflection metadata is not valid JSON: {abi_json}") from e
+    try:
+      self._arg_descs = self._abi_dict["a"]
+      self._ret_descs = self._abi_dict["r"]
+    except KeyError as e:
+      raise RuntimeError(
+          f"Malformed function reflection metadata: {reflection}") from e
+    if not isinstance(self._arg_descs, list) or not isinstance(
+        self._ret_descs, list):
+      raise RuntimeError(
+          f"Malformed function reflection metadata structure: {reflection}")
+
+    # Detect whether the results are a slist/stuple/sdict, which indicates
+    # that they are inlined with the function's results.
+    if len(self._ret_descs) == 1:
+      maybe_inlined = self._ret_descs[0]
+      if maybe_inlined and maybe_inlined[0] in ["slist", "stuple", "sdict"]:
+        self._has_inlined_results = True
+
+  def __repr__(self):
+    return repr(self._vm_function)
+
+
+# VM to Python converters. All take:
+#   inv: Invocation
+#   vm_list: VmVariantList to read from
+#   vm_index: Index in the vm_list to extract
+#   desc: The ABI descriptor list (or None if in dynamic mode)
+# Return the corresponding Python object.
+
+
+def _vm_to_ndarray(inv: Invocation, vm_list: VmVariantList, vm_index: int,
+                   desc):
+  # The descriptor for an ndarray is like:
+  #   ["ndarray", "<dtype>", <rank>, <dim>...]
+  #   ex: ['ndarray', 'i32', 1, 25948]
+  buffer_view = vm_list.get_as_buffer_view(vm_index)
+  dtype_str = desc[1]
+  try:
+    dtype = ABI_TYPE_TO_DTYPE[dtype_str]
+  except KeyError:
+    _raise_return_error(inv, f"unrecognized dtype '{dtype_str}'")
+  x = DeviceArray(inv.device,
+                  buffer_view,
+                  implicit_host_transfer=True,
+                  override_dtype=dtype)
+  return x
+
+
+def _vm_to_sdict(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+  # The descriptor for an sdict is like:
+  #   ['sdict', ['key1', value1], ...]
+  sub_vm_list = vm_list.get_as_list(vm_index)
+  item_keys = []
+  item_descs = []
+  for k, d in desc[1:]:
+    item_keys.append(k)
+    item_descs.append(d)
+  py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs)
+  return dict(zip(item_keys, py_items))
+
+
+def _vm_to_slist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+  # The descriptor for an slist is like:
+  #   ['slist, item1, ...]
+  sub_vm_list = vm_list.get_as_list(vm_index)
+  item_descs = desc[1:]
+  py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs)
+  return py_items
+
+
+def _vm_to_stuple(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+  return tuple(_vm_to_slist(inv, vm_list, vm_index, desc))
+
+
+def _vm_to_scalar(type_bound: type):
+
+  def convert(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+    value = vm_list.get_variant(vm_index)
+    if not isinstance(value, type_bound):
+      raise ReturnError(
+          f"expected an {type_bound} value but got {value.__class__}")
+    return value
+
+  return convert
+
+
+def _vm_to_pylist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+  # The descriptor for a pylist is like:
+  #   ['pylist', element_type]
+  sub_vm_list = vm_list.get_as_list(vm_index)
+  element_type_desc = desc[1:]
+  py_items = _extract_vm_sequence_to_python(
+      inv, sub_vm_list, element_type_desc * len(sub_vm_list))
+  return py_items
+
+
+VM_TO_PYTHON_CONVERTERS = {
+    "ndarray": _vm_to_ndarray,
+    "sdict": _vm_to_sdict,
+    "slist": _vm_to_slist,
+    "stuple": _vm_to_stuple,
+    "py_homogeneous_list": _vm_to_pylist,
+
+    # Scalars.
+    "i8": _vm_to_scalar(int),
+    "i16": _vm_to_scalar(int),
+    "i32": _vm_to_scalar(int),
+    "i64": _vm_to_scalar(int),
+    "f16": _vm_to_scalar(float),
+    "f32": _vm_to_scalar(float),
+    "f64": _vm_to_scalar(float),
+    "bf16": _vm_to_scalar(float),
+}
+
+ABI_TYPE_TO_DTYPE = {
+    # TODO: Others.
+    "f32": np.float32,
+    "i32": np.int32,
+    "i64": np.int64,
+    "f64": np.float64,
+    "i16": np.int16,
+    "i8": np.int8,
+    "i1": np.bool_,
+}
+
+# When we get an ndarray as an argument and are implicitly mapping it to a
+# buffer view, flags for doing so.
+IMPLICIT_BUFFER_ARG_MEMORY_TYPE = MemoryType.DEVICE_LOCAL
+IMPLICIT_BUFFER_ARG_USAGE = (BufferUsage.DISPATCH | BufferUsage.TRANSFER |
+                             BufferUsage.MAPPING)
+
+
+def _is_ndarray_descriptor(desc):
+  return desc and desc[0] == "ndarray"
+
+
+def _is_0d_ndarray_descriptor(desc):
+  # Example: ["ndarray", "f32", 0]
+  return desc and desc[0] == "ndarray" and desc[2] == 0
+
+
+def _cast_scalar_to_ndarray(inv: Invocation, x, desc):
+  # Example descriptor: ["ndarray", "f32", 0]
+  dtype_str = desc[1]
+  try:
+    dtype = ABI_TYPE_TO_DTYPE[dtype_str]
+  except KeyError:
+    _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'")
+  return dtype(x)
+
+
+class ArgumentError(ValueError):
+  pass
+
+
+class ReturnError(ValueError):
+  pass
+
+
+def _raise_argument_error(inv: Invocation,
+                          summary: str,
+                          e: Optional[Exception] = None):
+  new_e = ArgumentError(
+      f"Error passing argument: {summary} "
+      f"(while encoding argument {inv.summarize_arg_error()})")
+  if e:
+    raise new_e from e
+  else:
+    raise new_e
+
+
+def _raise_return_error(inv: Invocation,
+                        summary: str,
+                        e: Optional[Exception] = None):
+  new_e = ReturnError(f"Error processing function return: {summary} "
+                      f"(while decoding return {inv.summarize_return_error()})")
+  if e:
+    raise new_e from e
+  else:
+    raise new_e
+
+
+def _extract_vm_sequence_to_python(inv: Invocation, vm_list, descs):
+  vm_list_arity = len(vm_list)
+  if descs is None:
+    descs = [None] * vm_list_arity
+  elif vm_list_arity != len(descs):
+    _raise_return_error(
+        inv, f"mismatched return arity: {vm_list_arity} vs {len(descs)}")
+  results = []
+  for vm_index, desc in zip(range(vm_list_arity), descs):
+    inv.current_return_list = vm_list
+    inv.current_return_index = vm_index
+    inv.current_desc = desc
+    if desc is None:
+      # Dynamic (non reflection mode).
+      converted = vm_list.get_variant(vm_index)
+      # Special case: Upgrade HalBufferView to a DeviceArray. We do that here
+      # since this is higher level and it preserves layering. Note that
+      # the reflection case also does this conversion.
+      if isinstance(converted, HalBufferView):
+        converted = DeviceArray(inv.device,
+                                converted,
+                                implicit_host_transfer=True)
+    else:
+      # Known type descriptor.
+      vm_type = desc if isinstance(desc, str) else desc[0]
+      try:
+        converter = VM_TO_PYTHON_CONVERTERS[vm_type]
+      except KeyError:
+        _raise_return_error(inv, f"cannot map VM type to Python: {vm_type}")
+      try:
+        converted = converter(inv, vm_list, vm_index, desc)
+      except ReturnError:
+        raise
+      except Exception as e:
+        _raise_return_error(inv, f"exception converting from VM type to Python",
+                            e)
+    results.append(converted)
+  return results
diff --git a/runtime/bindings/python/iree/runtime/function_test.py b/runtime/bindings/python/iree/runtime/function_test.py
new file mode 100644
index 0000000..57b8481
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/function_test.py
@@ -0,0 +1,596 @@
+# Lint as: python3
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import numpy as np
+
+from absl.testing import absltest
+
+from iree import runtime as rt
+from iree.runtime.function import (
+    FunctionInvoker,
+    IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
+    IMPLICIT_BUFFER_ARG_USAGE,
+)
+from iree.runtime.binding import VmVariantList
+
+
+class MockVmContext:
+
+  def __init__(self, invoke_callback):
+    self._invoke_callback = invoke_callback
+    self.invocations = []
+
+  def invoke(self, vm_function, arg_list, ret_list):
+    self._invoke_callback(arg_list, ret_list)
+    self.invocations.append((vm_function, arg_list, ret_list))
+    print(f"INVOKE: {arg_list} -> {ret_list}")
+
+  @property
+  def mock_arg_reprs(self):
+    return repr([arg_list for _, arg_list, _ in self.invocations])
+
+
+class MockVmFunction:
+
+  def __init__(self, reflection):
+    self.reflection = reflection
+
+
+class FunctionTest(absltest.TestCase):
+
+  @classmethod
+  def setUpClass(cls):
+    # Doesn't matter what device. We just need one.
+    config = rt.Config("vmvx")
+    cls.device = config.device
+
+  def testNoReflectionScalars(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+      ret_list.push_int(4)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={})
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker(1, 2)
+    self.assertEqual("[<VmVariantList(2): [1, 2]>]", vm_context.mock_arg_reprs)
+    self.assertEqual((3, 4), result)
+
+  def testKeywordArgs(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [
+                        "i32",
+                        ["named", "a", "i32"],
+                        ["named", "b", "i32"],
+                    ],
+                    "r": ["i32",],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker(-1, a=1, b=2)
+    self.assertEqual("[<VmVariantList(3): [-1, 1, 2]>]",
+                     vm_context.mock_arg_reprs)
+    self.assertEqual(3, result)
+
+  def testListArg(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi":
+            json.dumps({
+                "a": [["slist", "i32", "i32"],],
+                "r": ["i32",],
+            })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    _ = invoker([2, 3])
+    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+                     vm_context.mock_arg_reprs)
+
+  def testListArgNoReflection(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={})
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    _ = invoker([2, 3])
+    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+                     vm_context.mock_arg_reprs)
+
+  def testListArgArityMismatch(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi":
+            json.dumps({
+                "a": [["slist", "i32", "i32"],],
+                "r": ["i32",],
+            })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    with self.assertRaisesRegex(ValueError,
+                                "expected a sequence with 2 values. got:"):
+      _ = invoker([2, 3, 4])
+
+  def testTupleArg(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi":
+            json.dumps({
+                "a": [["stuple", "i32", "i32"],],
+                "r": ["i32",],
+            })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    _ = invoker((2, 3))
+    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+                     vm_context.mock_arg_reprs)
+
+  def testDictArg(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [["sdict", ["a", "i32"], ["b", "i32"]],],
+                    "r": ["i32",],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    _ = invoker({"b": 3, "a": 2})
+    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+                     vm_context.mock_arg_reprs)
+
+  def testDictArgArityMismatch(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [["sdict", ["a", "i32"], ["b", "i32"]],],
+                    "r": ["i32",],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    with self.assertRaisesRegex(ValueError,
+                                "expected a dict with 2 values. got:"):
+      _ = invoker({"a": 2, "b": 3, "c": 4})
+
+  def testDictArgKeyError(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [["sdict", ["a", "i32"], ["b", "i32"]],],
+                    "r": ["i32",],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    with self.assertRaisesRegex(ValueError, "could not get item 'b' from: "):
+      _ = invoker({"a": 2, "c": 3})
+
+  def testDictArgNoReflection(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={})
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    _ = invoker({"b": 3, "a": 2})
+    self.assertEqual("[<VmVariantList(1): [List[2, 3]]>]",
+                     vm_context.mock_arg_reprs)
+
+  def testInlinedResults(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+      ret_list.push_int(4)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi": json.dumps({
+            "a": [],
+            "r": [["slist", "i32", "i32"]],
+        })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker()
+    self.assertEqual([3, 4], result)
+
+  def testNestedResults(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+      sub_list = VmVariantList(2)
+      sub_dict = VmVariantList(2)
+      sub_dict.push_int(100)
+      sub_dict.push_int(200)
+      sub_list.push_list(sub_dict)
+      sub_list.push_int(6)
+      ret_list.push_list(sub_list)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [],
+                    "r": [
+                        "i32",
+                        [
+                            "slist",
+                            ["sdict", ["bar", "i32"], ["foo", "i32"]],
+                            "i64",
+                        ]
+                    ],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker()
+    self.assertEqual((3, [{"bar": 100, "foo": 200}, 6]), result)
+
+  def testMissingPositional(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [
+                        "i32",
+                        ["named", "a", "i32"],
+                        ["named", "b", "i32"],
+                    ],
+                    "r": ["i32",],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
+      result = invoker(a=1, b=1)
+
+  def testMissingPositionalNdarray(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [
+                        ["ndarray", "i32", 1, 1],
+                        ["named", "a", ["ndarray", "i32", 1, 1]],
+                        ["named", "b", ["ndarray", "i32", 1, 1]],
+                    ],
+                    "r": ["i32",],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
+      result = invoker(a=1, b=1)
+
+  def testMissingKeyword(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [
+                        "i32",
+                        ["named", "a", "i32"],
+                        ["named", "b", "i32"],
+                    ],
+                    "r": ["i32",],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
+      result = invoker(-1, a=1)
+
+  def testMissingKeywordNdArray(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [
+                        ["ndarray", "i32", 1, 1],
+                        ["named", "a", ["ndarray", "i32", 1, 1]],
+                        ["named", "b", ["ndarray", "i32", 1, 1]],
+                    ],
+                    "r": ["i32",],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    with self.assertRaisesRegex(ValueError, "mismatched call arity:"):
+      result = invoker(-1, a=1)
+
+  def testExtraKeyword(self):
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_int(3)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(
+        reflection={
+            "iree.abi":
+                json.dumps({
+                    "a": [
+                        "i32",
+                        ["named", "a", "i32"],
+                        ["named", "b", "i32"],
+                    ],
+                    "r": ["i32",],
+                })
+        })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    with self.assertRaisesRegex(ValueError, "specified kwarg 'c' is unknown"):
+      result = invoker(-1, a=1, b=2, c=3)
+
+  def testNdarrayArg(self):
+    arg_array = np.asarray([1, 0], dtype=np.int32)
+
+    invoked_arg_list = None
+
+    def invoke(arg_list, ret_list):
+      nonlocal invoked_arg_list
+      invoked_arg_list = arg_list
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi": json.dumps({
+            "a": [["ndarray", "i32", 1, 2]],
+            "r": [],
+        })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker(arg_array)
+    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
+                     repr(invoked_arg_list))
+
+  def testDeviceArrayArg(self):
+    # Note that since the device array is set up to disallow implicit host
+    # transfers, this also verifies that no accidental/automatic transfers
+    # are done as part of marshalling the array to the function.
+    arg_array = rt.asdevicearray(self.device,
+                                 np.asarray([1, 0], dtype=np.int32),
+                                 implicit_host_transfer=False)
+
+    invoked_arg_list = None
+
+    def invoke(arg_list, ret_list):
+      nonlocal invoked_arg_list
+      invoked_arg_list = arg_list
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi": json.dumps({
+            "a": [["ndarray", "i32", 1, 2]],
+            "r": [],
+        })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker(arg_array)
+    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
+                     repr(invoked_arg_list))
+
+  def testBufferViewArg(self):
+    arg_buffer_view = self.device.allocator.allocate_buffer_copy(
+        memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
+        allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
+        buffer=np.asarray([1, 0], dtype=np.int32),
+        element_type=rt.HalElementType.SINT_32)
+
+    invoked_arg_list = None
+
+    def invoke(arg_list, ret_list):
+      nonlocal invoked_arg_list
+      invoked_arg_list = arg_list
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi": json.dumps({
+            "a": [["ndarray", "i32", 1, 2]],
+            "r": [],
+        })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    _ = invoker(arg_buffer_view)
+    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
+                     repr(invoked_arg_list))
+
+  def testNdarrayArgNoReflection(self):
+    arg_array = np.asarray([1, 0], dtype=np.int32)
+
+    invoked_arg_list = None
+
+    def invoke(arg_list, ret_list):
+      nonlocal invoked_arg_list
+      invoked_arg_list = arg_list
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={})
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker(arg_array)
+    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
+                     repr(invoked_arg_list))
+
+  def testDeviceArrayArgNoReflection(self):
+    # Note that since the device array is set up to disallow implicit host
+    # transfers, this also verifies that no accidental/automatic transfers
+    # are done as part of marshalling the array to the function.
+    arg_array = rt.asdevicearray(self.device,
+                                 np.asarray([1, 0], dtype=np.int32),
+                                 implicit_host_transfer=False)
+
+    invoked_arg_list = None
+
+    def invoke(arg_list, ret_list):
+      nonlocal invoked_arg_list
+      invoked_arg_list = arg_list
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={})
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker(arg_array)
+    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
+                     repr(invoked_arg_list))
+
+  def testBufferViewArgNoReflection(self):
+    arg_buffer_view = self.device.allocator.allocate_buffer_copy(
+        memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
+        allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
+        buffer=np.asarray([1, 0], dtype=np.int32),
+        element_type=rt.HalElementType.SINT_32)
+
+    invoked_arg_list = None
+
+    def invoke(arg_list, ret_list):
+      nonlocal invoked_arg_list
+      invoked_arg_list = arg_list
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={})
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    _ = invoker(arg_buffer_view)
+    self.assertEqual("<VmVariantList(1): [HalBufferView(2:0x20000011)]>",
+                     repr(invoked_arg_list))
+
+  def testReturnBufferView(self):
+    result_array = np.asarray([1, 0], dtype=np.int32)
+
+    def invoke(arg_list, ret_list):
+      buffer_view = self.device.allocator.allocate_buffer_copy(
+          memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
+          allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
+          buffer=result_array,
+          element_type=rt.HalElementType.SINT_32)
+      ret_list.push_buffer_view(buffer_view)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi": json.dumps({
+            "a": [],
+            "r": [["ndarray", "i32", 1, 2]],
+        })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker()
+    np.testing.assert_array_equal([1, 0], result)
+
+  def testReturnBufferViewNoReflection(self):
+    result_array = np.asarray([1, 0], dtype=np.int32)
+
+    def invoke(arg_list, ret_list):
+      buffer_view = self.device.allocator.allocate_buffer_copy(
+          memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
+          allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
+          buffer=result_array,
+          element_type=rt.HalElementType.SINT_32)
+      ret_list.push_buffer_view(buffer_view)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={})
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker()
+    np.testing.assert_array_equal([1, 0], result)
+
+  # TODO: Fill out all return types.
+  def testReturnTypeNdArrayBool(self):
+    result_array = np.asarray([1, 0], dtype=np.int8)
+
+    def invoke(arg_list, ret_list):
+      buffer_view = self.device.allocator.allocate_buffer_copy(
+          memory_type=IMPLICIT_BUFFER_ARG_MEMORY_TYPE,
+          allowed_usage=IMPLICIT_BUFFER_ARG_USAGE,
+          buffer=result_array,
+          element_type=rt.HalElementType.UINT_8)
+      ret_list.push_buffer_view(buffer_view)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi": json.dumps({
+            "a": [],
+            "r": [["ndarray", "i1", 1, 2]],
+        })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker()
+    # assertEqual on bool arrays is fraught for... reasons.
+    np.testing.assert_array_equal([True, False], result)
+
+  def testReturnTypeList(self):
+    vm_list = VmVariantList(2)
+    vm_list.push_int(1)
+    vm_list.push_int(2)
+
+    def invoke(arg_list, ret_list):
+      ret_list.push_list(vm_list)
+
+    vm_context = MockVmContext(invoke)
+    vm_function = MockVmFunction(reflection={
+        "iree.abi":
+            json.dumps({
+                "a": [],
+                "r": [["py_homogeneous_list", "i64"]],
+            })
+    })
+    invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
+    result = invoker()
+    self.assertEqual("[1, 2]", repr(result))
+
+
+if __name__ == "__main__":
+  absltest.main()
diff --git a/runtime/bindings/python/iree/runtime/hal.cc b/runtime/bindings/python/iree/runtime/hal.cc
new file mode 100644
index 0000000..1d577a6
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/hal.cc
@@ -0,0 +1,514 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./hal.h"
+
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+#include "pybind11/numpy.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
+// out of scope.
+class PyBufferReleaser {
+ public:
+  PyBufferReleaser(Py_buffer& b) : b_(b) {}
+  ~PyBufferReleaser() { PyBuffer_Release(&b_); }
+
+ private:
+  Py_buffer& b_;
+};
+
+static std::string ToHexString(const uint8_t* data, size_t length) {
+  static constexpr char kHexChars[] = {'0', '1', '2', '3', '4', '5', '6', '7',
+                                       '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
+  std::string s(length * 2, ' ');
+  for (size_t i = 0; i < length; ++i) {
+    s[2 * i + 0] = kHexChars[(data[i] & 0xF0) >> 4];
+    s[2 * i + 1] = kHexChars[(data[i] & 0x0F) >> 0];
+  }
+  return s;
+}
+static std::string ToHexString(uint32_t value) {
+  return ToHexString((const uint8_t*)&value, sizeof(value));
+}
+
+}  // namespace
+
+//------------------------------------------------------------------------------
+// HalAllocator
+//------------------------------------------------------------------------------
+
+py::dict HalAllocator::QueryStatistics() {
+  py::dict items;
+  iree_hal_allocator_statistics_t stats;
+  iree_hal_allocator_query_statistics(raw_ptr(), &stats);
+#if IREE_STATISTICS_ENABLE
+  items["host_bytes_peak"] = stats.host_bytes_peak;
+  items["host_bytes_allocated"] = stats.host_bytes_allocated;
+  items["host_bytes_freed"] = stats.host_bytes_freed;
+  items["device_bytes_peak"] = stats.device_bytes_peak;
+  items["device_bytes_allocated"] = stats.device_bytes_allocated;
+  items["device_bytes_freed"] = stats.device_bytes_freed;
+#endif
+  return items;
+}
+
+py::str HalAllocator::FormattedStatistics() {
+  // Perform all allocating string manipulation without early exit.
+  iree_string_builder_t builder;
+  iree_string_builder_initialize(iree_allocator_system(), &builder);
+  iree_hal_allocator_statistics_t stats;
+  iree_hal_allocator_query_statistics(raw_ptr(), &stats);
+  auto status = iree_hal_allocator_statistics_format(&stats, &builder);
+  iree_string_view_t view = iree_string_builder_view(&builder);
+  py::str result = py::str(view.data, view.size);
+  iree_string_builder_deinitialize(&builder);
+
+  // Check/raise after all memory alloc/dealloc.
+  CheckApiStatus(status, "unable to format statistics");
+  return result;
+}
+
+py::object HalAllocator::AllocateBufferCopy(
+    int memory_type, int allowed_usage, py::object buffer,
+    std::optional<iree_hal_element_types_t> element_type) {
+  IREE_TRACE_SCOPE0("HalAllocator::AllocateBufferCopy");
+  // Request a view of the buffer (use the raw python C API to avoid
+  // some allocation and copying at the pybind level).
+  Py_buffer py_view;
+  // Note that only C-Contiguous ND-arrays are presently supported, so
+  // only request that via PyBUF_ND. Long term, we should consult an
+  // "oracle" in the runtime to determine the precise required format
+  // and set flags accordingly (and fallback/copy on failure).
+  int flags = PyBUF_FORMAT | PyBUF_ND;
+
+  // Acquire the backing buffer and setup RAII release.
+  if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
+    // The GetBuffer call is required to set an appropriate error.
+    throw py::error_already_set();
+  }
+  PyBufferReleaser py_view_releaser(py_view);
+
+  iree_hal_buffer_params_t params = {0};
+  // TODO: Should not require host visible :(
+  params.type = memory_type | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE;
+  params.usage = allowed_usage;
+
+  iree_hal_buffer_t* hal_buffer = nullptr;
+  iree_status_t status = iree_ok_status();
+  {
+    py::gil_scoped_release release;
+    status = iree_hal_allocator_allocate_buffer(
+        raw_ptr(), params, py_view.len,
+        iree_make_const_byte_span(py_view.buf, py_view.len), &hal_buffer);
+  }
+  CheckApiStatus(status, "Failed to allocate device visible buffer");
+
+  if (!element_type) {
+    return py::cast(HalBuffer::StealFromRawPtr(hal_buffer),
+                    py::return_value_policy::move);
+  }
+
+  // Create the buffer_view. (note that numpy shape is ssize_t, so we need to
+  // copy).
+  iree_hal_encoding_type_t encoding_type =
+      IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
+  std::vector<iree_hal_dim_t> dims(py_view.ndim);
+  std::copy(py_view.shape, py_view.shape + py_view.ndim, dims.begin());
+  iree_hal_buffer_view_t* hal_buffer_view;
+  CheckApiStatus(
+      iree_hal_buffer_view_create(
+          hal_buffer, dims.data(), dims.size(), *element_type, encoding_type,
+          iree_hal_allocator_host_allocator(raw_ptr()), &hal_buffer_view),
+      "Error allocating buffer_view");
+  iree_hal_buffer_release(hal_buffer);
+
+  return py::cast(HalBufferView::StealFromRawPtr(hal_buffer_view),
+                  py::return_value_policy::move);
+}
+
+//------------------------------------------------------------------------------
+// HalBuffer
+//------------------------------------------------------------------------------
+
+namespace {
+
+void AppendHalBufferRepr(iree_hal_buffer_t* buffer, std::string& repr) {
+  repr.append(std::to_string(iree_hal_buffer_byte_length(buffer)));
+  repr.append(" bytes (at offset ");
+  repr.append(std::to_string(iree_hal_buffer_byte_offset(buffer)));
+  repr.append(" into ");
+  repr.append(std::to_string(iree_hal_buffer_allocation_size(buffer)));
+  repr.append("), memory_type=");
+
+  // Memory type.
+  iree_bitfield_string_temp_t tmp;
+  iree_string_view_t sv;
+  sv = iree_hal_memory_type_format(iree_hal_buffer_memory_type(buffer), &tmp);
+  repr.append(sv.data, sv.size);
+
+  // Allowed access.
+  repr.append(", allowed_access=");
+  sv = iree_hal_memory_access_format(iree_hal_buffer_allowed_access(buffer),
+                                     &tmp);
+  repr.append(sv.data, sv.size);
+
+  // Allowed usage.
+  repr.append(", allowed_usage=");
+  sv =
+      iree_hal_buffer_usage_format(iree_hal_buffer_allowed_usage(buffer), &tmp);
+  repr.append(sv.data, sv.size);
+}
+
+}  // namespace
+
+py::str HalBuffer::Repr() {
+  std::string repr("<HalBuffer ");
+  AppendHalBufferRepr(raw_ptr(), repr);
+  repr.append(">");
+  return py::str(repr);
+}
+
+//------------------------------------------------------------------------------
+// HalBufferView
+//------------------------------------------------------------------------------
+
+py::str HalBufferView::Repr() {
+  std::string repr("<HalBufferView (");
+
+  // Shape.
+  iree_host_size_t rank = iree_hal_buffer_view_shape_rank(raw_ptr());
+  for (iree_host_size_t i = 0; i < rank; ++i) {
+    if (i > 0) {
+      repr.append(", ");
+    }
+    repr.append(std::to_string(iree_hal_buffer_view_shape_dim(raw_ptr(), i)));
+  }
+  repr.append(")");
+
+  // Element type.
+  repr.append(", element_type=0x");
+  auto element_type = iree_hal_buffer_view_element_type(raw_ptr());
+  repr.append(ToHexString(static_cast<uint32_t>(element_type)));
+
+  repr.append(", ");
+  AppendHalBufferRepr(iree_hal_buffer_view_buffer(raw_ptr()), repr);
+  repr.append(">");
+  return py::str(repr);
+}
+
+//------------------------------------------------------------------------------
+// HalDriver
+//------------------------------------------------------------------------------
+
+std::vector<std::string> HalDriver::Query() {
+  iree_hal_driver_info_t* driver_infos = NULL;
+  iree_host_size_t driver_info_count = 0;
+  CheckApiStatus(
+      iree_hal_driver_registry_enumerate(iree_hal_driver_registry_default(),
+                                         iree_allocator_system(), &driver_infos,
+                                         &driver_info_count),
+      "Error enumerating HAL drivers");
+  std::vector<std::string> driver_names(driver_info_count);
+  for (iree_host_size_t i = 0; i < driver_info_count; ++i) {
+    driver_names[i] = std::string(driver_infos[i].driver_name.data,
+                                  driver_infos[i].driver_name.size);
+  }
+  iree_allocator_free(iree_allocator_system(), driver_infos);
+  return driver_names;
+}
+
+HalDriver HalDriver::Create(const std::string& driver_name) {
+  iree_hal_driver_t* driver;
+  CheckApiStatus(iree_hal_driver_registry_try_create_by_name(
+                     iree_hal_driver_registry_default(),
+                     {driver_name.data(), driver_name.size()},
+                     iree_allocator_system(), &driver),
+                 "Error creating driver");
+  return HalDriver::StealFromRawPtr(driver);
+}
+
+HalDevice HalDriver::CreateDefaultDevice() {
+  iree_hal_device_t* device;
+  CheckApiStatus(iree_hal_driver_create_default_device(
+                     raw_ptr(), iree_allocator_system(), &device),
+                 "Error creating default device");
+  return HalDevice::StealFromRawPtr(device);
+}
+
+//------------------------------------------------------------------------------
+// Enum helpers
+//------------------------------------------------------------------------------
+
+namespace {
+
+py::object MapElementTypeToDType(iree_hal_element_type_t element_type) {
+  // See: https://docs.python.org/3/c-api/arg.html#numbers
+  // TODO: Handle dtypes that do not map to a code (i.e. fp16).
+  const char* dtype_code;
+  switch (element_type) {
+    case IREE_HAL_ELEMENT_TYPE_INT_8:
+    case IREE_HAL_ELEMENT_TYPE_SINT_8:
+      dtype_code = "b";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_UINT_8:
+      dtype_code = "B";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_INT_16:
+    case IREE_HAL_ELEMENT_TYPE_SINT_16:
+      dtype_code = "h";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_UINT_16:
+      dtype_code = "H";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_INT_32:
+    case IREE_HAL_ELEMENT_TYPE_SINT_32:
+      dtype_code = "i";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_UINT_32:
+      dtype_code = "I";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_INT_64:
+    case IREE_HAL_ELEMENT_TYPE_SINT_64:
+      dtype_code = "l";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_UINT_64:
+      dtype_code = "L";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
+      dtype_code = "f";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
+      dtype_code = "d";
+      break;
+    case IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER, 1):
+      dtype_code = "?";
+      break;
+    default:
+      throw RaiseValueError("Unsupported VM Buffer -> numpy dtype mapping");
+  }
+  return py::dtype(dtype_code);
+}
+
+}  // namespace
+
+//------------------------------------------------------------------------------
+// Bindings
+//------------------------------------------------------------------------------
+
+void SetupHalBindings(pybind11::module m) {
+  // Enums.
+  py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType")
+      .value("NONE", IREE_HAL_MEMORY_TYPE_NONE)
+      .value("TRANSIENT", IREE_HAL_MEMORY_TYPE_TRANSIENT)
+      .value("HOST_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)
+      .value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT)
+      .value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED)
+      .value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL)
+      .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)
+      .value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)
+      .export_values()
+      .def("__or__",
+           [](enum iree_hal_memory_type_bits_t self,
+              enum iree_hal_memory_type_bits_t other) { return self | other; })
+      .def("__and__",
+           [](enum iree_hal_memory_type_bits_t self,
+              enum iree_hal_memory_type_bits_t other) { return self & other; });
+
+  py::enum_<enum iree_hal_buffer_compatibility_bits_t>(m, "BufferCompatibility")
+      .value("NONE", IREE_HAL_BUFFER_COMPATIBILITY_NONE)
+      .value("ALLOCATABLE", IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE)
+      .value("IMPORTABLE", IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE)
+      .value("EXPORTABLE", IREE_HAL_BUFFER_COMPATIBILITY_EXPORTABLE)
+      .value("QUEUE_TRANSFER", IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER)
+      .value("QUEUE_DISPATCH", IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH)
+      .export_values()
+      .def("__or__",
+           [](enum iree_hal_buffer_compatibility_bits_t self,
+              enum iree_hal_buffer_compatibility_bits_t other) {
+             return self | other;
+           })
+      .def("__and__", [](enum iree_hal_buffer_compatibility_bits_t self,
+                         enum iree_hal_buffer_compatibility_bits_t other) {
+        return self & other;
+      });
+
+  py::enum_<enum iree_hal_buffer_usage_bits_t>(m, "BufferUsage")
+      .value("NONE", IREE_HAL_BUFFER_USAGE_NONE)
+      .value("CONSTANT", IREE_HAL_BUFFER_USAGE_CONSTANT)
+      .value("TRANSFER", IREE_HAL_BUFFER_USAGE_TRANSFER)
+      .value("MAPPING", IREE_HAL_BUFFER_USAGE_MAPPING)
+      .value("DISPATCH", IREE_HAL_BUFFER_USAGE_DISPATCH)
+      .export_values()
+      .def("__or__",
+           [](enum iree_hal_buffer_usage_bits_t self,
+              enum iree_hal_buffer_usage_bits_t other) {
+             return (enum iree_hal_buffer_usage_bits_t)(self | other);
+           })
+      .def("__and__", [](enum iree_hal_buffer_usage_bits_t self,
+                         enum iree_hal_buffer_usage_bits_t other) {
+        return (enum iree_hal_buffer_usage_bits_t)(self & other);
+      });
+
+  py::enum_<enum iree_hal_memory_access_bits_t>(m, "MemoryAccess")
+      .value("NONE", IREE_HAL_MEMORY_ACCESS_NONE)
+      .value("READ", IREE_HAL_MEMORY_ACCESS_READ)
+      .value("WRITE", IREE_HAL_MEMORY_ACCESS_WRITE)
+      .value("DISCARD", IREE_HAL_MEMORY_ACCESS_DISCARD)
+      .value("DISCARD_WRITE", IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE)
+      .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL)
+      .export_values()
+      .def(
+          "__or__",
+          [](enum iree_hal_memory_access_bits_t self,
+             enum iree_hal_memory_access_bits_t other) { return self | other; })
+      .def("__and__", [](enum iree_hal_memory_access_bits_t self,
+                         enum iree_hal_memory_access_bits_t other) {
+        return self & other;
+      });
+
+  py::enum_<enum iree_hal_element_types_t>(m, "HalElementType")
+      .value("NONE", IREE_HAL_ELEMENT_TYPE_NONE)
+      .value("OPAQUE_8", IREE_HAL_ELEMENT_TYPE_OPAQUE_8)
+      .value("OPAQUE_16", IREE_HAL_ELEMENT_TYPE_OPAQUE_16)
+      .value("OPAQUE_32", IREE_HAL_ELEMENT_TYPE_OPAQUE_32)
+      .value("OPAQUE_64", IREE_HAL_ELEMENT_TYPE_OPAQUE_64)
+      .value("INT_4", IREE_HAL_ELEMENT_TYPE_INT_4)
+      .value("INT_8", IREE_HAL_ELEMENT_TYPE_INT_8)
+      .value("INT_16", IREE_HAL_ELEMENT_TYPE_INT_16)
+      .value("INT_32", IREE_HAL_ELEMENT_TYPE_INT_32)
+      .value("INT_64", IREE_HAL_ELEMENT_TYPE_INT_64)
+      .value("SINT_4", IREE_HAL_ELEMENT_TYPE_SINT_4)
+      .value("SINT_8", IREE_HAL_ELEMENT_TYPE_SINT_8)
+      .value("SINT_16", IREE_HAL_ELEMENT_TYPE_SINT_16)
+      .value("SINT_32", IREE_HAL_ELEMENT_TYPE_SINT_32)
+      .value("SINT_64", IREE_HAL_ELEMENT_TYPE_SINT_64)
+      .value("UINT_4", IREE_HAL_ELEMENT_TYPE_UINT_4)
+      .value("UINT_8", IREE_HAL_ELEMENT_TYPE_UINT_8)
+      .value("UINT_16", IREE_HAL_ELEMENT_TYPE_UINT_16)
+      .value("UINT_32", IREE_HAL_ELEMENT_TYPE_UINT_32)
+      .value("UINT_64", IREE_HAL_ELEMENT_TYPE_UINT_64)
+      .value("FLOAT_16", IREE_HAL_ELEMENT_TYPE_FLOAT_16)
+      .value("FLOAT_32", IREE_HAL_ELEMENT_TYPE_FLOAT_32)
+      .value("FLOAT_64", IREE_HAL_ELEMENT_TYPE_FLOAT_64)
+      .value("BFLOAT_16", IREE_HAL_ELEMENT_TYPE_BFLOAT_16)
+      .value("BOOL_8",
+             static_cast<iree_hal_element_types_t>(IREE_HAL_ELEMENT_TYPE_VALUE(
+                 IREE_HAL_NUMERICAL_TYPE_INTEGER, 1)))
+      .export_values()
+      .def_static("map_to_dtype", &MapElementTypeToDType);
+
+  py::class_<HalDevice>(m, "HalDevice")
+      .def_property_readonly("allocator", [](HalDevice& self) {
+        return HalAllocator::BorrowFromRawPtr(self.allocator());
+      });
+
+  py::class_<HalDriver>(m, "HalDriver")
+      .def_static("query", &HalDriver::Query)
+      .def_static("create", &HalDriver::Create, py::arg("driver_name"))
+      .def("create_default_device", &HalDriver::CreateDefaultDevice);
+
+  py::class_<HalAllocator>(m, "HalAllocator")
+      .def("trim",
+           [](HalAllocator& self) {
+             CheckApiStatus(iree_hal_allocator_trim(self.raw_ptr()),
+                            "Error trim()'ing HAL allocator");
+           })
+      .def_property_readonly(
+          "has_statistics",
+          [](HalAllocator& self) -> bool { return IREE_STATISTICS_ENABLE; })
+      .def_property_readonly("statistics", &HalAllocator::QueryStatistics)
+      .def_property_readonly("formatted_statistics",
+                             &HalAllocator::FormattedStatistics)
+      .def(
+          "query_compatibility",
+          [](HalAllocator& self, int memory_type, int allowed_usage,
+             int intended_usage, iree_device_size_t allocation_size) -> int {
+            iree_hal_buffer_params_t params = {0};
+            params.type = memory_type;
+            params.usage = allowed_usage & intended_usage;
+            return iree_hal_allocator_query_compatibility(
+                self.raw_ptr(), params, allocation_size);
+          },
+          py::arg("memory_type"), py::arg("allowed_usage"),
+          py::arg("intended_usage"), py::arg("allocation_size"))
+      .def(
+          "allocate_buffer",
+          [](HalAllocator& self, int memory_type, int allowed_usage,
+             iree_device_size_t allocation_size) {
+            iree_hal_buffer_params_t params = {0};
+            params.type = memory_type;
+            params.usage = allowed_usage;
+            iree_hal_buffer_t* buffer = nullptr;
+            iree_const_byte_span_t empty_initial_data{nullptr, 0};
+            CheckApiStatus(iree_hal_allocator_allocate_buffer(
+                               self.raw_ptr(), params, allocation_size,
+                               empty_initial_data, &buffer),
+                           "could not allocate buffer");
+            return HalBuffer::StealFromRawPtr(buffer);
+          },
+          py::arg("memory_type"), py::arg("allowed_usage"),
+          py::arg("allocation_size"),
+          "Allocates a new buffer with requested characteristics (does not "
+          "initialize with specific data).")
+      .def("allocate_buffer_copy", &HalAllocator::AllocateBufferCopy,
+           py::arg("memory_type"), py::arg("allowed_usage"), py::arg("buffer"),
+           py::arg("element_type") = py::none(),
+           "Allocates a new buffer and initializes it from a Python buffer "
+           "object. If an element type is specified, wraps in a BufferView "
+           "matching the characteristics of the Python buffer. The format is "
+           "requested as ND/C-Contiguous, which may incur copies if not "
+           "already in that format.");
+
+  py::class_<HalBuffer>(m, "HalBuffer")
+      .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"),
+           py::arg("byte_length"))
+      .def("create_view", &HalBuffer::CreateView, py::arg("shape"),
+           py::arg("element_size"))
+      .def("__repr__", &HalBuffer::Repr);
+
+  py::class_<HalBufferView>(m, "HalBufferView")
+      .def("map", HalMappedMemory::Create)
+      .def_property_readonly(
+          "shape",
+          [](HalBufferView& self) {
+            iree_host_size_t rank =
+                iree_hal_buffer_view_shape_rank(self.raw_ptr());
+            auto* dims = iree_hal_buffer_view_shape_dims(self.raw_ptr());
+            py::list result;
+            for (iree_host_size_t i = 0; i < rank; ++i) {
+              result.append(dims[i]);
+            }
+            return result;
+          })
+      .def_property_readonly(
+          "element_type",
+          [](HalBufferView& self) {
+            return iree_hal_buffer_view_element_type(self.raw_ptr());
+          })
+      .def("__repr__", &HalBufferView::Repr);
+
+  py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol())
+      .def_buffer(&HalMappedMemory::ToBufferInfo)
+      .def("asarray",
+           [](HalMappedMemory& self, std::vector<iree_host_size_t> shape,
+              py::object dtype) {
+             py::object py_mapped_memory = py::cast(self);
+             return py::array(std::move(dtype), shape,
+                              self.mapped_memory().contents.data,
+                              std::move(py_mapped_memory) /* base */);
+           });
+
+  py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector));
+}
+
+}  // namespace python
+}  // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/hal.h b/runtime/bindings/python/iree/runtime/hal.h
new file mode 100644
index 0000000..02201a9
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/hal.h
@@ -0,0 +1,204 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
+#define IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
+
+#include <vector>
+
+#include "./binding.h"
+#include "./status_utils.h"
+#include "iree/hal/api.h"
+
+namespace iree {
+namespace python {
+
+//------------------------------------------------------------------------------
+// Retain/release bindings
+//------------------------------------------------------------------------------
+
+template <>
+struct ApiPtrAdapter<iree_hal_driver_t> {
+  static void Retain(iree_hal_driver_t* d) { iree_hal_driver_retain(d); }
+  static void Release(iree_hal_driver_t* d) { iree_hal_driver_release(d); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_device_t> {
+  static void Retain(iree_hal_device_t* d) { iree_hal_device_retain(d); }
+  static void Release(iree_hal_device_t* d) { iree_hal_device_release(d); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_allocator_t> {
+  static void Retain(iree_hal_allocator_t* d) { iree_hal_allocator_retain(d); }
+  static void Release(iree_hal_allocator_t* d) {
+    iree_hal_allocator_release(d);
+  }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_buffer_t> {
+  static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); }
+  static void Release(iree_hal_buffer_t* b) { iree_hal_buffer_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_buffer_view_t> {
+  static void Retain(iree_hal_buffer_view_t* bv) {
+    iree_hal_buffer_view_retain(bv);
+  }
+  static void Release(iree_hal_buffer_view_t* bv) {
+    iree_hal_buffer_view_release(bv);
+  }
+};
+
+//------------------------------------------------------------------------------
+// ApiRefCounted types
+//------------------------------------------------------------------------------
+
+class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
+ public:
+  iree_hal_allocator_t* allocator() {
+    return iree_hal_device_allocator(raw_ptr());
+  }
+};
+
+class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
+ public:
+  static std::vector<std::string> Query();
+  static HalDriver Create(const std::string& driver_name);
+
+  HalDevice CreateDefaultDevice();
+};
+
+class HalAllocator : public ApiRefCounted<HalAllocator, iree_hal_allocator_t> {
+ public:
+  py::dict QueryStatistics();
+  py::str FormattedStatistics();
+
+  py::object AllocateBufferCopy(
+      int memory_type, int allowed_usage, py::object buffer,
+      std::optional<iree_hal_element_types_t> element_type);
+};
+
+struct HalShape {
+ public:
+  static HalShape FromIntVector(std::vector<int32_t> indices) {
+    HalShape s;
+    s.s = {indices.begin(), indices.end()};
+    return s;
+  }
+
+  std::vector<int32_t> s;
+};
+
+class HalBufferView
+    : public ApiRefCounted<HalBufferView, iree_hal_buffer_view_t> {
+ public:
+  py::str Repr();
+};
+
+class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> {
+ public:
+  iree_device_size_t byte_length() const {
+    return iree_hal_buffer_byte_length(raw_ptr());
+  }
+
+  void FillZero(iree_device_size_t byte_offset,
+                iree_device_size_t byte_length) {
+    CheckApiStatus(
+        iree_hal_buffer_map_zero(raw_ptr(), byte_offset, byte_length),
+        "Error zero filling buffer");
+  }
+
+  // TODO(laurenzo): make this take element_type instead.
+  HalBufferView CreateView(HalShape& shape, size_t element_size) {
+    iree_hal_buffer_view_t* bv;
+    iree_hal_element_type_t element_type = iree_hal_make_element_type(
+        IREE_HAL_ELEMENT_TYPE_NONE, element_size * 8);
+    iree_hal_encoding_type_t encoding_type =
+        IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
+    CheckApiStatus(iree_hal_buffer_view_create(
+                       raw_ptr(), shape.s.data(), shape.s.size(), element_type,
+                       encoding_type, iree_allocator_system(), &bv),
+                   "Error creating buffer view");
+    return HalBufferView::StealFromRawPtr(bv);
+  }
+
+  py::str Repr();
+};
+
+// Wrapper around an iree_hal_buffer_mapping_t and iree_hal_buffer_view_t
+// which retains the latter and unmaps/releases on deallocation.
+class HalMappedMemory {
+ public:
+  HalMappedMemory(iree_hal_buffer_mapping_t mapped_memory,
+                  iree_hal_buffer_view_t* bv)
+      : mapped_memory_(mapped_memory), bv_(bv) {
+    iree_hal_buffer_view_retain(bv_);
+  }
+  ~HalMappedMemory() {
+    if (bv_) {
+      iree_hal_buffer_unmap_range(&mapped_memory_);
+      iree_hal_buffer_view_release(bv_);
+    }
+  }
+  HalMappedMemory(HalMappedMemory&& other)
+      : mapped_memory_(other.mapped_memory_), bv_(other.bv_) {
+    other.bv_ = nullptr;
+  }
+
+  static HalMappedMemory Create(HalBufferView& bv) {
+    iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr());
+    iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer);
+    iree_hal_buffer_mapping_t mapped_memory = {{0}};
+    CheckApiStatus(
+        iree_hal_buffer_map_range(buffer, IREE_HAL_MAPPING_MODE_SCOPED,
+                                  IREE_HAL_MEMORY_ACCESS_READ, 0, byte_length,
+                                  &mapped_memory),
+        "Could not map memory");
+    return HalMappedMemory(mapped_memory, bv.raw_ptr());
+  }
+
+  py::buffer_info ToBufferInfo() {
+    std::vector<int32_t> shape(iree_hal_buffer_view_shape_rank(bv_));
+    CheckApiStatus(
+        iree_hal_buffer_view_shape(bv_, shape.size(), shape.data(), nullptr),
+        "Error getting buffer view shape");
+    iree_hal_element_type_t element_type =
+        iree_hal_buffer_view_element_type(bv_);
+    int32_t element_size = iree_hal_element_dense_byte_count(element_type);
+    std::vector<py::ssize_t> dims(shape.size());
+    for (int i = 0; i < shape.size(); ++i) {
+      dims[i] = shape[i];
+    }
+    std::vector<py::ssize_t> strides(shape.size());
+    if (!strides.empty()) {
+      strides[shape.size() - 1] = element_size;
+      for (int i = shape.size() - 2; i >= 0; --i) {
+        strides[i] = strides[i + 1] * shape[i + 1];
+      }
+    }
+
+    return py::buffer_info(mapped_memory_.contents.data, element_size,
+                           py::format_descriptor<float>::format(), shape.size(),
+                           dims, strides);
+  }
+
+  iree_hal_buffer_mapping_t& mapped_memory() { return mapped_memory_; }
+
+ private:
+  iree_hal_buffer_mapping_t mapped_memory_ = {{0}};
+  iree_hal_buffer_view_t* bv_ = nullptr;
+};
+
+void SetupHalBindings(pybind11::module m);
+
+}  // namespace python
+}  // namespace iree
+
+#endif  // IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
diff --git a/runtime/bindings/python/iree/runtime/hal_test.py b/runtime/bindings/python/iree/runtime/hal_test.py
new file mode 100644
index 0000000..f9f27e9
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/hal_test.py
@@ -0,0 +1,132 @@
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import iree.runtime
+import numpy as np
+import unittest
+
+
+class NonDeviceHalTest(unittest.TestCase):
+
+  def testEnums(self):
+    print("MemoryType:", iree.runtime.MemoryType)
+    print("HOST_VISIBLE:", int(iree.runtime.MemoryType.HOST_VISIBLE))
+
+    # Enum and/or operations on BufferCompatibility.
+    self.assertEqual(
+        iree.runtime.BufferCompatibility.IMPORTABLE |
+        iree.runtime.BufferCompatibility.EXPORTABLE,
+        int(iree.runtime.BufferCompatibility.IMPORTABLE) |
+        int(iree.runtime.BufferCompatibility.EXPORTABLE))
+    self.assertEqual(
+        iree.runtime.BufferCompatibility.EXPORTABLE &
+        iree.runtime.BufferCompatibility.EXPORTABLE,
+        int(iree.runtime.BufferCompatibility.EXPORTABLE))
+
+    # Enum and/or operations on BufferUsage.
+    self.assertEqual(
+        iree.runtime.BufferUsage.CONSTANT | iree.runtime.BufferUsage.TRANSFER,
+        int(iree.runtime.BufferUsage.CONSTANT) |
+        int(iree.runtime.BufferUsage.TRANSFER))
+    self.assertEqual(
+        iree.runtime.BufferUsage.CONSTANT & iree.runtime.BufferUsage.CONSTANT,
+        int(iree.runtime.BufferUsage.CONSTANT))
+
+    # Enum and/or operations on MemoryAccess.
+    self.assertEqual(
+        iree.runtime.MemoryAccess.READ | iree.runtime.MemoryAccess.WRITE,
+        int(iree.runtime.MemoryAccess.READ) |
+        int(iree.runtime.MemoryAccess.WRITE))
+    self.assertEqual(
+        iree.runtime.MemoryAccess.ALL & iree.runtime.MemoryAccess.READ,
+        int(iree.runtime.MemoryAccess.READ))
+
+    # Enum and/or operations on MemoryType.
+    self.assertEqual(
+        iree.runtime.MemoryType.TRANSIENT |
+        iree.runtime.MemoryType.HOST_VISIBLE,
+        int(iree.runtime.MemoryType.TRANSIENT) |
+        int(iree.runtime.MemoryType.HOST_VISIBLE))
+    self.assertEqual(
+        iree.runtime.MemoryType.TRANSIENT & iree.runtime.MemoryType.TRANSIENT,
+        int(iree.runtime.MemoryType.TRANSIENT))
+
+
+class DeviceHalTest(unittest.TestCase):
+
+  def setUp(self):
+    super().setUp()
+    self.driver = iree.runtime.HalDriver.create("vmvx")
+    self.device = self.driver.create_default_device()
+    self.allocator = self.device.allocator
+
+  def testTrim(self):
+    self.allocator.trim()
+    # Just running is sufficient.
+
+  def testStatistics(self):
+    stats_dict = self.allocator.statistics
+    stats_str = self.allocator.formatted_statistics
+    if self.allocator.has_statistics:
+      self.assertIn("host_bytes_peak", stats_dict)
+      self.assertIn("host_bytes_allocated", stats_dict)
+      self.assertIn("host_bytes_freed", stats_dict)
+      self.assertIn("device_bytes_peak", stats_dict)
+      self.assertIn("device_bytes_allocated", stats_dict)
+      self.assertIn("device_bytes_freed", stats_dict)
+      self.assertIn("HOST_LOCAL", stats_str)
+
+  def testQueryCompatibility(self):
+    compat = self.allocator.query_compatibility(
+        memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+        allowed_usage=iree.runtime.BufferUsage.CONSTANT,
+        intended_usage=iree.runtime.BufferUsage.CONSTANT |
+        iree.runtime.BufferUsage.TRANSFER,
+        allocation_size=1024)
+    print("COMPAT:", compat)
+    self.assertTrue(
+        bool(compat & int(iree.runtime.BufferCompatibility.ALLOCATABLE)),
+        "should be allocatable")
+    self.assertTrue(
+        bool(compat & int(iree.runtime.BufferCompatibility.IMPORTABLE)),
+        "should be importable")
+    self.assertTrue(
+        bool(compat & int(iree.runtime.BufferCompatibility.EXPORTABLE)),
+        "should be exportable")
+
+  def testAllocateBuffer(self):
+    buffer = self.allocator.allocate_buffer(
+        memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+        allowed_usage=iree.runtime.BufferUsage.CONSTANT,
+        allocation_size=13)
+    print("BUFFER:", buffer)
+
+  def testAllocateBufferCopy(self):
+    ary = np.zeros([3, 4], dtype=np.int32) + 2
+    buffer = self.allocator.allocate_buffer_copy(
+        memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+        allowed_usage=iree.runtime.BufferUsage.CONSTANT,
+        buffer=ary)
+    self.assertEqual(
+        repr(buffer),
+        "<HalBuffer 48 bytes (at offset 0 into 48), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=CONSTANT|TRANSFER|MAPPING>"
+    )
+
+  def testAllocateBufferViewCopy(self):
+    ary = np.zeros([3, 4], dtype=np.int32) + 2
+    buffer = self.allocator.allocate_buffer_copy(
+        memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+        allowed_usage=iree.runtime.BufferUsage.CONSTANT,
+        buffer=ary,
+        element_type=iree.runtime.HalElementType.SINT_32)
+    self.assertEqual(
+        repr(buffer),
+        "<HalBufferView (3, 4), element_type=0x20000011, 48 bytes (at offset 0 into 48), memory_type=DEVICE_LOCAL|HOST_VISIBLE, allowed_access=ALL, allowed_usage=CONSTANT|TRANSFER|MAPPING>"
+    )
+
+
+if __name__ == "__main__":
+  unittest.main()
diff --git a/runtime/bindings/python/iree/runtime/initialize_module.cc b/runtime/bindings/python/iree/runtime/initialize_module.cc
new file mode 100644
index 0000000..211e7dc
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/initialize_module.cc
@@ -0,0 +1,50 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./binding.h"
+#include "./hal.h"
+#include "./invoke.h"
+#include "./status_utils.h"
+#include "./vm.h"
+#include "iree/base/internal/flags.h"
+#include "iree/base/status_cc.h"
+#include "iree/hal/drivers/init.h"
+
+namespace iree {
+namespace python {
+
+PYBIND11_MODULE(binding, m) {
+  IREE_CHECK_OK(iree_hal_register_all_available_drivers(
+      iree_hal_driver_registry_default()));
+
+  m.doc() = "IREE Binding Backend Helpers";
+  SetupHalBindings(m);
+  SetupInvokeBindings(m);
+  SetupVmBindings(m);
+
+  m.def("parse_flags", [](py::args py_flags) {
+    std::vector<std::string> alloced_flags;
+    alloced_flags.push_back("python");
+    for (auto &py_flag : py_flags) {
+      alloced_flags.push_back(py::cast<std::string>(py_flag));
+    }
+
+    // Must build pointer vector after filling so pointers are stable.
+    std::vector<char *> flag_ptrs;
+    for (auto &alloced_flag : alloced_flags) {
+      flag_ptrs.push_back(const_cast<char *>(alloced_flag.c_str()));
+    }
+
+    char **argv = &flag_ptrs[0];
+    int argc = flag_ptrs.size();
+    CheckApiStatus(
+        iree_flags_parse(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv),
+        "Error parsing flags");
+  });
+}
+
+}  // namespace python
+}  // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/invoke.cc b/runtime/bindings/python/iree/runtime/invoke.cc
new file mode 100644
index 0000000..b0b9a59
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/invoke.cc
@@ -0,0 +1,713 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./invoke.h"
+
+#include "./hal.h"
+#include "./vm.h"
+#include "iree/base/api.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+#include "iree/modules/hal/module.h"
+#include "iree/vm/api.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+class InvokeContext {
+ public:
+  InvokeContext(HalDevice &device) : device_(device) {}
+
+  HalDevice &device() { return device_; }
+  HalAllocator allocator() {
+    // TODO: Unfortunate that we inc ref here but that is how our object model
+    // is set up.
+    return HalAllocator::BorrowFromRawPtr(device().allocator());
+  }
+
+ private:
+  HalDevice device_;
+};
+
+using PackCallback =
+    std::function<void(InvokeContext &, iree_vm_list_t *, py::handle)>;
+
+class InvokeStatics {
+ public:
+  ~InvokeStatics() {
+    for (auto it : py_type_to_pack_callbacks_) {
+      py::handle(it.first).dec_ref();
+    }
+  }
+
+  py::str kNamedTag = py::str("named");
+  py::str kSlistTag = py::str("slist");
+  py::str kStupleTag = py::str("stuple");
+  py::str kSdictTag = py::str("sdict");
+
+  py::int_ kZero = py::int_(0);
+  py::int_ kOne = py::int_(1);
+  py::int_ kTwo = py::int_(2);
+  py::str kAsArray = py::str("asarray");
+  py::str kMapDtypeToElementTypeAttr = py::str("map_dtype_to_element_type");
+  py::str kContiguousArg = py::str("C");
+  py::str kArrayProtocolAttr = py::str("__array__");
+  py::str kDtypeAttr = py::str("dtype");
+
+  // Primitive type names.
+  py::str kF32 = py::str("f32");
+  py::str kF64 = py::str("f64");
+  py::str kI1 = py::str("i1");
+  py::str kI8 = py::str("i8");
+  py::str kI16 = py::str("i16");
+  py::str kI32 = py::str("i32");
+  py::str kI64 = py::str("i64");
+
+  // Compound types names.
+  py::str kNdarray = py::str("ndarray");
+
+  // Attribute names.
+  py::str kAttrBufferView = py::str("_buffer_view");
+
+  // Module 'numpy'.
+  py::module &numpy_module() { return numpy_module_; }
+
+  py::object &runtime_module() {
+    if (!runtime_module_) {
+      runtime_module_ = py::module::import("iree.runtime");
+    }
+    return *runtime_module_;
+  }
+
+  py::module &array_interop_module() {
+    if (!array_interop_module_) {
+      array_interop_module_ = py::module::import("iree.runtime.array_interop");
+    }
+    return *array_interop_module_;
+  }
+
+  py::object &device_array_type() {
+    if (!device_array_type_) {
+      device_array_type_ = runtime_module().attr("DeviceArray");
+    }
+    return *device_array_type_;
+  }
+
+  py::type &hal_buffer_view_type() { return hal_buffer_view_type_; }
+
+  py::object MapElementAbiTypeToDtype(py::object &element_abi_type) {
+    try {
+      return abi_type_to_dtype_[element_abi_type];
+    } catch (std::exception &) {
+      std::string msg("could not map abi type ");
+      msg.append(py::cast<std::string>(py::repr(element_abi_type)));
+      msg.append(" to numpy dtype");
+      throw std::invalid_argument(std::move(msg));
+    }
+  }
+
+  enum iree_hal_element_types_t MapDtypeToElementType(py::object dtype) {
+    // TODO: Consider porting this from a py func to C++ as it can be on
+    // the critical path.
+    try {
+      py::object element_type =
+          array_interop_module().attr(kMapDtypeToElementTypeAttr)(dtype);
+      if (element_type.is_none()) {
+        throw std::invalid_argument("mapping not found");
+      }
+      return py::cast<enum iree_hal_element_types_t>(element_type);
+    } catch (std::exception &e) {
+      std::string msg("could not map dtype ");
+      msg.append(py::cast<std::string>(py::repr(dtype)));
+      msg.append(" to element type: ");
+      msg.append(e.what());
+      throw std::invalid_argument(std::move(msg));
+    }
+  }
+
+  PackCallback AbiTypeToPackCallback(py::handle desc) {
+    return AbiTypeToPackCallback(
+        std::move(desc), /*desc_is_list=*/py::isinstance<py::list>(desc));
+  }
+
+  // Given an ABI desc, return a callback that can pack a corresponding py
+  // value into a list. For efficiency, the caller must specify whether the
+  // desc is a list (this check already needs to be done typically so
+  // passed in).
+  PackCallback AbiTypeToPackCallback(py::handle desc, bool desc_is_list) {
+    // Switch based on descriptor type.
+    if (desc_is_list) {
+      // Compound type.
+      py::object compound_type = desc[kZero];
+      if (compound_type.equal(kNdarray)) {
+        // Has format:
+        //   ["ndarray", "f32", dim0, dim1, ...]
+        // Extract static information about the target.
+        std::vector<int64_t> abi_shape(py::len(desc) - 2);
+        for (size_t i = 0, e = abi_shape.size(); i < e; ++i) {
+          py::handle dim = desc[py::int_(i + 2)];
+          abi_shape[i] = dim.is_none() ? -1 : py::cast<int64_t>(dim);
+        }
+
+        // Map abi element type to dtype.
+        py::object abi_type = desc[kOne];
+        py::object target_dtype = MapElementAbiTypeToDtype(abi_type);
+        auto hal_element_type = MapDtypeToElementType(target_dtype);
+
+        return [this, target_dtype = std::move(target_dtype), hal_element_type,
+                abi_shape = std::move(abi_shape)](InvokeContext &c,
+                                                  iree_vm_list_t *list,
+                                                  py::handle py_value) {
+          IREE_TRACE_SCOPE0("ArgumentPacker::ReflectionNdarray");
+          HalBufferView *bv = nullptr;
+          py::object retained_bv;
+          if (py::isinstance(py_value, device_array_type())) {
+            // Short-circuit: If a DeviceArray is provided, assume it is
+            // correct.
+            IREE_TRACE_SCOPE0("PackDeviceArray");
+            bv = py::cast<HalBufferView *>(py_value.attr(kAttrBufferView));
+          } else if (py::isinstance(py_value, hal_buffer_view_type())) {
+            // Short-circuit: If a HalBufferView is provided directly.
+            IREE_TRACE_SCOPE0("PackBufferView");
+            bv = py::cast<HalBufferView *>(py_value);
+          } else {
+            // Fall back to the array protocol to generate a host side
+            // array and then convert that.
+            IREE_TRACE_SCOPE0("PackHostArray");
+            py::object host_array;
+            try {
+              host_array = numpy_module().attr(kAsArray)(py_value, target_dtype,
+                                                         kContiguousArg);
+            } catch (std::exception &e) {
+              std::string msg("could not convert value to numpy array: dtype=");
+              msg.append(py::cast<std::string>(py::repr(target_dtype)));
+              msg.append(", error='");
+              msg.append(e.what());
+              msg.append("', value=");
+              msg.append(py::cast<std::string>(py::repr(py_value)));
+              throw std::invalid_argument(std::move(msg));
+            }
+
+            retained_bv = c.allocator().AllocateBufferCopy(
+                IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
+                IREE_HAL_BUFFER_USAGE_DISPATCH |
+                    IREE_HAL_BUFFER_USAGE_TRANSFER |
+                    IREE_HAL_BUFFER_USAGE_MAPPING,
+                host_array, hal_element_type);
+            bv = py::cast<HalBufferView *>(retained_bv);
+          }
+
+          // TODO: Add some shape verification. Not strictly necessary as the VM
+          // will check, but may make error reporting nicer.
+          // TODO: It is theoretically possible to enqueue further conversions
+          // on the device, but for now we require things to line up closely.
+          // TODO: If adding further manipulation here, please make this common
+          // with the generic access case.
+          iree_vm_ref_t buffer_view_ref =
+              iree_hal_buffer_view_retain_ref(bv->raw_ptr());
+          CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
+                         "could not push buffer view to list");
+        };
+      } else if (compound_type.equal(kSlistTag) ||
+                 compound_type.equal(kStupleTag)) {
+        // Tuple/list extraction.
+        // When decoding a list or tuple, the desc object is like:
+        // ['slist', [...value_type_0...], ...]
+        // Where the type is either 'slist' or 'stuple'.
+        std::vector<PackCallback> sub_packers(py::len(desc) - 1);
+        for (size_t i = 0; i < sub_packers.size(); i++) {
+          sub_packers[i] = AbiTypeToPackCallback(desc[py::int_(i + 1)]);
+        }
+        return [sub_packers = std::move(sub_packers)](InvokeContext &c,
+                                                      iree_vm_list_t *list,
+                                                      py::handle py_value) {
+          if (py::len(py_value) != sub_packers.size()) {
+            std::string msg("expected a sequence with ");
+            msg.append(std::to_string(sub_packers.size()));
+            msg.append(" values. got: ");
+            msg.append(py::cast<std::string>(py::repr(py_value)));
+            throw std::invalid_argument(std::move(msg));
+          }
+          VmVariantList item_list = VmVariantList::Create(sub_packers.size());
+          for (size_t i = 0; i < sub_packers.size(); ++i) {
+            py::object item_py_value;
+            try {
+              item_py_value = py_value[py::int_(i)];
+            } catch (std::exception &e) {
+              std::string msg("could not get item ");
+              msg.append(std::to_string(i));
+              msg.append(" from: ");
+              msg.append(py::cast<std::string>(py::repr(py_value)));
+              msg.append(": ");
+              msg.append(e.what());
+              throw std::invalid_argument(std::move(msg));
+            }
+            sub_packers[i](c, item_list.raw_ptr(), item_py_value);
+          }
+
+          // Push the sub list.
+          iree_vm_ref_t retained =
+              iree_vm_list_retain_ref(item_list.steal_raw_ptr());
+          iree_vm_list_push_ref_move(list, &retained);
+        };
+      } else if (compound_type.equal(kSdictTag)) {
+        // Dict extraction.
+        // The descriptor for an sdict is like:
+        //   ['sdict', ['key1', value1], ...]
+        std::vector<std::pair<py::object, PackCallback>> sub_packers(
+            py::len(desc) - 1);
+        for (size_t i = 0; i < sub_packers.size(); i++) {
+          py::object sub_desc = desc[py::int_(i + 1)];
+          py::object key = sub_desc[kZero];
+          py::object value_desc = sub_desc[kOne];
+          sub_packers[i] =
+              std::make_pair(std::move(key), AbiTypeToPackCallback(value_desc));
+        }
+        return [sub_packers = std::move(sub_packers)](InvokeContext &c,
+                                                      iree_vm_list_t *list,
+                                                      py::handle py_value) {
+          if (py::len(py_value) != sub_packers.size()) {
+            std::string msg("expected a dict with ");
+            msg.append(std::to_string(sub_packers.size()));
+            msg.append(" values. got: ");
+            msg.append(py::cast<std::string>(py::repr(py_value)));
+            throw std::invalid_argument(std::move(msg));
+          }
+          VmVariantList item_list = VmVariantList::Create(sub_packers.size());
+          for (size_t i = 0; i < sub_packers.size(); ++i) {
+            py::object item_py_value;
+            try {
+              item_py_value = py_value[sub_packers[i].first];
+            } catch (std::exception &e) {
+              std::string msg("could not get item ");
+              msg.append(py::cast<std::string>(py::repr(sub_packers[i].first)));
+              msg.append(" from: ");
+              msg.append(py::cast<std::string>(py::repr(py_value)));
+              msg.append(": ");
+              msg.append(e.what());
+              throw std::invalid_argument(std::move(msg));
+            }
+            sub_packers[i].second(c, item_list.raw_ptr(), item_py_value);
+          }
+
+          // Push the sub list.
+          iree_vm_ref_t retained =
+              iree_vm_list_retain_ref(item_list.steal_raw_ptr());
+          iree_vm_list_push_ref_move(list, &retained);
+        };
+      } else {
+        std::string message("Unrecognized reflection compound type: ");
+        message.append(py::cast<std::string>(compound_type));
+        throw std::invalid_argument(message);
+      }
+    } else {
+      // Primtive type.
+      py::str prim_type = py::cast<py::str>(desc);
+      if (prim_type.equal(kF32)) {
+        // f32
+        return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          iree_vm_value_t vm_value =
+              iree_vm_value_make_f32(py::cast<float>(py_value));
+          CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+                         "could not append value");
+        };
+      } else if (prim_type.equal(kF64)) {
+        // f64
+        return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          iree_vm_value_t vm_value =
+              iree_vm_value_make_f64(py::cast<double>(py_value));
+          CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+                         "could not append value");
+        };
+      } else if (prim_type.equal(kI32)) {
+        // i32.
+        return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          iree_vm_value_t vm_value =
+              iree_vm_value_make_i32(py::cast<int32_t>(py_value));
+          CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+                         "could not append value");
+        };
+      } else if (prim_type.equal(kI64)) {
+        // i64.
+        return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          iree_vm_value_t vm_value =
+              iree_vm_value_make_i64(py::cast<int64_t>(py_value));
+          CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+                         "could not append value");
+        };
+      } else if (prim_type.equal(kI8)) {
+        // i8.
+        return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          iree_vm_value_t vm_value =
+              iree_vm_value_make_i8(py::cast<int8_t>(py_value));
+          CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+                         "could not append value");
+        };
+      } else if (prim_type.equal(kI16)) {
+        // i16.
+        return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          iree_vm_value_t vm_value =
+              iree_vm_value_make_i16(py::cast<int16_t>(py_value));
+          CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+                         "could not append value");
+        };
+      } else {
+        std::string message("Unrecognized reflection primitive type: ");
+        message.append(py::cast<std::string>(prim_type));
+        throw std::invalid_argument(message);
+      }
+    }
+  }
+
+  PackCallback GetGenericPackCallbackFor(py::handle arg) {
+    PopulatePyTypeToPackCallbacks();
+    py::type clazz = py::type::of(arg);
+    auto found_it = py_type_to_pack_callbacks_.find(clazz.ptr());
+    if (found_it == py_type_to_pack_callbacks_.end()) {
+      // Probe to see if we have a host array.
+      if (py::hasattr(arg, kArrayProtocolAttr)) {
+        return GetGenericPackCallbackForNdarray();
+      }
+      return {};
+    }
+
+    return found_it->second;
+  }
+
+ private:
+  PackCallback GetGenericPackCallbackForNdarray() {
+    return [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+      IREE_TRACE_SCOPE0("ArgumentPacker::GenericNdarray");
+      py::object host_array;
+      try {
+        host_array = numpy_module().attr(kAsArray)(
+            py_value, /*dtype=*/py::none(), kContiguousArg);
+      } catch (std::exception &e) {
+        std::string msg("could not convert value to numpy array: ");
+        msg.append("error='");
+        msg.append(e.what());
+        msg.append("', value=");
+        msg.append(py::cast<std::string>(py::repr(py_value)));
+        throw std::invalid_argument(std::move(msg));
+      }
+
+      auto hal_element_type =
+          MapDtypeToElementType(host_array.attr(kDtypeAttr));
+
+      // Put it on the device.
+      py::object retained_bv = c.allocator().AllocateBufferCopy(
+          IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
+          IREE_HAL_BUFFER_USAGE_DISPATCH | IREE_HAL_BUFFER_USAGE_TRANSFER |
+              IREE_HAL_BUFFER_USAGE_MAPPING,
+          host_array, hal_element_type);
+      HalBufferView *bv = py::cast<HalBufferView *>(retained_bv);
+
+      // TODO: If adding further manipulation here, please make this common
+      // with the reflection access case.
+      iree_vm_ref_t buffer_view_ref =
+          iree_hal_buffer_view_retain_ref(bv->raw_ptr());
+      CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
+                     "could not append value");
+    };
+  }
+
+  void PopulatePyTypeToPackCallbacks() {
+    if (!py_type_to_pack_callbacks_.empty()) return;
+
+    // We only care about int and double in the numeric hierarchy. Since Python
+    // has no further refinement of these, just treat them as vm 64 bit int and
+    // floats and let the VM take care of it. There isn't much else we can do.
+    AddPackCallback(
+        py::type::of(py::cast(1)),
+        [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          iree_vm_value_t vm_value =
+              iree_vm_value_make_i64(py::cast<int64_t>(py_value));
+          CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+                         "could not append value");
+        });
+
+    AddPackCallback(
+        py::type::of(py::cast(1.0)),
+        [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          iree_vm_value_t vm_value =
+              iree_vm_value_make_f64(py::cast<double>(py_value));
+          CheckApiStatus(iree_vm_list_push_value(list, &vm_value),
+                         "could not append value");
+        });
+
+    // List/tuple.
+    auto sequence_callback = [this](InvokeContext &c, iree_vm_list_t *list,
+                                    py::handle py_value) {
+      auto py_seq = py::cast<py::sequence>(py_value);
+      VmVariantList item_list = VmVariantList::Create(py::len(py_seq));
+      for (py::object py_item : py_seq) {
+        PackCallback sub_packer = GetGenericPackCallbackFor(py_item);
+        if (!sub_packer) {
+          std::string message("could not convert python value to VM: ");
+          message.append(py::cast<std::string>(py::repr(py_item)));
+          throw std::invalid_argument(std::move(message));
+        }
+        sub_packer(c, item_list.raw_ptr(), py_item);
+      }
+      // Push the sub list.
+      iree_vm_ref_t retained =
+          iree_vm_list_retain_ref(item_list.steal_raw_ptr());
+      iree_vm_list_push_ref_move(list, &retained);
+    };
+    AddPackCallback(py::type::of(py::list{}), sequence_callback);
+    AddPackCallback(py::type::of(py::tuple{}), sequence_callback);
+
+    // Dict.
+    auto dict_callback = [this](InvokeContext &c, iree_vm_list_t *list,
+                                py::handle py_value) {
+      // Gets all dict items and sorts (by key).
+      auto py_dict = py::cast<py::dict>(py_value);
+      py::list py_keys;
+      for (std::pair<py::handle, py::handle> it : py_dict) {
+        py_keys.append(it.first);
+      }
+      py_keys.attr("sort")();
+
+      VmVariantList item_list = VmVariantList::Create(py_keys.size());
+      for (auto py_key : py_keys) {
+        py::object py_item = py_dict[py_key];
+        PackCallback sub_packer = GetGenericPackCallbackFor(py_item);
+        if (!sub_packer) {
+          std::string message("could not convert python value to VM: ");
+          message.append(py::cast<std::string>(py::repr(py_item)));
+          throw std::invalid_argument(std::move(message));
+        }
+        sub_packer(c, item_list.raw_ptr(), py_item);
+      }
+      // Push the sub list.
+      iree_vm_ref_t retained =
+          iree_vm_list_retain_ref(item_list.steal_raw_ptr());
+      iree_vm_list_push_ref_move(list, &retained);
+    };
+    AddPackCallback(py::type::of(py::dict{}), dict_callback);
+
+    // HalBufferView.
+    AddPackCallback(
+        py::type::of<HalBufferView>(),
+        [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          HalBufferView *bv = py::cast<HalBufferView *>(py_value);
+          iree_vm_ref_t buffer_view_ref =
+              iree_hal_buffer_view_retain_ref(bv->raw_ptr());
+          CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
+                         "could not append value");
+        });
+
+    // DeviceArray.
+    AddPackCallback(
+        device_array_type(),
+        [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) {
+          HalBufferView *bv =
+              py::cast<HalBufferView *>(py_value.attr(kAttrBufferView));
+          iree_vm_ref_t buffer_view_ref =
+              iree_hal_buffer_view_retain_ref(bv->raw_ptr());
+          CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref),
+                         "could not append value");
+        });
+  }
+
+  void AddPackCallback(py::handle t, PackCallback pcb) {
+    assert(py_type_to_pack_callbacks_.count(t.ptr()) == 0 && "duplicate types");
+    t.inc_ref();
+    py_type_to_pack_callbacks_.insert(std::make_pair(t.ptr(), std::move(pcb)));
+  }
+
+  py::dict BuildAbiTypeToDtype() {
+    auto d = py::dict();
+    d[kF32] = numpy_module().attr("float32");
+    d[kF64] = numpy_module().attr("float64");
+    d[kI1] = numpy_module().attr("bool_");
+    d[kI8] = numpy_module().attr("int8");
+    d[kI16] = numpy_module().attr("int16");
+    d[kI64] = numpy_module().attr("int64");
+    d[kI32] = numpy_module().attr("int32");
+    return d;
+  }
+
+  // Cached modules and types. Those that involve recursive lookup within
+  // our top level module, we defer. Those outside, we cache at creation.
+  py::module numpy_module_ = py::module::import("numpy");
+  std::optional<py::object> runtime_module_;
+  std::optional<py::module> array_interop_module_;
+  std::optional<py::object> device_array_type_;
+  py::type hal_buffer_view_type_ = py::type::of<HalBufferView>();
+
+  // Maps Python type to a PackCallback that can generically code it.
+  // This will have inc_ref() called on them when added.
+  std::unordered_map<PyObject *, PackCallback> py_type_to_pack_callbacks_;
+
+  // Dict of str (ABI dtype like 'f32') to numpy dtype.
+  py::dict abi_type_to_dtype_ = BuildAbiTypeToDtype();
+};
+
+/// Object that can pack Python arguments into a VM List for a specific
+/// function.
+class ArgumentPacker {
+ public:
+  ArgumentPacker(InvokeStatics &statics, std::optional<py::list> arg_descs)
+      : statics_(statics) {
+    IREE_TRACE_SCOPE0("ArgumentPacker::Init");
+    if (!arg_descs) {
+      dynamic_dispatch_ = true;
+    } else {
+      // Reflection dispatch.
+      for (py::handle desc : *arg_descs) {
+        int arg_index = flat_arg_packers_.size();
+        std::optional<std::string> kwarg_name;
+        py::object retained_sub_desc;
+
+        bool desc_is_list = py::isinstance<py::list>(desc);
+
+        // Check if named.
+        //   ["named", "kwarg_name", sub_desc]
+        // If found, then we set kwarg_name and reset desc to the sub_desc.
+        if (desc_is_list) {
+          py::object maybe_named_field = desc[statics.kZero];
+          if (maybe_named_field.equal(statics.kNamedTag)) {
+            py::object name_field = desc[statics.kOne];
+            retained_sub_desc = desc[statics.kTwo];
+            kwarg_name = py::cast<std::string>(name_field);
+            desc = retained_sub_desc;
+            desc_is_list = py::isinstance<py::list>(desc);
+
+            kwarg_to_index_[name_field] = arg_index;
+          }
+        }
+
+        if (!kwarg_name) {
+          pos_only_arg_count_ += 1;
+        }
+
+        flat_arg_packers_.push_back(
+            statics.AbiTypeToPackCallback(desc, desc_is_list));
+      }
+    }
+  }
+
+  /// Packs positional/kw arguments into a suitable VmVariantList and returns
+  /// it.
+  VmVariantList Pack(InvokeContext &invoke_context, py::sequence pos_args,
+                     py::dict kw_args) {
+    // Dynamic dispatch.
+    if (dynamic_dispatch_) {
+      IREE_TRACE_SCOPE0("ArgumentPacker::PackDynamic");
+      if (!kw_args.empty()) {
+        throw std::invalid_argument(
+            "kwargs not supported for dynamic dispatch functions");
+      }
+
+      VmVariantList arg_list = VmVariantList::Create(pos_args.size());
+      for (py::handle py_arg : pos_args) {
+        PackCallback packer = statics_.GetGenericPackCallbackFor(py_arg);
+        if (!packer) {
+          std::string message("could not convert python value to VM: ");
+          message.append(py::cast<std::string>(py::repr(py_arg)));
+          throw std::invalid_argument(std::move(message));
+        }
+        // TODO: Better error handling by catching the exception and
+        // reporting which arg has a problem.
+        packer(invoke_context, arg_list.raw_ptr(), py_arg);
+      }
+      return arg_list;
+    } else {
+      IREE_TRACE_SCOPE0("ArgumentPacker::PackReflection");
+
+      // Reflection based dispatch.
+      std::vector<py::handle> py_args(flat_arg_packers_.size());
+
+      if (pos_args.size() > pos_only_arg_count_) {
+        std::string message("mismatched call arity: expected ");
+        message.append(std::to_string(pos_only_arg_count_));
+        message.append(" got ");
+        message.append(std::to_string(pos_args.size()));
+        throw std::invalid_argument(std::move(message));
+      }
+
+      // Positional args.
+      size_t pos_index = 0;
+      for (py::handle py_arg : pos_args) {
+        py_args[pos_index++] = py_arg;
+      }
+
+      // Keyword args.
+      for (auto it : kw_args) {
+        int found_index;
+        try {
+          found_index = py::cast<int>(kwarg_to_index_[it.first]);
+        } catch (std::exception &) {
+          std::string message("specified kwarg '");
+          message.append(py::cast<py::str>(it.first));
+          message.append("' is unknown");
+          throw std::invalid_argument(std::move(message));
+        }
+        if (py_args[found_index]) {
+          std::string message(
+              "mismatched call arity: duplicate keyword argument '");
+          message.append(py::cast<py::str>(it.first));
+          message.append("'");
+          throw std::invalid_argument(std::move(message));
+        }
+        py_args[found_index] = it.second;
+      }
+
+      // Now check to see that all args are set.
+      for (size_t i = 0; i < py_args.size(); ++i) {
+        if (!py_args[i]) {
+          std::string message(
+              "mismatched call arity: expected a value for argument ");
+          message.append(std::to_string(i));
+          throw std::invalid_argument(std::move(message));
+        }
+      }
+
+      // Start packing into the list.
+      VmVariantList arg_list = VmVariantList::Create(flat_arg_packers_.size());
+      for (size_t i = 0; i < py_args.size(); ++i) {
+        // TODO: Better error handling by catching the exception and
+        // reporting which arg has a problem.
+        flat_arg_packers_[i](invoke_context, arg_list.raw_ptr(), py_args[i]);
+      }
+      return arg_list;
+    }
+  }
+
+ private:
+  InvokeStatics &statics_;
+
+  int pos_only_arg_count_ = 0;
+
+  // Dictionary of py::str -> py::int_ mapping kwarg names to position in
+  // the argument list. We store this as a py::dict because it is optimized
+  // for py::str lookup.
+  py::dict kwarg_to_index_;
+
+  std::vector<PackCallback> flat_arg_packers_;
+
+  // If true, then there is no dispatch metadata and we process fully
+  // dynamically.
+  bool dynamic_dispatch_ = false;
+};
+
+}  // namespace
+
+void SetupInvokeBindings(pybind11::module &m) {
+  py::class_<InvokeStatics>(m, "_InvokeStatics");
+  py::class_<InvokeContext>(m, "InvokeContext").def(py::init<HalDevice &>());
+  py::class_<ArgumentPacker>(m, "ArgumentPacker")
+      .def(py::init<InvokeStatics &, std::optional<py::list>>())
+      .def("pack", &ArgumentPacker::Pack);
+
+  m.attr("_invoke_statics") = py::cast(InvokeStatics());
+}
+
+}  // namespace python
+}  // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/invoke.h b/runtime/bindings/python/iree/runtime/invoke.h
new file mode 100644
index 0000000..206524f
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/invoke.h
@@ -0,0 +1,20 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_RT_INVOKE_H_
+#define IREE_BINDINGS_PYTHON_IREE_RT_INVOKE_H_
+
+#include "./binding.h"
+
+namespace iree {
+namespace python {
+
+void SetupInvokeBindings(pybind11::module &m);
+
+}  // namespace python
+}  // namespace iree
+
+#endif  // IREE_BINDINGS_PYTHON_IREE_RT_INVOKE_H_
diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py
new file mode 100644
index 0000000..007e96e
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/scripts/iree_benchmark_trace/__main__.py
@@ -0,0 +1,21 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+import subprocess
+import sys
+
+
+def main(args=None):
+  if args is None:
+    args = sys.argv[1:]
+  exe = os.path.join(os.path.dirname(__file__), "..", "..",
+                     "iree-benchmark-trace")
+  return subprocess.call(args=[exe] + args)
+
+
+if __name__ == "__main__":
+  sys.exit(main())
diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py
new file mode 100644
index 0000000..a5509a3
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/scripts/iree_run_module/__main__.py
@@ -0,0 +1,20 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+import subprocess
+import sys
+
+
+def main(args=None):
+  if args is None:
+    args = sys.argv[1:]
+  exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-run-module")
+  return subprocess.call(args=[exe] + args)
+
+
+if __name__ == "__main__":
+  sys.exit(main())
diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py
new file mode 100644
index 0000000..08dced3
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/scripts/iree_run_trace/__main__.py
@@ -0,0 +1,20 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+import subprocess
+import sys
+
+
+def main(args=None):
+  if args is None:
+    args = sys.argv[1:]
+  exe = os.path.join(os.path.dirname(__file__), "..", "..", "iree-run-trace")
+  return subprocess.call(args=[exe] + args)
+
+
+if __name__ == "__main__":
+  sys.exit(main())
diff --git a/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py b/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py
new file mode 100644
index 0000000..58f2118
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/scripts/iree_tracy_capture/__main__.py
@@ -0,0 +1,21 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+import subprocess
+import sys
+
+
+def main(args=None):
+  if args is None:
+    args = sys.argv[1:]
+  exe = os.path.join(os.path.dirname(__file__), "..", "..",
+                     "iree-tracy-capture")
+  return subprocess.call(args=[exe] + args)
+
+
+if __name__ == "__main__":
+  sys.exit(main())
diff --git a/runtime/bindings/python/iree/runtime/status_utils.cc b/runtime/bindings/python/iree/runtime/status_utils.cc
new file mode 100644
index 0000000..a05bfd5
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/status_utils.cc
@@ -0,0 +1,69 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./status_utils.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+PyObject* ApiStatusToPyExcClass(iree_status_t status) {
+  switch (iree_status_code(status)) {
+    case IREE_STATUS_INVALID_ARGUMENT:
+      return PyExc_ValueError;
+    case IREE_STATUS_OUT_OF_RANGE:
+      return PyExc_IndexError;
+    case IREE_STATUS_UNIMPLEMENTED:
+      return PyExc_NotImplementedError;
+    default:
+      return PyExc_RuntimeError;
+  }
+}
+
+static std::string ApiStatusToString(iree_status_t status) {
+  iree_host_size_t buffer_length = 0;
+  if (IREE_UNLIKELY(!iree_status_format(status, /*buffer_capacity=*/0,
+                                        /*buffer=*/NULL, &buffer_length))) {
+    return "";
+  }
+  std::string result;
+  result.resize(buffer_length);
+  // NOTE: buffer capacity needs to be +1 for the NUL terminator in snprintf.
+  return iree_status_format(status, result.size() + 1,
+                            const_cast<char*>(result.data()), &buffer_length)
+             ? result
+             : "";
+}
+
+}  // namespace
+
+pybind11::error_already_set ApiStatusToPyExc(iree_status_t status,
+                                             const char* message) {
+  assert(!iree_status_is_ok(status));
+  std::string full_message;
+
+  auto status_str = ApiStatusToString(status);
+  if (status_str.empty()) {
+    full_message = std::string(message) + ": " +
+                   iree_status_code_string(iree_status_code(status));
+  } else {
+    full_message = std::string(message) + ": " + status_str;
+  }
+
+  PyErr_SetString(ApiStatusToPyExcClass(status), full_message.c_str());
+  iree_status_ignore(status);
+  return pybind11::error_already_set();
+}
+
+pybind11::error_already_set RaisePyError(PyObject* exc_class,
+                                         const char* message) {
+  PyErr_SetString(exc_class, message);
+  return pybind11::error_already_set();
+}
+
+}  // namespace python
+}  // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/status_utils.h b/runtime/bindings/python/iree/runtime/status_utils.h
new file mode 100644
index 0000000..d87d308
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/status_utils.h
@@ -0,0 +1,48 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_COMMON_STATUS_UTILS_H_
+#define IREE_BINDINGS_PYTHON_IREE_COMMON_STATUS_UTILS_H_
+
+#include "iree/base/api.h"
+#include "pybind11/pybind11.h"
+
+namespace iree {
+namespace python {
+
+// Raises a value error with the given message.
+// Correct usage:
+//   throw RaiseValueError(PyExc_ValueError, "Foobar'd");
+pybind11::error_already_set RaisePyError(PyObject* exc_class,
+                                         const char* message);
+
+// Raises a value error with the given message.
+// Correct usage:
+//   throw RaiseValueError("Foobar'd");
+inline pybind11::error_already_set RaiseValueError(const char* message) {
+  return RaisePyError(PyExc_ValueError, message);
+}
+
+pybind11::error_already_set ApiStatusToPyExc(iree_status_t status,
+                                             const char* message);
+
+inline void CheckApiStatus(iree_status_t status, const char* message) {
+  if (iree_status_is_ok(status)) {
+    return;
+  }
+  throw ApiStatusToPyExc(status, message);
+}
+
+inline void CheckApiNotNull(const void* p, const char* message) {
+  if (!p) {
+    throw RaiseValueError(message);
+  }
+}
+
+}  // namespace python
+}  // namespace iree
+
+#endif  // IREE_BINDINGS_PYTHON_IREE_COMMON_STATUS_UTILS_H_
diff --git a/runtime/bindings/python/iree/runtime/system_api.py b/runtime/bindings/python/iree/runtime/system_api.py
new file mode 100644
index 0000000..fc82a3b
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/system_api.py
@@ -0,0 +1,358 @@
+# Lint as: python3
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Top-level python system API.
+
+This facility layers on top of the underlying binding native facilities and
+exposes them in a way that allows general operation against contexts, modules
+and functions.
+"""
+
+# pylint: disable=protected-access
+# pylint: disable=unused-argument
+# pylint: disable=g-explicit-length-test
+
+# TODO(#4131) python>=3.7: Use postponed type annotations.
+
+__all__ = [
+    "load_vm_flatbuffer",
+    "load_vm_flatbuffer_file",
+    "load_vm_module",
+    "load_vm_modules",
+    "normalize_value",
+    "Config",
+    "SystemContext",
+    "TARGET_BACKEND_TO_DRIVER",
+]
+
+import logging
+import os
+import sys
+
+from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
+
+from . import binding as _binding
+from .function import FunctionInvoker
+from . import tracing
+
+import numpy as np
+
+# Environment key for a comma-delimitted list of drivers to try to load.
+PREFERRED_DRIVER_ENV_KEY = "IREE_DEFAULT_DRIVER"
+
+# Default value for IREE_DRIVER
+DEFAULT_IREE_DRIVER_VALUE = "dylib,vulkan,vmvx"
+
+# Mapping from IREE target backends to their corresponding drivers.
+TARGET_BACKEND_TO_DRIVER = {
+    "dylib-llvm-aot": "dylib",
+    "vmvx": "vmvx",
+    "vulkan-spirv": "vulkan",
+}
+
+
+def _create_default_iree_driver(
+    driver_names: Optional[Sequence[str]] = None) -> _binding.HalDriver:
+  """Returns a default driver based on environment settings."""
+  # TODO(laurenzo): Ideally this should take a VmModule and join any explicitly
+  # provided driver list with environmental constraints and what the module
+  # was compiled for.
+  if driver_names is None:
+    # Read from environment.
+    driver_names = os.environ.get(PREFERRED_DRIVER_ENV_KEY)
+    if driver_names is None:
+      driver_names = DEFAULT_IREE_DRIVER_VALUE
+    driver_names = driver_names.split(",")
+  available_driver_names = _binding.HalDriver.query()
+  driver_exceptions = {}
+  for driver_name in driver_names:
+    if driver_name not in available_driver_names:
+      logging.error("Could not create driver %s (not registered)", driver_name)
+      continue
+    try:
+      driver = _binding.HalDriver.create(driver_name)
+    except Exception as ex:  # pylint: disable=broad-except
+      logging.exception("Could not create default driver %s", driver_name)
+      driver_exceptions[driver_name] = ex
+      continue
+
+    # Sanity check creation of the default device and skip the driver if
+    # this fails (this works around issues where the driver is present
+    # but there are no devices). This default initialization scheme needs
+    # to be improved.
+    try:
+      device = driver.create_default_device()
+    except Exception as ex:
+      logging.exception("Could not create default driver device %s",
+                        driver_name)
+      driver_exceptions[driver_name] = ex
+      continue
+
+    logging.debug("Created IREE driver %s: %r", driver_name, driver)
+    return driver
+
+  # All failed.
+  raise RuntimeError(
+      f"Could not create any requested driver {repr(driver_names)} (available="
+      f"{repr(available_driver_names)}) : {repr(driver_exceptions)}")
+
+
+class Config:
+  """System configuration."""
+
+  driver: _binding.HalDriver
+  device: _binding.HalDevice
+  vm_instance: _binding.VmInstance
+  default_vm_modules: Tuple[_binding.VmModule, ...]
+  tracer: Optional[tracing.Tracer]
+
+  def __init__(self,
+               driver_name: Optional[str] = None,
+               tracer: Optional[tracing.Tracer] = None):
+    self.vm_instance = _binding.VmInstance()
+    self.driver = _create_default_iree_driver(
+        driver_name.split(",") if driver_name is not None else None)
+    self.device = self.driver.create_default_device()
+    hal_module = _binding.create_hal_module(self.device)
+    self.default_vm_modules = (hal_module,)
+    self.tracer = tracer or tracing.get_default_tracer()
+    if self.tracer and self.tracer.enabled:
+      logging.info("IREE runtime tracing calls to path: %s",
+                   self.tracer.trace_path)
+    else:
+      self.tracer = None
+
+
+_global_config = None
+
+
+def _get_global_config():
+  global _global_config
+  if _global_config is None:
+    _global_config = Config()
+  return _global_config
+
+
+def _bool_to_int8(
+    array: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]:
+  if not isinstance(array, np.ndarray):
+    return array
+
+  # IREE models booleans as i8s.
+  # TODO(#5359): This cast should be moved into the function abi.
+  if array.dtype == np.bool:
+    array = array.astype(np.int8)
+  return array
+
+
+def normalize_value(
+    value: Any) -> Optional[Union[np.ndarray, List[Any], Tuple[Any]]]:
+  """Normalizes the given value for input to (or comparison with) IREE."""
+  if value is None:
+    # Exclude None from falling through to blanket np.asarray conversion.
+    return value
+
+  if isinstance(value, (list, tuple, dict)):
+    return value
+
+  array = np.asarray(value)
+  # TODO(#5359): Move into the function abi.
+  if isinstance(value, (bool, int, float)):
+    # Manually convert ints and floats to 32 bits.
+    if array.dtype == np.float64:
+      array = array.astype(np.float32)
+    elif array.dtype == np.int64:
+      array = array.astype(np.int32)
+
+  return array
+
+
+def _convert_lists_to_tuples(pytree):
+  if isinstance(pytree, Sequence):
+    return tuple(_convert_lists_to_tuples(leaf) for leaf in pytree)
+  elif isinstance(pytree, Mapping):
+    for key in pytree:
+      pytree[key] = _convert_lists_to_tuples(pytree[key])
+    return pytree
+  else:
+    return pytree
+
+
+class BoundModule:
+  """Wraps a VmModule with its context and provides nice python accessors.
+
+  Resolves item access (["foo"]) as function resolution.
+  """
+
+  def __init__(self, context: "SystemContext", vm_module: _binding.VmModule):
+    self._context = context
+    self._tracer = self._context._config.tracer
+    self._vm_module = vm_module
+    self._lazy_functions = dict()
+
+    # Let the tracing infra create a traced module.
+    self.traced_module = None
+    if self._tracer:
+      self.traced_module = self._tracer.persist_vm_module(vm_module)
+
+  @property
+  def name(self):
+    return self._vm_module.name
+
+  @property
+  def vm_module(self):
+    return self._vm_module
+
+  def __getattr__(self, name):
+    try:
+      return self[name]
+    except KeyError:
+      raise AttributeError(name)
+
+  def __getitem__(self, name):
+    vm_function = self._lazy_functions.get(name)
+    if vm_function is not None:
+      return vm_function
+
+    vm_function = self._vm_module.lookup_function(name)
+    if vm_function is None:
+      raise KeyError(f"Function '{name}' not found in module '{self}'")
+
+    # TODO: Needing to know the precise device to allocate on here is bad
+    # layering and will need to be fixed in some fashion if/when doing
+    # heterogenous dispatch.
+    return FunctionInvoker(self._context.vm_context,
+                           self._context.config.device, vm_function,
+                           self._context._tracer)
+
+  def __repr__(self):
+    return f"<BoundModule {repr(self._vm_module)}>"
+
+
+class BoundModules(dict):
+  """Provides nice python accessors for a dict of BoundModules."""
+
+  def __getattr__(self, name):
+    try:
+      return self[name]
+    except KeyError:
+      raise AttributeError(name)
+
+
+class SystemContext:
+  """Global system."""
+
+  def __init__(self, vm_modules=None, config: Optional[Config] = None):
+    self._config = config if config is not None else _get_global_config()
+    logging.debug("SystemContext driver=%r", self._config.driver)
+    self._is_dynamic = vm_modules is None
+    if self._is_dynamic:
+      init_vm_modules = None
+    else:
+      init_vm_modules = self._config.default_vm_modules + tuple(vm_modules)
+
+    self._vm_context = _binding.VmContext(instance=self._config.vm_instance,
+                                          modules=init_vm_modules)
+
+    if self._is_dynamic:
+      self._vm_context.register_modules(self._config.default_vm_modules)
+      self._bound_modules = BoundModules([
+          (m.name, BoundModule(self, m))
+          for m in self._config.default_vm_modules
+      ])
+    else:
+      self._bound_modules = BoundModules([
+          (m.name, BoundModule(self, m)) for m in init_vm_modules
+      ])
+
+    self._tracer = None  # type: Optional[tracing.ContextTracer]
+    if self._config.tracer:
+      self._tracer = tracing.ContextTracer(
+          self._config.tracer,
+          is_dynamic=self._is_dynamic,
+          modules=[bm.traced_module for bm in self._bound_modules.values()])
+
+  @property
+  def vm_context(self) -> _binding.VmContext:
+    return self._vm_context
+
+  @property
+  def is_dynamic(self) -> bool:
+    return self._is_dynamic
+
+  @property
+  def config(self) -> Config:
+    return self._config
+
+  @property
+  def instance(self) -> _binding.VmInstance:
+    return self._instance
+
+  @property
+  def modules(self) -> BoundModules:
+    return self._bound_modules
+
+  def add_vm_modules(self, vm_modules):
+    assert self._is_dynamic, "Cannot 'add_module' on a static context"
+    for m in vm_modules:
+      if m.name in self._bound_modules:
+        raise ValueError(f"Attempt to register duplicate VmModule: '{m.name}'")
+      bound_module = BoundModule(self, m)
+      self._bound_modules[m.name] = bound_module
+      if self._tracer:
+        self._tracer.add_module(bound_module.traced_module)
+    self._vm_context.register_modules(vm_modules)
+
+  def add_vm_module(self, vm_module):
+    self.add_vm_modules((vm_module,))
+
+
+def load_vm_modules(*vm_modules, config: Optional[Config] = None):
+  """Loads VmModules into a new SystemContext and returns them."""
+  context = SystemContext(vm_modules=vm_modules, config=config)
+  bound_modules = [context.modules[m.name] for m in vm_modules]
+  return bound_modules
+
+
+def load_vm_module(vm_module, config: Optional[Config] = None):
+  """Loads a VmModule into a new SystemContext and returns it."""
+  return load_vm_modules(vm_module, config=config)[0]
+
+
+def load_vm_flatbuffer(vm_flatbuffer: bytes,
+                       *,
+                       driver: Optional[str] = None,
+                       backend: Optional[str] = None) -> BoundModule:
+  """Loads a VM Flatbuffer into a callable module.
+
+  Either 'driver' or 'backend' must be specified.
+  """
+  if driver is None and backend is None:
+    raise ValueError("Either 'driver' or 'backend' must be specified, but got "
+                     "'None' for both.")
+  if backend is not None and driver is not None:
+    raise ValueError("Cannot specify both 'driver' and a 'backend' to infer "
+                     "the driver from.")
+  if backend is not None:
+    driver = TARGET_BACKEND_TO_DRIVER[backend]
+  vm_module = _binding.VmModule.from_flatbuffer(vm_flatbuffer)
+  bound_module = load_vm_module(vm_module, Config(driver))
+  return bound_module
+
+
+# TODO: There should be an API for mmap'ing the file which should be used
+# instead of reading into memory.
+def load_vm_flatbuffer_file(path: str,
+                            *,
+                            driver: Optional[str] = None,
+                            backend: Optional[str] = None) -> BoundModule:
+  """Loads a file containing a VM Flatbuffer into a callable module.
+
+  Either 'driver' or 'backend' must be specified.
+  """
+  with open(path, "rb") as f:
+    vm_flatbuffer = f.read()
+  return load_vm_flatbuffer(vm_flatbuffer, driver=driver, backend=backend)
diff --git a/runtime/bindings/python/iree/runtime/system_api_test.py b/runtime/bindings/python/iree/runtime/system_api_test.py
new file mode 100644
index 0000000..ed9a585
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/system_api_test.py
@@ -0,0 +1,141 @@
+# Lint as: python3
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# pylint: disable=unused-variable
+
+import os
+import re
+import tempfile
+
+from absl import logging
+from absl.testing import absltest
+import iree.compiler
+import iree.runtime
+import numpy as np
+
+
+def create_simple_mul_module():
+  binary = iree.compiler.compile_str(
+      """
+      module @arithmetic {
+        func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+            %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+            return %0 : tensor<4xf32>
+        }
+      }
+      """,
+      input_type="mhlo",
+      target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
+  )
+  m = iree.runtime.VmModule.from_flatbuffer(binary)
+  return m
+
+
+class SystemApiTest(absltest.TestCase):
+
+  def test_non_existing_driver(self):
+    with self.assertRaisesRegex(RuntimeError,
+                                "Could not create any requested driver"):
+      config = iree.runtime.Config("nothere1,nothere2")
+
+  def test_subsequent_driver(self):
+    config = iree.runtime.Config("nothere1,dylib")
+
+  def test_empty_dynamic(self):
+    ctx = iree.runtime.SystemContext()
+    self.assertTrue(ctx.is_dynamic)
+    self.assertIn("hal", ctx.modules)
+    self.assertEqual(ctx.modules.hal.name, "hal")
+
+  def test_empty_static(self):
+    ctx = iree.runtime.SystemContext(vm_modules=())
+    self.assertFalse(ctx.is_dynamic)
+    self.assertIn("hal", ctx.modules)
+    self.assertEqual(ctx.modules.hal.name, "hal")
+
+  def test_custom_dynamic(self):
+    ctx = iree.runtime.SystemContext()
+    self.assertTrue(ctx.is_dynamic)
+    ctx.add_vm_module(create_simple_mul_module())
+    self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
+    f = ctx.modules.arithmetic["simple_mul"]
+    f_repr = repr(f)
+    logging.info("f_repr: %s", f_repr)
+    self.assertEqual(f_repr, "<VmFunction simple_mul(0rr_r), reflection = {}>")
+
+  def test_duplicate_module(self):
+    ctx = iree.runtime.SystemContext()
+    self.assertTrue(ctx.is_dynamic)
+    ctx.add_vm_module(create_simple_mul_module())
+    with self.assertRaisesRegex(ValueError, "arithmetic"):
+      ctx.add_vm_module(create_simple_mul_module())
+
+  def test_static_invoke(self):
+    ctx = iree.runtime.SystemContext()
+    self.assertTrue(ctx.is_dynamic)
+    ctx.add_vm_module(create_simple_mul_module())
+    self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
+    f = ctx.modules.arithmetic["simple_mul"]
+    arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+    arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+    results = f(arg0, arg1)
+    np.testing.assert_allclose(results, [4., 10., 18., 28.])
+
+  def test_chained_invoke(self):
+    # This ensures that everything works if DeviceArrays are returned
+    # and input to functions.
+    ctx = iree.runtime.SystemContext()
+    self.assertTrue(ctx.is_dynamic)
+    ctx.add_vm_module(create_simple_mul_module())
+    self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
+    f = ctx.modules.arithmetic["simple_mul"]
+    arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+    arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+    results = f(arg0, arg1)
+    results2 = f(results, results)
+    np.testing.assert_allclose(results2, [16., 100., 324., 784.])
+
+  def test_tracing_explicit(self):
+    with tempfile.TemporaryDirectory() as temp_dir:
+      tracer = iree.runtime.Tracer(temp_dir)
+      config = iree.runtime.Config("dylib", tracer=tracer)
+      self.verify_tracing(config, temp_dir)
+
+  def test_tracing_from_environment(self):
+    original = os.environ.get(iree.runtime.TRACE_PATH_ENV_KEY)
+    try:
+      with tempfile.TemporaryDirectory() as temp_dir:
+        os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = temp_dir
+        config = iree.runtime.Config("dylib")
+        self.verify_tracing(config, temp_dir)
+    finally:
+      if original:
+        os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = original
+
+  def verify_tracing(self, config, temp_dir):
+    logging.info("Tracing test to: %s", temp_dir)
+    ctx = iree.runtime.SystemContext(config=config)
+    ctx.add_vm_module(create_simple_mul_module())
+    f = ctx.modules.arithmetic["simple_mul"]
+    arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+    arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+    results = f(arg0, arg1)
+    self.assertTrue(os.path.exists(os.path.join(temp_dir, "arithmetic.vmfb")))
+    self.assertTrue(os.path.exists(os.path.join(temp_dir, "calls.yaml")))
+    # TODO: Once replay is possible, verify that.
+
+  def test_load_vm_module(self):
+    arithmetic = iree.runtime.load_vm_module(create_simple_mul_module())
+    arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+    arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+    results = arithmetic.simple_mul(arg0, arg1)
+    print("SIMPLE_MUL RESULTS:", results)
+    np.testing.assert_allclose(results, [4., 10., 18., 28.])
+
+
+if __name__ == "__main__":
+  absltest.main()
diff --git a/runtime/bindings/python/iree/runtime/tracing.py b/runtime/bindings/python/iree/runtime/tracing.py
new file mode 100644
index 0000000..ea804ac
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/tracing.py
@@ -0,0 +1,170 @@
+"""Tracing support."""
+
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from genericpath import exists
+from typing import Dict, List, Optional, Sequence
+
+import logging
+import os
+import sys
+
+from . import binding as _binding
+
+try:
+  import yaml
+except ModuleNotFoundError:
+  _has_yaml = False
+else:
+  _has_yaml = True
+
+__all__ = [
+    "get_default_tracer",
+    "Tracer",
+    "TRACE_PATH_ENV_KEY",
+]
+
+TRACE_PATH_ENV_KEY = "IREE_SAVE_CALLS"
+
+
+class Tracer:
+  """Object for tracing calls made into the runtime."""
+
+  def __init__(self, trace_path: str):
+    if not _has_yaml:
+      self.enabled = False
+      logging.warning("PyYAML not installed: tracing will be disabled")
+      return
+    self.enabled = True
+    self.trace_path = trace_path
+    os.makedirs(trace_path, exist_ok=True)
+    self._name_count = dict()  # type: Dict[str, int]
+
+  def persist_vm_module(self, vm_module: _binding.VmModule) -> "TracedModule":
+    # Depending on how the module was created, there are different bits
+    # of information available to reconstruct.
+    name = vm_module.name
+    flatbuffer_blob = vm_module.stashed_flatbuffer_blob
+    if flatbuffer_blob:
+      save_path = os.path.join(self.trace_path,
+                               self.get_unique_name(f"{name}.vmfb"))
+      logging.info("Saving traced vmfb to %s", save_path)
+      with open(save_path, "wb") as f:
+        f.write(flatbuffer_blob)
+      return TracedModule(self, vm_module, save_path)
+
+    # No persistent form, but likely they are built-in modules.
+    return TracedModule(self, vm_module)
+
+  def get_unique_name(self, local_name: str) -> str:
+    if local_name not in self._name_count:
+      self._name_count[local_name] = 1
+      return local_name
+    stem, ext = os.path.splitext(local_name)
+    index = self._name_count[local_name]
+    self._name_count[local_name] += 1
+    unique_name = f"{stem}__{index}{ext}"
+    return unique_name
+
+
+class TracedModule:
+  """Wraps a VmModule with additional information for tracing."""
+
+  def __init__(self,
+               parent: Tracer,
+               vm_module: _binding.VmModule,
+               vmfb_path: Optional[str] = None):
+    self._parent = parent
+    self._vm_module = vm_module
+    self._vmfb_path = vmfb_path
+
+  def serialize(self):
+    module_record = {"name": self._vm_module.name}
+    if self._vmfb_path:
+      module_record["type"] = "bytecode"
+      module_record["path"] = os.path.relpath(self._vmfb_path,
+                                              self._parent.trace_path)
+    else:
+      module_record["type"] = "builtin"
+
+    return module_record
+
+
+class ContextTracer:
+  """Traces invocations against a context."""
+
+  def __init__(self, parent: Tracer, is_dynamic: bool,
+               modules: Sequence[TracedModule]):
+    self._parent = parent
+    self._modules = list(modules)  # type: List[TracedModule]
+    self._frame_count = 0
+    self._file_path = os.path.join(parent.trace_path,
+                                   parent.get_unique_name("calls.yaml"))
+    if os.path.exists(self._file_path):
+      # Truncate the file.
+      with open(self._file_path, "wt"):
+        pass
+    else:
+      os.makedirs(os.path.dirname(parent.trace_path), exist_ok=True)
+    logging.info("Tracing context events to: %s", self._file_path)
+    self.emit_frame({
+        "type": "context_load",
+    })
+    for module in self._modules:
+      self.emit_frame({
+          "type": "module_load",
+          "module": module.serialize(),
+      })
+
+  def add_module(self, module: TracedModule):
+    self._modules.append(module)
+    self.emit_frame({
+        "type": "module_load",
+        "module": module.serialize(),
+    })
+
+  def start_call(self, function: _binding.VmFunction):
+    logging.info("Tracing call to %s.%s", function.module_name, function.name)
+
+    # Start assembling the call record.
+    record = {
+        "type": "call",
+        "function": "%s.%s" % (function.module_name, function.name),
+    }
+    return CallTrace(self, record)
+
+  def emit_frame(self, frame: dict):
+    self._frame_count += 1
+    with open(self._file_path, "at") as f:
+      if self._frame_count != 1:
+        f.write("---\n")
+      contents = yaml.dump(frame, sort_keys=False)
+      f.write(contents)
+
+
+class CallTrace:
+
+  def __init__(self, parent: ContextTracer, record: dict):
+    self._parent = parent
+    self._record = record
+
+  def add_vm_list(self, vm_list: _binding.VmVariantList, key: str):
+    mapped = []
+    for i in range(len(vm_list)):
+      mapped.append(vm_list.get_serialized_trace_value(i))
+    self._record[key] = mapped
+
+  def end_call(self):
+    self._parent.emit_frame(self._record)
+
+
+def get_default_tracer() -> Optional[Tracer]:
+  """Gets a default call tracer based on environment variables."""
+  default_path = os.getenv(TRACE_PATH_ENV_KEY)
+  if not default_path:
+    return None
+  return Tracer(default_path)
diff --git a/runtime/bindings/python/iree/runtime/unix_version.lds b/runtime/bindings/python/iree/runtime/unix_version.lds
new file mode 100644
index 0000000..fd766d1
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/unix_version.lds
@@ -0,0 +1,11 @@
+/* Copyright 2020 The IREE Authors
+ *
+/* Licensed under the Apache License v2.0 with LLVM Exceptions.
+/* See https://llvm.org/LICENSE.txt for license information.
+/* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ */
+
+{
+  global: PyInit_binding;
+  local: *;
+};
diff --git a/runtime/bindings/python/iree/runtime/vm.cc b/runtime/bindings/python/iree/runtime/vm.cc
new file mode 100644
index 0000000..53dbbd1
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/vm.cc
@@ -0,0 +1,616 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./vm.h"
+
+#include "./status_utils.h"
+#include "iree/base/api.h"
+#include "iree/base/status_cc.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+#include "iree/modules/hal/module.h"
+#include "iree/vm/api.h"
+#include "pybind11/numpy.h"
+
+namespace iree {
+namespace python {
+
+namespace {
+
+VmModule CreateHalModule(HalDevice* device) {
+  iree_vm_module_t* module;
+  CheckApiStatus(iree_hal_module_create(device->raw_ptr(),
+                                        iree_allocator_system(), &module),
+                 "Error creating hal module");
+  return VmModule::StealFromRawPtr(module);
+}
+
+// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
+// out of scope.
+class PyBufferReleaser {
+ public:
+  PyBufferReleaser(Py_buffer& b) : b_(b) {}
+  ~PyBufferReleaser() { PyBuffer_Release(&b_); }
+
+ private:
+  Py_buffer& b_;
+};
+
+py::dict GetFunctionReflectionDict(iree_vm_function_t& f) {
+  py::dict attrs;
+  for (int i = 0;; ++i) {
+    iree_string_view_t key;
+    iree_string_view_t value;
+    auto status = iree_vm_get_function_reflection_attr(f, i, &key, &value);
+    if (iree_status_is_not_found(status)) {
+      iree_status_ignore(status);
+      break;
+    }
+    CheckApiStatus(status, "Error getting reflection attr");
+    py::str key_str(key.data, key.size);
+    py::str value_str(value.data, value.size);
+    attrs[std::move(key_str)] = std::move(value_str);
+  }
+  return attrs;
+}
+
+}  // namespace
+
+//------------------------------------------------------------------------------
+// VmInstance
+//------------------------------------------------------------------------------
+
+VmInstance VmInstance::Create() {
+  IREE_TRACE_SCOPE0("VmInstance::Create");
+  iree_vm_instance_t* instance;
+  auto status = iree_vm_instance_create(iree_allocator_system(), &instance);
+  CheckApiStatus(status, "Error creating instance");
+  return VmInstance::StealFromRawPtr(instance);
+}
+
+//------------------------------------------------------------------------------
+// VmContext
+//------------------------------------------------------------------------------
+
+VmContext VmContext::Create(VmInstance* instance,
+                            std::optional<std::vector<VmModule*>> modules) {
+  IREE_TRACE_SCOPE0("VmContext::Create");
+  iree_vm_context_t* context;
+  if (!modules) {
+    // Simple create with open allowed modules.
+    auto status =
+        iree_vm_context_create(instance->raw_ptr(), IREE_VM_CONTEXT_FLAG_NONE,
+                               iree_allocator_system(), &context);
+    CheckApiStatus(status, "Error creating vm context");
+  } else {
+    // Closed set of modules.
+    std::vector<iree_vm_module_t*> module_handles;
+    module_handles.resize(modules->size());
+    for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
+      module_handles[i] = (*modules)[i]->raw_ptr();
+    }
+    auto status = iree_vm_context_create_with_modules(
+        instance->raw_ptr(), IREE_VM_CONTEXT_FLAG_NONE, module_handles.data(),
+        module_handles.size(), iree_allocator_system(), &context);
+    CheckApiStatus(status, "Error creating vm context with modules");
+  }
+
+  IREE_CHECK(context);
+  return VmContext::StealFromRawPtr(context);
+}
+
+void VmContext::RegisterModules(std::vector<VmModule*> modules) {
+  std::vector<iree_vm_module_t*> module_handles;
+  module_handles.resize(modules.size());
+  for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
+    module_handles[i] = modules[i]->raw_ptr();
+  }
+  auto status = iree_vm_context_register_modules(raw_ptr(), &module_handles[0],
+                                                 module_handles.size());
+  CheckApiStatus(status, "Error registering modules");
+}
+
+void VmContext::Invoke(iree_vm_function_t f, VmVariantList& inputs,
+                       VmVariantList& outputs) {
+  iree_status_t status;
+  {
+    py::gil_scoped_release release;
+    status = iree_vm_invoke(raw_ptr(), f, IREE_VM_INVOCATION_FLAG_NONE, nullptr,
+                            inputs.raw_ptr(), outputs.raw_ptr(),
+                            iree_allocator_system());
+  }
+  CheckApiStatus(status, "Error invoking function");
+}
+
+//------------------------------------------------------------------------------
+// VmModule
+//------------------------------------------------------------------------------
+
+VmModule VmModule::FromFlatbufferBlob(py::object flatbuffer_blob_object) {
+  IREE_TRACE_SCOPE0("VmModule::FromFlatbufferBlob");
+  auto flatbuffer_blob = py::cast<py::buffer>(flatbuffer_blob_object);
+  auto buffer_info = flatbuffer_blob.request();
+  iree_vm_module_t* module;
+
+  // Bridge to the C-based deallocator API.
+  auto* raw_ptr = flatbuffer_blob.ptr();
+  auto ctl_fn = +([](void* self, iree_allocator_command_t command,
+                     const void* params, void** inout_ptr) {
+    assert(command == IREE_ALLOCATOR_COMMAND_FREE);
+    PyObject* object_ptr = static_cast<PyObject*>(*inout_ptr);
+    Py_XDECREF(object_ptr);
+    return iree_ok_status();
+  });
+  flatbuffer_blob.inc_ref();
+  iree_allocator_t deallocator{/*self=*/NULL, /*ctl=*/ctl_fn};
+
+  auto status = iree_vm_bytecode_module_create(
+      {static_cast<const uint8_t*>(buffer_info.ptr),
+       static_cast<iree_host_size_t>(buffer_info.size)},
+      deallocator, iree_allocator_system(), &module);
+  if (!iree_status_is_ok(status)) {
+    iree_allocator_free(deallocator, raw_ptr);
+  }
+
+  CheckApiStatus(status, "Error creating vm module from flatbuffer");
+  auto py_module = VmModule::StealFromRawPtr(module);
+  py_module.stashed_flatbuffer_blob = flatbuffer_blob_object;
+  return py_module;
+}
+
+std::optional<iree_vm_function_t> VmModule::LookupFunction(
+    const std::string& name, iree_vm_function_linkage_t linkage) {
+  iree_vm_function_t f;
+  auto status = iree_vm_module_lookup_function_by_name(
+      raw_ptr(), linkage, {name.data(), name.size()}, &f);
+  if (iree_status_is_not_found(status)) {
+    iree_status_ignore(status);
+    return std::nullopt;
+  }
+  CheckApiStatus(status, "Error looking up function");
+  return f;
+}
+
+//------------------------------------------------------------------------------
+// VmVariantList
+//------------------------------------------------------------------------------
+
+void VmVariantList::PushFloat(double fvalue) {
+  // Note that Python floats are f64.
+  iree_vm_value_t value = iree_vm_value_make_f64(fvalue);
+  CheckApiStatus(iree_vm_list_push_value(raw_ptr(), &value),
+                 "Could not push float");
+}
+
+void VmVariantList::PushInt(int64_t ivalue) {
+  // Note that Python ints are unbounded, so just use the largest type we
+  // have.
+  iree_vm_value_t value = iree_vm_value_make_i64(ivalue);
+  CheckApiStatus(iree_vm_list_push_value(raw_ptr(), &value),
+                 "Could not push int");
+}
+
+void VmVariantList::PushList(VmVariantList& other) {
+  iree_vm_ref_t retained = iree_vm_list_retain_ref(other.raw_ptr());
+  iree_vm_list_push_ref_move(raw_ptr(), &retained);
+}
+
+void VmVariantList::PushBufferView(HalBufferView& buffer_view) {
+  iree_vm_ref_t buffer_view_ref =
+      iree_hal_buffer_view_retain_ref(buffer_view.raw_ptr());
+  CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &buffer_view_ref),
+                 "Error moving buffer view");
+}
+
+py::object VmVariantList::GetAsList(int index) {
+  iree_vm_ref_t ref = {0};
+  CheckApiStatus(iree_vm_list_get_ref_assign(raw_ptr(), index, &ref),
+                 "Could not access list element");
+  iree_vm_list_t* sub_list = NULL;
+  CheckApiStatus(iree_vm_list_check_deref(ref, &sub_list),
+                 "Could not deref list (wrong type?)");
+  iree_vm_list_retain(sub_list);
+  return py::cast(VmVariantList(sub_list));
+}
+
+py::object VmVariantList::GetVariant(int index) {
+  iree_vm_variant_t v = iree_vm_variant_empty();
+  CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
+                 "Could not access list element");
+  if (iree_vm_type_def_is_value(&v.type)) {
+    // Convert a value type.
+    switch (v.type.value_type) {
+      case IREE_VM_VALUE_TYPE_I8:
+        return py::cast(v.i8);
+      case IREE_VM_VALUE_TYPE_I16:
+        return py::cast(v.i16);
+      case IREE_VM_VALUE_TYPE_I32:
+        return py::cast(v.i32);
+      case IREE_VM_VALUE_TYPE_I64:
+        return py::cast(v.i64);
+      case IREE_VM_VALUE_TYPE_F32:
+        return py::cast(v.f32);
+      case IREE_VM_VALUE_TYPE_F64:
+        return py::cast(v.f64);
+      default:
+        throw RaiseValueError("Unsupported VM value type conversion");
+    }
+  } else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) {
+    return py::none();
+  } else if (iree_vm_type_def_is_ref(&v.type)) {
+    // Convert reference type.
+    if (iree_vm_list_isa(v.ref)) {
+      return GetAsList(index);
+    } else if (iree_hal_buffer_view_isa(v.ref)) {
+      return GetAsBufferView(index);
+    }
+  }
+
+  throw RaiseValueError("Unsupported VM to Python Type Conversion");
+}
+
+py::object VmVariantList::GetAsSerializedTraceValue(int index) {
+  iree_vm_variant_t v = iree_vm_variant_empty();
+  CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
+                 "Could not access list element");
+  if (iree_vm_type_def_is_value(&v.type)) {
+    // Convert a value type.
+    py::dict record;
+    switch (v.type.value_type) {
+      case IREE_VM_VALUE_TYPE_I8:
+        record["i8"] = py::cast(v.i8);
+        break;
+      case IREE_VM_VALUE_TYPE_I16:
+        record["i16"] = py::cast(v.i16);
+        break;
+      case IREE_VM_VALUE_TYPE_I32:
+        record["i32"] = py::cast(v.i32);
+        break;
+      case IREE_VM_VALUE_TYPE_I64:
+        record["i64"] = py::cast(v.i64);
+        break;
+      case IREE_VM_VALUE_TYPE_F32:
+        record["f32"] = py::cast(v.f32);
+        break;
+      case IREE_VM_VALUE_TYPE_F64:
+        record["f64"] = py::cast(v.f64);
+        break;
+      default:
+        throw RaiseValueError("Unsupported VM value type conversion");
+    }
+    record["type"] = py::cast("value");
+    return std::move(record);
+  } else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) {
+    py::dict record;
+    record["type"] = "null";
+    return std::move(record);
+  } else if (iree_vm_type_def_is_ref(&v.type)) {
+    // Convert reference type.
+    if (iree_vm_list_isa(v.ref)) {
+      py::dict record;
+      record["type"] = "vm.list";
+      py::list items;
+      iree_vm_list_t* sub_list = NULL;
+      CheckApiStatus(iree_vm_list_check_deref(v.ref, &sub_list),
+                     "Could not deref list (wrong type?)");
+      iree_vm_list_retain(sub_list);
+      VmVariantList sub_list_object(sub_list);
+      for (int i = 0, e = sub_list_object.size(); i < e; ++i) {
+        items.append(sub_list_object.GetAsSerializedTraceValue(i));
+      }
+      record["items"] = std::move(items);
+      return std::move(record);
+    } else if (iree_hal_buffer_view_isa(v.ref)) {
+      py::dict record;
+      record["type"] = "hal.buffer_view";
+      iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(v.ref);
+      if (!buffer_view) {
+        throw RaiseValueError(
+            "Could not deref result buffer view (wrong type?)");
+      }
+      iree_hal_buffer_t* raw_buffer = iree_hal_buffer_view_buffer(buffer_view);
+      if (!raw_buffer) {
+        throw RaiseValueError("Could not deref result buffer (wrong type?)");
+      }
+
+      // Extract dims from the buffer view.
+      size_t rank = 0;
+      std::vector<int32_t> dims(6);
+      iree_status_t status = iree_hal_buffer_view_shape(
+          buffer_view, dims.capacity(), dims.data(), &rank);
+      if (iree_status_is_out_of_range(status)) {
+        dims.resize(rank);
+        status = iree_hal_buffer_view_shape(buffer_view, dims.capacity(),
+                                            dims.data(), &rank);
+      }
+      CheckApiStatus(status, "Error extracting shape");
+      dims.resize(rank);
+      record["shape"] = py::cast(std::move(dims));
+
+      // Element type.
+      iree_hal_element_type_t element_type =
+          iree_hal_buffer_view_element_type(buffer_view);
+      // TODO: Would be nice to output as hex.
+      record["element_type"] = element_type;
+
+      // Map memory.
+      iree_device_size_t byte_length = iree_hal_buffer_byte_length(raw_buffer);
+      iree_hal_buffer_mapping_t mapped_memory = {{0}};
+      CheckApiStatus(iree_hal_buffer_map_range(
+                         raw_buffer, IREE_HAL_MAPPING_MODE_SCOPED,
+                         IREE_HAL_MEMORY_ACCESS_READ, 0 /* element_offset */,
+                         byte_length, &mapped_memory),
+                     "Could not map memory");
+      record["contents"] =
+          py::bytes(reinterpret_cast<const char*>(mapped_memory.contents.data),
+                    mapped_memory.contents.data_length);
+      iree_hal_buffer_unmap_range(&mapped_memory);
+
+      return std::move(record);
+    }
+  }
+
+  throw RaiseValueError("Unsupported VM to Python Type Conversion");
+}
+
+py::object VmVariantList::GetAsBufferView(int index) {
+  iree_vm_variant_t v = iree_vm_variant_empty();
+  CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
+                 "Could not access list element");
+  iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(v.ref);
+  if (!buffer_view) {
+    throw RaiseValueError("Could not deref result buffer view (wrong type?)");
+  }
+  return py::cast(HalBufferView::BorrowFromRawPtr(buffer_view),
+                  py::return_value_policy::move);
+}
+
+namespace {
+
+static std::string ToHexString(const uint8_t* data, size_t length) {
+  static constexpr char kHexChars[] = {'0', '1', '2', '3', '4', '5', '6', '7',
+                                       '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
+  std::string s(length * 2, ' ');
+  for (size_t i = 0; i < length; ++i) {
+    s[2 * i + 0] = kHexChars[(data[i] & 0xF0) >> 4];
+    s[2 * i + 1] = kHexChars[(data[i] & 0x0F) >> 0];
+  }
+  return s;
+}
+static std::string ToHexString(uint32_t value) {
+  return ToHexString((const uint8_t*)&value, sizeof(value));
+}
+
+void AppendListContents(std::string& out, iree_vm_list_t* list,
+                        std::unordered_set<iree_vm_list_t*>& visited) {
+  for (iree_host_size_t i = 0, e = iree_vm_list_size(list); i < e; ++i) {
+    iree_vm_variant_t variant = iree_vm_variant_empty();
+    iree_status_t status = iree_vm_list_get_variant(list, i, &variant);
+    if (!iree_status_is_ok(status)) {
+      iree_status_ignore(status);
+      out.append("Error");
+      continue;
+    }
+    if (i > 0) out.append(", ");
+
+    if (iree_vm_variant_is_value(variant)) {
+      // Convert a value type to a string.
+      switch (variant.type.value_type) {
+        case IREE_VM_VALUE_TYPE_I8: {
+          out += std::to_string(variant.i8);
+          break;
+        }
+        case IREE_VM_VALUE_TYPE_I16: {
+          out += std::to_string(variant.i16);
+          break;
+        }
+        case IREE_VM_VALUE_TYPE_I32: {
+          out += std::to_string(variant.i32);
+          break;
+        }
+        case IREE_VM_VALUE_TYPE_I64: {
+          out += std::to_string(variant.i64);
+          break;
+        }
+        case IREE_VM_VALUE_TYPE_F32: {
+          out += std::to_string(variant.f32);
+          break;
+        }
+        case IREE_VM_VALUE_TYPE_F64: {
+          out += std::to_string(variant.f64);
+          break;
+        }
+        default:
+          throw RaiseValueError("Unsupported VM value type to string");
+      }
+    } else if (iree_vm_variant_is_ref(variant)) {
+      // Pretty print a subset of ABI impacting known types.
+      if (iree_hal_buffer_isa(variant.ref)) {
+        auto* hal_buffer = iree_hal_buffer_deref(variant.ref);
+        assert(hal_buffer);
+        out += std::string("HalBuffer(") +
+               std::to_string(iree_hal_buffer_byte_length(hal_buffer)) + ")";
+      } else if (iree_hal_buffer_view_isa(variant.ref)) {
+        auto hal_bv = iree_hal_buffer_view_deref(variant.ref);
+        out += "HalBufferView(";
+        std::vector<int32_t> shape(iree_hal_buffer_view_shape_rank(hal_bv));
+        iree_hal_buffer_view_shape(hal_bv, shape.size(), shape.data(), nullptr);
+        for (size_t i = 0; i < shape.size(); ++i) {
+          if (i > 0) out += 'x';
+          out += std::to_string(shape[i]);
+        }
+        out += ":0x" +
+               ToHexString(static_cast<uint32_t>(
+                   iree_hal_buffer_view_element_type(hal_bv))) +
+               ")";
+      } else if (iree_vm_list_isa(variant.ref)) {
+        out.append("List[");
+        iree_vm_list_t* sub_list = iree_vm_list_deref(variant.ref);
+        if (visited.insert(sub_list).second) {
+          AppendListContents(out, sub_list, visited);
+        } else {
+          out.append("...circular...");
+        }
+        out.append("]");
+      } else {
+        out += "Unknown(" + std::to_string(variant.type.ref_type) + ")";
+      }
+    } else {
+      out.append("None");
+    }
+  }
+}
+
+}  // namespace
+
+std::string VmVariantList::DebugString() const {
+  // The variant list API requires mutability, so we const cast to it internally
+  // so we can maintain a const DebugString() for callers.
+  auto mutable_this = const_cast<VmVariantList*>(this);
+  std::string s =
+      std::string("<VmVariantList(") + std::to_string(size()) + "): [";
+  iree_vm_list_t* list = mutable_this->raw_ptr();
+  std::unordered_set<iree_vm_list_t*> visited;
+  visited.insert(list);
+  AppendListContents(s, list, visited);
+  s.append("]>");
+  return s;
+}
+
+void SetupVmBindings(pybind11::module m) {
+  IREE_CHECK_OK(iree_vm_register_builtin_types());
+  IREE_CHECK_OK(iree_hal_module_register_types());
+
+  // Built-in module creation.
+  m.def("create_hal_module", &CreateHalModule);
+
+  py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage")
+      .value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
+      .value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
+      .value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
+      .export_values();
+
+  // Mutation and inspection of the variant list is mostly opaque to python.
+  py::class_<VmVariantList>(m, "VmVariantList")
+      .def(py::init(&VmVariantList::Create))
+      .def_property_readonly("size", &VmVariantList::size)
+      .def("__len__", &VmVariantList::size)
+      .def("get_as_buffer_view", &VmVariantList::GetAsBufferView)
+      .def("get_as_list", &VmVariantList::GetAsList)
+      .def("get_variant", &VmVariantList::GetVariant)
+      .def("get_serialized_trace_value",
+           &VmVariantList::GetAsSerializedTraceValue)
+      .def("push_float", &VmVariantList::PushFloat)
+      .def("push_int", &VmVariantList::PushInt)
+      .def("push_list", &VmVariantList::PushList)
+      .def("push_buffer_view", &VmVariantList::PushBufferView)
+      .def("__repr__", &VmVariantList::DebugString);
+
+  py::class_<iree_vm_function_t>(m, "VmFunction")
+      .def_readonly("linkage", &iree_vm_function_t::linkage)
+      .def_readonly("ordinal", &iree_vm_function_t::ordinal)
+      .def_property_readonly("name",
+                             [](iree_vm_function_t& self) {
+                               iree_string_view_t name =
+                                   iree_vm_function_name(&self);
+                               return py::str(name.data, name.size);
+                             })
+      .def_property_readonly("module_name",
+                             [](iree_vm_function_t& self) {
+                               iree_string_view_t name =
+                                   iree_vm_module_name(self.module);
+                               return py::str(name.data, name.size);
+                             })
+      .def_property_readonly("reflection",
+                             [](iree_vm_function_t& self) {
+                               return GetFunctionReflectionDict(self);
+                             })
+      .def("__repr__", [](iree_vm_function_t& self) {
+        iree_string_view_t name = iree_vm_function_name(&self);
+        std::string repr("<VmFunction ");
+        repr.append(name.data, name.size);
+
+        iree_vm_function_signature_t sig = iree_vm_function_signature(&self);
+        repr.append("(");
+        repr.append(sig.calling_convention.data, sig.calling_convention.size);
+        repr.append("), reflection = ");
+        py::dict reflection = GetFunctionReflectionDict(self);
+        repr.append(py::cast<std::string>(py::repr(reflection)));
+        repr.append(">");
+        return repr;
+      });
+
+  py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create));
+
+  py::class_<VmContext>(m, "VmContext")
+      .def(py::init(&VmContext::Create), py::arg("instance"),
+           py::arg("modules") = std::optional<std::vector<VmModule*>>())
+      .def("register_modules", &VmContext::RegisterModules)
+      .def_property_readonly("context_id", &VmContext::context_id)
+      .def("invoke", &VmContext::Invoke);
+
+  py::class_<VmModule>(m, "VmModule")
+      .def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
+      .def_property_readonly("name", &VmModule::name)
+      .def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
+           py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT)
+      .def_property_readonly(
+          "stashed_flatbuffer_blob",
+          [](VmModule& self) { return self.get_stashed_flatbuffer_blob(); })
+      .def_property_readonly(
+          "function_names",
+          [](VmModule& self) {
+            py::list names;
+            iree_vm_module_signature_t sig =
+                iree_vm_module_signature(self.raw_ptr());
+            for (size_t ordinal = 0; ordinal < sig.export_function_count;
+                 ++ordinal) {
+              iree_vm_function_t f;
+              auto status = iree_vm_module_lookup_function_by_ordinal(
+                  self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f);
+              if (iree_status_is_not_found(status)) {
+                iree_status_ignore(status);
+                break;
+              }
+              CheckApiStatus(status, "Error enumerating module");
+              iree_string_view_t fname = iree_vm_function_name(&f);
+              py::str name(fname.data, fname.size);
+              names.append(name);
+            }
+            return names;
+          })
+      .def("__repr__", [](VmModule& self) {
+        std::string repr("<VmModule ");
+        iree_string_view_t name = iree_vm_module_name(self.raw_ptr());
+        repr.append(name.data, name.size);
+
+        iree_vm_module_signature_t sig =
+            iree_vm_module_signature(self.raw_ptr());
+        repr.append(" : [");
+        for (size_t ordinal = 0; ordinal < sig.export_function_count;
+             ++ordinal) {
+          iree_vm_function_t f;
+          auto status = iree_vm_module_lookup_function_by_ordinal(
+              self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f);
+          if (iree_status_is_not_found(status)) {
+            iree_status_ignore(status);
+            break;
+          }
+          CheckApiStatus(status, "Error enumerating module");
+          iree_string_view_t fname = iree_vm_function_name(&f);
+          if (ordinal > 0) {
+            repr.append(", ");
+          }
+          repr.append(fname.data, fname.size);
+        }
+        repr.append("]");
+        repr.append(">");
+        return repr;
+      });
+}
+
+}  // namespace python
+}  // namespace iree
diff --git a/runtime/bindings/python/iree/runtime/vm.h b/runtime/bindings/python/iree/runtime/vm.h
new file mode 100644
index 0000000..48fdbab
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/vm.h
@@ -0,0 +1,168 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_BINDINGS_PYTHON_IREE_RT_VM_H_
+#define IREE_BINDINGS_PYTHON_IREE_RT_VM_H_
+
+#include <optional>
+
+#include "./binding.h"
+#include "./hal.h"
+#include "iree/base/api.h"
+#include "iree/vm/api.h"
+#include "iree/vm/bytecode_module.h"
+
+namespace iree {
+namespace python {
+
+class FunctionAbi;
+
+//------------------------------------------------------------------------------
+// Retain/release bindings
+//------------------------------------------------------------------------------
+
+template <>
+struct ApiPtrAdapter<iree_vm_instance_t> {
+  static void Retain(iree_vm_instance_t* b) { iree_vm_instance_retain(b); }
+  static void Release(iree_vm_instance_t* b) { iree_vm_instance_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_context_t> {
+  static void Retain(iree_vm_context_t* b) { iree_vm_context_retain(b); }
+  static void Release(iree_vm_context_t* b) { iree_vm_context_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_module_t> {
+  static void Retain(iree_vm_module_t* b) { iree_vm_module_retain(b); }
+  static void Release(iree_vm_module_t* b) { iree_vm_module_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_invocation_t> {
+  static void Retain(iree_vm_invocation_t* b) { iree_vm_invocation_retain(b); }
+  static void Release(iree_vm_invocation_t* b) {
+    iree_vm_invocation_release(b);
+  }
+};
+
+//------------------------------------------------------------------------------
+// VmVariantList
+//------------------------------------------------------------------------------
+
+class VmVariantList {
+ public:
+  VmVariantList() : list_(nullptr) {}
+  ~VmVariantList() {
+    if (list_) {
+      iree_vm_list_release(list_);
+    }
+  }
+
+  VmVariantList(VmVariantList&& other) {
+    list_ = other.list_;
+    other.list_ = nullptr;
+  }
+
+  VmVariantList& operator=(const VmVariantList&) = delete;
+  VmVariantList(const VmVariantList&) = delete;
+
+  static VmVariantList Create(iree_host_size_t capacity) {
+    iree_vm_list_t* list;
+    CheckApiStatus(iree_vm_list_create(/*element_type=*/nullptr, capacity,
+                                       iree_allocator_system(), &list),
+                   "Error allocating variant list");
+    return VmVariantList(list);
+  }
+
+  iree_host_size_t size() const { return iree_vm_list_size(list_); }
+
+  iree_vm_list_t* raw_ptr() { return list_; }
+  const iree_vm_list_t* raw_ptr() const { return list_; }
+  iree_vm_list_t* steal_raw_ptr() {
+    iree_vm_list_t* stolen = list_;
+    list_ = nullptr;
+    return stolen;
+  }
+  void AppendNullRef() {
+    iree_vm_ref_t null_ref = {0};
+    CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &null_ref),
+                   "Error appending to list");
+  }
+
+  std::string DebugString() const;
+  void PushFloat(double fvalue);
+  void PushInt(int64_t ivalue);
+  void PushList(VmVariantList& other);
+  void PushBufferView(HalBufferView& buffer_view);
+  py::object GetAsList(int index);
+  py::object GetAsBufferView(int index);
+  py::object GetVariant(int index);
+  py::object GetAsSerializedTraceValue(int index);
+
+ private:
+  VmVariantList(iree_vm_list_t* list) : list_(list) {}
+  iree_vm_list_t* list_;
+};
+
+//------------------------------------------------------------------------------
+// ApiRefCounted types
+//------------------------------------------------------------------------------
+
+class VmInstance : public ApiRefCounted<VmInstance, iree_vm_instance_t> {
+ public:
+  static VmInstance Create();
+};
+
+class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
+ public:
+  static VmModule FromFlatbufferBlob(py::object flatbuffer_blob_object);
+
+  std::optional<iree_vm_function_t> LookupFunction(
+      const std::string& name, iree_vm_function_linkage_t linkage);
+
+  std::string name() const {
+    auto name_sv = iree_vm_module_name(raw_ptr());
+    return std::string(name_sv.data, name_sv.size);
+  }
+
+  py::object get_stashed_flatbuffer_blob() { return stashed_flatbuffer_blob; }
+
+ private:
+  // If the module was created from a flatbuffer blob, we stash it here.
+  py::object stashed_flatbuffer_blob = py::none();
+};
+
+class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> {
+ public:
+  // Creates a context, optionally with modules, which will make the context
+  // static, disallowing further module registration (and may be more
+  // efficient).
+  static VmContext Create(VmInstance* instance,
+                          std::optional<std::vector<VmModule*>> modules);
+
+  // Registers additional modules. Only valid for non static contexts (i.e.
+  // those created without modules.
+  void RegisterModules(std::vector<VmModule*> modules);
+
+  // Unique id for this context.
+  int context_id() const { return iree_vm_context_id(raw_ptr()); }
+
+  // Synchronously invokes the given function.
+  void Invoke(iree_vm_function_t f, VmVariantList& inputs,
+              VmVariantList& outputs);
+};
+
+class VmInvocation : public ApiRefCounted<VmInvocation, iree_vm_invocation_t> {
+};
+
+void SetupVmBindings(pybind11::module m);
+
+}  // namespace python
+}  // namespace iree
+
+#endif  // IREE_BINDINGS_PYTHON_IREE_RT_VM_H_
diff --git a/runtime/bindings/python/iree/runtime/vm_test.py b/runtime/bindings/python/iree/runtime/vm_test.py
new file mode 100644
index 0000000..d730442
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/vm_test.py
@@ -0,0 +1,193 @@
+# Lint as: python3
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# pylint: disable=unused-variable
+
+from absl import logging
+from absl.testing import absltest
+import iree.compiler
+import iree.runtime
+import numpy as np
+
+
+def create_add_scalar_module():
+  binary = iree.compiler.compile_str(
+      """
+      func.func @add_scalar(%arg0: i32, %arg1: i32) -> i32 {
+        %0 = arith.addi %arg0, %arg1 : i32
+        return %0 : i32
+      }
+      """,
+      input_type="mhlo",
+      target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
+  )
+  m = iree.runtime.VmModule.from_flatbuffer(binary)
+  return m
+
+
+def create_simple_static_mul_module():
+  binary = iree.compiler.compile_str(
+      """
+      func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+          %0 = "mhlo.multiply"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+          return %0 : tensor<4xf32>
+      }
+      """,
+      input_type="mhlo",
+      target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
+  )
+  m = iree.runtime.VmModule.from_flatbuffer(binary)
+  return m
+
+
+def create_simple_dynamic_abs_module():
+  # TODO(laurenzo): Compile for more backends as dynamic shapes come online.
+  target_backends = iree.compiler.DEFAULT_TESTING_BACKENDS
+  binary = iree.compiler.compile_str(
+      """
+      func.func @simple_mul(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
+          %0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+          return %0 : tensor<?x?xf32>
+      }
+      """,
+      input_type="mhlo",
+      target_backends=target_backends,
+  )
+  m = iree.runtime.VmModule.from_flatbuffer(binary)
+  return m
+
+
+class VmTest(absltest.TestCase):
+
+  @classmethod
+  def setUpClass(cls):
+    super().setUpClass()
+    driver_names = iree.runtime.HalDriver.query()
+    logging.info("driver_names: %s", driver_names)
+    cls.driver = iree.runtime.HalDriver.create(
+        iree.compiler.core.DEFAULT_TESTING_DRIVER)
+    cls.device = cls.driver.create_default_device()
+    cls.hal_module = iree.runtime.create_hal_module(cls.device)
+
+  def test_variant_list(self):
+    l = iree.runtime.VmVariantList(5)
+    logging.info("variant_list: %s", l)
+    self.assertEqual(l.size, 0)
+
+  def test_variant_list_i64(self):
+    l = iree.runtime.VmVariantList(5)
+    # Push a value that exceeds 32-bit range.
+    l.push_int(10 * 1000 * 1000 * 1000)
+    self.assertEqual(str(l), "<VmVariantList(1): [10000000000]>")
+
+  def test_variant_list_buffers(self):
+    ET = iree.runtime.HalElementType
+    for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16),
+                   (np.int32, ET.SINT_32), (np.int64, ET.SINT_64),
+                   (np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16),
+                   (np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64),
+                   (np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)):
+      # TODO: Unimplemented: (np.float16, ET.FLOAT_16)
+      lst = iree.runtime.VmVariantList(5)
+      ary1 = np.asarray([1, 2, 3, 4], dtype=dt)
+      bv1 = self.device.allocator.allocate_buffer_copy(
+          memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
+          allowed_usage=(iree.runtime.BufferUsage.DISPATCH |
+                         iree.runtime.BufferUsage.TRANSFER |
+                         iree.runtime.BufferUsage.MAPPING),
+          buffer=ary1,
+          element_type=et)
+      lst.push_buffer_view(bv1)
+      ary2 = iree.runtime.DeviceArray(self.device,
+                                      lst.get_as_buffer_view(0),
+                                      override_dtype=dt,
+                                      implicit_host_transfer=True)
+      np.testing.assert_array_equal(ary1, ary2)
+      with self.assertRaises(IndexError):
+        lst.get_as_buffer_view(1)
+
+  def test_variant_list_list(self):
+    lst1 = iree.runtime.VmVariantList(5)
+    lst2 = iree.runtime.VmVariantList(5)
+    lst1.push_list(lst2)
+    self.assertEqual("<VmVariantList(1): [List[]]>", str(lst1))
+    lstout = lst1.get_as_list(0)
+    self.assertEqual("<VmVariantList(0): []>", str(lstout))
+    with self.assertRaises(IndexError):
+      lst1.get_as_list(1)
+
+  def test_context_id(self):
+    instance = iree.runtime.VmInstance()
+    context1 = iree.runtime.VmContext(instance)
+    context2 = iree.runtime.VmContext(instance)
+    self.assertGreater(context2.context_id, context1.context_id)
+
+  def test_module_basics(self):
+    m = create_simple_static_mul_module()
+    f = m.lookup_function("simple_mul")
+    self.assertGreaterEqual(f.ordinal, 0)
+    notfound = m.lookup_function("notfound")
+    self.assertIs(notfound, None)
+
+  def test_dynamic_module_context(self):
+    instance = iree.runtime.VmInstance()
+    context = iree.runtime.VmContext(instance)
+    m = create_simple_static_mul_module()
+    context.register_modules([self.hal_module, m])
+
+  def test_static_module_context(self):
+    m = create_simple_static_mul_module()
+    logging.info("module: %s", m)
+    instance = iree.runtime.VmInstance()
+    logging.info("instance: %s", instance)
+    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    logging.info("context: %s", context)
+
+  def test_dynamic_shape_compile(self):
+    m = create_simple_dynamic_abs_module()
+    logging.info("module: %s", m)
+    instance = iree.runtime.VmInstance()
+    logging.info("instance: %s", instance)
+    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    logging.info("context: %s", context)
+
+  def test_add_scalar_new_abi(self):
+    m = create_add_scalar_module()
+    instance = iree.runtime.VmInstance()
+    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    f = m.lookup_function("add_scalar")
+    finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
+    result = finv(5, 6)
+    logging.info("result: %s", result)
+    self.assertEqual(result, 11)
+
+  def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
+    m = create_simple_dynamic_abs_module()
+    instance = iree.runtime.VmInstance()
+    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    f = m.lookup_function("simple_mul")
+    finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
+    arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
+    result = finv(arg0)
+    logging.info("result: %s", result)
+    np.testing.assert_allclose(result, [[1., 2.], [3., 4.]])
+
+  def test_synchronous_invoke_function_new_abi(self):
+    m = create_simple_static_mul_module()
+    instance = iree.runtime.VmInstance()
+    context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
+    f = m.lookup_function("simple_mul")
+    finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
+    arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+    arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+    result = finv(arg0, arg1)
+    logging.info("result: %s", result)
+    np.testing.assert_allclose(result, [4., 10., 18., 28.])
+
+
+if __name__ == "__main__":
+  absltest.main()
diff --git a/runtime/pyproject.toml b/runtime/pyproject.toml
new file mode 100644
index 0000000..22b4019
--- /dev/null
+++ b/runtime/pyproject.toml
@@ -0,0 +1,14 @@
+[build-system]
+requires = [
+    "setuptools>=42",
+    "wheel",
+    # There is no fundamental reason to pin this CMake version, beyond
+    # build stability.
+    "cmake==3.22.2",
+    "ninja==1.10.2",
+    "packaging",
+    # Version 2.7.0 excluded: https://github.com/pybind/pybind11/issues/3136
+    "pybind11>=2.6.0,!=2.7.0",
+    "PyYAML",
+]
+build-backend = "setuptools.build_meta"
diff --git a/runtime/setup.py b/runtime/setup.py
new file mode 100644
index 0000000..fa684bc
--- /dev/null
+++ b/runtime/setup.py
@@ -0,0 +1,391 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This builds just the runtime API and is relatively quick to build.
+# To install:
+#   pip install .
+# To build a wheel:
+#   pip wheel .
+#
+# It is recommended to build with Ninja and ccache. To do so, set environment
+# variables by prefixing to above invocations:
+#   CMAKE_C_COMPILER_LAUNCHER=ccache CMAKE_CXX_COMPILER_LAUNCHER=ccache
+#
+# On CIs, it is often advantageous to re-use/control the CMake build directory.
+# This can be set with the IREE_RUNTIME_API_CMAKE_BUILD_DIR env var.
+#
+# A custom package suffix can be specified with the environment variable:
+#   IREE_RUNTIME_CUSTOM_PACKAGE_SUFFIX
+#
+# Select CMake options are available from environment variables:
+#   IREE_HAL_DRIVER_CUDA
+#   IREE_HAL_DRIVER_VULKAN
+#   IREE_ENABLE_RUNTIME_TRACING
+#   IREE_BUILD_TRACY
+
+from gettext import install
+import json
+from multiprocessing.spawn import prepare
+import os
+import platform
+import re
+import shutil
+import subprocess
+import sys
+import sysconfig
+
+from distutils.command.build import build as _build
+from setuptools import find_namespace_packages, setup, Extension
+from setuptools.command.build_ext import build_ext as _build_ext
+from setuptools.command.build_py import build_py as _build_py
+
+
+def check_pip_version():
+  from packaging import version
+  # Pip versions < 22.0.3 default to out of tree builds, which is quite
+  # incompatible with what we do (and has other issues). Pip >= 22.0.4
+  # removed this option entirely and are only in-tree builds. Since the
+  # old behavior can silently produce unworking installations, we aggressively
+  # suppress it.
+  try:
+    import pip
+  except ModuleNotFoundError:
+    # If pip not installed, we are obviously not trying to package via pip.
+    pass
+  else:
+    if (version.parse(pip.__version__) < version.parse("21.3")):
+      print("ERROR: pip version >= 21.3 required")
+      print("Upgrade: pip install pip --upgrade")
+      sys.exit(2)
+
+
+check_pip_version()
+
+# This file can be run directly from the source tree or it can be CMake
+# configured so it can run from the build tree with an already existing
+# build tree. We detect the difference based on whether the following
+# are expanded by CMake.
+CONFIGURED_SOURCE_DIR = "@IREE_SOURCE_DIR@"
+CONFIGURED_BINARY_DIR = "@IREE_BINARY_DIR@"
+
+IREE_SOURCE_DIR = None
+IREE_BINARY_DIR = None
+
+# We must do the intermediate installation to a fixed location that agrees
+# between what we pass to setup() and cmake. So hard-code it here.
+# Note that setup() needs a relative path (to the setup.py file).
+SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__))
+CMAKE_INSTALL_DIR_REL = os.path.join("build", "cmake_install")
+CMAKE_INSTALL_DIR_ABS = os.path.join(SETUPPY_DIR, CMAKE_INSTALL_DIR_REL)
+
+IS_CONFIGURED = CONFIGURED_SOURCE_DIR[0] != "@"
+if IS_CONFIGURED:
+  IREE_SOURCE_DIR = CONFIGURED_SOURCE_DIR
+  IREE_BINARY_DIR = CONFIGURED_BINARY_DIR
+  print(
+      f"Running setup.py from build tree: "
+      f"SOURCE_DIR = {IREE_SOURCE_DIR} "
+      f"BINARY_DIR = {IREE_BINARY_DIR}",
+      file=sys.stderr)
+else:
+  IREE_SOURCE_DIR = os.path.join(SETUPPY_DIR, "..")
+  IREE_BINARY_DIR = os.getenv("IREE_RUNTIME_API_CMAKE_BUILD_DIR")
+  if not IREE_BINARY_DIR:
+    # Note that setuptools always builds into a "build" directory that
+    # is a sibling of setup.py, so we just colonize a sub-directory of that
+    # by default.
+    IREE_BINARY_DIR = os.path.join(SETUPPY_DIR, "build", "cmake_build")
+  print(
+      f"Running setup.py from source tree: "
+      f"SOURCE_DIR = {IREE_SOURCE_DIR} "
+      f"BINARY_DIR = {IREE_BINARY_DIR}",
+      file=sys.stderr)
+
+# Setup and get version information.
+VERSION_INFO_FILE = os.path.join(IREE_SOURCE_DIR, "version_info.json")
+
+
+def load_version_info():
+  with open(VERSION_INFO_FILE, "rt") as f:
+    return json.load(f)
+
+
+try:
+  version_info = load_version_info()
+except FileNotFoundError:
+  print("version_info.json not found. Using defaults", file=sys.stderr)
+  version_info = {}
+
+PACKAGE_SUFFIX = version_info.get("package-suffix") or ""
+PACKAGE_VERSION = version_info.get("package-version") or "0.1dev1"
+
+
+def maybe_nuke_cmake_cache():
+  # From run to run under pip, we can end up with different paths to ninja,
+  # which isn't great and will confuse cmake. Detect if the location of
+  # ninja changes and force a cache flush.
+  ninja_path = ""
+  try:
+    import ninja
+  except ModuleNotFoundError:
+    pass
+  else:
+    ninja_path = ninja.__file__
+  expected_stamp_contents = f"{sys.executable}\n{ninja_path}"
+
+  # In order to speed things up on CI and not rebuild everything, we nuke
+  # the CMakeCache.txt file if the path to the Python interpreter changed.
+  # Ideally, CMake would let us reconfigure this dynamically... but it does
+  # not (and gets very confused).
+  PYTHON_STAMP_FILE = os.path.join(IREE_BINARY_DIR, "python_stamp.txt")
+  if os.path.exists(PYTHON_STAMP_FILE):
+    with open(PYTHON_STAMP_FILE, "rt") as f:
+      actual_stamp_contents = f.read()
+      if actual_stamp_contents == expected_stamp_contents:
+        # All good.
+        return
+
+  # Mismatch or not found. Clean it.
+  cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt")
+  if os.path.exists(cmake_cache_file):
+    print("Removing CMakeCache.txt because Python version changed",
+          file=sys.stderr)
+    os.remove(cmake_cache_file)
+
+  # And write.
+  with open(PYTHON_STAMP_FILE, "wt") as f:
+    f.write(expected_stamp_contents)
+
+
+def get_env_cmake_option(name: str, default_value: bool = False) -> bool:
+  svalue = os.getenv(name)
+  if not svalue:
+    svalue = "ON" if default_value else "OFF"
+  return f"-D{name}={svalue}"
+
+
+def prepare_installation():
+  subprocess.check_call(["cmake", "--version"])
+  version_py_content = generate_version_py()
+  print(f"Generating version.py:\n{version_py_content}", file=sys.stderr)
+
+  if not IS_CONFIGURED:
+    # Build from source tree.
+    os.makedirs(IREE_BINARY_DIR, exist_ok=True)
+    maybe_nuke_cmake_cache()
+    print(f"CMake build dir: {IREE_BINARY_DIR}", file=sys.stderr)
+    print(f"CMake install dir: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr)
+    cfg = "Release"
+    cmake_args = [
+        "-GNinja",
+        "--log-level=VERBOSE",
+        "-DIREE_BUILD_PYTHON_BINDINGS=ON",
+        "-DIREE_BUILD_COMPILER=OFF",
+        "-DIREE_BUILD_SAMPLES=OFF",
+        "-DIREE_BUILD_TESTS=OFF",
+        "-DPython3_EXECUTABLE={}".format(sys.executable),
+        "-DCMAKE_BUILD_TYPE={}".format(cfg),
+        get_env_cmake_option("IREE_HAL_DRIVER_CUDA"),
+        get_env_cmake_option("IREE_HAL_DRIVER_VULKAN",
+                             "OFF" if platform.system() == "Darwin" else "ON"),
+        get_env_cmake_option("IREE_ENABLE_RUNTIME_TRACING"),
+        get_env_cmake_option("IREE_BUILD_TRACY"),
+    ]
+
+    # Only do a from-scratch configure if not already configured.
+    cmake_cache_file = os.path.join(IREE_BINARY_DIR, "CMakeCache.txt")
+    if not os.path.exists(cmake_cache_file):
+      print(f"Configuring with: {cmake_args}", file=sys.stderr)
+      subprocess.check_call(["cmake", IREE_SOURCE_DIR] + cmake_args,
+                            cwd=IREE_BINARY_DIR)
+    else:
+      print(f"Not re-configuring (already configured)", file=sys.stderr)
+
+    # Build.
+    subprocess.check_call([
+        "cmake", "--build", ".", "--target",
+        "runtime/bindings/python/iree/runtime/all"
+    ],
+                          cwd=IREE_BINARY_DIR)
+    print("Build complete.", file=sys.stderr)
+
+  # Install the directory we care about.
+  install_subdirectory = os.path.join(IREE_BINARY_DIR, "runtime", "bindings",
+                                      "python", "iree", "runtime")
+  install_args = [
+      "-DCMAKE_INSTALL_DO_STRIP=ON",
+      f"-DCMAKE_INSTALL_PREFIX={CMAKE_INSTALL_DIR_ABS}/",
+      "-P",
+      os.path.join(install_subdirectory, "cmake_install.cmake"),
+  ]
+  print(f"Installing with: {install_args}", file=sys.stderr)
+  subprocess.check_call(["cmake"] + install_args, cwd=install_subdirectory)
+
+  # Write version.py directly into install dir.
+  version_py_file = os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages",
+                                 "iree_runtime", "iree", "runtime",
+                                 "version.py")
+  os.makedirs(os.path.dirname(version_py_file), exist_ok=True)
+  with open(version_py_file, "wt") as f:
+    f.write(version_py_content)
+
+  print(f"Installation prepared: {CMAKE_INSTALL_DIR_ABS}", file=sys.stderr)
+
+
+class CMakeBuildPy(_build_py):
+
+  def run(self):
+    # It is critical that the target directory contain all built extensions,
+    # or else setuptools will helpfully compile an empty binary for us
+    # (this is the **worst** possible thing it could do). We just copy
+    # everything. What's another hundred megs between friends?
+    target_dir = os.path.abspath(self.build_lib)
+    print(f"Building in target dir: {target_dir}", file=sys.stderr)
+    os.makedirs(target_dir, exist_ok=True)
+    print("Copying install to target.", file=sys.stderr)
+    if os.path.exists(target_dir):
+      shutil.rmtree(target_dir)
+    shutil.copytree(os.path.join(CMAKE_INSTALL_DIR_ABS, "python_packages",
+                                 "iree_runtime"),
+                    target_dir,
+                    symlinks=False)
+    print("Target populated.", file=sys.stderr)
+
+
+class CustomBuild(_build):
+
+  def run(self):
+    self.run_command("build_py")
+    self.run_command("build_ext")
+    self.run_command("build_scripts")
+
+
+class CMakeExtension(Extension):
+
+  def __init__(self, name, sourcedir=""):
+    Extension.__init__(self, name, sources=[])
+    self.sourcedir = os.path.abspath(sourcedir)
+
+
+class NoopBuildExtension(_build_ext):
+
+  def __init__(self, *args, **kwargs):
+    assert False
+
+  def build_extension(self, ext):
+    pass
+
+
+def generate_version_py():
+  return f"""# Auto-generated version info.
+PACKAGE_SUFFIX = "{PACKAGE_SUFFIX}"
+VERSION = "{PACKAGE_VERSION}"
+REVISIONS = {json.dumps(find_git_versions())}
+"""
+
+
+def find_git_versions():
+  revisions = {}
+  try:
+    revisions["IREE"] = subprocess.check_output(
+        ["git", "rev-parse", "HEAD"],
+        cwd=IREE_SOURCE_DIR).decode("utf-8").strip()
+  except subprocess.SubprocessError as e:
+    print(f"ERROR: Could not get IREE revision: {e}", file=sys.stderr)
+  return revisions
+
+
+def find_git_submodule_revision(submodule_path):
+  try:
+    data = subprocess.check_output(["git", "ls-tree", "HEAD", submodule_path],
+                                   cwd=IREE_SOURCE_DIR).decode("utf-8").strip()
+    columns = re.split("\\s+", data)
+    return columns[2]
+  except Exception as e:
+    print(
+        f"ERROR: Could not get submodule revision for {submodule_path}"
+        f" ({e})",
+        file=sys.stderr)
+    return ""
+
+
+prepare_installation()
+
+packages = find_namespace_packages(where=os.path.join(CMAKE_INSTALL_DIR_ABS,
+                                                      "python_packages",
+                                                      "iree_runtime"),
+                                   include=[
+                                       "iree.runtime",
+                                       "iree.runtime.*",
+                                   ])
+print(f"Found runtime packages: {packages}")
+
+with open(
+    os.path.join(IREE_SOURCE_DIR, "runtime", "bindings", "python", "iree",
+                 "runtime", "README.md"), "rt") as f:
+  README = f.read()
+
+custom_package_suffix = os.getenv("IREE_RUNTIME_CUSTOM_PACKAGE_SUFFIX")
+if not custom_package_suffix:
+  custom_package_suffix = ""
+
+setup(
+    name=f"iree-runtime{PACKAGE_SUFFIX}{custom_package_suffix}",
+    version=f"{PACKAGE_VERSION}",
+    author="IREE Authors",
+    author_email="iree-discuss@googlegroups.com",
+    description="IREE Python Runtime Components",
+    long_description=README,
+    long_description_content_type="text/markdown",
+    license="Apache-2.0",
+    classifiers=[
+        "Development Status :: 3 - Alpha",
+        "License :: OSI Approved :: Apache Software License",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
+    ],
+    url="https://github.com/google/iree",
+    python_requires=">=3.7",
+    ext_modules=[
+        CMakeExtension("iree.runtime.binding"),
+    ],
+    cmdclass={
+        "build": CustomBuild,
+        "built_ext": NoopBuildExtension,
+        "build_py": CMakeBuildPy,
+    },
+    zip_safe=False,
+    package_dir={
+        # Note: Must be relative path, so we line this up with the absolute
+        # path built above. Note that this must exist prior to the call.
+        "": f"{CMAKE_INSTALL_DIR_REL}/python_packages/iree_runtime",
+    },
+    packages=packages,
+    # Matching the native extension as a data file keeps setuptools from
+    # "building" it (i.e. turning it into a static binary).
+    package_data={
+        "": [
+            f"*{sysconfig.get_config_var('EXT_SUFFIX')}",
+            "iree-run-module*",
+            "iree-run-trace*",
+            "iree-benchmark-trace*",
+            "iree-tracy-capture*",
+        ],
+    },
+    entry_points={
+        "console_scripts": [
+            "iree-run-module = iree.runtime.scripts.iree_run_module.__main__:main",
+            "iree-run-trace = iree.runtime.scripts.iree_run_trace.__main__:main",
+            "iree-benchmark-trace = iree.runtime.scripts.iree_benchmark_trace.__main__:main",
+            "iree-tracy-capture = iree.runtime.scripts.iree_tracy_capture.__main__:main",
+        ],
+    },
+    install_requires=[
+        "numpy",
+        "PyYAML",
+    ],
+)