Adding iree_hal_device_profiling_flush. (#14829)

This allows HAL devices to be poked to flush profiling data. The Vulkan
HAL now uses this signal to flush tracing events and
iree-benchmark-module has been updated to send the signal.

Fixes #14827.
diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c
index 91ef618..034feae 100644
--- a/experimental/cuda2/cuda_device.c
+++ b/experimental/cuda2/cuda_device.c
@@ -762,6 +762,12 @@
   return iree_ok_status();
 }
 
+static iree_status_t iree_hal_cuda2_device_profiling_flush(
+    iree_hal_device_t* base_device) {
+  // Unimplemented (and that's ok).
+  return iree_ok_status();
+}
+
 static iree_status_t iree_hal_cuda2_device_profiling_end(
     iree_hal_device_t* base_device) {
   // Unimplemented (and that's ok).
@@ -797,5 +803,6 @@
     .queue_flush = iree_hal_cuda2_device_queue_flush,
     .wait_semaphores = iree_hal_cuda2_device_wait_semaphores,
     .profiling_begin = iree_hal_cuda2_device_profiling_begin,
+    .profiling_flush = iree_hal_cuda2_device_profiling_flush,
     .profiling_end = iree_hal_cuda2_device_profiling_end,
 };
diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c
index e8c8b45..68c1053 100644
--- a/experimental/rocm/rocm_device.c
+++ b/experimental/rocm/rocm_device.c
@@ -397,6 +397,12 @@
   return iree_ok_status();
 }
 
+static iree_status_t iree_hal_rocm_device_profiling_flush(
+    iree_hal_device_t* base_device) {
+  // Unimplemented (and that's ok).
+  return iree_ok_status();
+}
+
 static iree_status_t iree_hal_rocm_device_profiling_end(
     iree_hal_device_t* base_device) {
   // Unimplemented (and that's ok).
@@ -432,5 +438,6 @@
     .queue_flush = iree_hal_rocm_device_queue_flush,
     .wait_semaphores = iree_hal_rocm_device_wait_semaphores,
     .profiling_begin = iree_hal_rocm_device_profiling_begin,
+    .profiling_flush = iree_hal_rocm_device_profiling_flush,
     .profiling_end = iree_hal_rocm_device_profiling_end,
 };
diff --git a/experimental/webgpu/webgpu_device.c b/experimental/webgpu/webgpu_device.c
index 435aaa6..1002bba 100644
--- a/experimental/webgpu/webgpu_device.c
+++ b/experimental/webgpu/webgpu_device.c
@@ -430,14 +430,20 @@
 }
 
 static iree_status_t iree_hal_webgpu_device_profiling_begin(
-    iree_hal_device_t* device,
+    iree_hal_device_t* base_device,
     const iree_hal_device_profiling_options_t* options) {
   // Unimplemented (and that's ok).
   return iree_ok_status();
 }
 
+static iree_status_t iree_hal_webgpu_device_profiling_flush(
+    iree_hal_device_t* base_device) {
+  // Unimplemented (and that's ok).
+  return iree_ok_status();
+}
+
 static iree_status_t iree_hal_webgpu_device_profiling_end(
-    iree_hal_device_t* device) {
+    iree_hal_device_t* base_device) {
   // Unimplemented (and that's ok).
   return iree_ok_status();
 }
@@ -468,5 +474,6 @@
     .queue_flush = iree_hal_webgpu_device_queue_flush,
     .wait_semaphores = iree_hal_webgpu_device_wait_semaphores,
     .profiling_begin = iree_hal_webgpu_device_profiling_begin,
+    .profiling_flush = iree_hal_webgpu_device_profiling_flush,
     .profiling_end = iree_hal_webgpu_device_profiling_end,
 };
diff --git a/runtime/src/iree/hal/device.c b/runtime/src/iree/hal/device.c
index efd2252..6a61ebd 100644
--- a/runtime/src/iree/hal/device.c
+++ b/runtime/src/iree/hal/device.c
@@ -398,6 +398,15 @@
 }
 
 IREE_API_EXPORT iree_status_t
+iree_hal_device_profiling_flush(iree_hal_device_t* device) {
+  IREE_ASSERT_ARGUMENT(device);
+  IREE_TRACE_ZONE_BEGIN(z0);
+  iree_status_t status = _VTABLE_DISPATCH(device, profiling_flush)(device);
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
+
+IREE_API_EXPORT iree_status_t
 iree_hal_device_profiling_end(iree_hal_device_t* device) {
   IREE_ASSERT_ARGUMENT(device);
   IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/runtime/src/iree/hal/device.h b/runtime/src/iree/hal/device.h
index 0a6f790..679a113 100644
--- a/runtime/src/iree/hal/device.h
+++ b/runtime/src/iree/hal/device.h
@@ -535,6 +535,10 @@
     iree_hal_device_t* device,
     const iree_hal_device_profiling_options_t* options);
 
+// Flushes any pending profiling data. May be a no-op.
+IREE_API_EXPORT iree_status_t
+iree_hal_device_profiling_flush(iree_hal_device_t* device);
+
 // Ends a profile previous started with iree_hal_device_profiling_begin.
 // The device must be idle before calling this method.
 IREE_API_EXPORT iree_status_t
@@ -662,6 +666,7 @@
   iree_status_t(IREE_API_PTR* profiling_begin)(
       iree_hal_device_t* device,
       const iree_hal_device_profiling_options_t* options);
+  iree_status_t(IREE_API_PTR* profiling_flush)(iree_hal_device_t* device);
   iree_status_t(IREE_API_PTR* profiling_end)(iree_hal_device_t* device);
 } iree_hal_device_vtable_t;
 IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_device_vtable_t);
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index dc8fc25..136d729 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -720,6 +720,12 @@
   return iree_ok_status();
 }
 
