[runtime][python] Add debug sink to bindings (#19013)
We don't support custom debug sinks in the Runtime Python bindings.
In particular the ability to register a custom callback when tracing
tensors.
This change makes it possible to create a HAL module with a Python
function as a callback.
This implementation does not handle the case of referencing directly or
indirectly the HAL module, VM context or VM instance in the callback
function object. In such a scenario the circular reference will not be
collected by the garbage collector and will leak. No no check is done
to guard against this. It is possible to traverse the Python object
structure to detect a reference to VM objects but it would require more
effort.
Here is added a callback to the debug sink in the IREE native runtime
API that signals when the runtime is done using the debug sink.
We need this since the Python objects corresponding to native runtime
objects are ephemeral and can not be used to hold the reference to the
debug sink.
---------
Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index 7171c2e..f0b7b37 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -118,6 +118,7 @@
"iree/runtime/io.py"
"iree/runtime/system_api.py"
"iree/runtime/system_setup.py"
+ "iree/runtime/typing.py"
"iree/runtime/version.py"
"iree/_runtime/__init__.py"
"iree/_runtime/libs.py"
@@ -231,13 +232,6 @@
iree_py_test(
NAME
- hal_test
- SRCS
- "tests/hal_test.py"
-)
-
-iree_py_test(
- NAME
io_test
SRCS
"tests/io_test.py"
@@ -264,6 +258,13 @@
iree_py_test(
NAME
+ hal_test
+ SRCS
+ "tests/hal_test.py"
+ )
+
+ iree_py_test(
+ NAME
io_runtime_test
SRCS
"tests/io_runtime_test.py"
diff --git a/runtime/bindings/python/binding.h b/runtime/bindings/python/binding.h
index 8a8ffc1..7de2c10 100644
--- a/runtime/bindings/python/binding.h
+++ b/runtime/bindings/python/binding.h
@@ -45,7 +45,9 @@
public:
using RawPtrType = T*;
ApiRefCounted() : instance_(nullptr) {}
- ApiRefCounted(ApiRefCounted& other) : instance_(other.instance_) { Retain(); }
+ ApiRefCounted(const ApiRefCounted& other) : instance_(other.instance_) {
+ Retain();
+ }
ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
other.instance_ = nullptr;
}
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index a25a4ad..1450be7 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -6,15 +6,20 @@
#include "./hal.h"
+#include <nanobind/intrusive/ref.h>
#include <nanobind/nanobind.h>
+#include <nanobind/stl/function.h>
#include <nanobind/stl/vector.h>
+#include <algorithm>
+#include <iterator>
#include <optional>
#include "./local_dlpack.h"
#include "./numpy_interop.h"
#include "./vm.h"
#include "iree/base/internal/path.h"
+#include "iree/base/status.h"
#include "iree/hal/api.h"
#include "iree/hal/utils/allocators.h"
#include "iree/modules/hal/module.h"
@@ -1069,8 +1074,10 @@
// HAL module
//------------------------------------------------------------------------------
-VmModule CreateHalModule(VmInstance* instance, std::optional<HalDevice*> device,
- std::optional<py::list> devices) {
+VmModule CreateHalModule(
+ VmInstance* instance, std::optional<HalDevice*> device,
+ std::optional<py::list> devices,
+ std::optional<py::ref<HalModuleDebugSink>> debug_sink) {
if (device && devices) {
PyErr_SetString(
PyExc_ValueError,
@@ -1095,13 +1102,114 @@
devices_ptr = devices_vector.data();
device_count = devices_vector.size();
}
- CheckApiStatus(
- iree_hal_module_create(instance->raw_ptr(), device_count, devices_ptr,
- IREE_HAL_MODULE_FLAG_NONE,
- iree_hal_module_debug_sink_stdio(stderr),
- iree_allocator_system(), &module),
- "Error creating hal module");
- return VmModule::StealFromRawPtr(module);
+
+ iree_hal_module_debug_sink_t iree_hal_module_debug_sink =
+ iree_hal_module_debug_sink_stdio(stderr);
+ if (debug_sink) {
+ iree_hal_module_debug_sink = (*debug_sink)->AsIreeHalModuleDebugSink();
+ }
+
+ CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device_count,
+ devices_ptr, IREE_HAL_MODULE_FLAG_NONE,
+ iree_hal_module_debug_sink,
+ iree_allocator_system(), &module),
+ "Error creating hal module");
+ VmModule vm_module = VmModule::StealFromRawPtr(module);
+ if (debug_sink) {
+ // Retain a reference. We want the callback to be valid after
+ // the user has dropped its reference to the HAL module Python object and
+ // not burden the user with lifetime management.
+ // The counter will be decremented once the IREE runtime does not use the
+ // debug sink anymore.
+ (*debug_sink)->inc_ref();
+ }
+ return vm_module;
+}
+
+HalModuleDebugSink::HalModuleDebugSink(
+ HalModuleBufferViewTraceCallback buffer_view_trace_callback)
+ : buffer_view_trace_callback_(buffer_view_trace_callback) {}
+
+iree_hal_module_debug_sink_t HalModuleDebugSink::AsIreeHalModuleDebugSink()
+ const {
+ iree_hal_module_debug_sink_t res;
+ memset(&res, 0, sizeof(res));
+ res.buffer_view_trace.fn = HalModuleDebugSink::IreeHalModuleBufferViewTrace;
+ res.buffer_view_trace.user_data = const_cast<HalModuleDebugSink*>(this);
+ res.destroy.fn = HalModuleDebugSink::DestroyCallback;
+ res.destroy.user_data = const_cast<HalModuleDebugSink*>(this);
+ return res;
+}
+
+HalModuleBufferViewTraceCallback&
+HalModuleDebugSink::GetHalModuleBufferViewTraceCallback() {
+ return this->buffer_view_trace_callback_;
+}
+
+static std::vector<HalBufferView> CreateHalBufferViewVector(
+ iree_host_size_t buffer_view_count, iree_hal_buffer_view_t** buffer_views) {
+ std::vector<HalBufferView> res;
+ res.reserve(buffer_view_count);
+ std::transform(buffer_views, buffer_views + buffer_view_count,
+ std::back_inserter(res),
+ [](iree_hal_buffer_view_t* buffer_view) {
+ return HalBufferView::BorrowFromRawPtr(buffer_view);
+ });
+ return res;
+}
+
+iree_status_t HalModuleDebugSink::DestroyCallback(void* user_data) {
+ HalModuleDebugSink* debug_sink =
+ reinterpret_cast<HalModuleDebugSink*>(user_data);
+ debug_sink->dec_ref();
+ return iree_ok_status();
+}
+
+iree_status_t HalModuleDebugSink::IreeHalModuleBufferViewTrace(
+ void* user_data, iree_string_view_t key, iree_host_size_t buffer_view_count,
+ iree_hal_buffer_view_t** buffer_views, iree_allocator_t host_allocator) {
+ auto debug_sink = reinterpret_cast<HalModuleDebugSink*>(user_data);
+ std::vector<HalBufferView> buffer_views_vec =
+ CreateHalBufferViewVector(buffer_view_count, buffer_views);
+ try {
+ debug_sink->buffer_view_trace_callback_(std::string(key.data, key.size),
+ buffer_views_vec);
+ } catch (const py::python_error& e) {
+ return iree_make_status(IREE_STATUS_UNKNOWN, "%s", e.what());
+ }
+
+ return iree_ok_status();
+}
+
+static int HalModuleDebugSinkTpTraverse(PyObject* self, visitproc visit,
+ void* arg) {
+ // Inform Python's garbage collector about the references we hold.
+
+ // Retrieve a pointer to the C++ instance associated with 'self'
+ // (never fails)
+ HalModuleDebugSink* debug_sink = py::inst_ptr<HalModuleDebugSink>(self);
+
+ // Although we are not tracking cycles involving the HAL module or VM context
+ // we still want to properly destroy the callback and let the GC know what
+ // references we hold. If debug_sink->GetHalModuleBufferViewTraceCallback()
+ // has an associated CPython object, return it. If not, value.ptr() will equal
+ // NULL, which is also fine.
+ py::handle buffer_view_trace_callback =
+ py::find(debug_sink->GetHalModuleBufferViewTraceCallback());
+
+ // Inform the Python GC about the instance.
+ Py_VISIT(buffer_view_trace_callback.ptr());
+
+ return 0;
+}
+
+int HalModuleDebugSinkTpClear(PyObject* self) {
+ // Retrieve a pointer to the C++ instance associated with 'self'
+ // (never fails)
+ HalModuleDebugSink* debug_sink = py::inst_ptr<HalModuleDebugSink>(self);
+ debug_sink->GetHalModuleBufferViewTraceCallback() = nullptr;
+
+ return 0;
}
//------------------------------------------------------------------------------
@@ -1113,7 +1221,8 @@
// Built-in module creation.
m.def("create_hal_module", &CreateHalModule, py::arg("instance"),
- py::arg("device") = py::none(), py::arg("devices") = py::none());
+ py::arg("device") = py::none(), py::arg("devices") = py::none(),
+ py::arg("debug_sink") = py::none());
// Enums.
py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType")
@@ -1777,6 +1886,27 @@
py::arg("target_buffer"), py::arg("pattern"),
py::arg("target_offset") = 0, py::arg("length") = py::none(),
py::arg("end") = false);
+
+ PyType_Slot debug_sink_slots[] = {
+ {Py_tp_traverse, (void*)HalModuleDebugSinkTpTraverse},
+ {Py_tp_clear, (void*)HalModuleDebugSinkTpClear},
+ {0, nullptr}};
+ py::class_<HalModuleDebugSink>(
+ m, "HalModuleDebugSink", py::type_slots(debug_sink_slots),
+ py::intrusive_ptr<HalModuleDebugSink>(
+ [](HalModuleDebugSink* debug_sink, PyObject* po) noexcept {
+ debug_sink->set_self_py(po);
+ }))
+ .def(
+ "__init__",
+ [](HalModuleDebugSink* self,
+ HalModuleBufferViewTraceCallback buffer_view_trace_callback) {
+ new (self) HalModuleDebugSink(buffer_view_trace_callback);
+ },
+ py::arg("buffer_view_trace_callback"))
+ .def_prop_ro("buffer_view_trace_callback", [](HalModuleDebugSink& self) {
+ return self.GetHalModuleBufferViewTraceCallback();
+ });
}
} // namespace python
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index 7dbc108..5ea9e70 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -7,6 +7,9 @@
#ifndef IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
#define IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
+#include <nanobind/intrusive/counter.h>
+
+#include <functional>
#include <vector>
#include "./binding.h"
@@ -14,6 +17,7 @@
#include "./vm.h"
#include "iree/base/string_view.h"
#include "iree/hal/api.h"
+#include "iree/modules/hal/debugging.h"
namespace iree {
namespace python {
@@ -289,6 +293,33 @@
class HalCommandBuffer
: public ApiRefCounted<HalCommandBuffer, iree_hal_command_buffer_t> {};
+using HalModuleBufferViewTraceCallback =
+ std::function<void(const std::string&, const std::vector<HalBufferView>&)>;
+
+// HAL debug sinks need ot live as long as the HAL module. This means the
+// underlying native object, not just the HAL module Python object.
+// This is necessary since here we hold a reference to a callback to a Python
+// function. This function needs to live after the destruction of the HAL module
+// Python object if it is registered into the VM context.
+// The HAL module and VM context Python objects are owners of the debug sink.
+class HalModuleDebugSink : public py::intrusive_base {
+ public:
+ HalModuleDebugSink(
+ HalModuleBufferViewTraceCallback buffer_view_trace_callback);
+ iree_hal_module_debug_sink_t AsIreeHalModuleDebugSink() const;
+ HalModuleBufferViewTraceCallback& GetHalModuleBufferViewTraceCallback();
+
+ private:
+ HalModuleBufferViewTraceCallback buffer_view_trace_callback_;
+
+ static iree_status_t DestroyCallback(void* user_data);
+
+ static iree_status_t IreeHalModuleBufferViewTrace(
+ void* user_data, iree_string_view_t key,
+ iree_host_size_t buffer_view_count, iree_hal_buffer_view_t** buffer_views,
+ iree_allocator_t host_allocator);
+};
+
void SetupHalBindings(nanobind::module_ m);
} // namespace python
diff --git a/runtime/bindings/python/initialize_module.cc b/runtime/bindings/python/initialize_module.cc
index c79da46..48d0e29 100644
--- a/runtime/bindings/python/initialize_module.cc
+++ b/runtime/bindings/python/initialize_module.cc
@@ -4,7 +4,10 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include <nanobind/intrusive/counter.h>
+
#include <memory>
+#include <nanobind/intrusive/counter.inl>
#include "./binding.h"
#include "./hal.h"
@@ -78,6 +81,16 @@
});
m.def("disable_leak_checker", []() { py::set_leak_warnings(false); });
+
+ py::intrusive_init(
+ [](PyObject *o) noexcept {
+ py::gil_scoped_acquire guard;
+ Py_INCREF(o);
+ },
+ [](PyObject *o) noexcept {
+ py::gil_scoped_acquire guard;
+ Py_DECREF(o);
+ });
}
} // namespace python
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index 0163a6d..8cec280 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -56,6 +56,10 @@
VmRef,
)
+# Debug imports
+from ._binding import HalModuleDebugSink
+from .typing import HalModuleBufferViewTraceCallback
+
from .array_interop import *
from .benchmark import *
from .system_api import *
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index b4ef2ba..499e0c6 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -1,13 +1,23 @@
-from typing import Any, Callable, ClassVar, List, Optional, Sequence, Tuple, Union
-
-from typing import overload
-
+from typing import (
+ Any,
+ Callable,
+ ClassVar,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ overload,
+)
import asyncio
+from .typing import HalModuleBufferViewTraceCallback
+
def create_hal_module(
instance: VmInstance,
device: Optional[HalDevice] = None,
devices: Optional[List[HalDevice]] = None,
+ debug_sink: Optional[HalModuleDebugSink] = None,
) -> VmModule: ...
def create_io_parameters_module(
instance: VmInstance, *providers: ParameterProvider
@@ -310,6 +320,36 @@
deadline: Optional[int] = None,
) -> None: ...
+class HalModuleDebugSink:
+ def __init__(
+ self, buffer_view_trace_callback: Optional[HalModuleBufferViewTraceCallback]
+ ):
+ """The function object buffer_view_trace_callback must not include the
+ corresponding HAL VmModule, VmInstance or VmContext in its closure.
+ Native runtime objects are managed by reference counting and do not track
+ cyclic references. This will create an uncollectible cycle.
+ E.g.
+
+ ```
+ vm_context = ...
+ def callback(key, hal_buffer_view):
+ print(vm_context)
+ hal_module = iree.runtime.create_hal_module(
+ vm_instance,
+ device,
+ debug_sink=iree.runtime.HalModuleDebugSink(callback),
+ )
+ ```
+
+ This callback will cause the VM context to never be destroyed.
+ """
+
+ ...
+ @property
+ def buffer_view_trace_callback(
+ self,
+ ) -> Optional[HalModuleBufferViewTraceCallback]: ...
+
class Linkage(int):
EXPORT: ClassVar[Linkage] = ...
IMPORT: ClassVar[Linkage] = ...
diff --git a/runtime/bindings/python/iree/runtime/typing.py b/runtime/bindings/python/iree/runtime/typing.py
new file mode 100644
index 0000000..738184a
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/typing.py
@@ -0,0 +1,25 @@
+# Copyright 2025 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Type hints."""
+
+from typing import Callable, List, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from . import HalBufferView
+
+TraceKey = str
+HalModuleBufferViewTraceCallback = Callable[[TraceKey, List["HalBufferView"]], None]
+"""Tracing function for buffers to pass to the runtime.
+This allows custom behavior when executing an IREE module with tensor tracing
+instructions. MLIR e.g.
+
+```
+flow.tensor.trace "MyTensors" [
+ %tensor1 : tensor<1xf32>,
+ %tensor2 : tensor<2xf32>
+]
+```
+"""
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index 348a628..0a141a3 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -4,7 +4,9 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import List
import iree.runtime
+import iree.compiler
import gc
import numpy as np
@@ -633,5 +635,147 @@
)
+class HalModuleDebugSinkTest(unittest.TestCase):
+ COMPILED_TRACE_TENSOR: bytes
+
+ @classmethod
+ def compile_trace_tensor(cls):
+ if not hasattr(cls, "COMPILED_TRACE_TENSOR"):
+ cls.COMPILED_TRACE_TENSOR = iree.compiler.compile_str(
+ """
+ func.func @trace_args(%arg0: tensor<2xi32>, %arg1: tensor<3xi32>) {
+ flow.tensor.trace "debug_sink_test" = [
+ %arg0: tensor<2xi32>,
+ %arg1: tensor<3xi32>
+ ]
+ return
+ }
+ """,
+ target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
+ )
+ return cls.COMPILED_TRACE_TENSOR
+
+ def testHalModuleBufferViewTraceCallback(self):
+ """Check that the trace tensor callback gets called with the expected
+ arguments."""
+ program_bytes = HalModuleDebugSinkTest.compile_trace_tensor()
+
+ arg0 = np.array([1, 2], dtype=np.int32)
+ arg1 = np.array([3, 4, 5], dtype=np.int32)
+
+ callback_key: str = None
+ callback_buffer_views = None
+
+ def callback(key: str, buffer_views: List[iree.runtime.HalBufferView]):
+ nonlocal callback_key
+ callback_key = key
+ nonlocal callback_buffer_views
+ callback_buffer_views = buffer_views
+
+ instance = iree.runtime.VmInstance()
+ device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+ hal_module = iree.runtime.create_hal_module(
+ instance, device, debug_sink=iree.runtime.HalModuleDebugSink(callback)
+ )
+ program_module = iree.runtime.VmModule.copy_buffer(instance, program_bytes)
+ context = iree.runtime.VmContext(instance)
+ context.register_modules([hal_module, program_module])
+ fn = program_module.lookup_function("trace_args")
+ fn_invoker = iree.runtime.FunctionInvoker(context, device, fn)
+ fn_invoker(arg0, arg1)
+
+ assert callback_key == "debug_sink_test"
+ assert len(callback_buffer_views) == 2
+ actual_arg0 = iree.runtime.DeviceArray(
+ device, callback_buffer_views[0]
+ ).to_host()
+ actual_arg1 = iree.runtime.DeviceArray(
+ device, callback_buffer_views[1]
+ ).to_host()
+ np.testing.assert_equal(actual_arg0, arg0)
+ np.testing.assert_equal(actual_arg1, arg1)
+
+ def testNoneHalModuleDebugSink(self):
+ device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+ instance = iree.runtime.VmInstance()
+ hal_module = iree.runtime.create_hal_module(
+ instance,
+ device,
+ debug_sink=None,
+ )
+
+ def testExceptionInHalModuleBufferViewTraceCallback(self):
+ """When an exception occurs in the callback check that it properly propagates
+ through the bindings and results in a IREE module function failed invocation.
+ """
+ program_bytes = HalModuleDebugSinkTest.compile_trace_tensor()
+
+ arg0 = np.array([1, 2], dtype=np.int32)
+ arg1 = np.array([3, 4, 5], dtype=np.int32)
+
+ device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+
+ class TestException(Exception):
+ def __init__(self, msg: str):
+ super().__init__(msg)
+
+ def callback(key: str, buffer_views: List[iree.runtime.HalBufferView]):
+ raise TestException("This is a test exception")
+
+ instance = iree.runtime.VmInstance()
+ hal_module = iree.runtime.create_hal_module(
+ instance, device, debug_sink=iree.runtime.HalModuleDebugSink(callback)
+ )
+ program_module = iree.runtime.VmModule.copy_buffer(instance, program_bytes)
+ context = iree.runtime.VmContext(instance)
+ context.register_modules([hal_module, program_module])
+ fn = program_module.lookup_function("trace_args")
+ fn_invoker = iree.runtime.FunctionInvoker(context, device, fn)
+ # TODO: once IREE status chains messages test for the actual message we raise
+ # within the callback.
+ self.assertRaisesRegex(RuntimeError, "UNKNOWN", fn_invoker, arg0, arg1)
+
+ def testHalModuleBufferViewTraceCallbackReferencingItselfDoesNotLeak(self):
+ """Check that if we do not hold reference to the HAL module or VM context,
+ but we hold a reference to the debug sink in the callback, the callback object
+ does not leak.
+ """
+ is_callback_destroyed: bool = False
+
+ class Callback:
+ def __del__(self):
+ nonlocal is_callback_destroyed
+ is_callback_destroyed = True
+
+ def __call__(
+ self, key: str, buffer_views: List[iree.runtime.HalBufferView]
+ ):
+ pass
+
+ callback = Callback()
+ debug_sink = iree.runtime.HalModuleDebugSink(callback)
+ setattr(callback, "debug_sink", debug_sink)
+
+ device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+
+ vm_instance = iree.runtime.VmInstance()
+ hal_module = iree.runtime.create_hal_module(
+ vm_instance,
+ device,
+ debug_sink=debug_sink,
+ )
+ vm_context = iree.runtime.VmContext(vm_instance)
+ vm_context.register_modules([hal_module])
+ assert not is_callback_destroyed
+
+ del callback
+ del debug_sink
+ del hal_module
+ del vm_instance
+ del vm_context
+ gc.collect()
+ assert is_callback_destroyed
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 0be8912..072e4fc 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -20,6 +20,7 @@
#include "iree/modules/hal/module.h"
#include "iree/tooling/modules/resolver.h"
#include "iree/vm/api.h"
+#include "nanobind/nanobind.h"
using namespace nanobind::literals;
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index d5e1c7b..ea1dab0 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -208,6 +208,9 @@
return Invoke(function_name);
}
+ static iree_hal_device_t*& device() { return CheckTest::device_; }
+ static iree_vm_instance_t*& instance() { return CheckTest::instance_; }
+
private:
static iree_hal_device_t* device_;
static iree_vm_instance_t* instance_;
@@ -223,6 +226,27 @@
iree_vm_module_t* CheckTest::check_module_ = nullptr;
iree_vm_module_t* CheckTest::hal_module_ = nullptr;
+TEST_F(CheckTest, HalModuleDebugSinkDestroyCallbackIsCalled) {
+ struct UserData {
+ bool is_callback_called = false;
+ };
+
+ iree_hal_module_debug_sink_t sink = {};
+ sink.destroy.fn = [](void* user_data) {
+ reinterpret_cast<UserData*>(user_data)->is_callback_called = true;
+ return iree_ok_status();
+ };
+ UserData user_data;
+ sink.destroy.user_data = &user_data;
+ iree_vm_module_t* hal_module;
+ IREE_ASSERT_OK(iree_hal_module_create(
+ instance(), /*device_count=*/1, &device(), IREE_HAL_MODULE_FLAG_NONE,
+ sink, iree_allocator_system(), &hal_module));
+ IREE_ASSERT_FALSE(user_data.is_callback_called);
+ iree_vm_module_release(hal_module);
+ IREE_ASSERT_TRUE(user_data.is_callback_called);
+}
+
TEST_F(CheckTest, ExpectTrueSuccess) {
IREE_ASSERT_OK(InvokeValue("expect_true", {iree_vm_value_make_i32(1)}));
}
diff --git a/runtime/src/iree/modules/hal/debugging.h b/runtime/src/iree/modules/hal/debugging.h
index 7500f2c..37a2857 100644
--- a/runtime/src/iree/modules/hal/debugging.h
+++ b/runtime/src/iree/modules/hal/debugging.h
@@ -30,12 +30,23 @@
void* user_data;
} iree_hal_module_buffer_view_trace_callback_t;
+typedef iree_status_t(
+ IREE_API_PTR* iree_hal_module_debug_sink_destroy_callback_fn_t)(
+ void* user_data);
+
+// Called by the runtime when the HAL module no longer needs the debug sink.
+typedef struct iree_hal_module_debug_sink_destroy_callback_t {
+ iree_hal_module_debug_sink_destroy_callback_fn_t fn;
+ void* user_data;
+} iree_hal_module_debug_sink_destroy_callback_t;
+
// Interface for a HAL module debug event sink.
// Any referenced user data must remain live for the lifetime of the HAL module
// the sink is provided to.
typedef struct iree_hal_module_debug_sink_t {
// Called on each hal.buffer_view.trace.
iree_hal_module_buffer_view_trace_callback_t buffer_view_trace;
+ iree_hal_module_debug_sink_destroy_callback_t destroy;
} iree_hal_module_debug_sink_t;
// Returns a default debug sink that outputs nothing.
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 24a11db..ae5f9e1 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -48,6 +48,11 @@
static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
+
+ if (module->debug_sink.destroy.fn) {
+ module->debug_sink.destroy.fn(module->debug_sink.destroy.user_data);
+ }
+
for (iree_host_size_t i = 0; i < module->device_count; ++i) {
iree_hal_device_release(module->devices[i]);
}