Rollup of minor runtime fixes/cleanup from the AMDGPU branch. (#19621)

diff --git a/runtime/bindings/python/CMakeLists.txt b/runtime/bindings/python/CMakeLists.txt
index 7171c2e..f0b7b37 100644
--- a/runtime/bindings/python/CMakeLists.txt
+++ b/runtime/bindings/python/CMakeLists.txt
@@ -118,6 +118,7 @@
     "iree/runtime/io.py"
     "iree/runtime/system_api.py"
     "iree/runtime/system_setup.py"
+    "iree/runtime/typing.py"
     "iree/runtime/version.py"
     "iree/_runtime/__init__.py"
     "iree/_runtime/libs.py"
@@ -231,13 +232,6 @@
 
 iree_py_test(
   NAME
-    hal_test
-  SRCS
-    "tests/hal_test.py"
-)
-
-iree_py_test(
-  NAME
     io_test
   SRCS
     "tests/io_test.py"
@@ -264,6 +258,13 @@
 
   iree_py_test(
     NAME
+      hal_test
+    SRCS
+      "tests/hal_test.py"
+  )
+
+  iree_py_test(
+    NAME
       io_runtime_test
     SRCS
       "tests/io_runtime_test.py"
diff --git a/runtime/bindings/python/binding.h b/runtime/bindings/python/binding.h
index 8a8ffc1..7de2c10 100644
--- a/runtime/bindings/python/binding.h
+++ b/runtime/bindings/python/binding.h
@@ -45,7 +45,9 @@
  public:
   using RawPtrType = T*;
   ApiRefCounted() : instance_(nullptr) {}
-  ApiRefCounted(ApiRefCounted& other) : instance_(other.instance_) { Retain(); }
+  ApiRefCounted(const ApiRefCounted& other) : instance_(other.instance_) {
+    Retain();
+  }
   ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) {
     other.instance_ = nullptr;
   }
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index a25a4ad..1450be7 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -6,15 +6,20 @@
 
 #include "./hal.h"
 
+#include <nanobind/intrusive/ref.h>
 #include <nanobind/nanobind.h>
+#include <nanobind/stl/function.h>
 #include <nanobind/stl/vector.h>
 
+#include <algorithm>
+#include <iterator>
 #include <optional>
 
 #include "./local_dlpack.h"
 #include "./numpy_interop.h"
 #include "./vm.h"
 #include "iree/base/internal/path.h"
+#include "iree/base/status.h"
 #include "iree/hal/api.h"
 #include "iree/hal/utils/allocators.h"
 #include "iree/modules/hal/module.h"
@@ -1069,8 +1074,10 @@
 // HAL module
 //------------------------------------------------------------------------------
 
-VmModule CreateHalModule(VmInstance* instance, std::optional<HalDevice*> device,
-                         std::optional<py::list> devices) {
+VmModule CreateHalModule(
+    VmInstance* instance, std::optional<HalDevice*> device,
+    std::optional<py::list> devices,
+    std::optional<py::ref<HalModuleDebugSink>> debug_sink) {
   if (device && devices) {
     PyErr_SetString(
         PyExc_ValueError,
@@ -1095,13 +1102,114 @@
     devices_ptr = devices_vector.data();
     device_count = devices_vector.size();
   }
-  CheckApiStatus(
-      iree_hal_module_create(instance->raw_ptr(), device_count, devices_ptr,
-                             IREE_HAL_MODULE_FLAG_NONE,
-                             iree_hal_module_debug_sink_stdio(stderr),
-                             iree_allocator_system(), &module),
-      "Error creating hal module");
-  return VmModule::StealFromRawPtr(module);
+
+  iree_hal_module_debug_sink_t iree_hal_module_debug_sink =
+      iree_hal_module_debug_sink_stdio(stderr);
+  if (debug_sink) {
+    iree_hal_module_debug_sink = (*debug_sink)->AsIreeHalModuleDebugSink();
+  }
+
+  CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device_count,
+                                        devices_ptr, IREE_HAL_MODULE_FLAG_NONE,
+                                        iree_hal_module_debug_sink,
+                                        iree_allocator_system(), &module),
+                 "Error creating hal module");
+  VmModule vm_module = VmModule::StealFromRawPtr(module);
+  if (debug_sink) {
+    // Retain a reference. We want the callback to be valid after
+    // the user has dropped its reference to the HAL module Python object and
+    // not burden the user with lifetime management.
+    // The counter will be decremented once the IREE runtime does not use the
+    // debug sink anymore.
+    (*debug_sink)->inc_ref();
+  }
+  return vm_module;
+}
+
+HalModuleDebugSink::HalModuleDebugSink(
+    HalModuleBufferViewTraceCallback buffer_view_trace_callback)
+    : buffer_view_trace_callback_(buffer_view_trace_callback) {}
+
+iree_hal_module_debug_sink_t HalModuleDebugSink::AsIreeHalModuleDebugSink()
+    const {
+  iree_hal_module_debug_sink_t res;
+  memset(&res, 0, sizeof(res));
+  res.buffer_view_trace.fn = HalModuleDebugSink::IreeHalModuleBufferViewTrace;
+  res.buffer_view_trace.user_data = const_cast<HalModuleDebugSink*>(this);
+  res.destroy.fn = HalModuleDebugSink::DestroyCallback;
+  res.destroy.user_data = const_cast<HalModuleDebugSink*>(this);
+  return res;
+}
+
+HalModuleBufferViewTraceCallback&
+HalModuleDebugSink::GetHalModuleBufferViewTraceCallback() {
+  return this->buffer_view_trace_callback_;
+}
+
+static std::vector<HalBufferView> CreateHalBufferViewVector(
+    iree_host_size_t buffer_view_count, iree_hal_buffer_view_t** buffer_views) {
+  std::vector<HalBufferView> res;
+  res.reserve(buffer_view_count);
+  std::transform(buffer_views, buffer_views + buffer_view_count,
+                 std::back_inserter(res),
+                 [](iree_hal_buffer_view_t* buffer_view) {
+                   return HalBufferView::BorrowFromRawPtr(buffer_view);
+                 });
+  return res;
+}
+
+iree_status_t HalModuleDebugSink::DestroyCallback(void* user_data) {
+  HalModuleDebugSink* debug_sink =
+      reinterpret_cast<HalModuleDebugSink*>(user_data);
+  debug_sink->dec_ref();
+  return iree_ok_status();
+}
+
+iree_status_t HalModuleDebugSink::IreeHalModuleBufferViewTrace(
+    void* user_data, iree_string_view_t key, iree_host_size_t buffer_view_count,
+    iree_hal_buffer_view_t** buffer_views, iree_allocator_t host_allocator) {
+  auto debug_sink = reinterpret_cast<HalModuleDebugSink*>(user_data);
+  std::vector<HalBufferView> buffer_views_vec =
+      CreateHalBufferViewVector(buffer_view_count, buffer_views);
+  try {
+    debug_sink->buffer_view_trace_callback_(std::string(key.data, key.size),
+                                            buffer_views_vec);
+  } catch (const py::python_error& e) {
+    return iree_make_status(IREE_STATUS_UNKNOWN, "%s", e.what());
+  }
+
+  return iree_ok_status();
+}
+
+static int HalModuleDebugSinkTpTraverse(PyObject* self, visitproc visit,
+                                        void* arg) {
+  // Inform Python's garbage collector about the references we hold.
+
+  // Retrieve a pointer to the C++ instance associated with 'self'
+  // (never fails)
+  HalModuleDebugSink* debug_sink = py::inst_ptr<HalModuleDebugSink>(self);
+
+  // Although we are not tracking cycles involving the HAL module or VM context
+  // we still want to properly destroy the callback and let the GC know what
+  // references we hold. If debug_sink->GetHalModuleBufferViewTraceCallback()
+  // has an associated CPython object, return it. If not, value.ptr() will equal
+  // NULL, which is also fine.
+  py::handle buffer_view_trace_callback =
+      py::find(debug_sink->GetHalModuleBufferViewTraceCallback());
+
+  // Inform the Python GC about the instance.
+  Py_VISIT(buffer_view_trace_callback.ptr());
+
+  return 0;
+}
+
+int HalModuleDebugSinkTpClear(PyObject* self) {
+  // Retrieve a pointer to the C++ instance associated with 'self'
+  // (never fails)
+  HalModuleDebugSink* debug_sink = py::inst_ptr<HalModuleDebugSink>(self);
+  debug_sink->GetHalModuleBufferViewTraceCallback() = nullptr;
+
+  return 0;
 }
 
 //------------------------------------------------------------------------------
@@ -1113,7 +1221,8 @@
 
   // Built-in module creation.
   m.def("create_hal_module", &CreateHalModule, py::arg("instance"),
-        py::arg("device") = py::none(), py::arg("devices") = py::none());
+        py::arg("device") = py::none(), py::arg("devices") = py::none(),
+        py::arg("debug_sink") = py::none());
 
   // Enums.
   py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType")
@@ -1777,6 +1886,27 @@
           py::arg("target_buffer"), py::arg("pattern"),
           py::arg("target_offset") = 0, py::arg("length") = py::none(),
           py::arg("end") = false);
