[metal] Construct argument buffer at dispatch recording time This is in preparation of upcoming changes to use one single device buffer for holding all argument buffers.
diff --git a/experimental/metal/direct_command_buffer.m b/experimental/metal/direct_command_buffer.m index 276a2f9..7117969 100644 --- a/experimental/metal/direct_command_buffer.m +++ b/experimental/metal/direct_command_buffer.m
@@ -61,10 +61,13 @@ } iree_hal_metal_barrier_segment_t; // + Additional inline allocation for holding all buffer barriers. -typedef struct iree_hal_metal_buffer_and_usage_t { - id<MTLBuffer> buffer; +typedef struct iree_hal_metal_descriptor_t { + uint32_t set; + uint32_t binding; + iree_hal_buffer_t* buffer; + iree_device_size_t offset; MTLResourceUsage usage; -} iree_hal_metal_buffer_and_usage_t; +} iree_hal_metal_descriptor_t; // API data for dispatch command segments. typedef struct iree_hal_metal_dispatch_segment_t { @@ -77,23 +80,17 @@ iree_device_size_t workgroups_offset; uint32_t workgroup_count[3]; - // The number of argument buffers for descriptor sets. - iree_host_size_t argument_buffer_count; - // The list of argument buffers, pointing to the end of the segment allocation. - id<MTLBuffer>* argument_buffers; - - // The number of buffer usage fields. - iree_host_size_t buffer_usage_count; - // The list of buffer usage fields, pointing to the end of the segment allocation. - iree_hal_metal_buffer_and_usage_t* buffer_usages; + // The number of descriptors bound for this dispatch. + iree_host_size_t descriptor_count; + // The list of bound descriptors, pointing to the end of the segment allocation. + iree_hal_metal_descriptor_t* descriptors; // The number of push constant values. iree_host_size_t push_constant_count; // The list of push constants, pointing to the end of the segment allocation. int32_t* push_constants; } iree_hal_metal_dispatch_segment_t; -// + Additional inline allocation for holding all argument buffers. -// + Additional inline allocation for holding all buffer usages. +// + Additional inline allocation for holding all bound descriptors. // + Additional inline allocation for holding all push constants. // API data for fill buffer command segments. @@ -159,13 +156,6 @@ // iree_hal_metal_command_buffer_t //===------------------------------------------------------------------------------------------===// -typedef struct iree_hal_metal_descriptor_t { - uint32_t set; - uint32_t binding; - iree_hal_buffer_t* buffer; - iree_host_size_t offset; -} iree_hal_metal_descriptor_t; - typedef struct iree_hal_metal_command_buffer_t { iree_hal_command_buffer_t base; @@ -859,6 +849,13 @@ return true; } +static inline MTLResourceUsage iree_hal_metal_get_metal_resource_usage( + iree_hal_descriptor_set_layout_binding_t* binding) { + MTLResourceUsage usage = MTLResourceUsageRead; + if (binding->flags != IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY) usage |= MTLResourceUsageWrite; + return usage; +} + static iree_status_t iree_hal_metal_command_buffer_push_descriptor_set( iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, iree_host_size_t binding_count, @@ -894,13 +891,21 @@ "exceeded available binding slots for push descriptor sets"); } + iree_hal_descriptor_set_layout_t* set_layout = + iree_hal_metal_pipeline_layout_descriptor_set_layout(pipeline_layout, set); + int start_index = command_buffer->current_total_binding_count; for (iree_host_size_t i = 0; i < binding_count; ++i) { iree_hal_metal_descriptor_t* descriptor = &descriptors[start_index + i]; + descriptor->set = set; descriptor->binding = bindings[i].binding; descriptor->buffer = bindings[i].buffer; descriptor->offset = bindings[i].offset; + + iree_hal_descriptor_set_layout_binding_t* binding_params = + iree_hal_metal_descriptor_set_layout_binding(set_layout, descriptor->binding); + descriptor->usage = iree_hal_metal_get_metal_resource_usage(binding_params); } qsort(&descriptors[start_index], binding_count, sizeof(descriptors[0]), compare_descriptor); @@ -926,13 +931,6 @@ return iree_ok_status(); } -static inline MTLResourceUsage iree_hal_metal_get_metal_resource_usage( - iree_hal_descriptor_set_layout_binding_t* binding) { - MTLResourceUsage usage = MTLResourceUsageRead; - if (binding->flags != IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY) usage |= MTLResourceUsageWrite; - return usage; -} - // Creates an argument encoder and its backing argument buffer for the given kernel |function|'s // |buffer_index|. The argument encoder will be set to encode into the newly created argument // buffer. Callers are expected to release both the argument encoder and buffer. @@ -983,15 +981,12 @@ // Allocate the command segment and keep track of all necessary API data. uint8_t* storage_base = NULL; iree_hal_metal_command_segment_t* segment = NULL; - iree_host_size_t descriptor_set_count = command_buffer->current_max_set_number + 1; - iree_host_size_t descriptor_set_length = descriptor_set_count * sizeof(id<MTLBuffer>); - iree_host_size_t buffer_usage_length = - command_buffer->current_total_binding_count * sizeof(iree_hal_metal_buffer_and_usage_t); + iree_host_size_t descriptor_count = command_buffer->current_total_binding_count; + iree_host_size_t descriptor_length = descriptor_count * sizeof(iree_hal_metal_descriptor_t); iree_host_size_t push_constant_count = iree_hal_metal_pipeline_layout_push_constant_count(command_buffer->current_pipeline_layout); iree_host_size_t push_constant_length = push_constant_count * sizeof(int32_t); - iree_host_size_t total_size = - sizeof(*segment) + descriptor_set_length + buffer_usage_length + push_constant_length; + iree_host_size_t total_size = sizeof(*segment) + descriptor_length + push_constant_length; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_arena_allocate(&command_buffer->arena, total_size, (void**)&storage_base)); @@ -1003,62 +998,15 @@ segment->dispatch.kernel_params = kernel_params; - // Get pointers to the argument buffer inline allocation. - segment->dispatch.argument_buffer_count = descriptor_set_count; - id<MTLBuffer>* descriptor_set_ptr = (id<MTLBuffer>*)(storage_base + sizeof(*segment)); - memset(descriptor_set_ptr, 0, descriptor_set_length); - segment->dispatch.argument_buffers = descriptor_set_ptr; - - // Get pointers to the buffer usage inline allocation. - segment->dispatch.buffer_usage_count = command_buffer->current_total_binding_count; - iree_hal_metal_buffer_and_usage_t* buffer_usage_ptr = - (iree_hal_metal_buffer_and_usage_t*)(storage_base + sizeof(*segment) + descriptor_set_length); - memset(buffer_usage_ptr, 0, buffer_usage_length); - segment->dispatch.buffer_usages = buffer_usage_ptr; - - // Build argument buffers for all descriptor sets and keep at the end of the current segment for - // later access. Also do the same for buffer usage information. - iree_hal_metal_descriptor_t* descriptors = command_buffer->current_descriptors; - int binding_count = command_buffer->current_total_binding_count; - int i = 0; - while (i < binding_count) { - // Build argument encoder and argument buffer for the current descriptor set. - uint32_t current_set = descriptors[i].set; - - id<MTLArgumentEncoder> argument_encoder; - id<MTLBuffer> argument_buffer; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_metal_create_argument_encoder( - command_buffer->command_buffer.device, command_buffer->command_buffer, - kernel_params.function, current_set, &argument_encoder, &argument_buffer)); - - iree_hal_descriptor_set_layout_t* set_layout = - iree_hal_metal_pipeline_layout_descriptor_set_layout( - command_buffer->current_pipeline_layout, current_set); - IREE_ASSERT(set_layout != NULL); - - // Now put all bound buffers belonging to the current descriptor set into the argument buffer. - for (; i < binding_count && descriptors[i].set == current_set; ++i) { - unsigned current_binding = descriptors[i].binding; - id<MTLBuffer> current_buffer = - iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer)); - iree_host_size_t offset = - iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset; - [argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding]; - - iree_hal_descriptor_set_layout_binding_t* binding_params = - iree_hal_metal_descriptor_set_layout_binding(set_layout, current_binding); - IREE_ASSERT(binding_params != NULL); - buffer_usage_ptr[i].buffer = current_buffer; - buffer_usage_ptr[i].usage = iree_hal_metal_get_metal_resource_usage(binding_params); - } - descriptor_set_ptr[current_set] = argument_buffer; - } + // Copy descriptors to the end of the current segment for later access. + segment->dispatch.descriptor_count = descriptor_count; + uint8_t* descriptor_ptr = storage_base + sizeof(*segment); + memcpy(descriptor_ptr, command_buffer->current_descriptors, descriptor_length); + segment->dispatch.descriptors = (iree_hal_metal_descriptor_t*)descriptor_ptr; // Copy push constants to the end of the current segment for later access. segment->dispatch.push_constant_count = push_constant_count; - uint8_t* push_constant_ptr = - storage_base + sizeof(*segment) + descriptor_set_length + buffer_usage_length; + uint8_t* push_constant_ptr = storage_base + sizeof(*segment) + descriptor_length; memcpy(push_constant_ptr, (const uint8_t*)command_buffer->push_constants, push_constant_length); segment->dispatch.push_constants = (int32_t*)push_constant_ptr; @@ -1083,16 +1031,34 @@ atIndex:IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX]; } - // Record buffer usages. - for (iree_host_size_t i = 0; i < segment->buffer_usage_count; ++i) { - [compute_encoder useResource:segment->buffer_usages[i].buffer - usage:segment->buffer_usages[i].usage]; - } + // Record argument buffers for all descriptors and record buffer usages. + iree_hal_metal_descriptor_t* descriptors = segment->descriptors; + iree_host_size_t i = 0; + while (i < segment->descriptor_count) { + uint32_t current_set = descriptors[i].set; - // Record all argument buffers. - for (iree_host_size_t i = 0; i < segment->argument_buffer_count; ++i) { - if (segment->argument_buffers[i] == nil) continue; - [compute_encoder setBuffer:segment->argument_buffers[i] offset:0 atIndex:i]; + // Build argument encoder and argument buffer for the current descriptor set. + id<MTLArgumentEncoder> argument_encoder; + id<MTLBuffer> argument_buffer; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_metal_create_argument_encoder( + command_buffer->command_buffer.device, command_buffer->command_buffer, + segment->kernel_params.function, current_set, &argument_encoder, &argument_buffer)); + + // Now record all bound buffers belonging to the current set into the argument buffer. + for (; i < segment->descriptor_count && descriptors[i].set == current_set; ++i) { + uint32_t current_binding = descriptors[i].binding; + id<MTLBuffer> current_buffer = + iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer)); + iree_host_size_t offset = + iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset; + [argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding]; + + // Also record buffer usages. + [compute_encoder useResource:current_buffer usage:descriptors[i].usage]; + } + // Record the argument buffer. + [compute_encoder setBuffer:argument_buffer offset:0 atIndex:current_set]; } // Record the dispatch, either direct or indirect.