| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "experimental/metal/direct_command_buffer.h" |
| |
| #import <Metal/Metal.h> |
| |
| #include "experimental/metal/builtin_executables.h" |
| #include "experimental/metal/metal_buffer.h" |
| #include "experimental/metal/metal_device.h" |
| #include "experimental/metal/metal_kernel_library.h" |
| #include "experimental/metal/pipeline_layout.h" |
| #include "iree/base/api.h" |
| #include "iree/base/target_platform.h" |
| #include "iree/base/tracing.h" |
| #include "iree/hal/api.h" |
| #include "iree/hal/utils/resource_set.h" |
| |
| typedef struct iree_hal_metal_descriptor_t { |
| uint32_t set; |
| uint32_t binding; |
| iree_hal_buffer_t* buffer; |
| iree_host_size_t offset; |
| |
| } iree_hal_metal_descriptor_t; |
| |
| typedef struct iree_hal_metal_command_buffer_t { |
| iree_hal_command_buffer_t base; |
| |
| // The Metal command queue owning this command buffer. |
| id<MTLCommandQueue> queue; |
| |
| // For polyfilling fill/copy/update buffers that are not directly supported by Metal APIs. |
| iree_hal_metal_builtin_executable_t* builtin_executable; |
| |
| id<MTLCommandBuffer> command_buffer; |
| |
| MTLDispatchType dispatch_type; |
| |
| // The current active compute/blit encoders for encoding compute for memory operations. |
| // Metal commands are encoded into the command buffer with such encoders, and each encoder can |
| // only encode the specific type of operations it supports. |
| id<MTLComputeCommandEncoder> compute_encoder; |
| id<MTLBlitCommandEncoder> blit_encoder; |
| |
| // MTLEven used for synchronization when we switch between blit and compute encoders. |
| // Normally we would use MTLFence objects, but the difference between IREE HAL and Metal API means |
| // we may see many encoder switches. It would require creating a lot GPU objects. In order to |
| // avoid the cost, we just use one MTLEvent with different values for different switches. |
| id<MTLEvent> encoder_event; |
| // The next available encoder event value to signal/wait to/on. |
| uint64_t next_encoder_event_value; |
| |
| // Metal APIs mandate we create argument bufffers (for descriptor sets) from compiled kernel |
| // function. That means we need to bind the compute kernel first before setting descriptors and |
| // binding buffers. So we need to cache the descriptor information by ourselves and apply them in |
| // a delayed manner. |
| |
| // A sorted flat list of descriptors from all pushed descriptor sets. |
| iree_hal_metal_descriptor_t current_descriptors[IREE_HAL_METAL_MAX_BINDING_COUNT]; |
| // The total used slot count / next unused slot index in |current_descriptors|. |
| int current_total_binding_count; |
| // The max descriptor set number we have seen thus far. |
| int current_max_set_number; |
| |
| // All available push constants updated each time push_constants is called. Reset only with the |
| // command buffer and otherwise will maintain its values during recording to allow for partial |
| // push_constants updates. |
| int32_t push_constants[IREE_HAL_METAL_MAX_PUSH_CONSTANT_COUNT]; |
| |
| // The current pipeline layout used for push descriptors. |
| iree_hal_pipeline_layout_t* current_pipeline_layout; |
| |
| iree_allocator_t host_allocator; |
| |
| // Maintains a reference to all resources used within the command buffer. Resets on each begin. |
| iree_hal_resource_set_t* resource_set; |
| } iree_hal_metal_command_buffer_t; |
| |
| static const iree_hal_command_buffer_vtable_t iree_hal_metal_command_buffer_vtable; |
| |
| static iree_hal_metal_command_buffer_t* iree_hal_metal_command_buffer_cast( |
| iree_hal_command_buffer_t* base_value) { |
| IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_command_buffer_vtable); |
| return (iree_hal_metal_command_buffer_t*)base_value; |
| } |
| |
| id<MTLCommandBuffer> iree_hal_metal_direct_command_buffer_handle( |
| iree_hal_command_buffer_t* base_command_buffer) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| return command_buffer->command_buffer; |
| } |
| |
| static void iree_hal_metal_end_compute_encoder(iree_hal_metal_command_buffer_t* command_buffer) { |
| if (command_buffer->compute_encoder) { |
| [command_buffer->compute_encoder endEncoding]; |
| [command_buffer->compute_encoder release]; // -1 |
| command_buffer->compute_encoder = nil; |
| } |
| } |
| |
| static void iree_hal_metal_end_blit_encoder(iree_hal_metal_command_buffer_t* command_buffer) { |
| if (command_buffer->blit_encoder) { |
| [command_buffer->blit_encoder endEncoding]; |
| [command_buffer->blit_encoder release]; // -1 |
| command_buffer->blit_encoder = nil; |
| } |
| } |
| |
| static id<MTLComputeCommandEncoder> iree_hal_metal_get_or_begin_compute_encoder( |
| iree_hal_metal_command_buffer_t* command_buffer) { |
| id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer; |
| |
| // If we are switching encoders, we would need to use a fence to synchronize "one or more |
| // resources across different passes within a command buffer." |
| // https://developer.apple.com/documentation/metal/resource_synchronization |
| uint64_t encoder_event_value = 0; |
| if (command_buffer->blit_encoder) { |
| iree_hal_metal_end_blit_encoder(command_buffer); |
| encoder_event_value = command_buffer->next_encoder_event_value++; |
| [metal_handle encodeSignalEvent:command_buffer->encoder_event value:encoder_event_value]; |
| } |
| |
| if (!command_buffer->compute_encoder) { |
| if (encoder_event_value != 0) { |
| [metal_handle encodeWaitForEvent:command_buffer->encoder_event value:encoder_event_value]; |
| } |
| @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation. |
| // We manage commands dependencies and insert barriers explicitly in IREE; so use the |
| // concurrent dispatch type for compute encoders. |
| command_buffer->compute_encoder = [[metal_handle |
| computeCommandEncoderWithDispatchType:command_buffer->dispatch_type] retain]; // +1 |
| } |
| } |
| |
| return command_buffer->compute_encoder; |
| } |
| |
| static id<MTLBlitCommandEncoder> iree_hal_metal_get_or_begin_blit_encoder( |
| iree_hal_metal_command_buffer_t* command_buffer) { |
| id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer; |
| |
| // If we are switching encoders, we would need to use a fence to synchronize "one or more |
| // resources across different passes within a command buffer." |
| // https://developer.apple.com/documentation/metal/resource_synchronization |
| uint64_t encoder_event_value = 0; |
| if (command_buffer->compute_encoder) { |
| iree_hal_metal_end_compute_encoder(command_buffer); |
| encoder_event_value = command_buffer->next_encoder_event_value++; |
| [metal_handle encodeSignalEvent:command_buffer->encoder_event value:encoder_event_value]; |
| } |
| |
| if (!command_buffer->blit_encoder) { |
| if (encoder_event_value != 0) { |
| [metal_handle encodeWaitForEvent:command_buffer->encoder_event value:encoder_event_value]; |
| } |
| @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation. |
| command_buffer->blit_encoder = [[metal_handle blitCommandEncoder] retain]; // +1 |
| } |
| } |
| |
| return command_buffer->blit_encoder; |
| } |
| |
| iree_status_t iree_hal_metal_direct_command_buffer_create( |
| iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode, |
| iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, |
| iree_hal_metal_command_buffer_resource_reference_mode_t resource_reference_mode, |
| id<MTLCommandQueue> queue, iree_allocator_t host_allocator, iree_arena_block_pool_t* block_pool, |
| iree_hal_metal_builtin_executable_t* builtin_executable, |
| iree_hal_command_buffer_t** out_command_buffer) { |
| IREE_ASSERT_ARGUMENT(device); |
| IREE_ASSERT_ARGUMENT(out_command_buffer); |
| IREE_ASSERT_TRUE(iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)); |
| IREE_ASSERT_TRUE(!iree_any_bit_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED)); |
| *out_command_buffer = NULL; |
| |
| if (binding_capacity > 0) { |
| // TODO(#10144): support indirect command buffers with binding tables. |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplemented indirect command buffers"); |
| } |
| |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_hal_metal_command_buffer_t* command_buffer = NULL; |
| iree_status_t status = |
| iree_allocator_malloc(host_allocator, sizeof(*command_buffer), (void**)&command_buffer); |
| if (iree_status_is_ok(status)) { |
| iree_hal_command_buffer_initialize( |
| device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, binding_capacity, |
| &iree_hal_metal_command_buffer_vtable, &command_buffer->base); |
| command_buffer->queue = [queue retain]; // +1 |
| command_buffer->builtin_executable = builtin_executable; |
| @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation. |
| // We track resource lifetime by ourselves in IREE; so just do unretained references to |
| // resources in Metal command buffer, which avoids overhead and gives better performance. |
| MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new]; // +1 |
| descriptor.retainedReferences = |
| resource_reference_mode == IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED; |
| descriptor.errorOptions = MTLCommandBufferErrorOptionNone; |
| command_buffer->command_buffer = |
| [[queue commandBufferWithDescriptor:descriptor] retain]; // +1 |
| [descriptor release]; // -1 |
| } |
| const iree_hal_metal_device_params_t* params = iree_hal_metal_device_params(device); |
| command_buffer->dispatch_type = |
| params->command_dispatch_type == IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT |
| ? MTLDispatchTypeConcurrent |
| : MTLDispatchTypeSerial; |
| command_buffer->compute_encoder = nil; |
| command_buffer->blit_encoder = nil; |
| command_buffer->encoder_event = [queue.device newEvent]; // +1 |
| command_buffer->next_encoder_event_value = 1; |
| memset(command_buffer->current_descriptors, 0, |
| IREE_HAL_METAL_MAX_BINDING_COUNT * sizeof(command_buffer->current_descriptors[0])); |
| command_buffer->current_total_binding_count = 0; |
| command_buffer->current_max_set_number = -1; |
| memset(command_buffer->push_constants, 0, sizeof(command_buffer->push_constants)); |
| command_buffer->current_pipeline_layout = NULL; |
| command_buffer->host_allocator = host_allocator; |
| status = iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set); |
| } |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static void iree_hal_metal_command_buffer_destroy(iree_hal_command_buffer_t* base_command_buffer) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| [command_buffer->encoder_event release]; // -1 |
| IREE_ASSERT_EQ(command_buffer->compute_encoder, nil); |
| IREE_ASSERT_EQ(command_buffer->blit_encoder, nil); |
| [command_buffer->command_buffer release]; // -1 |
| [command_buffer->queue release]; // -1 |
| iree_hal_resource_set_free(command_buffer->resource_set); |
| iree_allocator_free(command_buffer->host_allocator, command_buffer); |
| |
| IREE_TRACE_ZONE_END(z0); |
| } |
| |
| bool iree_hal_metal_command_buffer_isa(iree_hal_command_buffer_t* command_buffer) { |
| return iree_hal_resource_is(&command_buffer->resource, &iree_hal_metal_command_buffer_vtable); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_begin( |
| iree_hal_command_buffer_t* base_command_buffer) { |
| // Nothing to do. |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_end( |
| iree_hal_command_buffer_t* base_command_buffer) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| iree_hal_metal_end_blit_encoder(command_buffer); |
| iree_hal_metal_end_compute_encoder(command_buffer); |
| return iree_ok_status(); |
| } |
| |
| static void iree_hal_metal_command_buffer_begin_debug_group( |
| iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, |
| iree_hal_label_color_t label_color, const iree_hal_label_location_t* location) { |
| // TODO(antiagainst): implement support for debug group |
| } |
| |
| static void iree_hal_metal_command_buffer_end_debug_group( |
| iree_hal_command_buffer_t* base_command_buffer) { |
| // TODO(antiagainst): implement support for debug group |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_execution_barrier( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_execution_stage_t source_stage_mask, |
| iree_hal_execution_stage_t target_stage_mask, iree_hal_execution_barrier_flags_t flags, |
| iree_host_size_t memory_barrier_count, const iree_hal_memory_barrier_t* memory_barriers, |
| iree_host_size_t buffer_barrier_count, const iree_hal_buffer_barrier_t* buffer_barriers) { |
| if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || |
| iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "barrier involving host not yet supported"); |
| } |
| |
| if (flags != IREE_HAL_EXECUTION_BARRIER_FLAG_NONE) { |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "non-zero barrier flag not yet supported"); |
| } |
| |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| id<MTLComputeCommandEncoder> encoder = |
| iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
| |
| if (memory_barrier_count != 0) { |
| // If there is a memory barrier specified, we have to place a catch-all barrier for all buffers. |
| // Metal does not provide a more fine-grained control here. |
| [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; |
| return iree_ok_status(); |
| } |
| |
| if (buffer_barrier_count != 0) { |
| // But we do have the option to specify a list of buffers to synchronize if only buffer barriers |
| // are specified. |
| id<MTLResource>* resources = |
| (id<MTLResource>*)iree_alloca(sizeof(id<MTLResource>) * buffer_barrier_count); |
| for (unsigned i = 0; i < buffer_barrier_count; ++i) { |
| resources[i] = |
| iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(buffer_barriers[i].buffer)); |
| } |
| [encoder memoryBarrierWithResources:resources count:buffer_barrier_count]; |
| } |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_signal_event( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, |
| iree_hal_execution_stage_t source_stage_mask) { |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_reset_event( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, |
| iree_hal_execution_stage_t source_stage_mask) { |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_wait_events( |
| iree_hal_command_buffer_t* base_command_buffer, iree_host_size_t event_count, |
| const iree_hal_event_t** events, iree_hal_execution_stage_t source_stage_mask, |
| iree_hal_execution_stage_t target_stage_mask, iree_host_size_t memory_barrier_count, |
| const iree_hal_memory_barrier_t* memory_barriers, iree_host_size_t buffer_barrier_count, |
| const iree_hal_buffer_barrier_t* buffer_barriers) { |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_discard_buffer( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { |
| // This is a hint to the device and we have nothing to do for Metal. |
| return iree_ok_status(); |
| } |
| |
| // Fills |value| with the duplicated single byte value and return true if the given |pattern| has |
| // duplicated values for each of its |pattern_length| bytes. |
| static bool iree_hal_metal_get_duplicated_single_byte_value(const void* pattern, |
| size_t pattern_length, uint8_t* value) { |
| switch (pattern_length) { |
| case 1: { |
| *value = *(uint8_t*)pattern; |
| return true; |
| } |
| case 2: { |
| uint16_t two_bytes = *(uint16_t*)pattern; |
| uint16_t byte0 = two_bytes & 0xffu; |
| uint16_t byte1 = two_bytes >> 8u; |
| if (byte0 == byte1) { |
| *value = (int8_t)byte0; |
| return true; |
| } |
| break; |
| } |
| case 4: { |
| uint32_t four_bytes = *(uint32_t*)pattern; |
| uint32_t byte0 = four_bytes & 0xffu; |
| uint32_t byte1 = (four_bytes >> 8u) & 0xffu; |
| uint32_t byte2 = (four_bytes >> 16u) & 0xffu; |
| uint32_t byte3 = four_bytes >> 24u; |
| if (byte0 == byte1 && byte0 == byte2 && byte0 == byte3) { |
| *value = (int8_t)byte0; |
| return true; |
| } |
| break; |
| } |
| default: |
| break; |
| } |
| return false; |
| } |
| |
| // Fills |value| by duplicating the given |pattern| into 4-bytes. |
| static iree_status_t iree_hal_metal_duplicate_to_four_byte_value(const void* pattern, |
| size_t pattern_length, |
| uint32_t* value) { |
| switch (pattern_length) { |
| case 1: { |
| uint8_t single_byte = *(uint8_t*)pattern; |
| *value = (uint32_t)single_byte; |
| *value |= (*value << 8u); |
| *value |= (*value << 16u); |
| return iree_ok_status(); |
| } |
| case 2: { |
| uint16_t two_bytes = *(uint16_t*)pattern; |
| *value = (uint32_t)two_bytes; |
| *value |= (*value << 16u); |
| return iree_ok_status(); |
| } |
| case 4: { |
| *value = *(uint32_t*)pattern; |
| return iree_ok_status(); |
| } |
| |
| default: |
| break; |
| } |
| return iree_make_status(IREE_STATUS_INTERNAL, "fill pattern should have 1/2/4 bytes"); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_fill_buffer( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* target_buffer, |
| iree_device_size_t target_offset, iree_device_size_t length, const void* pattern, |
| iree_host_size_t pattern_length) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| id<MTLBuffer> target_device_buffer = |
| iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer)); |
| target_offset += iree_hal_buffer_byte_offset(target_buffer); |
| |
| // Per the spec for fillBuffer:range:value: "The alignment and length of the range must both be a |
| // multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS." |
| #if defined(IREE_PLATFORM_MACOS) |
| bool can_use_metal_api = target_offset % 4 == 0 && length % 4 == 0; |
| #else |
| bool can_use_metal_api = true; |
| #endif |
| |
| // Note that fillBuffer:range:value: only accepts a single byte as the pattern but FillBuffer |
| // can accept 1/2/4 bytes. If the pattern itself contains repeated bytes, we can call into |
| // fillBuffer:range:value:. Otherwise we need to emulate the support. |
| uint8_t single_byte_value = 0u; |
| if (can_use_metal_api) { |
| can_use_metal_api &= iree_hal_metal_get_duplicated_single_byte_value(pattern, pattern_length, |
| &single_byte_value); |
| } |
| |
| IREE_RETURN_IF_ERROR( |
| iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer)); |
| |
| iree_status_t status = iree_ok_status(); |
| if (can_use_metal_api) { |
| id<MTLBlitCommandEncoder> encoder = iree_hal_metal_get_or_begin_blit_encoder(command_buffer); |
| [encoder fillBuffer:target_device_buffer |
| range:NSMakeRange(target_offset, length) |
| value:single_byte_value]; |
| } else { |
| id<MTLComputeCommandEncoder> compute_encoder = |
| iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
| uint32_t pattern_4byte = 0; |
| status = iree_hal_metal_duplicate_to_four_byte_value(pattern, pattern_length, &pattern_4byte); |
| if (iree_status_is_ok(status)) { |
| status = iree_hal_metal_builtin_executable_fill_buffer(command_buffer->builtin_executable, |
| compute_encoder, target_device_buffer, |
| target_offset, length, pattern_4byte); |
| } |
| } |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_copy_buffer_internal( |
| iree_hal_metal_command_buffer_t* command_buffer, id<MTLBuffer> source_device_buffer, |
| iree_device_size_t source_offset, id<MTLBuffer> target_device_buffer, |
| iree_device_size_t target_offset, iree_device_size_t length) { |
| // Per the spec for copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size, the source/target |
| // offset and length must be a multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS. |
| #if defined(IREE_PLATFORM_MACOS) |
| bool can_use_metal_api = source_offset % 4 == 0 && target_offset % 4 == 0 && length % 4 == 0; |
| #else |
| bool can_use_metal_api = true; |
| #endif |
| |
| iree_status_t status = iree_ok_status(); |
| if (can_use_metal_api) { |
| id<MTLBlitCommandEncoder> encoder = iree_hal_metal_get_or_begin_blit_encoder(command_buffer); |
| [encoder copyFromBuffer:source_device_buffer |
| sourceOffset:source_offset |
| toBuffer:target_device_buffer |
| destinationOffset:target_offset |
| size:length]; |
| } else { |
| id<MTLComputeCommandEncoder> encoder = |
| iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
| status = iree_hal_metal_builtin_executable_copy_buffer( |
| command_buffer->builtin_executable, encoder, source_device_buffer, source_offset, |
| target_device_buffer, target_offset, length); |
| } |
| |
| return status; |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_update_buffer( |
| iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, |
| iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, |
| iree_device_size_t target_offset, iree_device_size_t length) { |
| // There are no direct corresponding APIs in Metal. We emulate it by creating a buffer with the |
| // content and then copy it over. |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| id<MTLDevice> device = command_buffer->command_buffer.device; |
| MTLResourceOptions options = MTLResourceStorageModeShared | MTLResourceCPUCacheModeWriteCombined; |
| id<MTLBuffer> data_buffer = [device newBufferWithBytes:((uint8_t*)source_buffer + source_offset) |
| length:length |
| options:options]; // +1 |
| [command_buffer->command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cmdbuf) { |
| [data_buffer release]; // -1 |
| }]; |
| |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer)); |
| |
| id<MTLBuffer> target_device_buffer = |
| iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer)); |
| target_offset += iree_hal_buffer_byte_offset(target_buffer); |
| |
| iree_status_t status = iree_hal_metal_command_buffer_copy_buffer_internal( |
| command_buffer, data_buffer, /*source_offset=*/0, target_device_buffer, target_offset, |
| length); |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_copy_buffer( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* source_buffer, |
| iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer, |
| iree_device_size_t target_offset, iree_device_size_t length) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| const iree_hal_buffer_t* buffers[2] = {source_buffer, target_buffer}; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_resource_set_insert(command_buffer->resource_set, 2, buffers)); |
| |
| id<MTLBuffer> source_device_buffer = |
| iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(source_buffer)); |
| id<MTLBuffer> target_device_buffer = |
| iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer)); |
| |
| source_offset += iree_hal_buffer_byte_offset(source_buffer); |
| target_offset += iree_hal_buffer_byte_offset(target_buffer); |
| |
| iree_status_t status = iree_hal_metal_command_buffer_copy_buffer_internal( |
| command_buffer, source_device_buffer, source_offset, target_device_buffer, target_offset, |
| length); |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_collective( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel, |
| iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_binding_t send_binding, |
| iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) { |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "collectives not yet supported"); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_push_constants( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout, |
| iree_host_size_t offset, const void* values, iree_host_size_t values_length) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| |
| // "Binding a pipeline with a layout that is not compatible with the push constant layout does not |
| // disturb the push constant values." So we don't need to check whether the pipeline layout |
| // compatibility and invalidate existing values. |
| // See https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/vkCmdPushConstants.html |
| |
| if (IREE_UNLIKELY(offset + values_length >= sizeof(command_buffer->push_constants))) { |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "push constant range [%zu, %zu) out of range", offset, |
| offset + values_length); |
| } |
| |
| memcpy((uint8_t*)&command_buffer->push_constants + offset, values, values_length); |
| command_buffer->current_pipeline_layout = pipeline_layout; |
| |
| return iree_ok_status(); |
| } |
| |
| static int compare_descriptor(const void* a, const void* b) { |
| const iree_hal_metal_descriptor_t* buffer_a = (const iree_hal_metal_descriptor_t*)a; |
| const iree_hal_metal_descriptor_t* buffer_b = (const iree_hal_metal_descriptor_t*)b; |
| if (buffer_a->set != buffer_b->set) return buffer_a->set - buffer_b->set; |
| return buffer_a->binding - buffer_b->binding; |
| } |
| |
| // Returns true if the given |descriptors| array contains descriptors in ascending binding slot |
| // order and there is no duplicated binding slots. |
| static bool iree_hal_metal_is_sorted_unique_descriptors(iree_hal_metal_descriptor_t* descriptors, |
| int descriptor_count) { |
| for (int i = 1; i < descriptor_count; ++i) { |
| if (compare_descriptor(&descriptors[i - 1], &descriptors[i]) >= 0) return false; |
| } |
| return true; |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_push_descriptor_set( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout, |
| uint32_t set, iree_host_size_t binding_count, |
| const iree_hal_descriptor_set_binding_t* bindings) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| if (set == IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX) { |
| IREE_TRACE_ZONE_END(z0); |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| "descriptor set #%d reserved for push constant emulation", |
| IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX); |
| } |
| |
| for (iree_host_size_t i = 0; i < binding_count; ++i) { |
| if (bindings[i].buffer) continue; |
| IREE_TRACE_ZONE_END(z0); |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, |
| "unimplemented null buffer in push descriptor set"); |
| } |
| |
| iree_hal_metal_descriptor_t* descriptors = command_buffer->current_descriptors; |
| IREE_ASSERT(iree_hal_metal_is_sorted_unique_descriptors( |
| descriptors, command_buffer->current_total_binding_count)); |
| |
| if (command_buffer->current_max_set_number >= (int)set) { |
| // We are pushing an already seen set. This would invalidate all sets with the given number and |
| // larger ones. So clear all affected bindings. |
| // TODO(antiagainst): We should actually check current pipeline's layout compatibility with |
| // previous one and decide whether we should invalidate lower numbered sets too. For now we |
| // assume the compiler side is doing proper job of guaranteeing that. |
| // https://registry.khronos.org/vulkan/specs/1.3-extensions/html/chap14.html#descriptorsets-compatibility |
| int* count = &command_buffer->current_total_binding_count; |
| while (*count > 0 && descriptors[*count - 1].set >= (int)set) --(*count); |
| command_buffer->current_max_set_number = (*count == 0) ? -1 : descriptors[*count - 1].set; |
| } |
| |
| // Pushing a new set with a larger number. All sets with smaller number remain active. Just sort |
| // the current one and copy over the data. This is the expected usage pattern in IREE, where the |
| // compiler sorts/deduplicates descriptor sets, and pushes them in ascending order. |
| if (binding_count + command_buffer->current_total_binding_count > |
| IREE_HAL_METAL_MAX_BINDING_COUNT) { |
| IREE_TRACE_ZONE_END(z0); |
| return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, |
| "exceeded available binding slots for push descriptor sets"); |
| } |
| |
| int start_index = command_buffer->current_total_binding_count; |
| for (iree_host_size_t i = 0; i < binding_count; ++i) { |
| iree_hal_metal_descriptor_t* descriptor = &descriptors[start_index + i]; |
| descriptor->set = set; |
| descriptor->binding = bindings[i].binding; |
| descriptor->buffer = bindings[i].buffer; |
| descriptor->offset = bindings[i].offset; |
| } |
| qsort(&descriptors[start_index], binding_count, sizeof(descriptors[0]), compare_descriptor); |
| |
| command_buffer->current_max_set_number = set; |
| command_buffer->current_total_binding_count += binding_count; |
| |
| IREE_ASSERT(iree_hal_metal_is_sorted_unique_descriptors( |
| descriptors, command_buffer->current_total_binding_count)); |
| |
| // Retain all buffers bound in this descriptor set. |
| for (iree_host_size_t i = 0; i < binding_count; ++i) { |
| if (bindings[i].buffer) { |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &bindings[i].buffer)); |
| } |
| } |
| |
| command_buffer->current_pipeline_layout = pipeline_layout; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &pipeline_layout)); |
| |
| IREE_TRACE_ZONE_END(z0); |
| return iree_ok_status(); |
| } |
| |
| static inline MTLResourceUsage iree_hal_metal_get_metal_resource_usage( |
| iree_hal_descriptor_set_layout_binding_t* binding) { |
| MTLResourceUsage usage = MTLResourceUsageRead; |
| if (binding->flags != IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY) usage |= MTLResourceUsageWrite; |
| return usage; |
| } |
| |
| // Creates an argument encoder and its backing argument buffer for the given kernel |function|'s |
| // |buffer_index|. The argument encoder will be set to encode into the newly created argument |
| // buffer. Callers are expected to release both the argument encoder and buffer. |
| static iree_status_t iree_hal_metal_create_argument_encoder( |
| id<MTLDevice> device, id<MTLCommandBuffer> command_buffer, id<MTLFunction> function, |
| uint32_t buffer_index, id<MTLArgumentEncoder>* out_encoder, id<MTLBuffer>* out_buffer) { |
| id<MTLArgumentEncoder> argument_encoder = |
| [function newArgumentEncoderWithBufferIndex:buffer_index]; // +1 |
| if (!argument_encoder) { |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "invalid argument buffer index #%u", |
| buffer_index); |
| } |
| |
| __block id<MTLBuffer> argument_buffer = |
| [device newBufferWithLength:argument_encoder.encodedLength |
| options:MTLResourceStorageModeShared]; // +1 |
| if (!argument_buffer) { |
| return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, |
| "failed to create argument buffer with size = %ld bytes", |
| argument_encoder.encodedLength); |
| } |
| |
| // The arugment encoder and buffer can be deleted once the command buffer completes. |
| [command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cmdbuf) { |
| [argument_buffer release]; // -1 |
| [argument_encoder release]; // -1 |
| }]; |
| |
| [argument_encoder setArgumentBuffer:argument_buffer offset:0]; |
| *out_encoder = argument_encoder; |
| *out_buffer = argument_buffer; |
| return iree_ok_status(); |
| } |
| |
| // Prepares kernels and argument buffers needed for kernel dispatches. |
| static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, |
| int32_t entry_point, iree_hal_metal_kernel_params_t* kernel_params) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_kernel_library_entry_point_kernel_params( |
| executable, entry_point, kernel_params)); |
| if (!command_buffer->current_pipeline_layout) { |
| IREE_TRACE_ZONE_END(z0); |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "missing pipeline layout when dispatch"); |
| } |
| |
| // Set the compute kernel to dispatch. |
| id<MTLComputeCommandEncoder> compute_encoder = |
| iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
| [compute_encoder setComputePipelineState:kernel_params->pso]; |
| |
| iree_status_t status = iree_ok_status(); |
| |
| // Bind all buffers in all descriptor sets. |
| iree_hal_metal_descriptor_t* descriptors = command_buffer->current_descriptors; |
| int binding_count = command_buffer->current_total_binding_count; |
| int i = 0; |
| while (i < binding_count) { |
| // Build argument encoder and argument buffer for the current descriptor set. |
| uint32_t current_set = descriptors[i].set; |
| |
| id<MTLArgumentEncoder> argument_encoder; |
| id<MTLBuffer> argument_buffer; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_metal_create_argument_encoder( |
| command_buffer->command_buffer.device, command_buffer->command_buffer, |
| kernel_params->function, current_set, &argument_encoder, &argument_buffer)); |
| |
| iree_hal_descriptor_set_layout_t* set_layout = |
| iree_hal_metal_pipeline_layout_descriptor_set_layout( |
| command_buffer->current_pipeline_layout, current_set); |
| if (!set_layout) { |
| status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "cannot find descriptor set layout for set #%u", current_set); |
| break; |
| } |
| |
| // Now put all bound buffers belonging to the current descriptor set into the argument buffer. |
| for (; i < binding_count && descriptors[i].set == current_set; ++i) { |
| unsigned current_binding = descriptors[i].binding; |
| id<MTLBuffer> current_buffer = |
| iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer)); |
| iree_host_size_t offset = |
| iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset; |
| [argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding]; |
| |
| iree_hal_descriptor_set_layout_binding_t* binding_params = |
| iree_hal_metal_descriptor_set_layout_binding(set_layout, current_binding); |
| if (!binding_params) { |
| status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "cannot find information for binding #%u in set #%u", |
| current_binding, current_set); |
| break; |
| } |
| [compute_encoder useResource:current_buffer |
| usage:iree_hal_metal_get_metal_resource_usage(binding_params)]; |
| } |
| if (!iree_status_is_ok(status)) break; |
| |
| [compute_encoder setBuffer:argument_buffer offset:0 atIndex:current_set]; |
| } |
| |
| if (iree_hal_metal_pipeline_layout_push_constant_count(command_buffer->current_pipeline_layout)) { |
| [compute_encoder setBytes:(void*)command_buffer->push_constants |
| length:sizeof(command_buffer->push_constants) |
| atIndex:IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX]; |
| } |
| |
| if (iree_status_is_ok(status)) { |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &executable)); |
| } |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_dispatch( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, |
| int32_t entry_point, uint32_t workgroup_count_x, uint32_t workgroup_count_y, |
| uint32_t workgroup_count_z) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_hal_metal_kernel_params_t kernel_params; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_metal_command_buffer_prepare_dispatch(base_command_buffer, executable, |
| entry_point, &kernel_params)); |
| |
| id<MTLComputeCommandEncoder> compute_encoder = |
| iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
| uint32_t* workgroup_size = kernel_params.threadgroup_size; |
| [compute_encoder |
| dispatchThreadgroups:MTLSizeMake(workgroup_count_x, workgroup_count_y, workgroup_count_z) |
| threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1], workgroup_size[2])]; |
| |
| IREE_TRACE_ZONE_END(z0); |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_dispatch_indirect( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, |
| int32_t entry_point, iree_hal_buffer_t* workgroups_buffer, |
| iree_device_size_t workgroups_offset) { |
| iree_hal_metal_command_buffer_t* command_buffer = |
| iree_hal_metal_command_buffer_cast(base_command_buffer); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| iree_hal_metal_kernel_params_t kernel_params; |
| IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| z0, iree_hal_metal_command_buffer_prepare_dispatch(base_command_buffer, executable, |
| entry_point, &kernel_params)); |
| |
| id<MTLComputeCommandEncoder> compute_encoder = |
| iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
| uint32_t* workgroup_size = kernel_params.threadgroup_size; |
| id<MTLBuffer> metal_buffer = |
| iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(workgroups_buffer)); |
| [compute_encoder |
| dispatchThreadgroupsWithIndirectBuffer:metal_buffer |
| indirectBufferOffset:workgroups_offset |
| threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1], |
| workgroup_size[2])]; |
| |
| IREE_TRACE_ZONE_END(z0); |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_hal_metal_command_buffer_execute_commands( |
| iree_hal_command_buffer_t* base_command_buffer, iree_hal_command_buffer_t* base_commands, |
| iree_hal_buffer_binding_table_t binding_table) { |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "secondary command buffer not yet supported"); |
| } |
| |
| static const iree_hal_command_buffer_vtable_t iree_hal_metal_command_buffer_vtable = { |
| .destroy = iree_hal_metal_command_buffer_destroy, |
| .begin = iree_hal_metal_command_buffer_begin, |
| .end = iree_hal_metal_command_buffer_end, |
| .begin_debug_group = iree_hal_metal_command_buffer_begin_debug_group, |
| .end_debug_group = iree_hal_metal_command_buffer_end_debug_group, |
| .execution_barrier = iree_hal_metal_command_buffer_execution_barrier, |
| .signal_event = iree_hal_metal_command_buffer_signal_event, |
| .reset_event = iree_hal_metal_command_buffer_reset_event, |
| .wait_events = iree_hal_metal_command_buffer_wait_events, |
| .discard_buffer = iree_hal_metal_command_buffer_discard_buffer, |
| .fill_buffer = iree_hal_metal_command_buffer_fill_buffer, |
| .update_buffer = iree_hal_metal_command_buffer_update_buffer, |
| .copy_buffer = iree_hal_metal_command_buffer_copy_buffer, |
| .collective = iree_hal_metal_command_buffer_collective, |
| .push_constants = iree_hal_metal_command_buffer_push_constants, |
| .push_descriptor_set = iree_hal_metal_command_buffer_push_descriptor_set, |
| .dispatch = iree_hal_metal_command_buffer_dispatch, |
| .dispatch_indirect = iree_hal_metal_command_buffer_dispatch_indirect, |
| .execute_commands = iree_hal_metal_command_buffer_execute_commands, |
| }; |