+
+  PyType_Slot debug_sink_slots[] = {
+      {Py_tp_traverse, (void*)HalModuleDebugSinkTpTraverse},
+      {Py_tp_clear, (void*)HalModuleDebugSinkTpClear},
+      {0, nullptr}};
+  py::class_<HalModuleDebugSink>(
+      m, "HalModuleDebugSink", py::type_slots(debug_sink_slots),
+      py::intrusive_ptr<HalModuleDebugSink>(
+          [](HalModuleDebugSink* debug_sink, PyObject* po) noexcept {
+            debug_sink->set_self_py(po);
+          }))
+      .def(
+          "__init__",
+          [](HalModuleDebugSink* self,
+             HalModuleBufferViewTraceCallback buffer_view_trace_callback) {
+            new (self) HalModuleDebugSink(buffer_view_trace_callback);
+          },
+          py::arg("buffer_view_trace_callback"))
+      .def_prop_ro("buffer_view_trace_callback", [](HalModuleDebugSink& self) {
+        return self.GetHalModuleBufferViewTraceCallback();
+      });
 }
 
 }  // namespace python
diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h
index 7dbc108..5ea9e70 100644
--- a/runtime/bindings/python/hal.h
+++ b/runtime/bindings/python/hal.h
@@ -7,6 +7,9 @@
 #ifndef IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
 #define IREE_BINDINGS_PYTHON_IREE_RT_HAL_H_
 
