Some python binding plumbing to enable the new VM2 module loading. * Just adds enough of the new hal types to get a default device. * Restricted to vulkan for the moment. * Verifies that with the new hal module, the VM2 compiled simple_mul module loads. * Disabled vm_test for TAP due to missing vulkan deps on test hosts. PiperOrigin-RevId: 286590561
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD index e3d710a..d95dd48 100644 --- a/bindings/python/pyiree/BUILD +++ b/bindings/python/pyiree/BUILD
@@ -127,6 +127,7 @@ "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "//iree/compiler/Utils", + "//iree/modules/hal", "//iree/vm2", "//iree/vm2:bytecode_module", "@local_config_mlir//:IR", @@ -210,6 +211,8 @@ name = "vm_test", srcs = ["vm_test.py"], python_version = "PY3", + # TODO(laurenzo): Enable once test does not depend on a real vulkan device. + tags = ["notap"], deps = NUMPY_DEPS + [ "//bindings/python:pathsetup", # build_cleaner: keep "@absl_py//absl/testing:absltest",
diff --git a/bindings/python/pyiree/hal.cc b/bindings/python/pyiree/hal.cc index 413224a..a78f46b 100644 --- a/bindings/python/pyiree/hal.cc +++ b/bindings/python/pyiree/hal.cc
@@ -86,6 +86,43 @@ } // namespace +//------------------------------------------------------------------------------ +// HalDriver +//------------------------------------------------------------------------------ + +std::vector<std::string> HalDriver::Query() { + iree_string_view_t* driver_names; + iree_host_size_t driver_count; + CheckApiStatus(iree_hal_driver_registry_query_available_drivers( + IREE_ALLOCATOR_SYSTEM, &driver_names, &driver_count), + "Error querying drivers"); + + std::vector<std::string> drivers; + drivers.resize(driver_count); + for (iree_host_size_t i = 0; i < driver_count; ++i) { + drivers[i] = std::string(driver_names[i].data, driver_names[i].size); + } + free(driver_names); + return drivers; +} + +HalDriver HalDriver::Create(const std::string& driver_name) { + iree_hal_driver_t* driver; + CheckApiStatus(iree_hal_driver_registry_create_driver( + {driver_name.data(), driver_name.size()}, + IREE_ALLOCATOR_SYSTEM, &driver), + "Error creating driver"); + return HalDriver::CreateRetained(driver); +} + +HalDevice HalDriver::CreateDefaultDevice() { + iree_hal_device_t* device; + CheckApiStatus(iree_hal_driver_create_default_device( + raw_ptr(), IREE_ALLOCATOR_SYSTEM, &device), + "Error creating default device"); + return HalDevice::CreateRetained(device); +} + void SetupHalBindings(pybind11::module m) { // Enums. py::enum_<iree_hal_memory_type_t>(m, "MemoryType") @@ -115,6 +152,12 @@ .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL) .export_values(); + py::class_<HalDevice>(m, "HalDevice"); + py::class_<HalDriver>(m, "HalDriver") + .def_static("query", &HalDriver::Query) + .def_static("create", &HalDriver::Create, py::arg("driver_name")) + .def("create_default_device", &HalDriver::CreateDefaultDevice); + py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector)); py::class_<HalBufferView>(m, "BufferView") .def("map", HalMappedMemory::Create);
diff --git a/bindings/python/pyiree/hal.h b/bindings/python/pyiree/hal.h index 38e2c3b..2f7cfdb 100644 --- a/bindings/python/pyiree/hal.h +++ b/bindings/python/pyiree/hal.h
@@ -22,6 +22,22 @@ namespace iree { namespace python { +//------------------------------------------------------------------------------ +// Retain/release bindings +//------------------------------------------------------------------------------ + +template <> +struct ApiPtrAdapter<iree_hal_driver_t> { + static void Retain(iree_hal_driver_t* d) { iree_hal_driver_retain(d); } + static void Release(iree_hal_driver_t* d) { iree_hal_driver_release(d); } +}; + +template <> +struct ApiPtrAdapter<iree_hal_device_t> { + static void Retain(iree_hal_device_t* d) { iree_hal_device_retain(d); } + static void Release(iree_hal_device_t* d) { iree_hal_device_release(d); } +}; + template <> struct ApiPtrAdapter<iree_hal_buffer_t> { static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); } @@ -38,6 +54,22 @@ } }; +//------------------------------------------------------------------------------ +// ApiRefCounted types +//------------------------------------------------------------------------------ + +class HalDevice : public ApiRefCounted<HalDevice, iree_hal_device_t> { + public: +}; + +class HalDriver : public ApiRefCounted<HalDriver, iree_hal_driver_t> { + public: + static std::vector<std::string> Query(); + static HalDriver Create(const std::string& driver_name); + + HalDevice CreateDefaultDevice(); +}; + struct HalShape { public: static HalShape FromIntVector(std::vector<int32_t> indices) {
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc index 72af468..c2e59fc 100644 --- a/bindings/python/pyiree/vm.cc +++ b/bindings/python/pyiree/vm.cc
@@ -17,10 +17,13 @@ #include "absl/types/optional.h" #include "bindings/python/pyiree/status_utils.h" #include "iree/base/api.h" +#include "iree/modules/hal/hal_module.h" namespace iree { namespace python { +namespace { + RtModule CreateModuleFromBlob(std::shared_ptr<OpaqueBlob> blob) { iree_rt_module_t* module; auto free_fn = OpaqueBlob::CreateFreeFn(blob); @@ -31,6 +34,16 @@ return RtModule::CreateRetained(module); } +VmModule CreateHalModule(HalDevice* device) { + iree_vm_module_t* module; + CheckApiStatus( + iree_hal_module_create(device->raw_ptr(), IREE_ALLOCATOR_SYSTEM, &module), + "Error creating hal module"); + return VmModule::CreateRetained(module); +} + +} // namespace + //------------------------------------------------------------------------------ // VmInstance //------------------------------------------------------------------------------ @@ -115,9 +128,15 @@ } void SetupVmBindings(pybind11::module m) { + CHECK_EQ(IREE_STATUS_OK, iree_vm_register_builtin_types()); + CHECK_EQ(IREE_STATUS_OK, iree_hal_module_register_types()); + // Deprecated: VM1 module. m.def("create_module_from_blob", CreateModuleFromBlob); + // Built-in module creation. + m.def("create_hal_module", &CreateHalModule); + py::enum_<iree_vm_function_linkage_t>(m, "Linkage") .value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL) .value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
diff --git a/bindings/python/pyiree/vm_test.py b/bindings/python/pyiree/vm_test.py index 338d64d..3302907 100644 --- a/bindings/python/pyiree/vm_test.py +++ b/bindings/python/pyiree/vm_test.py
@@ -14,7 +14,6 @@ # limitations under the License. # pylint: disable=unused-variable -# pylint: disable=g-unreachable-test-method from absl.testing import absltest import pyiree @@ -36,6 +35,15 @@ class VmTest(absltest.TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + driver_names = pyiree.binding.hal.HalDriver.query() + print("DRIVER_NAMES =", driver_names) + cls.driver = pyiree.binding.hal.HalDriver.create("vulkan") + cls.device = cls.driver.create_default_device() + cls.hal_module = pyiree.binding.vm.create_hal_module(cls.device) + def test_variant_list(self): l = pyiree.binding.vm.VmVariantList(5) print(l) @@ -47,25 +55,26 @@ context2 = pyiree.binding.vm.VmContext(instance) self.assertGreater(context2.context_id, context1.context_id) - def disabled_test_module_basics(self): + def test_module_basics(self): m = create_simple_mul_module() f = m.lookup_function("simple_mul") self.assertGreater(f.ordinal, 0) notfound = m.lookup_function("notfound") self.assertIs(notfound, None) - def disabled_test_dynamic_module_context(self): + def test_dynamic_module_context(self): instance = pyiree.binding.vm.VmInstance() context = pyiree.binding.vm.VmContext(instance) m = create_simple_mul_module() - context.register_modules([m]) + context.register_modules([self.hal_module, m]) - def disabled_test_static_module_context(self): + def test_static_module_context(self): m = create_simple_mul_module() print(m) instance = pyiree.binding.vm.VmInstance() print(instance) - context = pyiree.binding.vm.VmContext(instance, modules=[m]) + context = pyiree.binding.vm.VmContext( + instance, modules=[self.hal_module, m]) print(context)