Some python binding plumbing to enable the new VM2 module loading.
* Just adds enough of the new hal types to get a default device.
* Restricted to vulkan for the moment.
* Verifies that with the new hal module, the VM2 compiled simple_mul module loads.
* Disabled vm_test for TAP due to missing vulkan deps on test hosts.
PiperOrigin-RevId: 286590561
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
index e3d710a..d95dd48 100644
--- a/bindings/python/pyiree/BUILD
+++ b/bindings/python/pyiree/BUILD
@@ -127,6 +127,7 @@
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"//iree/compiler/Utils",
+ "//iree/modules/hal",
"//iree/vm2",
"//iree/vm2:bytecode_module",
"@local_config_mlir//:IR",
@@ -210,6 +211,8 @@
name = "vm_test",
srcs = ["vm_test.py"],
python_version = "PY3",
+ # TODO(laurenzo): Enable once test does not depend on a real vulkan device.
+ tags = ["notap"],
deps = NUMPY_DEPS + [
"//bindings/python:pathsetup", # build_cleaner: keep
"@absl_py//absl/testing:absltest",
diff --git a/bindings/python/pyiree/hal.cc b/bindings/python/pyiree/hal.cc
index 413224a..a78f46b 100644
--- a/bindings/python/pyiree/hal.cc
+++ b/bindings/python/pyiree/hal.cc
@@ -86,6 +86,43 @@
} // namespace
+//------------------------------------------------------------------------------
+// HalDriver
+//------------------------------------------------------------------------------
+
+std::vector<std::string> HalDriver::Query() {
+ iree_string_view_t* driver_names;
+ iree_host_size_t driver_count;
+ CheckApiStatus(iree_hal_driver_registry_query_available_drivers(
+ IREE_ALLOCATOR_SYSTEM, &driver_names, &driver_count),
+ "Error querying drivers");
+
+ std::vector<std::string> drivers;
+ drivers.resize(driver_count);
+ for (iree_host_size_t i = 0; i < driver_count; ++i) {
+ drivers[i] = std::string(driver_names[i].data, driver_names[i].size);
+ }
+ free(driver_names);
+ return drivers;
+}
+
+HalDriver HalDriver::Create(const std::string& driver_name) {
+ iree_hal_driver_t* driver;
+ CheckApiStatus(iree_hal_driver_registry_create_driver(
+ {driver_name.data(), driver_name.size()},
+ IREE_ALLOCATOR_SYSTEM, &driver),
+ "Error creating driver");
+ return HalDriver::CreateRetained(driver);
+}
+
+HalDevice HalDriver::CreateDefaultDevice() {
+ iree_hal_device_t* device;
+ CheckApiStatus(iree_hal_driver_create_default_device(
+ raw_ptr(), IREE_ALLOCATOR_SYSTEM, &device),
+ "Error creating default device");
+ return HalDevice::CreateRetained(device);
+}
+
void SetupHalBindings(pybind11::module m) {
// Enums.
py::enum_<iree_hal_memory_type_t>(m, "MemoryType")
@@ -115,6 +152,12 @@
.value("ALL", IREE_HAL_MEMORY_ACCESS_ALL)
.export_values();
+ py::class_<HalDevice>(m, "HalDevice");
+ py::class_<HalDriver>(m, "HalDriver")
+ .def_static("query", &HalDriver::Query)
+ .def_static("create", &HalDriver::Create, py::arg("driver_name"))
+ .def("create_default_device", &HalDriver::CreateDefaultDevice);
+
py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector));
py::class_<HalBufferView>(m, "BufferView")
.def("map", HalMappedMemory::Create);
diff --git a/bindings/python/pyiree/hal.h b/bindings/python/pyiree/hal.h
index 38e2c3b..2f7cfdb 100644
--- a/bindings/python/pyiree/hal.h
+++ b/bindings/python/pyiree/hal.h
@@ -22,6 +22,22 @@
namespace iree {
namespace python {
+//------------------------------------------------------------------------------
+// Retain/release bindings
+//------------------------------------------------------------------------------
+
+template <>
+struct ApiPtrAdapter<iree_hal_driver_t> {
+ static void Retain(iree_hal_driver_t* d) { iree_hal_driver_retain(d); }
+ static void Release(iree_hal_driver_t* d) { iree_hal_driver_release(d); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_hal_device_t> {
+ static void Retain(iree_hal_device_t* d) { iree_hal_device_retain(d); }
+ static void Release(iree_hal_device_t* d) { iree_hal_device_release(d); }
+};
+
template <>
struct ApiPtrAdapter<iree_hal_buffer_t> {
static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); }
@@ -38,6 +54,22 @@
}
};
+//------------------------------------------------------------------------------
+// ApiRefCounted types
+//------------------------------------------------------------------------------
+
+class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> {
+ public:
+};
+
+class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> {
+ public:
+ static std::vector<std::string> Query();
+ static HalDriver Create(const std::string& driver_name);
+
+ HalDevice CreateDefaultDevice();
+};
+
struct HalShape {
public:
static HalShape FromIntVector(std::vector<int32_t> indices) {
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc
index 72af468..c2e59fc 100644
--- a/bindings/python/pyiree/vm.cc
+++ b/bindings/python/pyiree/vm.cc
@@ -17,10 +17,13 @@
#include "absl/types/optional.h"
#include "bindings/python/pyiree/status_utils.h"
#include "iree/base/api.h"
+#include "iree/modules/hal/hal_module.h"
namespace iree {
namespace python {
+namespace {
+
RtModule CreateModuleFromBlob(std::shared_ptr<OpaqueBlob> blob) {
iree_rt_module_t* module;
auto free_fn = OpaqueBlob::CreateFreeFn(blob);
@@ -31,6 +34,16 @@
return RtModule::CreateRetained(module);
}
+VmModule CreateHalModule(HalDevice* device) {
+ iree_vm_module_t* module;
+ CheckApiStatus(
+ iree_hal_module_create(device->raw_ptr(), IREE_ALLOCATOR_SYSTEM, &module),
+ "Error creating hal module");
+ return VmModule::CreateRetained(module);
+}
+
+} // namespace
+
//------------------------------------------------------------------------------
// VmInstance
//------------------------------------------------------------------------------
@@ -115,9 +128,15 @@
}
void SetupVmBindings(pybind11::module m) {
+ CHECK_EQ(IREE_STATUS_OK, iree_vm_register_builtin_types());
+ CHECK_EQ(IREE_STATUS_OK, iree_hal_module_register_types());
+
// Deprecated: VM1 module.
m.def("create_module_from_blob", CreateModuleFromBlob);
+ // Built-in module creation.
+ m.def("create_hal_module", &CreateHalModule);
+
py::enum_<iree_vm_function_linkage_t>(m, "Linkage")
.value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
.value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
diff --git a/bindings/python/pyiree/vm_test.py b/bindings/python/pyiree/vm_test.py
index 338d64d..3302907 100644
--- a/bindings/python/pyiree/vm_test.py
+++ b/bindings/python/pyiree/vm_test.py
@@ -14,7 +14,6 @@
# limitations under the License.
# pylint: disable=unused-variable
-# pylint: disable=g-unreachable-test-method
from absl.testing import absltest
import pyiree
@@ -36,6 +35,15 @@
class VmTest(absltest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ driver_names = pyiree.binding.hal.HalDriver.query()
+ print("DRIVER_NAMES =", driver_names)
+ cls.driver = pyiree.binding.hal.HalDriver.create("vulkan")
+ cls.device = cls.driver.create_default_device()
+ cls.hal_module = pyiree.binding.vm.create_hal_module(cls.device)
+
def test_variant_list(self):
l = pyiree.binding.vm.VmVariantList(5)
print(l)
@@ -47,25 +55,26 @@
context2 = pyiree.binding.vm.VmContext(instance)
self.assertGreater(context2.context_id, context1.context_id)
- def disabled_test_module_basics(self):
+ def test_module_basics(self):
m = create_simple_mul_module()
f = m.lookup_function("simple_mul")
self.assertGreater(f.ordinal, 0)
notfound = m.lookup_function("notfound")
self.assertIs(notfound, None)
- def disabled_test_dynamic_module_context(self):
+ def test_dynamic_module_context(self):
instance = pyiree.binding.vm.VmInstance()
context = pyiree.binding.vm.VmContext(instance)
m = create_simple_mul_module()
- context.register_modules([m])
+ context.register_modules([self.hal_module, m])
- def disabled_test_static_module_context(self):
+ def test_static_module_context(self):
m = create_simple_mul_module()
print(m)
instance = pyiree.binding.vm.VmInstance()
print(instance)
- context = pyiree.binding.vm.VmContext(instance, modules=[m])
+ context = pyiree.binding.vm.VmContext(
+ instance, modules=[self.hal_module, m])
print(context)