+#include <nanobind/intrusive/counter.h>
+
+#include <functional>
 #include <vector>
 
 #include "./binding.h"
@@ -14,6 +17,7 @@
 #include "./vm.h"
 #include "iree/base/string_view.h"
 #include "iree/hal/api.h"
+#include "iree/modules/hal/debugging.h"
 
 namespace iree {
 namespace python {
@@ -289,6 +293,33 @@
 class HalCommandBuffer
     : public ApiRefCounted<HalCommandBuffer, iree_hal_command_buffer_t> {};
 
+using HalModuleBufferViewTraceCallback =
+    std::function<void(const std::string&, const std::vector<HalBufferView>&)>;
+
+// HAL debug sinks need ot live as long as the HAL module. This means the
+// underlying native object, not just the HAL module Python object.
+// This is necessary since here we hold a reference to a callback to a Python
+// function. This function needs to live after the destruction of the HAL module
+// Python object if it is registered into the VM context.
+// The HAL module and VM context Python objects are owners of the debug sink.
+class HalModuleDebugSink : public py::intrusive_base {
+ public:
+  HalModuleDebugSink(
+      HalModuleBufferViewTraceCallback buffer_view_trace_callback);
+  iree_hal_module_debug_sink_t AsIreeHalModuleDebugSink() const;
+  HalModuleBufferViewTraceCallback& GetHalModuleBufferViewTraceCallback();
+
+ private:
+  HalModuleBufferViewTraceCallback buffer_view_trace_callback_;
+
+  static iree_status_t DestroyCallback(void* user_data);
+
+  static iree_status_t IreeHalModuleBufferViewTrace(
+      void* user_data, iree_string_view_t key,
+      iree_host_size_t buffer_view_count, iree_hal_buffer_view_t** buffer_views,
+      iree_allocator_t host_allocator);
+};
+
 void SetupHalBindings(nanobind::module_ m);
 
 }  // namespace python
diff --git a/runtime/bindings/python/initialize_module.cc b/runtime/bindings/python/initialize_module.cc
index c79da46..48d0e29 100644
--- a/runtime/bindings/python/initialize_module.cc
+++ b/runtime/bindings/python/initialize_module.cc
@@ -4,7 +4,10 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+#include <nanobind/intrusive/counter.h>
+
 #include <memory>
+#include <nanobind/intrusive/counter.inl>
 
 #include "./binding.h"
 #include "./hal.h"
@@ -78,6 +81,16 @@
   });
 
   m.def("disable_leak_checker", []() { py::set_leak_warnings(false); });
