Adding iree_hal_device_profiling_flush. (#14829)
This allows HAL devices to be poked to flush profiling data. The Vulkan
HAL now uses this signal to flush tracing events and
iree-benchmark-module has been updated to send the signal.
Fixes #14827.
diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c
index 91ef618..034feae 100644
--- a/experimental/cuda2/cuda_device.c
+++ b/experimental/cuda2/cuda_device.c
@@ -762,6 +762,12 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_cuda2_device_profiling_flush(
+ iree_hal_device_t* base_device) {
+ // Unimplemented (and that's ok).
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_cuda2_device_profiling_end(
iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
@@ -797,5 +803,6 @@
.queue_flush = iree_hal_cuda2_device_queue_flush,
.wait_semaphores = iree_hal_cuda2_device_wait_semaphores,
.profiling_begin = iree_hal_cuda2_device_profiling_begin,
+ .profiling_flush = iree_hal_cuda2_device_profiling_flush,
.profiling_end = iree_hal_cuda2_device_profiling_end,
};
diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c
index e8c8b45..68c1053 100644
--- a/experimental/rocm/rocm_device.c
+++ b/experimental/rocm/rocm_device.c
@@ -397,6 +397,12 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_rocm_device_profiling_flush(
+ iree_hal_device_t* base_device) {
+ // Unimplemented (and that's ok).
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_rocm_device_profiling_end(
iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
@@ -432,5 +438,6 @@
.queue_flush = iree_hal_rocm_device_queue_flush,
.wait_semaphores = iree_hal_rocm_device_wait_semaphores,
.profiling_begin = iree_hal_rocm_device_profiling_begin,
+ .profiling_flush = iree_hal_rocm_device_profiling_flush,
.profiling_end = iree_hal_rocm_device_profiling_end,
};
diff --git a/experimental/webgpu/webgpu_device.c b/experimental/webgpu/webgpu_device.c
index 435aaa6..1002bba 100644
--- a/experimental/webgpu/webgpu_device.c
+++ b/experimental/webgpu/webgpu_device.c
@@ -430,14 +430,20 @@
}
static iree_status_t iree_hal_webgpu_device_profiling_begin(
- iree_hal_device_t* device,
+ iree_hal_device_t* base_device,
const iree_hal_device_profiling_options_t* options) {
// Unimplemented (and that's ok).
return iree_ok_status();
}
+static iree_status_t iree_hal_webgpu_device_profiling_flush(
+ iree_hal_device_t* base_device) {
+ // Unimplemented (and that's ok).
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_webgpu_device_profiling_end(
- iree_hal_device_t* device) {
+ iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
return iree_ok_status();
}
@@ -468,5 +474,6 @@
.queue_flush = iree_hal_webgpu_device_queue_flush,
.wait_semaphores = iree_hal_webgpu_device_wait_semaphores,
.profiling_begin = iree_hal_webgpu_device_profiling_begin,
+ .profiling_flush = iree_hal_webgpu_device_profiling_flush,
.profiling_end = iree_hal_webgpu_device_profiling_end,
};
diff --git a/runtime/src/iree/hal/device.c b/runtime/src/iree/hal/device.c
index efd2252..6a61ebd 100644
--- a/runtime/src/iree/hal/device.c
+++ b/runtime/src/iree/hal/device.c
@@ -398,6 +398,15 @@
}
IREE_API_EXPORT iree_status_t
+iree_hal_device_profiling_flush(iree_hal_device_t* device) {
+ IREE_ASSERT_ARGUMENT(device);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_status_t status = _VTABLE_DISPATCH(device, profiling_flush)(device);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t
iree_hal_device_profiling_end(iree_hal_device_t* device) {
IREE_ASSERT_ARGUMENT(device);
IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/runtime/src/iree/hal/device.h b/runtime/src/iree/hal/device.h
index 0a6f790..679a113 100644
--- a/runtime/src/iree/hal/device.h
+++ b/runtime/src/iree/hal/device.h
@@ -535,6 +535,10 @@
iree_hal_device_t* device,
const iree_hal_device_profiling_options_t* options);
+// Flushes any pending profiling data. May be a no-op.
+IREE_API_EXPORT iree_status_t
+iree_hal_device_profiling_flush(iree_hal_device_t* device);
+
// Ends a profile previous started with iree_hal_device_profiling_begin.
// The device must be idle before calling this method.
IREE_API_EXPORT iree_status_t
@@ -662,6 +666,7 @@
iree_status_t(IREE_API_PTR* profiling_begin)(
iree_hal_device_t* device,
const iree_hal_device_profiling_options_t* options);
+ iree_status_t(IREE_API_PTR* profiling_flush)(iree_hal_device_t* device);
iree_status_t(IREE_API_PTR* profiling_end)(iree_hal_device_t* device);
} iree_hal_device_vtable_t;
IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_device_vtable_t);
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index dc8fc25..136d729 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -720,6 +720,12 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_cuda_device_profiling_flush(
+ iree_hal_device_t* base_device) {
+ // Unimplemented (and that's ok).
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_cuda_device_profiling_end(
iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
@@ -755,5 +761,6 @@
.queue_flush = iree_hal_cuda_device_queue_flush,
.wait_semaphores = iree_hal_cuda_device_wait_semaphores,
.profiling_begin = iree_hal_cuda_device_profiling_begin,
+ .profiling_flush = iree_hal_cuda_device_profiling_flush,
.profiling_end = iree_hal_cuda_device_profiling_end,
};
diff --git a/runtime/src/iree/hal/drivers/local_sync/sync_device.c b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
index 7b4a823..b6007a1 100644
--- a/runtime/src/iree/hal/drivers/local_sync/sync_device.c
+++ b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
@@ -464,7 +464,7 @@
}
static iree_status_t iree_hal_sync_device_profiling_begin(
- iree_hal_device_t* device,
+ iree_hal_device_t* base_device,
const iree_hal_device_profiling_options_t* options) {
// Unimplemented (and that's ok).
// We could hook in to vendor APIs (Intel/ARM/etc) or generic perf infra:
@@ -477,8 +477,14 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_sync_device_profiling_flush(
+ iree_hal_device_t* base_device) {
+ // Unimplemented (and that's ok).
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_sync_device_profiling_end(
- iree_hal_device_t* device) {
+ iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
return iree_ok_status();
}
@@ -512,5 +518,6 @@
.queue_flush = iree_hal_sync_device_queue_flush,
.wait_semaphores = iree_hal_sync_device_wait_semaphores,
.profiling_begin = iree_hal_sync_device_profiling_begin,
+ .profiling_flush = iree_hal_sync_device_profiling_flush,
.profiling_end = iree_hal_sync_device_profiling_end,
};
diff --git a/runtime/src/iree/hal/drivers/local_task/task_device.c b/runtime/src/iree/hal/drivers/local_task/task_device.c
index 99f158b..3b629f4 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_device.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_device.c
@@ -484,7 +484,7 @@
}
static iree_status_t iree_hal_task_device_profiling_begin(
- iree_hal_device_t* device,
+ iree_hal_device_t* base_device,
const iree_hal_device_profiling_options_t* options) {
// Unimplemented (and that's ok).
// We could hook in to vendor APIs (Intel/ARM/etc) or generic perf infra:
@@ -497,8 +497,14 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_task_device_profiling_flush(
+ iree_hal_device_t* base_device) {
+ // Unimplemented (and that's ok).
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_task_device_profiling_end(
- iree_hal_device_t* device) {
+ iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
return iree_ok_status();
}
@@ -532,5 +538,6 @@
.queue_flush = iree_hal_task_device_queue_flush,
.wait_semaphores = iree_hal_task_device_wait_semaphores,
.profiling_begin = iree_hal_task_device_profiling_begin,
+ .profiling_flush = iree_hal_task_device_profiling_flush,
.profiling_end = iree_hal_task_device_profiling_end,
};
diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m
index 9d47a79..8cc963f 100644
--- a/runtime/src/iree/hal/drivers/metal/metal_device.m
+++ b/runtime/src/iree/hal/drivers/metal/metal_device.m
@@ -523,6 +523,10 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_metal_device_profiling_flush(iree_hal_device_t* base_device) {
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_metal_device_profiling_end(iree_hal_device_t* base_device) {
iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
if (device->capture_manager) {
@@ -559,5 +563,6 @@
.queue_flush = iree_hal_metal_device_queue_flush,
.wait_semaphores = iree_hal_metal_device_wait_semaphores,
.profiling_begin = iree_hal_metal_device_profiling_begin,
+ .profiling_flush = iree_hal_metal_device_profiling_flush,
.profiling_end = iree_hal_metal_device_profiling_end,
};
diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
index 77cf830..d93b259 100644
--- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
@@ -1489,6 +1489,28 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_vulkan_device_profiling_flush(
+ iree_hal_device_t* base_device) {
+ iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device);
+ (void)device;
+
+#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
+ if (iree_all_bits_set(device->logical_device->enabled_features(),
+ IREE_HAL_VULKAN_FEATURE_ENABLE_TRACING)) {
+ for (iree_host_size_t i = 0; i < device->queue_count; ++i) {
+ iree_hal_vulkan_tracing_context_t* tracing_context =
+ device->queues[i]->tracing_context();
+ if (tracing_context) {
+ iree_hal_vulkan_tracing_context_collect(tracing_context,
+ VK_NULL_HANDLE);
+ }
+ }
+ }
+#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE
+
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_vulkan_device_profiling_end(
iree_hal_device_t* base_device) {
iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device);
@@ -1545,6 +1567,7 @@
/*.queue_flush=*/iree_hal_vulkan_device_queue_flush,
/*.wait_semaphores=*/iree_hal_vulkan_device_wait_semaphores,
/*.profiling_begin=*/iree_hal_vulkan_device_profiling_begin,
+ /*.profiling_flush=*/iree_hal_vulkan_device_profiling_flush,
/*.profiling_end=*/iree_hal_vulkan_device_profiling_end,
};
} // namespace
diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc
index c658ce0..f69ef49 100644
--- a/tools/iree-benchmark-module-main.cc
+++ b/tools/iree-benchmark-module-main.cc
@@ -158,6 +158,7 @@
static void BenchmarkGenericFunction(const std::string& benchmark_name,
int32_t batch_size,
+ iree_hal_device_t* device,
iree_vm_context_t* context,
iree_vm_function_t function,
iree_vm_list_t* inputs,
@@ -179,6 +180,11 @@
inputs, outputs.get(), iree_allocator_system()));
IREE_CHECK_OK(iree_vm_list_resize(outputs.get(), 0));
IREE_TRACE_ZONE_END(z1);
+ if (device) {
+ state.PauseTiming();
+ IREE_CHECK_OK(iree_hal_device_profiling_flush(device));
+ state.ResumeTiming();
+ }
}
state.SetItemsProcessed(state.iterations());
@@ -186,6 +192,7 @@
}
void RegisterGenericBenchmark(const std::string& function_name,
+ iree_hal_device_t* device,
iree_vm_context_t* context,
iree_vm_function_t function,
iree_vm_list_t* inputs) {
@@ -194,8 +201,8 @@
benchmark::RegisterBenchmark(benchmark_name.c_str(),
[=](benchmark::State& state) -> void {
BenchmarkGenericFunction(
- benchmark_name, batch_size, context,
- function, inputs, state);
+ benchmark_name, batch_size, device,
+ context, function, inputs, state);
})
// By default only the main thread is included in CPU time. Include all
// the threads instead.
@@ -328,6 +335,9 @@
IREE_TRACE_ZONE_END(z_end);
IREE_TRACE_ZONE_END(z1);
+ if (device) {
+ IREE_CHECK_OK(iree_hal_device_profiling_flush(device));
+ }
state.ResumeTiming();
}
state.SetItemsProcessed(state.iterations());
@@ -502,8 +512,8 @@
function, inputs_.get());
} else {
// Synchronous invocation.
- iree::RegisterGenericBenchmark(function_name, context_.get(), function,
- inputs_.get());
+ iree::RegisterGenericBenchmark(function_name, device_.get(),
+ context_.get(), function, inputs_.get());
}
return iree_ok_status();
}
@@ -530,8 +540,8 @@
function);
} else if (iree_string_view_equal(benchmark_type, IREE_SV("entry"))) {
iree::RegisterGenericBenchmark(
- std::string(function_name.data, function_name.size), context_.get(),
- function,
+ std::string(function_name.data, function_name.size), device_.get(),
+ context_.get(), function,
/*inputs=*/nullptr);
} else {
// Pick up generic () -> () functions.
@@ -570,7 +580,7 @@
// anything).
iree::RegisterGenericBenchmark(
std::string(function_name.data, function_name.size),
- context_.get(), function,
+ device_.get(), context_.get(), function,
/*inputs=*/nullptr);
}
}