[metal] Construct argument buffer at dispatch recording time

This is in preparation of upcoming changes to use one single
device buffer for holding all argument buffers.
diff --git a/experimental/metal/direct_command_buffer.m b/experimental/metal/direct_command_buffer.m
index 276a2f9..7117969 100644
--- a/experimental/metal/direct_command_buffer.m
+++ b/experimental/metal/direct_command_buffer.m
@@ -61,10 +61,13 @@
 } iree_hal_metal_barrier_segment_t;
 // + Additional inline allocation for holding all buffer barriers.
 
-typedef struct iree_hal_metal_buffer_and_usage_t {
-  id<MTLBuffer> buffer;
+typedef struct iree_hal_metal_descriptor_t {
+  uint32_t set;
+  uint32_t binding;
+  iree_hal_buffer_t* buffer;
+  iree_device_size_t offset;
   MTLResourceUsage usage;
-} iree_hal_metal_buffer_and_usage_t;
+} iree_hal_metal_descriptor_t;
 
 // API data for dispatch command segments.
 typedef struct iree_hal_metal_dispatch_segment_t {
@@ -77,23 +80,17 @@
   iree_device_size_t workgroups_offset;
   uint32_t workgroup_count[3];
 
-  // The number of argument buffers for descriptor sets.
-  iree_host_size_t argument_buffer_count;
-  // The list of argument buffers, pointing to the end of the segment allocation.
-  id<MTLBuffer>* argument_buffers;
-
-  // The number of buffer usage fields.
-  iree_host_size_t buffer_usage_count;
-  // The list of buffer usage fields, pointing to the end of the segment allocation.
-  iree_hal_metal_buffer_and_usage_t* buffer_usages;
+  // The number of descriptors bound for this dispatch.
+  iree_host_size_t descriptor_count;
+  // The list of bound descriptors, pointing to the end of the segment allocation.
+  iree_hal_metal_descriptor_t* descriptors;
 
   // The number of push constant values.
   iree_host_size_t push_constant_count;
   // The list of push constants, pointing to the end of the segment allocation.
   int32_t* push_constants;
 } iree_hal_metal_dispatch_segment_t;
-// + Additional inline allocation for holding all argument buffers.
-// + Additional inline allocation for holding all buffer usages.
+// + Additional inline allocation for holding all bound descriptors.
 // + Additional inline allocation for holding all push constants.
 
 // API data for fill buffer command segments.
@@ -159,13 +156,6 @@
 // iree_hal_metal_command_buffer_t
 //===------------------------------------------------------------------------------------------===//
 
-typedef struct iree_hal_metal_descriptor_t {
-  uint32_t set;
-  uint32_t binding;
-  iree_hal_buffer_t* buffer;
-  iree_host_size_t offset;
-} iree_hal_metal_descriptor_t;
-
 typedef struct iree_hal_metal_command_buffer_t {
   iree_hal_command_buffer_t base;
 
@@ -859,6 +849,13 @@
   return true;
 }
 
+static inline MTLResourceUsage iree_hal_metal_get_metal_resource_usage(
+    iree_hal_descriptor_set_layout_binding_t* binding) {
+  MTLResourceUsage usage = MTLResourceUsageRead;
+  if (binding->flags != IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY) usage |= MTLResourceUsageWrite;
+  return usage;
+}
+
 static iree_status_t iree_hal_metal_command_buffer_push_descriptor_set(
     iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout,
     uint32_t set, iree_host_size_t binding_count,
@@ -894,13 +891,21 @@
                             "exceeded available binding slots for push descriptor sets");
   }
 
+  iree_hal_descriptor_set_layout_t* set_layout =
+      iree_hal_metal_pipeline_layout_descriptor_set_layout(pipeline_layout, set);
+
   int start_index = command_buffer->current_total_binding_count;
   for (iree_host_size_t i = 0; i < binding_count; ++i) {
     iree_hal_metal_descriptor_t* descriptor = &descriptors[start_index + i];
+
     descriptor->set = set;
     descriptor->binding = bindings[i].binding;
     descriptor->buffer = bindings[i].buffer;
     descriptor->offset = bindings[i].offset;
+
+    iree_hal_descriptor_set_layout_binding_t* binding_params =
+        iree_hal_metal_descriptor_set_layout_binding(set_layout, descriptor->binding);
+    descriptor->usage = iree_hal_metal_get_metal_resource_usage(binding_params);
   }
   qsort(&descriptors[start_index], binding_count, sizeof(descriptors[0]), compare_descriptor);
 
