Re-implement the python tracing facility. (#6318)
diff --git a/bindings/python/iree/runtime/CMakeLists.txt b/bindings/python/iree/runtime/CMakeLists.txt
index 225d7ad..d501783 100644
--- a/bindings/python/iree/runtime/CMakeLists.txt
+++ b/bindings/python/iree/runtime/CMakeLists.txt
@@ -36,6 +36,7 @@
"__init__.py"
"function.py"
"system_api.py"
+ "tracing.py"
PYEXT_DEPS
::PyExtRt
)
@@ -67,6 +68,8 @@
MODULE_PATH iree/runtime
DEPS
bindings_python_iree_runtime_PyExtRt
+ ADDL_PACKAGE_FILES
+ ${CMAKE_CURRENT_SOURCE_DIR}/README.md
)
install(
diff --git a/bindings/python/iree/runtime/README.md b/bindings/python/iree/runtime/README.md
new file mode 100644
index 0000000..cd21ffe
--- /dev/null
+++ b/bindings/python/iree/runtime/README.md
@@ -0,0 +1,23 @@
+# IREE Python Runtime Components
+
+This package provides an API for running compiled IREE binaries and interfacing with the hardware-abstraction-layer.
+
+## Tracing
+
+Execution of calls against binaries can be traced for later replay (i.e. via
+tools like `iree-run-module`). This can be set up either explicitly or
+via environment variables.
+
+To trace via environment variable, set `IREE_SAVE_CALLS` to a directory to dump
+traces into. Each created `SystemContext` will result in one `calls.yaml`
+file (with an index appended to the stem for multiples). Any referenced
+module binaries will be dumped into the same directory and referenced by the
+YAML file.
+
+### Explicit API
+
+```python
+tracer = iree.runtime.Tracer(some_dir)
+config = iree.runtime.Config(driver, tracer)
+...
+```
diff --git a/bindings/python/iree/runtime/__init__.py b/bindings/python/iree/runtime/__init__.py
index c0eb796..589cbf0 100644
--- a/bindings/python/iree/runtime/__init__.py
+++ b/bindings/python/iree/runtime/__init__.py
@@ -17,7 +17,6 @@
from .binding import BufferUsage, HalBuffer, HalDevice, HalDriver, HalElementType, MemoryAccess, MemoryType, Shape
# Vm imports
from .binding import create_hal_module, Linkage, VmVariantList, VmFunction, VmInstance, VmContext, VmModule
-# SystemApi
from .system_api import *
-# Function
from .function import *
+from .tracing import *
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index 9f5b308..57537e7 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -5,12 +5,15 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Optional
+
import json
import logging
import numpy as np
from .binding import HalDevice, HalElementType, VmContext, VmFunction, VmVariantList
+from . import tracing
__all__ = [
"FunctionInvoker",
@@ -64,16 +67,19 @@
"_arg_descs",
"_ret_descs",
"_has_kwargs",
+ "_tracer",
]
def __init__(self, vm_context: VmContext, device: HalDevice,
- vm_function: VmFunction):
+ vm_function: VmFunction,
+ tracer: Optional[tracing.ContextTracer]):
self._vm_context = vm_context
# TODO: Needing to know the precise device to allocate on here is bad
# layering and will need to be fixed in some fashion if/when doing
# heterogenous dispatch.
self._device = device
self._vm_function = vm_function
+ self._tracer = tracer
self._abi_dict = None
self._arg_descs = None
self._ret_descs = None
@@ -85,6 +91,10 @@
return self._vm_function
def __call__(self, *args, **kwargs):
+ call_trace = None # type: Optional[tracing.CallTrace]
+ if self._tracer:
+ call_trace = self._tracer.start_call(self._vm_function)
+
# Initialize the capacity to our total number of args, since we should
# be below that when doing a flat invocation. May want to be more
# conservative here when considering nesting.
@@ -105,8 +115,14 @@
arg_list = VmVariantList(len(args))
ret_list = VmVariantList(len(ret_descs) if ret_descs is not None else 1)
_merge_python_sequence_to_vm(inv, arg_list, args, self._arg_descs)
+ if call_trace:
+ call_trace.add_vm_list(arg_list, "args")
self._vm_context.invoke(self._vm_function, arg_list, ret_list)
+ if call_trace:
+ call_trace.add_vm_list(ret_list, "results")
returns = _extract_vm_sequence_to_python(inv, ret_list, ret_descs)
+ if call_trace:
+ call_trace.end_call()
return_arity = len(returns)
if return_arity == 1:
return returns[0]
diff --git a/bindings/python/iree/runtime/setup.py.in b/bindings/python/iree/runtime/setup.py.in
index fc1f02e..4c84cc2 100644
--- a/bindings/python/iree/runtime/setup.py.in
+++ b/bindings/python/iree/runtime/setup.py.in
@@ -11,12 +11,8 @@
from setuptools import setup, find_namespace_packages, Extension
import sysconfig
-
-README = """# IREE Python Runtime Components
-
-This package provides an API for running compiled IREE binaries and interfacing
-with the hardware-abstraction-layer.
-"""
+with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as f:
+ README = f.read()
setup(
name="iree-runtime@IREE_RELEASE_PACKAGE_SUFFIX@",
@@ -35,7 +31,8 @@
"Development Status :: 3 - Alpha",
],
python_requires=">=3.6",
- packages=find_namespace_packages(include=["iree.runtime", "iree.runtime.*"]),
+ packages=find_namespace_packages(
+ include=["iree.runtime", "iree.runtime.*"]),
ext_modules=[
Extension(name="iree.runtime.binding", sources=[]),
],
diff --git a/bindings/python/iree/runtime/system_api.py b/bindings/python/iree/runtime/system_api.py
index cb7d0b3..3328ac3 100644
--- a/bindings/python/iree/runtime/system_api.py
+++ b/bindings/python/iree/runtime/system_api.py
@@ -28,6 +28,7 @@
"TARGET_BACKEND_TO_DRIVER",
]
+import logging
import os
import sys
@@ -35,6 +36,7 @@
from . import binding as _binding
from .function import FunctionInvoker
+from . import tracing
import numpy as np
@@ -109,14 +111,23 @@
device: _binding.HalDevice
vm_instance: _binding.VmInstance
default_vm_modules: Tuple[_binding.VmModule, ...]
+ tracer: Optional[tracing.Tracer]
- def __init__(self, driver_name: Optional[str] = None):
+ def __init__(self,
+ driver_name: Optional[str] = None,
+ tracer: Optional[tracing.Tracer] = None):
self.vm_instance = _binding.VmInstance()
self.driver = _create_default_iree_driver(
driver_name.split(",") if driver_name is not None else None)
self.device = self.driver.create_default_device()
hal_module = _binding.create_hal_module(self.device)
self.default_vm_modules = (hal_module,)
+ self.tracer = tracer or tracing.get_default_tracer()
+ if self.tracer and self.tracer.enabled:
+ logging.info("IREE runtime tracing calls to path: %s",
+ self.tracer.trace_path)
+ else:
+ self.tracer = None
_global_config = None
@@ -182,9 +193,15 @@
def __init__(self, context: "SystemContext", vm_module: _binding.VmModule):
self._context = context
+ self._tracer = self._context._config.tracer
self._vm_module = vm_module
self._lazy_functions = dict()
+ # Let the tracing infra create a traced module.
+ self.traced_module = None
+ if self._tracer:
+ self.traced_module = self._tracer.persist_vm_module(vm_module)
+
@property
def name(self):
return self._vm_module.name
@@ -212,7 +229,8 @@
# layering and will need to be fixed in some fashion if/when doing
# heterogenous dispatch.
return FunctionInvoker(self._context.vm_context,
- self._context.config.device, vm_function)
+ self._context.config.device, vm_function,
+ self._context._tracer)
def __repr__(self):
return f"<BoundModule {repr(self._vm_module)}>"
@@ -254,6 +272,13 @@
(m.name, BoundModule(self, m)) for m in init_vm_modules
])
+ self._tracer = None # type: Optional[tracing.ContextTracer]
+ if self._config.tracer:
+ self._tracer = tracing.ContextTracer(
+ self._config.tracer,
+ is_dynamic=self._is_dynamic,
+ modules=[bm.traced_module for bm in self._bound_modules.values()])
+
@property
def vm_context(self) -> _binding.VmContext:
return self._vm_context
@@ -279,7 +304,10 @@
for m in vm_modules:
if m.name in self._bound_modules:
raise ValueError(f"Attempt to register duplicate VmModule: '{m.name}'")
- self._bound_modules[m.name] = BoundModule(self, m)
+ bound_module = BoundModule(self, m)
+ self._bound_modules[m.name] = bound_module
+ if self._tracer:
+ self._tracer.add_module(bound_module.traced_module)
self._vm_context.register_modules(vm_modules)
def add_vm_module(self, vm_module):
diff --git a/bindings/python/iree/runtime/system_api_test.py b/bindings/python/iree/runtime/system_api_test.py
index 61b6297..eeeab80 100644
--- a/bindings/python/iree/runtime/system_api_test.py
+++ b/bindings/python/iree/runtime/system_api_test.py
@@ -7,7 +7,9 @@
# pylint: disable=unused-variable
+import os
import re
+import tempfile
from absl import logging
from absl.testing import absltest
@@ -83,19 +85,34 @@
results = f(arg0, arg1)
np.testing.assert_allclose(results, [4., 10., 18., 28.])
- # TODO: Re-implement tracing in a more sustainable fashion.
- # def test_serialize_values(self):
- # ctx = iree.runtime.SystemContext()
- # self.assertTrue(ctx.is_dynamic)
- # ctx.add_vm_module(create_simple_mul_module())
- # self.assertEqual(ctx.modules.arithmetic.name, "arithmetic")
- # f = ctx.modules.arithmetic["simple_mul"]
- # arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
- # arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
- # results = f(arg0, arg1)
- # inputs, outputs = f.get_serialized_values()
- # self.assertEqual(inputs, ("4xf32=1 2 3 4", "4xf32=4 5 6 7"))
- # self.assertEqual(outputs, ("4xf32=4 10 18 28",))
+ def test_tracing_explicit(self):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ tracer = iree.runtime.Tracer(temp_dir)
+ config = iree.runtime.Config("dylib", tracer=tracer)
+ self.verify_tracing(config, temp_dir)
+
+ def test_tracing_from_environment(self):
+ original = os.environ.get(iree.runtime.TRACE_PATH_ENV_KEY)
+ try:
+ with tempfile.TemporaryDirectory() as temp_dir:
+ os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = temp_dir
+ config = iree.runtime.Config("dylib")
+ self.verify_tracing(config, temp_dir)
+ finally:
+ if original:
+ os.environ[iree.runtime.TRACE_PATH_ENV_KEY] = original
+
+ def verify_tracing(self, config, temp_dir):
+ logging.info("Tracing test to: %s", temp_dir)
+ ctx = iree.runtime.SystemContext(config=config)
+ ctx.add_vm_module(create_simple_mul_module())
+ f = ctx.modules.arithmetic["simple_mul"]
+ arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+ arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+ results = f(arg0, arg1)
+ self.assertTrue(os.path.exists(os.path.join(temp_dir, "arithmetic.vmfb")))
+ self.assertTrue(os.path.exists(os.path.join(temp_dir, "calls.yaml")))
+ # TODO: Once replay is possible, verify that.
def test_load_vm_module(self):
arithmetic = iree.runtime.load_vm_module(create_simple_mul_module())
diff --git a/bindings/python/iree/runtime/tracing.py b/bindings/python/iree/runtime/tracing.py
new file mode 100644
index 0000000..08b9891
--- /dev/null
+++ b/bindings/python/iree/runtime/tracing.py
@@ -0,0 +1,168 @@
+"""Tracing support."""
+
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from genericpath import exists
+from typing import Dict, List, Optional, Sequence
+
+import logging
+import os
+import sys
+
+from . import binding as _binding
+
+try:
+ import yaml
+except ModuleNotFoundError:
+ _has_yaml = False
+else:
+ _has_yaml = True
+
+__all__ = [
+ "get_default_tracer",
+ "Tracer",
+ "TRACE_PATH_ENV_KEY",
+]
+
+TRACE_PATH_ENV_KEY = "IREE_SAVE_CALLS"
+
+
+class Tracer:
+ """Object for tracing calls made into the runtime."""
+
+ def __init__(self, trace_path: str):
+ if not _has_yaml:
+ self.enabled = False
+ logging.warning("PyYAML not installed: tracing will be disabled")
+ return
+ self.enabled = True
+ self.trace_path = trace_path
+ os.makedirs(trace_path, exist_ok=True)
+ self._name_count = dict() # type: Dict[str, int]
+
+ def persist_vm_module(self, vm_module: _binding.VmModule) -> "TracedModule":
+ # Depending on how the module was created, there are different bits
+ # of information available to reconstruct.
+ name = vm_module.name
+ flatbuffer_blob = vm_module.stashed_flatbuffer_blob
+ if flatbuffer_blob:
+ save_path = os.path.join(self.trace_path,
+ self.get_unique_name(f"{name}.vmfb"))
+ logging.info("Saving traced vmfb to %s", save_path)
+ with open(save_path, "wb") as f:
+ f.write(flatbuffer_blob)
+ return TracedModule(self, vm_module, save_path)
+
+ # No persistent form, but likely they are built-in modules.
+ return TracedModule(self, vm_module)
+
+ def get_unique_name(self, local_name: str) -> str:
+ if local_name not in self._name_count:
+ self._name_count[local_name] = 1
+ return local_name
+ stem, ext = os.path.splitext(local_name)
+ index = self._name_count[local_name]
+ self._name_count[local_name] += 1
+ unique_name = f"{stem}__{index}{ext}"
+ return unique_name
+
+
+class TracedModule:
+ """Wraps a VmModule with additional information for tracing."""
+
+ def __init__(self,
+ parent: Tracer,
+ vm_module: _binding.VmModule,
+ vmfb_path: Optional[str] = None):
+ self._parent = parent
+ self._vm_module = vm_module
+ self._vmfb_path = vmfb_path
+
+ def serialize(self):
+ module_record = {"name": self._vm_module.name}
+ if self._vmfb_path:
+ module_record["type"] = "bytecode"
+ module_record["vmfb_path"] = os.path.relpath(self._vmfb_path,
+ self._parent.trace_path)
+ else:
+ module_record["type"] = "native"
+
+ return module_record
+
+
+class ContextTracer:
+ """Traces invocations against a context."""
+
+ def __init__(self, parent: Tracer, is_dynamic: bool,
+ modules: Sequence[TracedModule]):
+ self._parent = parent
+ self._modules = list(modules) # type: List[TracedModule]
+ self._frame_count = 0
+ self._file_path = os.path.join(parent.trace_path,
+ parent.get_unique_name("calls.yaml"))
+ if os.path.exists(self._file_path):
+ # Truncate the file.
+ with open(self._file_path, "wt"):
+ pass
+ else:
+ os.makedirs(os.path.dirname(parent.trace_path), exist_ok=True)
+ logging.info("Tracing context events to: %s", self._file_path)
+ self.emit_frame({
+ "type": "init_context",
+ "is_dynamic": is_dynamic,
+ "modules": [m.serialize() for m in self._modules],
+ })
+
+ def add_module(self, module: TracedModule):
+ self._modules.append(module)
+ self.emit_frame({
+ "type": "add_module",
+ "module": module.serialize(),
+ })
+
+ def start_call(self, function: _binding.VmFunction):
+ logging.info("Tracing call to %s.%s", function.module_name, function.name)
+
+ # Start assembling the call record.
+ record = {
+ "type": "call",
+ "module_name": function.module_name,
+ "function_name": function.name,
+ }
+ return CallTrace(self, record)
+
+ def emit_frame(self, frame: dict):
+ self._frame_count += 1
+ with open(self._file_path, "at") as f:
+ if self._frame_count != 1:
+ f.write("---\n")
+ contents = yaml.dump(frame, sort_keys=False)
+ f.write(contents)
+
+
+class CallTrace:
+
+ def __init__(self, parent: ContextTracer, record: dict):
+ self._parent = parent
+ self._record = record
+
+ def add_vm_list(self, vm_list: _binding.VmVariantList, key: str):
+ mapped = []
+ for i in range(len(vm_list)):
+ mapped.append(vm_list.get_serialized_trace_value(i))
+ self._record[key] = mapped
+
+ def end_call(self):
+ self._parent.emit_frame(self._record)
+
+
+def get_default_tracer() -> Optional[Tracer]:
+ """Gets a default call tracer based on environment variables."""
+ default_path = os.getenv(TRACE_PATH_ENV_KEY)
+ if not default_path:
+ return None
+ return Tracer(default_path)
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index bfcba19..1379669 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -120,7 +120,8 @@
// VmModule
//------------------------------------------------------------------------------
-VmModule VmModule::FromFlatbufferBlob(py::buffer flatbuffer_blob) {
+VmModule VmModule::FromFlatbufferBlob(py::object flatbuffer_blob_object) {
+ auto flatbuffer_blob = py::cast<py::buffer>(flatbuffer_blob_object);
auto buffer_info = flatbuffer_blob.request();
iree_vm_module_t* module;
@@ -143,7 +144,9 @@
}
CheckApiStatus(status, "Error creating vm module from flatbuffer");
- return VmModule::CreateRetained(module);
+ auto py_module = VmModule::CreateRetained(module);
+ py_module.stashed_flatbuffer_blob = flatbuffer_blob_object;
+ return py_module;
}
std::optional<iree_vm_function_t> VmModule::LookupFunction(
@@ -294,6 +297,109 @@
throw RaiseValueError("Unsupported VM to Python Type Conversion");
}
+py::object VmVariantList::GetAsSerializedTraceValue(int index) {
+ iree_vm_variant_t v = iree_vm_variant_empty();
+ CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
+ "Could not access list element");
+ if (iree_vm_type_def_is_value(&v.type)) {
+ // Convert a value type.
+ py::dict record;
+ switch (v.type.value_type) {
+ case IREE_VM_VALUE_TYPE_I8:
+ record["i8"] = py::cast(v.i8);
+ break;
+ case IREE_VM_VALUE_TYPE_I16:
+ record["i16"] = py::cast(v.i16);
+ break;
+ case IREE_VM_VALUE_TYPE_I32:
+ record["i32"] = py::cast(v.i32);
+ break;
+ case IREE_VM_VALUE_TYPE_I64:
+ record["i64"] = py::cast(v.i64);
+ break;
+ case IREE_VM_VALUE_TYPE_F32:
+ record["f32"] = py::cast(v.f32);
+ break;
+ case IREE_VM_VALUE_TYPE_F64:
+ record["f64"] = py::cast(v.f64);
+ break;
+ default:
+ throw RaiseValueError("Unsupported VM value type conversion");
+ }
+ record["type"] = py::cast("value");
+ return std::move(record);
+ } else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) {
+ py::dict record;
+ record["type"] = "null";
+ return std::move(record);
+ } else if (iree_vm_type_def_is_ref(&v.type)) {
+ // Convert reference type.
+ if (iree_vm_list_isa(v.ref)) {
+ py::dict record;
+ record["type"] = "list";
+ py::list items;
+ iree_vm_list_t* sub_list = NULL;
+ CheckApiStatus(iree_vm_list_check_deref(v.ref, &sub_list),
+ "Could not deref list (wrong type?)");
+ iree_vm_list_retain(sub_list);
+ VmVariantList sub_list_object(sub_list);
+ for (int i = 0, e = sub_list_object.size(); i < e; ++i) {
+ items.append(sub_list_object.GetAsSerializedTraceValue(i));
+ }
+ record["items"] = std::move(items);
+ return std::move(record);
+ } else if (iree_hal_buffer_view_isa(v.ref)) {
+ py::dict record;
+ record["type"] = "buffer_view";
+ iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(v.ref);
+ if (!buffer_view) {
+ throw RaiseValueError(
+ "Could not deref result buffer view (wrong type?)");
+ }
+ iree_hal_buffer_t* raw_buffer = iree_hal_buffer_view_buffer(buffer_view);
+ if (!raw_buffer) {
+ throw RaiseValueError("Could not deref result buffer (wrong type?)");
+ }
+
+ // Extract dims from the buffer view.
+ size_t rank = 0;
+ std::vector<int32_t> dims(6);
+ iree_status_t status = iree_hal_buffer_view_shape(
+ buffer_view, dims.capacity(), dims.data(), &rank);
+ if (iree_status_is_out_of_range(status)) {
+ dims.resize(rank);
+ status = iree_hal_buffer_view_shape(buffer_view, dims.capacity(),
+ dims.data(), &rank);
+ }
+ CheckApiStatus(status, "Error extracting shape");
+ dims.resize(rank);
+ record["shape"] = py::cast(std::move(dims));
+
+ // Element type.
+ iree_hal_element_type_t element_type =
+ iree_hal_buffer_view_element_type(buffer_view);
+ // TODO: Would be nice to output as hex.
+ record["element_type"] = element_type;
+
+ // Map memory.
+ iree_device_size_t byte_length = iree_hal_buffer_byte_length(raw_buffer);
+ iree_hal_buffer_mapping_t mapped_memory;
+ CheckApiStatus(iree_hal_buffer_map_range(
+ raw_buffer, IREE_HAL_MEMORY_ACCESS_READ,
+ 0 /* element_offset */, byte_length, &mapped_memory),
+ "Could not map memory");
+ record["contents"] =
+ py::bytes(reinterpret_cast<const char*>(mapped_memory.contents.data),
+ mapped_memory.contents.data_length);
+ iree_hal_buffer_unmap_range(&mapped_memory);
+
+ return std::move(record);
+ }
+ }
+
+ throw RaiseValueError("Unsupported VM to Python Type Conversion");
+}
+
py::object VmVariantList::GetAsNdarray(int index) {
iree_vm_variant_t v = iree_vm_variant_empty();
CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
@@ -492,6 +598,8 @@
.def("get_as_ndarray", &VmVariantList::GetAsNdarray)
.def("get_as_list", &VmVariantList::GetAsList)
.def("get_variant", &VmVariantList::GetVariant)
+ .def("get_serialized_trace_value",
+ &VmVariantList::GetAsSerializedTraceValue)
.def("push_float", &VmVariantList::PushFloat)
.def("push_int", &VmVariantList::PushInt)
.def("push_list", &VmVariantList::PushList)
@@ -501,6 +609,18 @@
py::class_<iree_vm_function_t>(m, "VmFunction")
.def_readonly("linkage", &iree_vm_function_t::linkage)
.def_readonly("ordinal", &iree_vm_function_t::ordinal)
+ .def_property_readonly("name",
+ [](iree_vm_function_t& self) {
+ iree_string_view_t name =
+ iree_vm_function_name(&self);
+ return py::str(name.data, name.size);
+ })
+ .def_property_readonly("module_name",
+ [](iree_vm_function_t& self) {
+ iree_string_view_t name =
+ iree_vm_module_name(self.module);
+ return py::str(name.data, name.size);
+ })
.def_property_readonly("reflection",
[](iree_vm_function_t& self) {
return GetFunctionReflectionDict(self);
@@ -534,6 +654,9 @@
.def_property_readonly("name", &VmModule::name)
.def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT)
+ .def_property_readonly(
+ "stashed_flatbuffer_blob",
+ [](VmModule& self) { return self.get_stashed_flatbuffer_blob(); })
.def("__repr__", [](VmModule& self) {
std::string repr("<VmModule ");
iree_string_view_t name = iree_vm_module_name(self.raw_ptr());
diff --git a/bindings/python/iree/runtime/vm.h b/bindings/python/iree/runtime/vm.h
index 33da644..caf69db 100644
--- a/bindings/python/iree/runtime/vm.h
+++ b/bindings/python/iree/runtime/vm.h
@@ -99,6 +99,7 @@
py::object GetAsList(int index);
py::object GetAsNdarray(int index);
py::object GetVariant(int index);
+ py::object GetAsSerializedTraceValue(int index);
private:
VmVariantList(iree_vm_list_t* list) : list_(list) {}
@@ -116,7 +117,7 @@
class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
public:
- static VmModule FromFlatbufferBlob(py::buffer flatbuffer_blob);
+ static VmModule FromFlatbufferBlob(py::object flatbuffer_blob_object);
std::optional<iree_vm_function_t> LookupFunction(
const std::string& name, iree_vm_function_linkage_t linkage);
@@ -125,6 +126,12 @@
auto name_sv = iree_vm_module_name(raw_ptr());
return std::string(name_sv.data, name_sv.size);
}
+
+ py::object get_stashed_flatbuffer_blob() { return stashed_flatbuffer_blob; }
+
+ private:
+ // If the module was created from a flatbuffer blob, we stash it here.
+ py::object stashed_flatbuffer_blob = py::none();
};
class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> {
diff --git a/bindings/python/iree/runtime/vm_test.py b/bindings/python/iree/runtime/vm_test.py
index 81672c5..e29b2e1 100644
--- a/bindings/python/iree/runtime/vm_test.py
+++ b/bindings/python/iree/runtime/vm_test.py
@@ -140,38 +140,32 @@
logging.info("context: %s", context)
def test_add_scalar_new_abi(self):
- # TODO: Enable with new ABI.
- return
m = create_add_scalar_module()
instance = iree.runtime.VmInstance()
context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("add_scalar")
- finv = iree.runtime.FunctionInvoker(context, self.device, f)
+ finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
result = finv(5, 6)
logging.info("result: %s", result)
self.assertEqual(result, 11)
def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
- # TODO: Enable with new ABI.
- return
m = create_simple_dynamic_abs_module()
instance = iree.runtime.VmInstance()
context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
- finv = iree.runtime.FunctionInvoker(context, self.device, f)
+ finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
result = finv(arg0)
logging.info("result: %s", result)
np.testing.assert_allclose(result, [[1., 2.], [3., 4.]])
def test_synchronous_invoke_function_new_abi(self):
- # TODO: Enable with new ABI.
- return
m = create_simple_static_mul_module()
instance = iree.runtime.VmInstance()
context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
- finv = iree.runtime.FunctionInvoker(context, self.device, f)
+ finv = iree.runtime.FunctionInvoker(context, self.device, f, tracer=None)
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
result = finv(arg0, arg1)