+
+  py::intrusive_init(
+      [](PyObject *o) noexcept {
+        py::gil_scoped_acquire guard;
+        Py_INCREF(o);
+      },
+      [](PyObject *o) noexcept {
+        py::gil_scoped_acquire guard;
+        Py_DECREF(o);
+      });
 }
 
 }  // namespace python
diff --git a/runtime/bindings/python/iree/runtime/__init__.py b/runtime/bindings/python/iree/runtime/__init__.py
index 0163a6d..8cec280 100644
--- a/runtime/bindings/python/iree/runtime/__init__.py
+++ b/runtime/bindings/python/iree/runtime/__init__.py
@@ -56,6 +56,10 @@
     VmRef,
 )
 
+# Debug imports
+from ._binding import HalModuleDebugSink
+from .typing import HalModuleBufferViewTraceCallback
+
 from .array_interop import *
 from .benchmark import *
 from .system_api import *
diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi
index b4ef2ba..499e0c6 100644
--- a/runtime/bindings/python/iree/runtime/_binding.pyi
+++ b/runtime/bindings/python/iree/runtime/_binding.pyi
@@ -1,13 +1,23 @@
-from typing import Any, Callable, ClassVar, List, Optional, Sequence, Tuple, Union
-
-from typing import overload
-
+from typing import (
+    Any,
+    Callable,
+    ClassVar,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+    overload,
+)
 import asyncio
 
+from .typing import HalModuleBufferViewTraceCallback
+
 def create_hal_module(
     instance: VmInstance,
     device: Optional[HalDevice] = None,
     devices: Optional[List[HalDevice]] = None,
+    debug_sink: Optional[HalModuleDebugSink] = None,
 ) -> VmModule: ...
 def create_io_parameters_module(
     instance: VmInstance, *providers: ParameterProvider
@@ -310,6 +320,36 @@
         deadline: Optional[int] = None,
     ) -> None: ...
 