@@ -926,13 +931,6 @@
   return iree_ok_status();
 }
 
-static inline MTLResourceUsage iree_hal_metal_get_metal_resource_usage(
-    iree_hal_descriptor_set_layout_binding_t* binding) {
-  MTLResourceUsage usage = MTLResourceUsageRead;
-  if (binding->flags != IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY) usage |= MTLResourceUsageWrite;
-  return usage;
-}
-
 // Creates an argument encoder and its backing argument buffer for the given kernel |function|'s
 // |buffer_index|. The argument encoder will be set to encode into the newly created argument
 // buffer. Callers are expected to release both the argument encoder and buffer.
@@ -983,15 +981,12 @@
   // Allocate the command segment and keep track of all necessary API data.
   uint8_t* storage_base = NULL;
   iree_hal_metal_command_segment_t* segment = NULL;
-  iree_host_size_t descriptor_set_count = command_buffer->current_max_set_number + 1;
-  iree_host_size_t descriptor_set_length = descriptor_set_count * sizeof(id<MTLBuffer>);
-  iree_host_size_t buffer_usage_length =
-      command_buffer->current_total_binding_count * sizeof(iree_hal_metal_buffer_and_usage_t);
+  iree_host_size_t descriptor_count = command_buffer->current_total_binding_count;
+  iree_host_size_t descriptor_length = descriptor_count * sizeof(iree_hal_metal_descriptor_t);
   iree_host_size_t push_constant_count =
       iree_hal_metal_pipeline_layout_push_constant_count(command_buffer->current_pipeline_layout);
   iree_host_size_t push_constant_length = push_constant_count * sizeof(int32_t);
-  iree_host_size_t total_size =
-      sizeof(*segment) + descriptor_set_length + buffer_usage_length + push_constant_length;
+  iree_host_size_t total_size = sizeof(*segment) + descriptor_length + push_constant_length;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
       z0, iree_arena_allocate(&command_buffer->arena, total_size, (void**)&storage_base));
 
@@ -1003,62 +998,15 @@
 
   segment->dispatch.kernel_params = kernel_params;
 
-  // Get pointers to the argument buffer inline allocation.
-  segment->dispatch.argument_buffer_count = descriptor_set_count;
-  id<MTLBuffer>* descriptor_set_ptr = (id<MTLBuffer>*)(storage_base + sizeof(*segment));
-  memset(descriptor_set_ptr, 0, descriptor_set_length);
-  segment->dispatch.argument_buffers = descriptor_set_ptr;
-
-  // Get pointers to the buffer usage inline allocation.
-  segment->dispatch.buffer_usage_count = command_buffer->current_total_binding_count;
-  iree_hal_metal_buffer_and_usage_t* buffer_usage_ptr =
-      (iree_hal_metal_buffer_and_usage_t*)(storage_base + sizeof(*segment) + descriptor_set_length);
-  memset(buffer_usage_ptr, 0, buffer_usage_length);
-  segment->dispatch.buffer_usages = buffer_usage_ptr;
-
-  // Build argument buffers for all descriptor sets and keep at the end of the current segment for
-  // later access. Also do the same for buffer usage information.
-  iree_hal_metal_descriptor_t* descriptors = command_buffer->current_descriptors;
-  int binding_count = command_buffer->current_total_binding_count;
-  int i = 0;
-  while (i < binding_count) {
-    // Build argument encoder and argument buffer for the current descriptor set.
-    uint32_t current_set = descriptors[i].set;
-
-    id<MTLArgumentEncoder> argument_encoder;
-    id<MTLBuffer> argument_buffer;
-    IREE_RETURN_AND_END_ZONE_IF_ERROR(
-        z0, iree_hal_metal_create_argument_encoder(
-                command_buffer->command_buffer.device, command_buffer->command_buffer,
-                kernel_params.function, current_set, &argument_encoder, &argument_buffer));
-
-    iree_hal_descriptor_set_layout_t* set_layout =
-        iree_hal_metal_pipeline_layout_descriptor_set_layout(
-            command_buffer->current_pipeline_layout, current_set);
-    IREE_ASSERT(set_layout != NULL);
-
-    // Now put all bound buffers belonging to the current descriptor set into the argument buffer.
-    for (; i < binding_count && descriptors[i].set == current_set; ++i) {
-      unsigned current_binding = descriptors[i].binding;
-      id<MTLBuffer> current_buffer =
-          iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer));
-      iree_host_size_t offset =
-          iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset;
-      [argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding];
-
-      iree_hal_descriptor_set_layout_binding_t* binding_params =
-          iree_hal_metal_descriptor_set_layout_binding(set_layout, current_binding);
-      IREE_ASSERT(binding_params != NULL);
-      buffer_usage_ptr[i].buffer = current_buffer;
-      buffer_usage_ptr[i].usage = iree_hal_metal_get_metal_resource_usage(binding_params);
-    }
-    descriptor_set_ptr[current_set] = argument_buffer;
-  }
+  // Copy descriptors to the end of the current segment for later access.
+  segment->dispatch.descriptor_count = descriptor_count;
+  uint8_t* descriptor_ptr = storage_base + sizeof(*segment);
+  memcpy(descriptor_ptr, command_buffer->current_descriptors, descriptor_length);
+  segment->dispatch.descriptors = (iree_hal_metal_descriptor_t*)descriptor_ptr;
 
   // Copy push constants to the end of the current segment for later access.
   segment->dispatch.push_constant_count = push_constant_count;
