Rollup of minor runtime fixes/cleanup from the AMDGPU branch. (#19621)
diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index 7171c2e..f0b7b37 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -118,6 +118,7 @@
"iree/runtime/io.py"
"iree/runtime/system_api.py"
"iree/runtime/system_setup.py"
+ "iree/runtime/typing.py"
"iree/runtime/version.py"
"iree/_runtime/__init__.py"
"iree/_runtime/libs.py"
@@ -231,13 +232,6 @@
iree_py_test(
NAME
- hal_test
- SRCS
- "tests/hal_test.py"
-)
-
-iree_py_test(
- NAME
io_test
SRCS
"tests/io_test.py"
@@ -264,6 +258,13 @@
iree_py_test(
NAME
+ hal_test
+ SRCS
+ "tests/hal_test.py"
+ )
+
+ iree_py_test(
+ NAME
io_runtime_test
SRCS
"tests/io_runtime_test.py"
diff --git a/runtime/bindings/python/binding.h b/runtime/bindings/python/binding.h
index 8a8ffc1..7de2c10 100644
--- a/runtime/bindings/python/binding.h
+++ b/runtime/bindings/python/binding.h
@@ -45,7 +45,9 @@
public:
using RawPtrType = T*;
ApiRefCounted() : instance_(nullptr) {}
- ApiRefCounted(ApiRefCounted& other) : instance_(other.instance_) { Retain(); }
+ ApiRefCounted(const ApiRefCounted& other) : instance_(other.instance_) {
+ Retain();
+ }
ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
other.instance_ = nullptr;
}
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index a25a4ad..1450be7 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -6,15 +6,20 @@
#include "./hal.h"
+#include <nanobind/intrusive/ref.h>
#include <nanobind/nanobind.h>
+#include <nanobind/stl/function.h>
#include <nanobind/stl/vector.h>
+#include <algorithm>
+#include <iterator>
#include <optional>
#include "./local_dlpack.h"
#include "./numpy_interop.h"
#include "./vm.h"
#include "iree/base/internal/path.h"
+#include "iree/base/status.h"
#include "iree/hal/api.h"
#include "iree/hal/utils/allocators.h"
#include "iree/modules/hal/module.h"
@@ -1069,8 +1074,10 @@
// HAL module
//------------------------------------------------------------------------------
-VmModule CreateHalModule(VmInstance* instance, std::optional<HalDevice*> device,
- std::optional<py::list> devices) {
+VmModule CreateHalModule(
+ VmInstance* instance, std::optional<HalDevice*> device,
+ std::optional<py::list> devices,
+ std::optional<py::ref<HalModuleDebugSink>> debug_sink) {
if (device && devices) {
PyErr_SetString(
PyExc_ValueError,
@@ -1095,13 +1102,114 @@
devices_ptr = devices_vector.data();
device_count = devices_vector.size();
}
- CheckApiStatus(
- iree_hal_module_create(instance->raw_ptr(), device_count, devices_ptr,
- IREE_HAL_MODULE_FLAG_NONE,
- iree_hal_module_debug_sink_stdio(stderr),
- iree_allocator_system(), &module),
- "Error creating hal module");
- return VmModule::StealFromRawPtr(module);
+
+ iree_hal_module_debug_sink_t iree_hal_module_debug_sink =
+ iree_hal_module_debug_sink_stdio(stderr);
+ if (debug_sink) {
+ iree_hal_module_debug_sink = (*debug_sink)->AsIreeHalModuleDebugSink();
+ }
+
+ CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device_count,
+ devices_ptr, IREE_HAL_MODULE_FLAG_NONE,
+ iree_hal_module_debug_sink,
+ iree_allocator_system(), &module),
+ "Error creating hal module");
+ VmModule vm_module = VmModule::StealFromRawPtr(module);
+ if (debug_sink) {
+ // Retain a reference. We want the callback to be valid after
+ // the user has dropped its reference to the HAL module Python object and
+ // not burden the user with lifetime management.
+ // The counter will be decremented once the IREE runtime does not use the
+ // debug sink anymore.
+ (*debug_sink)->inc_ref();
+ }
+ return vm_module;
+}
+
+HalModuleDebugSink::HalModuleDebugSink(
+ HalModuleBufferViewTraceCallback buffer_view_trace_callback)
+ : buffer_view_trace_callback_(buffer_view_trace_callback) {}
+
+iree_hal_module_debug_sink_t HalModuleDebugSink::AsIreeHalModuleDebugSink()
+ const {
+ iree_hal_module_debug_sink_t res;
+ memset(&res, 0, sizeof(res));
+ res.buffer_view_trace.fn = HalModuleDebugSink::IreeHalModuleBufferViewTrace;
+ res.buffer_view_trace.user_data = const_cast<HalModuleDebugSink*>(this);
+ res.destroy.fn = HalModuleDebugSink::DestroyCallback;
+ res.destroy.user_data = const_cast<HalModuleDebugSink*>(this);
+ return res;
+}
+
+HalModuleBufferViewTraceCallback&
+HalModuleDebugSink::GetHalModuleBufferViewTraceCallback() {
+ return this->buffer_view_trace_callback_;
+}
+
+static std::vector<HalBufferView> CreateHalBufferViewVector(
+ iree_host_size_t buffer_view_count, iree_hal_buffer_view_t** buffer_views) {
+ std::vector<HalBufferView> res;
+ res.reserve(buffer_view_count);
+ std::transform(buffer_views, buffer_views + buffer_view_count,
+ std::back_inserter(res),
+ [](iree_hal_buffer_view_t* buffer_view) {
+ return HalBufferView::BorrowFromRawPtr(buffer_view);
+ });
+ return res;
+}
+
+iree_status_t HalModuleDebugSink::DestroyCallback(void* user_data) {
+ HalModuleDebugSink* debug_sink =
+ reinterpret_cast<HalModuleDebugSink*>(user_data);
+ debug_sink->dec_ref();
+ return iree_ok_status();
+}
+
+iree_status_t HalModuleDebugSink::IreeHalModuleBufferViewTrace(
+ void* user_data, iree_string_view_t key, iree_host_size_t buffer_view_count,
+ iree_hal_buffer_view_t** buffer_views, iree_allocator_t host_allocator) {
+ auto debug_sink = reinterpret_cast<HalModuleDebugSink*>(user_data);
+ std::vector<HalBufferView> buffer_views_vec =
+ CreateHalBufferViewVector(buffer_view_count, buffer_views);
+ try {
+ debug_sink->buffer_view_trace_callback_(std::string(key.data, key.size),
+ buffer_views_vec);
+ } catch (const py::python_error& e) {
+ return iree_make_status(IREE_STATUS_UNKNOWN, "%s", e.what());
+ }
+
+ return iree_ok_status();
+}
+
+static int HalModuleDebugSinkTpTraverse(PyObject* self, visitproc visit,
+ void* arg) {
+ // Inform Python's garbage collector about the references we hold.
+
+ // Retrieve a pointer to the C++ instance associated with 'self'
+ // (never fails)
+ HalModuleDebugSink* debug_sink = py::inst_ptr<HalModuleDebugSink>(self);
+
+ // Although we are not tracking cycles involving the HAL module or VM context
+ // we still want to properly destroy the callback and let the GC know what
+ // references we hold. If debug_sink->GetHalModuleBufferViewTraceCallback()
+ // has an associated CPython object, return it. If not, value.ptr() will equal
+ // NULL, which is also fine.
+ py::handle buffer_view_trace_callback =
+ py::find(debug_sink->GetHalModuleBufferViewTraceCallback());
+
+ // Inform the Python GC about the instance.
+ Py_VISIT(buffer_view_trace_callback.ptr());
+
+ return 0;
+}
+
+int HalModuleDebugSinkTpClear(PyObject* self) {
+ // Retrieve a pointer to the C++ instance associated with 'self'
+ // (never fails)
+ HalModuleDebugSink* debug_sink = py::inst_ptr<HalModuleDebugSink>(self);
+ debug_sink->GetHalModuleBufferViewTraceCallback() = nullptr;
+
+ return 0;
}
//------------------------------------------------------------------------------
@@ -1113,7 +1221,8 @@
// Built-in module creation.
m.def("create_hal_module", &CreateHalModule, py::arg("instance"),
- py::arg("device") = py::none(), py::arg("devices") = py::none());
+ py::arg("device") = py::none(), py::arg("devices") = py::none(),
+ py::arg("debug_sink") = py::none());
// Enums.
py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType")
@@ -1777,6 +1886,27 @@
py::arg("target_buffer"), py::arg("pattern"),
py::arg("target_offset") = 0, py::arg("length") = py::none(),
py::arg("end") = false);
+
+ PyType_Slot debug_sink_slots[] = {
+ {Py_tp_traverse, (void*)HalModuleDebugSinkTpTraverse},
+ {Py_tp_clear, (void*)HalModuleDebugSinkTpClear},
+ {0, nullptr}};
+ py::class_<HalModuleDebugSink>(
+ m, "HalModuleDebugSink", py::type_slots(debug_sink_slots),
+ py::intrusive_ptr<HalModuleDebugSink>(
+ [](HalModuleDebugSink* debug_sink, PyObject* po) noexcept {
+ debug_sink->set_self_py(po);
+ }))
+ .def(
+ "__init__",
+ [](HalModuleDebugSink* self,
+ HalModuleBufferViewTraceCallback buffer_view_trace_callback) {
+ new (self) HalModuleDebugSink(buffer_view_trace_callback);
+ },
+ py::arg("buffer_view_trace_callback"))
+ .def_prop_ro("buffer_view_trace_callback", [](HalModuleDebugSink& self) {
+ return self.GetHalModuleBufferViewTraceCallback();
+ });
}
} // namespace python
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index 7dbc108..5ea9e70 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -7,6 +7,9 @@
#ifndef IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
#define IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
+#include <nanobind/intrusive/counter.h>
+
+#include <functional>
#include <vector>
#include "./binding.h"
@@ -14,6 +17,7 @@
#include "./vm.h"
#include "iree/base/string_view.h"
#include "iree/hal/api.h"
+#include "iree/modules/hal/debugging.h"
namespace iree {
namespace python {
@@ -289,6 +293,33 @@
class HalCommandBuffer
: public ApiRefCounted<HalCommandBuffer, iree_hal_command_buffer_t> {};
+using HalModuleBufferViewTraceCallback =
+ std::function<void(const std::string&, const std::vector<HalBufferView>&)>;
+
+// HAL debug sinks need ot live as long as the HAL module. This means the
+// underlying native object, not just the HAL module Python object.
+// This is necessary since here we hold a reference to a callback to a Python
+// function. This function needs to live after the destruction of the HAL module
+// Python object if it is registered into the VM context.
+// The HAL module and VM context Python objects are owners of the debug sink.
+class HalModuleDebugSink : public py::intrusive_base {
+ public:
+ HalModuleDebugSink(
+ HalModuleBufferViewTraceCallback buffer_view_trace_callback);
+ iree_hal_module_debug_sink_t AsIreeHalModuleDebugSink() const;
+ HalModuleBufferViewTraceCallback& GetHalModuleBufferViewTraceCallback();
+
+ private:
+ HalModuleBufferViewTraceCallback buffer_view_trace_callback_;
+
+ static iree_status_t DestroyCallback(void* user_data);
+
+ static iree_status_t IreeHalModuleBufferViewTrace(
+ void* user_data, iree_string_view_t key,
+ iree_host_size_t buffer_view_count, iree_hal_buffer_view_t** buffer_views,
+ iree_allocator_t host_allocator);
+};
+
void SetupHalBindings(nanobind::module_ m);
} // namespace python
diff --git a/runtime/bindings/python/initialize_module.cc b/runtime/bindings/python/initialize_module.cc
index c79da46..48d0e29 100644
--- a/runtime/bindings/python/initialize_module.cc
+++ b/runtime/bindings/python/initialize_module.cc
@@ -4,7 +4,10 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include <nanobind/intrusive/counter.h>
+
#include <memory>
+#include <nanobind/intrusive/counter.inl>
#include "./binding.h"
#include "./hal.h"
@@ -78,6 +81,16 @@
});
m.def("disable_leak_checker", []() { py::set_leak_warnings(false); });
+
+ py::intrusive_init(
+ [](PyObject *o) noexcept {
+ py::gil_scoped_acquire guard;
+ Py_INCREF(o);
+ },
+ [](PyObject *o) noexcept {
+ py::gil_scoped_acquire guard;
+ Py_DECREF(o);
+ });
}
} // namespace python
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index 0163a6d..8cec280 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -56,6 +56,10 @@
VmRef,
)
+# Debug imports
+from ._binding import HalModuleDebugSink
+from .typing import HalModuleBufferViewTraceCallback
+
from .array_interop import *
from .benchmark import *
from .system_api import *
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index b4ef2ba..499e0c6 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -1,13 +1,23 @@
-from typing import Any, Callable, ClassVar, List, Optional, Sequence, Tuple, Union
-
-from typing import overload
-
+from typing import (
+ Any,
+ Callable,
+ ClassVar,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ overload,
+)
import asyncio
+from .typing import HalModuleBufferViewTraceCallback
+
def create_hal_module(
instance: VmInstance,
device: Optional[HalDevice] = None,
devices: Optional[List[HalDevice]] = None,
+ debug_sink: Optional[HalModuleDebugSink] = None,
) -> VmModule: ...
def create_io_parameters_module(
instance: VmInstance, *providers: ParameterProvider
@@ -310,6 +320,36 @@
deadline: Optional[int] = None,
) -> None: ...
+class HalModuleDebugSink:
+ def __init__(
+ self, buffer_view_trace_callback: Optional[HalModuleBufferViewTraceCallback]
+ ):
+ """The function object buffer_view_trace_callback must not include the
+ corresponding HAL VmModule, VmInstance or VmContext in its closure.
+ Native runtime objects are managed by reference counting and do not track
+ cyclic references. This will create an uncollectible cycle.
+ E.g.
+
+ ```
+ vm_context = ...
+ def callback(key, hal_buffer_view):
+ print(vm_context)
+ hal_module = iree.runtime.create_hal_module(
+ vm_instance,
+ device,
+ debug_sink=iree.runtime.HalModuleDebugSink(callback),
+ )
+ ```
+
+ This callback will cause the VM context to never be destroyed.
+ """
+
+ ...
+ @property
+ def buffer_view_trace_callback(
+ self,
+ ) -> Optional[HalModuleBufferViewTraceCallback]: ...
+
class Linkage(int):
EXPORT: ClassVar[Linkage] = ...
IMPORT: ClassVar[Linkage] = ...
diff --git a/runtime/bindings/python/iree/runtime/typing.py b/runtime/bindings/python/iree/runtime/typing.py
new file mode 100644
index 0000000..738184a
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/typing.py
@@ -0,0 +1,25 @@
+# Copyright 2025 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Type hints."""
+
+from typing import Callable, List, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from . import HalBufferView
+
+TraceKey = str
+HalModuleBufferViewTraceCallback = Callable[[TraceKey, List["HalBufferView"]], None]
+"""Tracing function for buffers to pass to the runtime.
+This allows custom behavior when executing an IREE module with tensor tracing
+instructions. MLIR e.g.
+
+```
+flow.tensor.trace "MyTensors" [
+ %tensor1 : tensor<1xf32>,
+ %tensor2 : tensor<2xf32>
+]
+```
+"""
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index 348a628..0a141a3 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -4,7 +4,9 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import List
import iree.runtime
+import iree.compiler
import gc
import numpy as np
@@ -633,5 +635,147 @@
)
+class HalModuleDebugSinkTest(unittest.TestCase):
+ COMPILED_TRACE_TENSOR: bytes
+
+ @classmethod
+ def compile_trace_tensor(cls):
+ if not hasattr(cls, "COMPILED_TRACE_TENSOR"):
+ cls.COMPILED_TRACE_TENSOR = iree.compiler.compile_str(
+ """
+ func.func @trace_args(%arg0: tensor<2xi32>, %arg1: tensor<3xi32>) {
+ flow.tensor.trace "debug_sink_test" = [
+ %arg0: tensor<2xi32>,
+ %arg1: tensor<3xi32>
+ ]
+ return
+ }
+ """,
+ target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
+ )
+ return cls.COMPILED_TRACE_TENSOR
+
+ def testHalModuleBufferViewTraceCallback(self):
+ """Check that the trace tensor callback gets called with the expected
+ arguments."""
+ program_bytes = HalModuleDebugSinkTest.compile_trace_tensor()
+
+ arg0 = np.array([1, 2], dtype=np.int32)
+ arg1 = np.array([3, 4, 5], dtype=np.int32)
+
+ callback_key: str = None
+ callback_buffer_views = None
+
+ def callback(key: str, buffer_views: List[iree.runtime.HalBufferView]):
+ nonlocal callback_key
+ callback_key = key
+ nonlocal callback_buffer_views
+ callback_buffer_views = buffer_views
+
+ instance = iree.runtime.VmInstance()
+ device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+ hal_module = iree.runtime.create_hal_module(
+ instance, device, debug_sink=iree.runtime.HalModuleDebugSink(callback)
+ )
+ program_module = iree.runtime.VmModule.copy_buffer(instance, program_bytes)
+ context = iree.runtime.VmContext(instance)
+ context.register_modules([hal_module, program_module])
+ fn = program_module.lookup_function("trace_args")
+ fn_invoker = iree.runtime.FunctionInvoker(context, device, fn)
+ fn_invoker(arg0, arg1)
+
+ assert callback_key == "debug_sink_test"
+ assert len(callback_buffer_views) == 2
+ actual_arg0 = iree.runtime.DeviceArray(
+ device, callback_buffer_views[0]
+ ).to_host()
+ actual_arg1 = iree.runtime.DeviceArray(
+ device, callback_buffer_views[1]
+ ).to_host()
+ np.testing.assert_equal(actual_arg0, arg0)
+ np.testing.assert_equal(actual_arg1, arg1)
+
+ def testNoneHalModuleDebugSink(self):
+ device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+ instance = iree.runtime.VmInstance()
+ hal_module = iree.runtime.create_hal_module(
+ instance,
+ device,
+ debug_sink=None,
+ )
+
+ def testExceptionInHalModuleBufferViewTraceCallback(self):
+ """When an exception occurs in the callback check that it properly propagates
+ through the bindings and results in a IREE module function failed invocation.
+ """
+ program_bytes = HalModuleDebugSinkTest.compile_trace_tensor()
+
+ arg0 = np.array([1, 2], dtype=np.int32)
+ arg1 = np.array([3, 4, 5], dtype=np.int32)
+
+ device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+
+ class TestException(Exception):
+ def __init__(self, msg: str):
+ super().__init__(msg)
+
+ def callback(key: str, buffer_views: List[iree.runtime.HalBufferView]):
+ raise TestException("This is a test exception")
+
+ instance = iree.runtime.VmInstance()
+ hal_module = iree.runtime.create_hal_module(
+ instance, device, debug_sink=iree.runtime.HalModuleDebugSink(callback)
+ )
+ program_module = iree.runtime.VmModule.copy_buffer(instance, program_bytes)
+ context = iree.runtime.VmContext(instance)
+ context.register_modules([hal_module, program_module])
+ fn = program_module.lookup_function("trace_args")
+ fn_invoker = iree.runtime.FunctionInvoker(context, device, fn)
+ # TODO: once IREE status chains messages test for the actual message we raise
+ # within the callback.
+ self.assertRaisesRegex(RuntimeError, "UNKNOWN", fn_invoker, arg0, arg1)
+
+ def testHalModuleBufferViewTraceCallbackReferencingItselfDoesNotLeak(self):
+ """Check that if we do not hold reference to the HAL module or VM context,
+ but we hold a reference to the debug sink in the callback, the callback object
+ does not leak.
+ """
+ is_callback_destroyed: bool = False
+
+ class Callback:
+ def __del__(self):
+ nonlocal is_callback_destroyed
+ is_callback_destroyed = True
+
+ def __call__(
+ self, key: str, buffer_views: List[iree.runtime.HalBufferView]
+ ):
+ pass
+
+ callback = Callback()
+ debug_sink = iree.runtime.HalModuleDebugSink(callback)
+ setattr(callback, "debug_sink", debug_sink)
+
+ device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+
+ vm_instance = iree.runtime.VmInstance()
+ hal_module = iree.runtime.create_hal_module(
+ vm_instance,
+ device,
+ debug_sink=debug_sink,
+ )
+ vm_context = iree.runtime.VmContext(vm_instance)
+ vm_context.register_modules([hal_module])
+ assert not is_callback_destroyed
+
+ del callback
+ del debug_sink
+ del hal_module
+ del vm_instance
+ del vm_context
+ gc.collect()
+ assert is_callback_destroyed
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 0be8912..072e4fc 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -20,6 +20,7 @@
#include "iree/modules/hal/module.h"
#include "iree/tooling/modules/resolver.h"
#include "iree/vm/api.h"
+#include "nanobind/nanobind.h"
using namespace nanobind::literals;
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index d5e1c7b..ea1dab0 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -208,6 +208,9 @@
return Invoke(function_name);
}
+ static iree_hal_device_t*& device() { return CheckTest::device_; }
+ static iree_vm_instance_t*& instance() { return CheckTest::instance_; }
+
private:
static iree_hal_device_t* device_;
static iree_vm_instance_t* instance_;
@@ -223,6 +226,27 @@
iree_vm_module_t* CheckTest::check_module_ = nullptr;
iree_vm_module_t* CheckTest::hal_module_ = nullptr;
+TEST_F(CheckTest, HalModuleDebugSinkDestroyCallbackIsCalled) {
+ struct UserData {
+ bool is_callback_called = false;
+ };
+
+ iree_hal_module_debug_sink_t sink = {};
+ sink.destroy.fn = [](void* user_data) {
+ reinterpret_cast<UserData*>(user_data)->is_callback_called = true;
+ return iree_ok_status();
+ };
+ UserData user_data;
+ sink.destroy.user_data = &user_data;
+ iree_vm_module_t* hal_module;
+ IREE_ASSERT_OK(iree_hal_module_create(
+ instance(), /*device_count=*/1, &device(), IREE_HAL_MODULE_FLAG_NONE,
+ sink, iree_allocator_system(), &hal_module));
+ IREE_ASSERT_FALSE(user_data.is_callback_called);
+ iree_vm_module_release(hal_module);
+ IREE_ASSERT_TRUE(user_data.is_callback_called);
+}
+
TEST_F(CheckTest, ExpectTrueSuccess) {
IREE_ASSERT_OK(InvokeValue("expect_true", {iree_vm_value_make_i32(1)}));
}
diff --git a/runtime/src/iree/modules/hal/debugging.h b/runtime/src/iree/modules/hal/debugging.h
index 7500f2c..37a2857 100644
--- a/runtime/src/iree/modules/hal/debugging.h
+++ b/runtime/src/iree/modules/hal/debugging.h
@@ -30,12 +30,23 @@
void* user_data;
} iree_hal_module_buffer_view_trace_callback_t;
+typedef iree_status_t(
+ IREE_API_PTR* iree_hal_module_debug_sink_destroy_callback_fn_t)(
+ void* user_data);
+
+// Called by the runtime when the HAL module no longer needs the debug sink.
+typedef struct iree_hal_module_debug_sink_destroy_callback_t {
+ iree_hal_module_debug_sink_destroy_callback_fn_t fn;
+ void* user_data;
+} iree_hal_module_debug_sink_destroy_callback_t;
+
// Interface for a HAL module debug event sink.
// Any referenced user data must remain live for the lifetime of the HAL module
// the sink is provided to.
typedef struct iree_hal_module_debug_sink_t {
// Called on each hal.buffer_view.trace.
iree_hal_module_buffer_view_trace_callback_t buffer_view_trace;
+ iree_hal_module_debug_sink_destroy_callback_t destroy;
} iree_hal_module_debug_sink_t;
// Returns a default debug sink that outputs nothing.
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 24a11db..ae5f9e1 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -48,6 +48,11 @@
static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
+
+ if (module->debug_sink.destroy.fn) {
+ module->debug_sink.destroy.fn(module->debug_sink.destroy.user_data);
+ }
+
for (iree_host_size_t i = 0; i < module->device_count; ++i) {
iree_hal_device_release(module->devices[i]);
}