+class HalModuleDebugSink:
+    def __init__(
+        self, buffer_view_trace_callback: Optional[HalModuleBufferViewTraceCallback]
+    ):
+        """The function object buffer_view_trace_callback must not include the
+        corresponding HAL VmModule, VmInstance or VmContext in its closure.
+        Native runtime objects are managed by reference counting and do not track
+        cyclic references. This will create an uncollectible cycle.
+        E.g.
+
+        ```
+        vm_context = ...
+        def callback(key, hal_buffer_view):
+            print(vm_context)
+        hal_module = iree.runtime.create_hal_module(
+            vm_instance,
+            device,
+            debug_sink=iree.runtime.HalModuleDebugSink(callback),
+        )
+        ```
+
+        This callback will cause the VM context to never be destroyed.
+        """
+
+        ...
+    @property
+    def buffer_view_trace_callback(
+        self,
+    ) -> Optional[HalModuleBufferViewTraceCallback]: ...
+
 class Linkage(int):
     EXPORT: ClassVar[Linkage] = ...
     IMPORT: ClassVar[Linkage] = ...
diff --git a/runtime/bindings/python/iree/runtime/typing.py b/runtime/bindings/python/iree/runtime/typing.py
new file mode 100644
index 0000000..738184a
--- /dev/null
+++ b/runtime/bindings/python/iree/runtime/typing.py
@@ -0,0 +1,25 @@
+# Copyright 2025 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Type hints."""
+
+from typing import Callable, List, TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from . import HalBufferView
+
+TraceKey = str
+HalModuleBufferViewTraceCallback = Callable[[TraceKey, List["HalBufferView"]], None]
+"""Tracing function for buffers to pass to the runtime.
+This allows custom behavior when executing an IREE module with tensor tracing
+instructions. MLIR e.g.
+
+```
+flow.tensor.trace "MyTensors" [
+    %tensor1 : tensor<1xf32>,
+    %tensor2 : tensor<2xf32>
+]
+```
+"""
diff --git a/runtime/bindings/python/tests/hal_test.py b/runtime/bindings/python/tests/hal_test.py
index 348a628..0a141a3 100644
--- a/runtime/bindings/python/tests/hal_test.py
+++ b/runtime/bindings/python/tests/hal_test.py
@@ -4,7 +4,9 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from typing import List
 import iree.runtime
+import iree.compiler
 
 import gc
 import numpy as np
@@ -633,5 +635,147 @@
         )
 
 