-  uint8_t* push_constant_ptr =
-      storage_base + sizeof(*segment) + descriptor_set_length + buffer_usage_length;
+  uint8_t* push_constant_ptr = storage_base + sizeof(*segment) + descriptor_length;
   memcpy(push_constant_ptr, (const uint8_t*)command_buffer->push_constants, push_constant_length);
   segment->dispatch.push_constants = (int32_t*)push_constant_ptr;
 
@@ -1083,16 +1031,34 @@
                       atIndex:IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX];
   }
 
-  // Record buffer usages.
-  for (iree_host_size_t i = 0; i < segment->buffer_usage_count; ++i) {
-    [compute_encoder useResource:segment->buffer_usages[i].buffer
-                           usage:segment->buffer_usages[i].usage];
-  }
+  // Record argument buffers for all descriptors and record buffer usages.
+  iree_hal_metal_descriptor_t* descriptors = segment->descriptors;
+  iree_host_size_t i = 0;
+  while (i < segment->descriptor_count) {
+    uint32_t current_set = descriptors[i].set;
 
-  // Record all argument buffers.
-  for (iree_host_size_t i = 0; i < segment->argument_buffer_count; ++i) {
-    if (segment->argument_buffers[i] == nil) continue;
-    [compute_encoder setBuffer:segment->argument_buffers[i] offset:0 atIndex:i];
+    // Build argument encoder and argument buffer for the current descriptor set.
+    id<MTLArgumentEncoder> argument_encoder;
+    id<MTLBuffer> argument_buffer;
+    IREE_RETURN_AND_END_ZONE_IF_ERROR(
+        z0, iree_hal_metal_create_argument_encoder(
+                command_buffer->command_buffer.device, command_buffer->command_buffer,
+                segment->kernel_params.function, current_set, &argument_encoder, &argument_buffer));
+
+    // Now record all bound buffers belonging to the current set into the argument buffer.
+    for (; i < segment->descriptor_count && descriptors[i].set == current_set; ++i) {
+      uint32_t current_binding = descriptors[i].binding;
+      id<MTLBuffer> current_buffer =
+          iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer));
+      iree_host_size_t offset =
+          iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset;
+      [argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding];
+
+      // Also record buffer usages.
+      [compute_encoder useResource:current_buffer usage:descriptors[i].usage];
+    }
+    // Record the argument buffer.
+    [compute_encoder setBuffer:argument_buffer offset:0 atIndex:current_set];
   }
 
   // Record the dispatch, either direct or indirect.