Fix Binding in ROCM Backend to match CUDA HAL (#6134)
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c
index 7dca043..9d56001 100644
--- a/experimental/rocm/direct_command_buffer.c
+++ b/experimental/rocm/direct_command_buffer.c
@@ -27,7 +27,7 @@
void *current_descriptor[];
} iree_hal_rocm_direct_command_buffer_t;
-static const size_t max_binding_count = 64;
+#define IREE_HAL_ROCM_MAX_BINDING_COUNT 64
extern const iree_hal_command_buffer_vtable_t
iree_hal_rocm_direct_command_buffer_vtable;
@@ -51,8 +51,8 @@
iree_hal_rocm_direct_command_buffer_t *command_buffer = NULL;
size_t total_size = sizeof(*command_buffer) +
- max_binding_count * sizeof(void *) +
- max_binding_count * sizeof(hipDeviceptr_t);
+ IREE_HAL_ROCM_MAX_BINDING_COUNT * sizeof(void *) +
+ IREE_HAL_ROCM_MAX_BINDING_COUNT * sizeof(hipDeviceptr_t);
iree_status_t status = iree_allocator_malloc(
context->host_allocator, total_size, (void **)&command_buffer);
if (iree_status_is_ok(status)) {
@@ -64,8 +64,8 @@
command_buffer->queue_affinity = queue_affinity;
hipDeviceptr_t *device_ptrs =
(hipDeviceptr_t *)(command_buffer->current_descriptor +
- max_binding_count);
- for (size_t i = 0; i < max_binding_count; i++) {
+ IREE_HAL_ROCM_MAX_BINDING_COUNT);
+ for (size_t i = 0; i < IREE_HAL_ROCM_MAX_BINDING_COUNT; i++) {
command_buffer->current_descriptor[i] = &device_ptrs[i];
}
command_buffer->total_size = total_size;
@@ -244,6 +244,21 @@
"need rocm implementation");
}
+// Tie together the binding index and its index in |bindings| array.
+typedef struct {
+ uint32_t index;
+ uint32_t binding;
+} iree_hal_rocm_binding_mapping_t;
+
+// Helper to sort the binding based on their binding index.
+static int compare_binding_index(const void *a, const void *b) {
+ const iree_hal_rocm_binding_mapping_t buffer_a =
+ *(const iree_hal_rocm_binding_mapping_t *)a;
+ const iree_hal_rocm_binding_mapping_t buffer_b =
+ *(const iree_hal_rocm_binding_mapping_t *)b;
+ return buffer_a.binding < buffer_b.binding ? -1 : 1;
+}
+
static iree_status_t iree_hal_rocm_direct_command_buffer_push_descriptor_set(
iree_hal_command_buffer_t *base_command_buffer,
iree_hal_executable_layout_t *executable_layout, uint32_t set,
@@ -251,16 +266,27 @@
const iree_hal_descriptor_set_binding_t *bindings) {
iree_hal_rocm_direct_command_buffer_t *command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+ // Convention with the compiler side. We map bindings to kernel argument.
+ // We compact the bindings to get a dense set of arguments and keep them order
+ // based on the binding index.
+ // Sort the binding based on the binding index and map the array index to the
+ // argument index.
+ iree_hal_rocm_binding_mapping_t binding_used[IREE_HAL_ROCM_MAX_BINDING_COUNT];
for (iree_host_size_t i = 0; i < binding_count; i++) {
- uint32_t arg_index = bindings[i].binding;
- assert(arg_index < max_binding_count &&
- "binding index larger than the max expected.");
+ iree_hal_rocm_binding_mapping_t buffer = {i, bindings[i].binding};
+ binding_used[i] = buffer;
+ }
+ qsort(binding_used, binding_count, sizeof(iree_hal_rocm_binding_mapping_t),
+ compare_binding_index);
+ assert(binding_count < IREE_HAL_ROCM_MAX_BINDING_COUNT &&
+ "binding count larger than the max expected.");
+ for (iree_host_size_t i = 0; i < binding_count; i++) {
+ iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index];
hipDeviceptr_t device_ptr =
iree_hal_rocm_buffer_device_pointer(
- iree_hal_buffer_allocated_buffer(bindings[i].buffer)) +
- iree_hal_buffer_byte_offset(bindings[i].buffer) + bindings[i].offset;
- *((hipDeviceptr_t *)command_buffer->current_descriptor[arg_index]) =
- device_ptr;
+ iree_hal_buffer_allocated_buffer(binding.buffer)) +
+ iree_hal_buffer_byte_offset(binding.buffer) + binding.offset;
+ *((hipDeviceptr_t *)command_buffer->current_descriptor[i]) = device_ptr;
}
return iree_ok_status();
}