[vulkan] Add flag to enable robust buffer access (#12547)
This is useful when debugging suspected out-of-bounds memory access
issues.
Issue: https://github.com/openxla/iree/issues/12415
diff --git a/runtime/src/iree/hal/drivers/vulkan/api.h b/runtime/src/iree/hal/drivers/vulkan/api.h
index 5638879..d6d8e6d 100644
--- a/runtime/src/iree/hal/drivers/vulkan/api.h
+++ b/runtime/src/iree/hal/drivers/vulkan/api.h
@@ -47,6 +47,14 @@
// identify slow dispatches and refine from there; be wary of whole-program
// tracing with this enabled.
IREE_HAL_VULKAN_FEATURE_ENABLE_TRACING = 1u << 2,
+
+ // Enables the `robustBufferAccess` physical device feature. This adds bounds
+ // checks to GPU memory accesses to make all accesses be in-bounds. This is
+ // only recommended for debugging purposes.
+ //
+ // NOTE: This affects the pipeline state and in turn may change the code
+ // generated by the Vulkan device compiler.
+ IREE_HAL_VULKAN_FEATURE_ENABLE_ROBUST_BUFFER_ACCESS = 1u << 3,
};
typedef uint32_t iree_hal_vulkan_features_t;
diff --git a/runtime/src/iree/hal/drivers/vulkan/registration/driver_module.cc b/runtime/src/iree/hal/drivers/vulkan/registration/driver_module.cc
index 537e79c..5427fac 100644
--- a/runtime/src/iree/hal/drivers/vulkan/registration/driver_module.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/registration/driver_module.cc
@@ -31,6 +31,9 @@
IREE_FLAG(bool, vulkan_tracing, true,
"Enables Vulkan tracing (if IREE tracing is enabled).");
+IREE_FLAG(bool, vulkan_robust_buffer_access, false,
+ "Enables the Vulkan 'robustBufferAccess' feature.");
+
IREE_FLAG(
bool, vulkan_dedicated_compute_queue, false,
"Use a dedicated queue with VK_QUEUE_COMPUTE_BIT for dispatch workloads.");
@@ -71,6 +74,10 @@
if (FLAG_vulkan_tracing) {
driver_options.requested_features |= IREE_HAL_VULKAN_FEATURE_ENABLE_TRACING;
}
+ if (FLAG_vulkan_robust_buffer_access) {
+ driver_options.requested_features |=
+ IREE_HAL_VULKAN_FEATURE_ENABLE_ROBUST_BUFFER_ACCESS;
+ }
if (FLAG_vulkan_dedicated_compute_queue) {
driver_options.device_options.flags |=
diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
index 32b21ed..a347cd1 100644
--- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
@@ -904,6 +904,16 @@
features2.features.shaderInt64 = VK_TRUE;
}
+ if (iree_all_bits_set(enabled_features,
+ IREE_HAL_VULKAN_FEATURE_ENABLE_ROBUST_BUFFER_ACCESS)) {
+ if (physical_device_features.robustBufferAccess != VK_TRUE) {
+ return iree_make_status(
+ IREE_STATUS_UNAVAILABLE,
+ "Robust buffer access not supported by physical device");
+ }
+ features2.features.robustBufferAccess = VK_TRUE;
+ }
+
VkPhysicalDeviceTimelineSemaphoreFeatures semaphore_features;
memset(&semaphore_features, 0, sizeof(semaphore_features));
semaphore_features.sType =