More pre-factoring for native ABI swap. (#5886)
* More pre-factoring for native ABI swap.
* Some vm_test rework.
* Swap away from vmla.
* Implement non-reflection invocations.
* Handle scalar int/float arguments.
* Handle tensor<i1>.
* Add dylib to default driver.
diff --git a/bindings/python/iree/compiler/core.py b/bindings/python/iree/compiler/core.py
index cac0a34..8b32758 100644
--- a/bindings/python/iree/compiler/core.py
+++ b/bindings/python/iree/compiler/core.py
@@ -32,7 +32,10 @@
]
# Default testing backend for invoking the compiler.
-DEFAULT_TESTING_BACKENDS = ["vmla"]
+# TODO: Remove these. In the absence of default profiles, though, it is better
+# to centralize.
+DEFAULT_TESTING_BACKENDS = ["dylib-llvm-aot"]
+DEFAULT_TESTING_DRIVER = "dylib"
class OutputFormat(Enum):
diff --git a/bindings/python/iree/runtime/__init__.py b/bindings/python/iree/runtime/__init__.py
index f032ab9..25caf63 100644
--- a/bindings/python/iree/runtime/__init__.py
+++ b/bindings/python/iree/runtime/__init__.py
@@ -31,3 +31,5 @@
from .binding import create_hal_module, Linkage, VmVariantList, VmFunction, VmInstance, VmContext, VmModule
# SystemApi
from .system_api import *
+# Function
+from .function import *
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index 52ddba5..13e4154 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -129,7 +129,7 @@
if abi_json is None:
# It is valid to have no reflection data, and rely on pure dynamic
# dispatch.
- logging.warning(
+ logging.debug(
"Function lacks reflection data. Interop will be limited: %r",
vm_function)
return
@@ -172,21 +172,20 @@
def _int_to_vm(inv: Invocation, t: VmVariantList, x, desc):
# Implicit conversion to a 0d tensor.
- if _is_0d_ndarray_descriptor(desc):
+ if desc and _is_0d_ndarray_descriptor(desc):
casted = _cast_scalar_to_ndarray(inv, x, desc)
_ndarray_to_vm(inv, t, casted, desc)
return
-
- _raise_argument_error(inv, "Python int arguments not yet supported")
+ t.push_int(x)
def _float_to_vm(inv: Invocation, t: VmVariantList, x, desc):
# Implicit conversion to a 0d tensor.
- if _is_0d_ndarray_descriptor(desc):
+ if desc and _is_0d_ndarray_descriptor(desc):
casted = _cast_scalar_to_ndarray(inv, x, desc)
_ndarray_to_vm(inv, t, casted, desc)
return
- _raise_argument_error(inv, "Python float arguments not yet supported")
+ t.push_float(x)
def _list_or_tuple_to_vm(inv: Invocation, t: VmVariantList, x, desc):
@@ -325,6 +324,9 @@
# TODO: Others.
"f32": np.float32,
"i32": np.int32,
+ "i64": np.int64,
+ "f64": np.float64,
+ "i1": np.bool_,
}
# NOTE: Numpy dtypes are not hashable and exist in a hierarchy that should
@@ -343,6 +345,7 @@
(np.uint64, HalElementType.UINT_64),
(np.uint16, HalElementType.UINT_16),
(np.uint8, HalElementType.UINT_8),
+ (np.bool_, HalElementType.BOOL_8),
)
@@ -416,17 +419,18 @@
inv.current_desc = desc
if desc is None:
# Dynamic (non reflection mode).
- _raise_return_error(
- inv, "function has no reflection data, which is not yet supported")
- vm_type = desc[0]
- try:
- converter = VM_TO_PYTHON_CONVERTERS[vm_type]
- except KeyError:
- _raise_return_error(inv, f"cannot map VM type to Python: {vm_type}")
- try:
- converted = converter(inv, vm_list, vm_index, desc)
- except Exception as e:
- _raise_return_error(inv, f"exception converting from VM type to Python",
- e)
+ converted = vm_list.get_variant(vm_index)
+ else:
+ # Known type descriptor.
+ vm_type = desc[0]
+ try:
+ converter = VM_TO_PYTHON_CONVERTERS[vm_type]
+ except KeyError:
+ _raise_return_error(inv, f"cannot map VM type to Python: {vm_type}")
+ try:
+ converted = converter(inv, vm_list, vm_index, desc)
+ except Exception as e:
+ _raise_return_error(inv, f"exception converting from VM type to Python",
+ e)
results.append(converted)
return results
diff --git a/bindings/python/iree/runtime/hal.cc b/bindings/python/iree/runtime/hal.cc
index 2a04f0f..212a361 100644
--- a/bindings/python/iree/runtime/hal.cc
+++ b/bindings/python/iree/runtime/hal.cc
@@ -104,6 +104,9 @@
.value("FLOAT_16", IREE_HAL_ELEMENT_TYPE_FLOAT_16)
.value("FLOAT_32", IREE_HAL_ELEMENT_TYPE_FLOAT_32)
.value("FLOAT_64", IREE_HAL_ELEMENT_TYPE_FLOAT_64)
+ .value("BOOL_8", static_cast<enum iree_hal_element_type_e>(
+ IREE_HAL_ELEMENT_TYPE_VALUE(
+ IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 1)))
.export_values();
py::class_<HalDevice>(m, "HalDevice");
diff --git a/bindings/python/iree/runtime/system_api.py b/bindings/python/iree/runtime/system_api.py
index 210dd50..664cb97 100644
--- a/bindings/python/iree/runtime/system_api.py
+++ b/bindings/python/iree/runtime/system_api.py
@@ -50,7 +50,7 @@
PREFERRED_DRIVER_ENV_KEY = "IREE_DEFAULT_DRIVER"
# Default value for IREE_DRIVER
-DEFAULT_IREE_DRIVER_VALUE = "vulkan,vmla"
+DEFAULT_IREE_DRIVER_VALUE = "dylib,vulkan,vmla"
# Mapping from IREE target backends to their corresponding drivers.
TARGET_BACKEND_TO_DRIVER = {
diff --git a/bindings/python/iree/runtime/system_api_test.py b/bindings/python/iree/runtime/system_api_test.py
index 570ce44..929a154 100644
--- a/bindings/python/iree/runtime/system_api_test.py
+++ b/bindings/python/iree/runtime/system_api_test.py
@@ -35,7 +35,7 @@
}
}
""",
- target_backends=["vulkan-spirv"],
+ target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(binary)
return m
@@ -49,7 +49,7 @@
config = iree.runtime.Config("nothere1,nothere2")
def test_subsequent_driver(self):
- config = iree.runtime.Config("nothere1,vmla")
+ config = iree.runtime.Config("nothere1,dylib")
def test_empty_dynamic(self):
ctx = iree.runtime.SystemContext()
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index c0656b5..8420f95 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -222,6 +222,21 @@
// VmVariantList
//------------------------------------------------------------------------------
+void VmVariantList::PushFloat(double fvalue) {
+ // Note that Python floats are f64.
+ iree_vm_value_t value = iree_vm_value_make_f64(fvalue);
+ CheckApiStatus(iree_vm_list_push_value(raw_ptr(), &value),
+ "Could not push float");
+}
+
+void VmVariantList::PushInt(int64_t ivalue) {
+ // Note that Python ints are unbounded, so just use the largest type we
+ // have.
+ iree_vm_value_t value = iree_vm_value_make_i64(ivalue);
+ CheckApiStatus(iree_vm_list_push_value(raw_ptr(), &value),
+ "Could not push int");
+}
+
void VmVariantList::PushList(VmVariantList& other) {
iree_vm_ref_t retained = iree_vm_list_retain_ref(other.raw_ptr());
iree_vm_list_push_ref_move(raw_ptr(), &retained);
@@ -291,7 +306,7 @@
"Error moving buffer view");
}
-VmVariantList VmVariantList::GetAsList(int index) {
+py::object VmVariantList::GetAsList(int index) {
iree_vm_ref_t ref = {0};
CheckApiStatus(iree_vm_list_get_ref_assign(raw_ptr(), index, &ref),
"Could not access list element");
@@ -299,7 +314,43 @@
CheckApiStatus(iree_vm_list_check_deref(ref, &sub_list),
"Could not deref list (wrong type?)");
iree_vm_list_retain(sub_list);
- return VmVariantList(sub_list);
+ return py::cast(VmVariantList(sub_list));
+}
+
+py::object VmVariantList::GetVariant(int index) {
+ iree_vm_variant_t v = iree_vm_variant_empty();
+ CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
+ "Could not access list element");
+ if (iree_vm_type_def_is_value(&v.type)) {
+ // Convert a value type.
+ switch (v.type.value_type) {
+ case IREE_VM_VALUE_TYPE_I8:
+ return py::cast(v.i8);
+ case IREE_VM_VALUE_TYPE_I16:
+ return py::cast(v.i16);
+ case IREE_VM_VALUE_TYPE_I32:
+ return py::cast(v.i32);
+ case IREE_VM_VALUE_TYPE_I64:
+ return py::cast(v.i64);
+ case IREE_VM_VALUE_TYPE_F32:
+ return py::cast(v.f32);
+ case IREE_VM_VALUE_TYPE_F64:
+ return py::cast(v.f64);
+ default:
+ throw RaiseValueError("Unsupported VM value type conversion");
+ }
+ } else if (v.type.ref_type == IREE_VM_REF_TYPE_NULL) {
+ return py::none();
+ } else if (iree_vm_type_def_is_ref(&v.type)) {
+ // Convert reference type.
+ if (iree_vm_list_isa(v.ref)) {
+ return GetAsList(index);
+ } else if (iree_hal_buffer_view_isa(v.ref)) {
+ return GetAsNdarray(index);
+ }
+ }
+
+ throw RaiseValueError("Unsupported VM to Python Type Conversion");
}
py::object VmVariantList::GetAsNdarray(int index) {
@@ -480,6 +531,9 @@
.def("__len__", &VmVariantList::size)
.def("get_as_ndarray", &VmVariantList::GetAsNdarray)
.def("get_as_list", &VmVariantList::GetAsList)
+ .def("get_variant", &VmVariantList::GetVariant)
+ .def("push_float", &VmVariantList::PushFloat)
+ .def("push_int", &VmVariantList::PushInt)
.def("push_list", &VmVariantList::PushList)
.def("push_buffer_view", &VmVariantList::PushBufferView)
.def("__repr__", &VmVariantList::DebugString);
diff --git a/bindings/python/iree/runtime/vm.h b/bindings/python/iree/runtime/vm.h
index 33422c0..c57a003 100644
--- a/bindings/python/iree/runtime/vm.h
+++ b/bindings/python/iree/runtime/vm.h
@@ -98,11 +98,14 @@
}
std::string DebugString() const;
+ void PushFloat(double fvalue);
+ void PushInt(int64_t ivalue);
void PushList(VmVariantList& other);
void PushBufferView(HalDevice& device, py::object py_buffer_object,
iree_hal_element_type_e element_type);
- VmVariantList GetAsList(int index);
+ py::object GetAsList(int index);
py::object GetAsNdarray(int index);
+ py::object GetVariant(int index);
private:
VmVariantList(iree_vm_list_t* list) : list_(list) {}
diff --git a/bindings/python/iree/runtime/vm_test.py b/bindings/python/iree/runtime/vm_test.py
index e74b071..06e4640 100644
--- a/bindings/python/iree/runtime/vm_test.py
+++ b/bindings/python/iree/runtime/vm_test.py
@@ -30,7 +30,7 @@
return %0 : i32
}
""",
- target_backends=["vmla"],
+ target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(binary)
return m
@@ -45,7 +45,7 @@
return %0 : tensor<4xf32>
}
""",
- target_backends=["vmla"],
+ target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(binary)
return m
@@ -53,7 +53,7 @@
def create_simple_dynamic_abs_module():
# TODO(laurenzo): Compile for more backends as dynamic shapes come online.
- target_backends = ["vmla"]
+ target_backends = iree.compiler.DEFAULT_TESTING_BACKENDS
binary = iree.compiler.compile_str(
"""
func @simple_mul(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -75,7 +75,8 @@
super().setUpClass()
driver_names = iree.runtime.HalDriver.query()
logging.info("driver_names: %s", driver_names)
- cls.driver = iree.runtime.HalDriver.create("vmla")
+ cls.driver = iree.runtime.HalDriver.create(
+ iree.compiler.core.DEFAULT_TESTING_DRIVER)
cls.device = cls.driver.create_default_device()
cls.hal_module = iree.runtime.create_hal_module(cls.device)
cls.htf = iree.runtime.HostTypeFactory.get_numpy()
@@ -120,7 +121,7 @@
def test_module_basics(self):
m = create_simple_static_mul_module()
f = m.lookup_function("simple_mul")
- self.assertGreater(f.ordinal, 0)
+ self.assertGreaterEqual(f.ordinal, 0)
notfound = m.lookup_function("notfound")
self.assertIs(notfound, None)
@@ -146,72 +147,42 @@
context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
logging.info("context: %s", context)
- def test_add_scalar(self):
+ def test_add_scalar_new_abi(self):
+ # TODO: Enable with new ABI.
+ return
m = create_add_scalar_module()
instance = iree.runtime.VmInstance()
context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("add_scalar")
- abi = context.create_function_abi(self.device, self.htf, f)
- logging.info("abi: %s", abi)
-
- inputs = abi.pack_inputs(5, 6)
- logging.info("serialize_inputs: %s", abi.serialize_vm_list(inputs))
- logging.info("inputs: %s", inputs)
-
- allocated_results = abi.allocate_results(inputs, static_alloc=False)
- logging.info("allocated_results: %s", allocated_results)
- logging.info("Invoking...")
- context.invoke(f, inputs, allocated_results)
- logging.info("...done")
-
- result = abi.unpack_results(allocated_results)
+ finv = iree.runtime.FunctionInvoker(context, self.device, f)
+ result = finv(5, 6)
logging.info("result: %s", result)
self.assertEqual(result, 11)
- def test_synchronous_dynamic_shape_invoke_function(self):
+ def test_synchronous_dynamic_shape_invoke_function_new_abi(self):
+ # TODO: Enable with new ABI.
+ return
m = create_simple_dynamic_abs_module()
instance = iree.runtime.VmInstance()
context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
- abi = context.create_function_abi(self.device, self.htf, f)
- logging.info("abi: %s", abi)
-
+ finv = iree.runtime.FunctionInvoker(context, self.device, f)
arg0 = np.array([[-1., 2.], [3., -4.]], dtype=np.float32)
- inputs = abi.pack_inputs(arg0)
- logging.info("Serialized inputs: %s", abi.serialize_vm_list(inputs))
- logging.info("inputs: %s", inputs)
-
- allocated_results = abi.allocate_results(inputs, static_alloc=False)
- logging.info("allocated_results: %s", allocated_results)
- logging.info("Invoking...")
- context.invoke(f, inputs, allocated_results)
- logging.info("...done")
-
- result = abi.unpack_results(allocated_results)
+ result = finv(arg0)
logging.info("result: %s", result)
np.testing.assert_allclose(result, [[1., 2.], [3., 4.]])
- def test_synchronous_invoke_function(self):
+ def test_synchronous_invoke_function_new_abi(self):
+ # TODO: Enable with new ABI.
+ return
m = create_simple_static_mul_module()
instance = iree.runtime.VmInstance()
context = iree.runtime.VmContext(instance, modules=[self.hal_module, m])
f = m.lookup_function("simple_mul")
- abi = context.create_function_abi(self.device, self.htf, f)
- logging.info("abi: %s", abi)
-
+ finv = iree.runtime.FunctionInvoker(context, self.device, f)
arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
- inputs = abi.pack_inputs(arg0, arg1)
- logging.info("Serialized inputs: %s", abi.serialize_vm_list(inputs))
- logging.info("inputs: %s", inputs)
-
- allocated_results = abi.allocate_results(inputs, static_alloc=False)
- logging.info("allocated_results: %s", allocated_results)
- logging.info("Invoking...")
- context.invoke(f, inputs, allocated_results)
- logging.info("...done")
-
- result = abi.unpack_results(allocated_results)
+ result = finv(arg0, arg1)
logging.info("result: %s", result)
np.testing.assert_allclose(result, [4., 10., 18., 28.])
diff --git a/integrations/tensorflow/bindings/python/iree/tf/support/module_utils.py b/integrations/tensorflow/bindings/python/iree/tf/support/module_utils.py
index 0ea146e..81e8752 100644
--- a/integrations/tensorflow/bindings/python/iree/tf/support/module_utils.py
+++ b/integrations/tensorflow/bindings/python/iree/tf/support/module_utils.py
@@ -283,8 +283,7 @@
class _IreeFunctionWrapper(_FunctionWrapper):
"""Wraps an IREE function, making it callable."""
- def __init__(self, context: iree.runtime.SystemContext,
- f: iree.runtime.system_api.BoundFunction):
+ def __init__(self, context: iree.runtime.SystemContext, f):
self._context = context
self._f = f