+static iree_status_t iree_hal_cuda_device_profiling_flush(
+    iree_hal_device_t* base_device) {
+  // Unimplemented (and that's ok).
+  return iree_ok_status();
+}
+
 static iree_status_t iree_hal_cuda_device_profiling_end(
     iree_hal_device_t* base_device) {
   // Unimplemented (and that's ok).
@@ -755,5 +761,6 @@
     .queue_flush = iree_hal_cuda_device_queue_flush,
     .wait_semaphores = iree_hal_cuda_device_wait_semaphores,
     .profiling_begin = iree_hal_cuda_device_profiling_begin,
+    .profiling_flush = iree_hal_cuda_device_profiling_flush,
     .profiling_end = iree_hal_cuda_device_profiling_end,
 };
diff --git a/runtime/src/iree/hal/drivers/local_sync/sync_device.c b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
index 7b4a823..b6007a1 100644
--- a/runtime/src/iree/hal/drivers/local_sync/sync_device.c
+++ b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
@@ -464,7 +464,7 @@
 }
 
 static iree_status_t iree_hal_sync_device_profiling_begin(
-    iree_hal_device_t* device,
+    iree_hal_device_t* base_device,
     const iree_hal_device_profiling_options_t* options) {
   // Unimplemented (and that's ok).
   // We could hook in to vendor APIs (Intel/ARM/etc) or generic perf infra:
@@ -477,8 +477,14 @@
   return iree_ok_status();
 }
 
+static iree_status_t iree_hal_sync_device_profiling_flush(
+    iree_hal_device_t* base_device) {
+  // Unimplemented (and that's ok).
+  return iree_ok_status();
+}
+
 static iree_status_t iree_hal_sync_device_profiling_end(
-    iree_hal_device_t* device) {
+    iree_hal_device_t* base_device) {
   // Unimplemented (and that's ok).
   return iree_ok_status();
 }
@@ -512,5 +518,6 @@
     .queue_flush = iree_hal_sync_device_queue_flush,
     .wait_semaphores = iree_hal_sync_device_wait_semaphores,
     .profiling_begin = iree_hal_sync_device_profiling_begin,
+    .profiling_flush = iree_hal_sync_device_profiling_flush,
     .profiling_end = iree_hal_sync_device_profiling_end,
 };
diff --git a/runtime/src/iree/hal/drivers/local_task/task_device.c b/runtime/src/iree/hal/drivers/local_task/task_device.c
index 99f158b..3b629f4 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_device.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_device.c
@@ -484,7 +484,7 @@
 }
 
 static iree_status_t iree_hal_task_device_profiling_begin(
-    iree_hal_device_t* device,
+    iree_hal_device_t* base_device,
     const iree_hal_device_profiling_options_t* options) {
   // Unimplemented (and that's ok).
   // We could hook in to vendor APIs (Intel/ARM/etc) or generic perf infra:
@@ -497,8 +497,14 @@
   return iree_ok_status();
 }
 
+static iree_status_t iree_hal_task_device_profiling_flush(
+    iree_hal_device_t* base_device) {
+  // Unimplemented (and that's ok).
+  return iree_ok_status();
+}
+
 static iree_status_t iree_hal_task_device_profiling_end(
-    iree_hal_device_t* device) {
+    iree_hal_device_t* base_device) {
   // Unimplemented (and that's ok).
   return iree_ok_status();
 }
@@ -532,5 +538,6 @@
     .queue_flush = iree_hal_task_device_queue_flush,
     .wait_semaphores = iree_hal_task_device_wait_semaphores,
     .profiling_begin = iree_hal_task_device_profiling_begin,
+    .profiling_flush = iree_hal_task_device_profiling_flush,
     .profiling_end = iree_hal_task_device_profiling_end,
 };
diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m
index 9d47a79..8cc963f 100644
--- a/runtime/src/iree/hal/drivers/metal/metal_device.m
+++ b/runtime/src/iree/hal/drivers/metal/metal_device.m
@@ -523,6 +523,10 @@
   return iree_ok_status();
 }
 