+class HalModuleDebugSinkTest(unittest.TestCase):
+    COMPILED_TRACE_TENSOR: bytes
+
+    @classmethod
+    def compile_trace_tensor(cls):
+        if not hasattr(cls, "COMPILED_TRACE_TENSOR"):
+            cls.COMPILED_TRACE_TENSOR = iree.compiler.compile_str(
+                """
+                func.func @trace_args(%arg0: tensor<2xi32>, %arg1: tensor<3xi32>) {
+                    flow.tensor.trace "debug_sink_test" = [
+                        %arg0: tensor<2xi32>,
+                        %arg1: tensor<3xi32>
+                    ]
+                    return
+                }
+                """,
+                target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
+            )
+        return cls.COMPILED_TRACE_TENSOR
+
+    def testHalModuleBufferViewTraceCallback(self):
+        """Check that the trace tensor callback gets called with the expected
+        arguments."""
+        program_bytes = HalModuleDebugSinkTest.compile_trace_tensor()
+
+        arg0 = np.array([1, 2], dtype=np.int32)
+        arg1 = np.array([3, 4, 5], dtype=np.int32)
+
+        callback_key: str = None
+        callback_buffer_views = None
+
+        def callback(key: str, buffer_views: List[iree.runtime.HalBufferView]):
+            nonlocal callback_key
+            callback_key = key
+            nonlocal callback_buffer_views
+            callback_buffer_views = buffer_views
+
+        instance = iree.runtime.VmInstance()
+        device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+        hal_module = iree.runtime.create_hal_module(
+            instance, device, debug_sink=iree.runtime.HalModuleDebugSink(callback)
+        )
+        program_module = iree.runtime.VmModule.copy_buffer(instance, program_bytes)
+        context = iree.runtime.VmContext(instance)
+        context.register_modules([hal_module, program_module])
+        fn = program_module.lookup_function("trace_args")
+        fn_invoker = iree.runtime.FunctionInvoker(context, device, fn)
+        fn_invoker(arg0, arg1)
+
+        assert callback_key == "debug_sink_test"
+        assert len(callback_buffer_views) == 2
+        actual_arg0 = iree.runtime.DeviceArray(
+            device, callback_buffer_views[0]
+        ).to_host()
+        actual_arg1 = iree.runtime.DeviceArray(
+            device, callback_buffer_views[1]
+        ).to_host()
+        np.testing.assert_equal(actual_arg0, arg0)
+        np.testing.assert_equal(actual_arg1, arg1)
+
+    def testNoneHalModuleDebugSink(self):
+        device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+        instance = iree.runtime.VmInstance()
+        hal_module = iree.runtime.create_hal_module(
+            instance,
+            device,
+            debug_sink=None,
+        )
+
+    def testExceptionInHalModuleBufferViewTraceCallback(self):
+        """When an exception occurs in the callback check that it properly propagates
+        through the bindings and results in a IREE module function failed invocation.
+        """
+        program_bytes = HalModuleDebugSinkTest.compile_trace_tensor()
+
+        arg0 = np.array([1, 2], dtype=np.int32)
+        arg1 = np.array([3, 4, 5], dtype=np.int32)
+
+        device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+
+        class TestException(Exception):
+            def __init__(self, msg: str):
+                super().__init__(msg)
+
+        def callback(key: str, buffer_views: List[iree.runtime.HalBufferView]):
+            raise TestException("This is a test exception")
+
+        instance = iree.runtime.VmInstance()
+        hal_module = iree.runtime.create_hal_module(
+            instance, device, debug_sink=iree.runtime.HalModuleDebugSink(callback)
+        )
+        program_module = iree.runtime.VmModule.copy_buffer(instance, program_bytes)
+        context = iree.runtime.VmContext(instance)
+        context.register_modules([hal_module, program_module])
+        fn = program_module.lookup_function("trace_args")
+        fn_invoker = iree.runtime.FunctionInvoker(context, device, fn)
+        # TODO: once IREE status chains messages test for the actual message we raise
+        # within the callback.
+        self.assertRaisesRegex(RuntimeError, "UNKNOWN", fn_invoker, arg0, arg1)
+
+    def testHalModuleBufferViewTraceCallbackReferencingItselfDoesNotLeak(self):
+        """Check that if we do not hold reference to the HAL module or VM context,
+        but we hold a reference to the debug sink in the callback, the callback object
+        does not leak.
+        """
+        is_callback_destroyed: bool = False
+
+        class Callback:
+            def __del__(self):
+                nonlocal is_callback_destroyed
+                is_callback_destroyed = True
+
+            def __call__(
+                self, key: str, buffer_views: List[iree.runtime.HalBufferView]
+            ):
+                pass
+
+        callback = Callback()
+        debug_sink = iree.runtime.HalModuleDebugSink(callback)
+        setattr(callback, "debug_sink", debug_sink)
+
+        device = iree.runtime.get_device(iree.compiler.core.DEFAULT_TESTING_DRIVER)
+
+        vm_instance = iree.runtime.VmInstance()
+        hal_module = iree.runtime.create_hal_module(
+            vm_instance,
+            device,
+            debug_sink=debug_sink,
+        )
+        vm_context = iree.runtime.VmContext(vm_instance)
+        vm_context.register_modules([hal_module])
+        assert not is_callback_destroyed
+
+        del callback
+        del debug_sink
+        del hal_module
+        del vm_instance
+        del vm_context
+        gc.collect()
+        assert is_callback_destroyed
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/runtime/bindings/python/vm.cc b/runtime/bindings/python/vm.cc
index 0be8912..072e4fc 100644
--- a/runtime/bindings/python/vm.cc
+++ b/runtime/bindings/python/vm.cc
@@ -20,6 +20,7 @@
 #include "iree/modules/hal/module.h"
 #include "iree/tooling/modules/resolver.h"
 #include "iree/vm/api.h"
