Fix Binding in ROCM Backend to match CUDA HAL (#6134)

diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c
index 7dca043..9d56001 100644
--- a/experimental/rocm/direct_command_buffer.c
+++ b/experimental/rocm/direct_command_buffer.c
@@ -27,7 +27,7 @@
   void *current_descriptor[];
 } iree_hal_rocm_direct_command_buffer_t;
 
-static const size_t max_binding_count = 64;
+#define IREE_HAL_ROCM_MAX_BINDING_COUNT 64
 
 extern const iree_hal_command_buffer_vtable_t
     iree_hal_rocm_direct_command_buffer_vtable;
@@ -51,8 +51,8 @@
 
   iree_hal_rocm_direct_command_buffer_t *command_buffer = NULL;
   size_t total_size = sizeof(*command_buffer) +
-                      max_binding_count * sizeof(void *) +
-                      max_binding_count * sizeof(hipDeviceptr_t);
+                      IREE_HAL_ROCM_MAX_BINDING_COUNT * sizeof(void *) +
+                      IREE_HAL_ROCM_MAX_BINDING_COUNT * sizeof(hipDeviceptr_t);
   iree_status_t status = iree_allocator_malloc(
       context->host_allocator, total_size, (void **)&command_buffer);
   if (iree_status_is_ok(status)) {
@@ -64,8 +64,8 @@
     command_buffer->queue_affinity = queue_affinity;
     hipDeviceptr_t *device_ptrs =
         (hipDeviceptr_t *)(command_buffer->current_descriptor +
-                           max_binding_count);
-    for (size_t i = 0; i < max_binding_count; i++) {
+                           IREE_HAL_ROCM_MAX_BINDING_COUNT);
+    for (size_t i = 0; i < IREE_HAL_ROCM_MAX_BINDING_COUNT; i++) {
       command_buffer->current_descriptor[i] = &device_ptrs[i];
     }
     command_buffer->total_size = total_size;
@@ -244,6 +244,21 @@
                           "need rocm implementation");
 }
 
+// Tie together the binding index and its index in |bindings| array.
+typedef struct {
+  uint32_t index;
+  uint32_t binding;
+} iree_hal_rocm_binding_mapping_t;
+
+// Helper to sort the binding based on their binding index.
+static int compare_binding_index(const void *a, const void *b) {
+  const iree_hal_rocm_binding_mapping_t buffer_a =
+      *(const iree_hal_rocm_binding_mapping_t *)a;
+  const iree_hal_rocm_binding_mapping_t buffer_b =
+      *(const iree_hal_rocm_binding_mapping_t *)b;
+  return buffer_a.binding < buffer_b.binding ? -1 : 1;
+}
+
 static iree_status_t iree_hal_rocm_direct_command_buffer_push_descriptor_set(
     iree_hal_command_buffer_t *base_command_buffer,
     iree_hal_executable_layout_t *executable_layout, uint32_t set,
@@ -251,16 +266,27 @@
     const iree_hal_descriptor_set_binding_t *bindings) {
   iree_hal_rocm_direct_command_buffer_t *command_buffer =
       iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+  // Convention with the compiler side. We map bindings to kernel argument.
+  // We compact the bindings to get a dense set of arguments and keep them order
+  // based on the binding index.
+  // Sort the binding based on the binding index and map the array index to the
+  // argument index.
+  iree_hal_rocm_binding_mapping_t binding_used[IREE_HAL_ROCM_MAX_BINDING_COUNT];
   for (iree_host_size_t i = 0; i < binding_count; i++) {
-    uint32_t arg_index = bindings[i].binding;
-    assert(arg_index < max_binding_count &&
-           "binding index larger than the max expected.");
+    iree_hal_rocm_binding_mapping_t buffer = {i, bindings[i].binding};
+    binding_used[i] = buffer;
+  }
+  qsort(binding_used, binding_count, sizeof(iree_hal_rocm_binding_mapping_t),
+        compare_binding_index);
+  assert(binding_count < IREE_HAL_ROCM_MAX_BINDING_COUNT &&
+         "binding count larger than the max expected.");
+  for (iree_host_size_t i = 0; i < binding_count; i++) {
+    iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index];
     hipDeviceptr_t device_ptr =
         iree_hal_rocm_buffer_device_pointer(
-            iree_hal_buffer_allocated_buffer(bindings[i].buffer)) +
-        iree_hal_buffer_byte_offset(bindings[i].buffer) + bindings[i].offset;
-    *((hipDeviceptr_t *)command_buffer->current_descriptor[arg_index]) =
-        device_ptr;
+            iree_hal_buffer_allocated_buffer(binding.buffer)) +
+        iree_hal_buffer_byte_offset(binding.buffer) + binding.offset;
+    *((hipDeviceptr_t *)command_buffer->current_descriptor[i]) = device_ptr;
   }
   return iree_ok_status();
 }