+static iree_status_t iree_hal_metal_device_profiling_flush(iree_hal_device_t* base_device) {
+  return iree_ok_status();
+}
+
 static iree_status_t iree_hal_metal_device_profiling_end(iree_hal_device_t* base_device) {
   iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
   if (device->capture_manager) {
@@ -559,5 +563,6 @@
     .queue_flush = iree_hal_metal_device_queue_flush,
     .wait_semaphores = iree_hal_metal_device_wait_semaphores,
     .profiling_begin = iree_hal_metal_device_profiling_begin,
+    .profiling_flush = iree_hal_metal_device_profiling_flush,
     .profiling_end = iree_hal_metal_device_profiling_end,
 };
diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
index 77cf830..d93b259 100644
--- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
@@ -1489,6 +1489,28 @@
   return iree_ok_status();
 }
 
+static iree_status_t iree_hal_vulkan_device_profiling_flush(
+    iree_hal_device_t* base_device) {
+  iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device);
+  (void)device;
+
+#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
+  if (iree_all_bits_set(device->logical_device->enabled_features(),
+                        IREE_HAL_VULKAN_FEATURE_ENABLE_TRACING)) {
+    for (iree_host_size_t i = 0; i < device->queue_count; ++i) {
+      iree_hal_vulkan_tracing_context_t* tracing_context =
+          device->queues[i]->tracing_context();
+      if (tracing_context) {
+        iree_hal_vulkan_tracing_context_collect(tracing_context,
+                                                VK_NULL_HANDLE);
+      }
+    }
+  }
+#endif  // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
+
+  return iree_ok_status();
+}
+
 static iree_status_t iree_hal_vulkan_device_profiling_end(
     iree_hal_device_t* base_device) {
   iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device);
@@ -1545,6 +1567,7 @@
     /*.queue_flush=*/iree_hal_vulkan_device_queue_flush,
     /*.wait_semaphores=*/iree_hal_vulkan_device_wait_semaphores,
     /*.profiling_begin=*/iree_hal_vulkan_device_profiling_begin,
+    /*.profiling_flush=*/iree_hal_vulkan_device_profiling_flush,
     /*.profiling_end=*/iree_hal_vulkan_device_profiling_end,
 };
 }  // namespace
diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc
index c658ce0..f69ef49 100644
--- a/tools/iree-benchmark-module-main.cc
+++ b/tools/iree-benchmark-module-main.cc
@@ -158,6 +158,7 @@
 
 static void BenchmarkGenericFunction(const std::string& benchmark_name,
                                      int32_t batch_size,
+                                     iree_hal_device_t* device,
                                      iree_vm_context_t* context,
                                      iree_vm_function_t function,
                                      iree_vm_list_t* inputs,
@@ -179,6 +180,11 @@
         inputs, outputs.get(), iree_allocator_system()));
     IREE_CHECK_OK(iree_vm_list_resize(outputs.get(), 0));
     IREE_TRACE_ZONE_END(z1);
+    if (device) {
+      state.PauseTiming();
+      IREE_CHECK_OK(iree_hal_device_profiling_flush(device));
+      state.ResumeTiming();
+    }
   }
   state.SetItemsProcessed(state.iterations());
 
@@ -186,6 +192,7 @@
 }
 
 void RegisterGenericBenchmark(const std::string& function_name,
+                              iree_hal_device_t* device,
                               iree_vm_context_t* context,
                               iree_vm_function_t function,
                               iree_vm_list_t* inputs) {
@@ -194,8 +201,8 @@
   benchmark::RegisterBenchmark(benchmark_name.c_str(),
                                [=](benchmark::State& state) -> void {
                                  BenchmarkGenericFunction(
-                                     benchmark_name, batch_size, context,
-                                     function, inputs, state);
+                                     benchmark_name, batch_size, device,
+                                     context, function, inputs, state);
                                })
       // By default only the main thread is included in CPU time. Include all
       // the threads instead.
@@ -328,6 +335,9 @@
     IREE_TRACE_ZONE_END(z_end);
 
     IREE_TRACE_ZONE_END(z1);
+    if (device) {
+      IREE_CHECK_OK(iree_hal_device_profiling_flush(device));
+    }
     state.ResumeTiming();
   }
   state.SetItemsProcessed(state.iterations());
@@ -502,8 +512,8 @@
                                    function, inputs_.get());
     } else {
       // Synchronous invocation.
-      iree::RegisterGenericBenchmark(function_name, context_.get(), function,
-                                     inputs_.get());
+      iree::RegisterGenericBenchmark(function_name, device_.get(),
+                                     context_.get(), function, inputs_.get());
     }
     return iree_ok_status();
   }
@@ -530,8 +540,8 @@
             function);
       } else if (iree_string_view_equal(benchmark_type, IREE_SV("entry"))) {
         iree::RegisterGenericBenchmark(
-            std::string(function_name.data, function_name.size), context_.get(),
-            function,
+            std::string(function_name.data, function_name.size), device_.get(),
+            context_.get(), function,
             /*inputs=*/nullptr);
       } else {
         // Pick up generic () -> () functions.
@@ -570,7 +580,7 @@
             // anything).
             iree::RegisterGenericBenchmark(
                 std::string(function_name.data, function_name.size),
-                context_.get(), function,
+                device_.get(), context_.get(), function,
                 /*inputs=*/nullptr);
           }
         }