+#include "nanobind/nanobind.h"
 
 using namespace nanobind::literals;
 
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index d5e1c7b..ea1dab0 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -208,6 +208,9 @@
     return Invoke(function_name);
   }
 
+  static iree_hal_device_t*& device() { return CheckTest::device_; }
+  static iree_vm_instance_t*& instance() { return CheckTest::instance_; }
+
  private:
   static iree_hal_device_t* device_;
   static iree_vm_instance_t* instance_;
@@ -223,6 +226,27 @@
 iree_vm_module_t* CheckTest::check_module_ = nullptr;
 iree_vm_module_t* CheckTest::hal_module_ = nullptr;
 
+TEST_F(CheckTest, HalModuleDebugSinkDestroyCallbackIsCalled) {
+  struct UserData {
+    bool is_callback_called = false;
+  };
+
+  iree_hal_module_debug_sink_t sink = {};
+  sink.destroy.fn = [](void* user_data) {
+    reinterpret_cast<UserData*>(user_data)->is_callback_called = true;
+    return iree_ok_status();
+  };
+  UserData user_data;
+  sink.destroy.user_data = &user_data;
+  iree_vm_module_t* hal_module;
+  IREE_ASSERT_OK(iree_hal_module_create(
+      instance(), /*device_count=*/1, &device(), IREE_HAL_MODULE_FLAG_NONE,
+      sink, iree_allocator_system(), &hal_module));
+  IREE_ASSERT_FALSE(user_data.is_callback_called);
+  iree_vm_module_release(hal_module);
+  IREE_ASSERT_TRUE(user_data.is_callback_called);
+}
+
 TEST_F(CheckTest, ExpectTrueSuccess) {
   IREE_ASSERT_OK(InvokeValue("expect_true", {iree_vm_value_make_i32(1)}));
 }
diff --git a/runtime/src/iree/modules/hal/debugging.h b/runtime/src/iree/modules/hal/debugging.h
index 7500f2c..37a2857 100644
--- a/runtime/src/iree/modules/hal/debugging.h
+++ b/runtime/src/iree/modules/hal/debugging.h
@@ -30,12 +30,23 @@
   void* user_data;
 } iree_hal_module_buffer_view_trace_callback_t;
 
+typedef iree_status_t(
+    IREE_API_PTR* iree_hal_module_debug_sink_destroy_callback_fn_t)(
+    void* user_data);
+
+// Called by the runtime when the HAL module no longer needs the debug sink.
+typedef struct iree_hal_module_debug_sink_destroy_callback_t {
+  iree_hal_module_debug_sink_destroy_callback_fn_t fn;
+  void* user_data;
+} iree_hal_module_debug_sink_destroy_callback_t;
+
 // Interface for a HAL module debug event sink.
 // Any referenced user data must remain live for the lifetime of the HAL module
 // the sink is provided to.
 typedef struct iree_hal_module_debug_sink_t {
   // Called on each hal.buffer_view.trace.
   iree_hal_module_buffer_view_trace_callback_t buffer_view_trace;
+  iree_hal_module_debug_sink_destroy_callback_t destroy;
 } iree_hal_module_debug_sink_t;
 
 // Returns a default debug sink that outputs nothing.
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 24a11db..ae5f9e1 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -48,6 +48,11 @@
 
 static void IREE_API_PTR iree_hal_module_destroy(void* base_module) {
   iree_hal_module_t* module = IREE_HAL_MODULE_CAST(base_module);
+
+  if (module->debug_sink.destroy.fn) {
+    module->debug_sink.destroy.fn(module->debug_sink.destroy.user_data);
+  }
+
   for (iree_host_size_t i = 0; i < module->device_count; ++i) {
     iree_hal_device_release(module->devices[i]);
   }