[metal] Use staging buffer for argument buffers and update sources This commit switches the Metal HAL driver to use a staging buffer for recording argument buffers and uploading buffer update source data. This avoids creating lots of small-sized buffers like we did previously, and avoids using command buffer completion callback to handle their lifetime.
diff --git a/experimental/metal/CMakeLists.txt b/experimental/metal/CMakeLists.txt index 0b0fc22..9552b58 100644 --- a/experimental/metal/CMakeLists.txt +++ b/experimental/metal/CMakeLists.txt
@@ -35,6 +35,8 @@ "nop_executable_cache.m" "pipeline_layout.h" "pipeline_layout.m" + "staging_buffer.h" + "staging_buffer.m" DEPS iree::base iree::base::core_headers
diff --git a/experimental/metal/api.h b/experimental/metal/api.h index 0a5112b..b2c9728 100644 --- a/experimental/metal/api.h +++ b/experimental/metal/api.h
@@ -46,11 +46,16 @@ // Must be initialized with iree_hal_metal_device_params_initialize prior to // use. typedef struct iree_hal_metal_device_params_t { - // Total size of each block in the device shared block pool. + // Total size in bytes of each block in the device shared block pool. // Larger sizes will lower overhead and ensure the heap isn't hit for // transient allocations while also increasing memory consumption. iree_host_size_t arena_block_size; + // Total size in bytes of per-queue uniform buffers for uploading parameters + // to the GPU (including argument buffers and update source buffers). + // Larger sizes better support more concurrent/complex command buffers. + iree_host_size_t queue_uniform_buffer_size; + // Command dispatch type in command buffers. // Normally we want to dispatch commands in command buffers in parallel, given // that IREE performs explicit dependency tracking and synchronization by
diff --git a/experimental/metal/direct_command_buffer.h b/experimental/metal/direct_command_buffer.h index 09c1478..3c929e5 100644 --- a/experimental/metal/direct_command_buffer.h +++ b/experimental/metal/direct_command_buffer.h
@@ -11,6 +11,7 @@ #include "experimental/metal/api.h" #include "experimental/metal/builtin_executables.h" +#include "experimental/metal/staging_buffer.h" #include "iree/base/internal/arena.h" #include "iree/hal/api.h" @@ -43,6 +44,7 @@ resource_reference_mode, id<MTLCommandQueue> queue, iree_allocator_t host_allocator, iree_arena_block_pool_t* block_pool, + iree_hal_metal_staging_buffer_t* staging_buffer, iree_hal_metal_builtin_executable_t* builtin_executable, iree_hal_command_buffer_t** out_command_buffer);
diff --git a/experimental/metal/direct_command_buffer.m b/experimental/metal/direct_command_buffer.m index 7117969..a6f3fbf 100644 --- a/experimental/metal/direct_command_buffer.m +++ b/experimental/metal/direct_command_buffer.m
@@ -13,6 +13,7 @@ #include "experimental/metal/metal_device.h" #include "experimental/metal/metal_kernel_library.h" #include "experimental/metal/pipeline_layout.h" +#include "experimental/metal/staging_buffer.h" #include "iree/base/api.h" #include "iree/base/target_platform.h" #include "iree/base/tracing.h" @@ -168,6 +169,10 @@ // Arena used for all allocations; references the shared device block pool. iree_arena_allocator_t arena; + // Per-queue shared uniform staging buffer for uploading parameters to the GPU, including argument + // buffers and buffer update source buffers. + iree_hal_metal_staging_buffer_t* staging_buffer; + // Linked list of command segments to be recorded into a command buffer. iree_hal_metal_command_segment_list_t segments; @@ -319,6 +324,7 @@ iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, iree_hal_metal_command_buffer_resource_reference_mode_t resource_reference_mode, id<MTLCommandQueue> queue, iree_allocator_t host_allocator, iree_arena_block_pool_t* block_pool, + iree_hal_metal_staging_buffer_t* staging_buffer, iree_hal_metal_builtin_executable_t* builtin_executable, iree_hal_command_buffer_t** out_command_buffer) { IREE_ASSERT_ARGUMENT(device); @@ -344,6 +350,7 @@ command_buffer->queue = [queue retain]; // +1 command_buffer->builtin_executable = builtin_executable; iree_arena_initialize(block_pool, &command_buffer->arena); + command_buffer->staging_buffer = staging_buffer; iree_hal_metal_command_segment_list_reset(&command_buffer->segments); @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation. // We track resource lifetime by ourselves in IREE; so just do unretained references to @@ -737,6 +744,7 @@ segment->target_buffer, segment->target_offset, segment->length); } + IREE_TRACE_ZONE_END(z0); return status; } @@ -744,20 +752,19 @@ iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length) { - // There are no direct corresponding APIs in Metal. We emulate it by creating a buffer with the - // content and then copy it over. iree_hal_metal_command_buffer_t* command_buffer = iree_hal_metal_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); - id<MTLDevice> device = command_buffer->command_buffer.device; - MTLResourceOptions options = MTLResourceStorageModeShared | MTLResourceCPUCacheModeWriteCombined; - id<MTLBuffer> data_buffer = [device newBufferWithBytes:((uint8_t*)source_buffer + source_offset) - length:length - options:options]; // +1 - [command_buffer->command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cmdbuf) { - [data_buffer release]; // -1 - }]; + // There are no direct corresponding APIs in Metal. We update the source buffer data to the + // staging buffer and then copy over. + + iree_const_byte_span_t source_data_span = + iree_make_const_byte_span((uint8_t*)source_buffer + source_offset, length); + uint32_t offset = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_metal_staging_buffer_append(command_buffer->staging_buffer, source_data_span, + /*alignment=*/4, &offset)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer)); @@ -767,8 +774,8 @@ target_offset += iree_hal_buffer_byte_offset(target_buffer); iree_status_t status = iree_hal_metal_command_segment_create_copy_buffer( - command_buffer, data_buffer, /*source_offset=*/0, target_device_buffer, target_offset, - length); + command_buffer, command_buffer->staging_buffer->metal_buffer, offset, target_device_buffer, + target_offset, length); IREE_TRACE_ZONE_END(z0); return status; @@ -931,37 +938,6 @@ return iree_ok_status(); } -// Creates an argument encoder and its backing argument buffer for the given kernel |function|'s -// |buffer_index|. The argument encoder will be set to encode into the newly created argument -// buffer. Callers are expected to release both the argument encoder and buffer. -static iree_status_t iree_hal_metal_create_argument_encoder( - id<MTLDevice> device, id<MTLCommandBuffer> command_buffer, id<MTLFunction> function, - uint32_t buffer_index, id<MTLArgumentEncoder>* out_encoder, id<MTLBuffer>* out_buffer) { - id<MTLArgumentEncoder> argument_encoder = - [function newArgumentEncoderWithBufferIndex:buffer_index]; // +1 - IREE_ASSERT(argument_encoder != nil); - - __block id<MTLBuffer> argument_buffer = - [device newBufferWithLength:argument_encoder.encodedLength - options:MTLResourceStorageModeShared]; // +1 - if (!argument_buffer) { - return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, - "failed to create argument buffer with size = %ld bytes", - argument_encoder.encodedLength); - } - - // The arugment encoder and buffer can be deleted once the command buffer completes. - [command_buffer addCompletedHandler:^(id<MTLCommandBuffer> cmdbuf) { - [argument_buffer release]; // -1 - [argument_encoder release]; // -1 - }]; - - [argument_encoder setArgumentBuffer:argument_buffer offset:0]; - *out_encoder = argument_encoder; - *out_buffer = argument_buffer; - return iree_ok_status(); -} - // Prepares kernels and argument buffers needed for kernel dispatches. static iree_status_t iree_hal_metal_command_segment_create_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, @@ -1033,17 +1009,23 @@ // Record argument buffers for all descriptors and record buffer usages. iree_hal_metal_descriptor_t* descriptors = segment->descriptors; - iree_host_size_t i = 0; - while (i < segment->descriptor_count) { + for (iree_host_size_t i = 0; i < segment->descriptor_count;) { uint32_t current_set = descriptors[i].set; // Build argument encoder and argument buffer for the current descriptor set. - id<MTLArgumentEncoder> argument_encoder; - id<MTLBuffer> argument_buffer; + id<MTLBuffer> argument_buffer = command_buffer->staging_buffer->metal_buffer; + id<MTLArgumentEncoder> argument_encoder = + [segment->kernel_params.function newArgumentEncoderWithBufferIndex:current_set]; // +1 + IREE_ASSERT(argument_encoder != nil); + + // Reserve space for the argument buffer from shared staging buffer. + iree_byte_span_t reservation; + uint32_t argument_buffer_offset; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_metal_create_argument_encoder( - command_buffer->command_buffer.device, command_buffer->command_buffer, - segment->kernel_params.function, current_set, &argument_encoder, &argument_buffer)); + z0, iree_hal_metal_staging_buffer_reserve( + command_buffer->staging_buffer, argument_encoder.encodedLength, + argument_encoder.alignment, &reservation, &argument_buffer_offset)); + [argument_encoder setArgumentBuffer:argument_buffer offset:argument_buffer_offset]; // Now record all bound buffers belonging to the current set into the argument buffer. for (; i < segment->descriptor_count && descriptors[i].set == current_set; ++i) { @@ -1058,7 +1040,9 @@ [compute_encoder useResource:current_buffer usage:descriptors[i].usage]; } // Record the argument buffer. - [compute_encoder setBuffer:argument_buffer offset:0 atIndex:current_set]; + [compute_encoder setBuffer:argument_buffer offset:argument_buffer_offset atIndex:current_set]; + + [argument_encoder release]; // -1 } // Record the dispatch, either direct or indirect.
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m index 3ce0f33..70a775f 100644 --- a/experimental/metal/metal_device.m +++ b/experimental/metal/metal_device.m
@@ -13,6 +13,7 @@ #include "experimental/metal/metal_shared_event.h" #include "experimental/metal/nop_executable_cache.h" #include "experimental/metal/pipeline_layout.h" +#include "experimental/metal/staging_buffer.h" #include "iree/base/api.h" #include "iree/base/tracing.h" #include "iree/hal/api.h" @@ -29,8 +30,10 @@ // contain inlined data uploads). iree_arena_block_pool_t block_pool; - // Original driver that owns this device. - iree_hal_driver_t* driver; + // Per-queue staging buffer for parameter uploads. + iree_hal_metal_staging_buffer_t staging_buffer; + + iree_hal_metal_device_params_t params; iree_allocator_t host_allocator; iree_hal_allocator_t* device_allocator; @@ -69,6 +72,7 @@ void iree_hal_metal_device_params_initialize(iree_hal_metal_device_params_t* out_params) { memset(out_params, 0, sizeof(*out_params)); out_params->arena_block_size = 32 * 1024; + out_params->queue_uniform_buffer_size = IREE_HAL_METAL_STAGING_BUFFER_DEFAULT_CAPACITY; out_params->command_dispatch_type = IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT; out_params->command_buffer_resource_reference_mode = IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_UNRETAINED; @@ -98,6 +102,7 @@ iree_status_t status = iree_hal_metal_allocator_create((iree_hal_device_t*)device, metal_device, params->resource_hazard_tracking_mode, host_allocator, &device->device_allocator); + iree_hal_metal_builtin_executable_t* builtin_executable = NULL; if (iree_status_is_ok(status)) { status = @@ -105,13 +110,20 @@ } else { iree_hal_device_release((iree_hal_device_t*)device); } + + if (iree_status_is_ok(status)) { + status = iree_hal_metal_staging_buffer_initialize( + metal_device, params->queue_uniform_buffer_size, &device->staging_buffer); + } else { + iree_hal_device_release((iree_hal_device_t*)device); + } + if (iree_status_is_ok(status)) { iree_hal_resource_initialize(&iree_hal_metal_device_vtable, &device->resource); iree_string_view_append_to_buffer(identifier, &device->identifier, (char*)device + iree_sizeof_struct(*device)); iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, &device->block_pool); - device->driver = driver; - iree_hal_driver_retain(device->driver); + device->params = *params; device->host_allocator = host_allocator; device->device = [metal_device retain]; // +1 device->queue = [metal_device newCommandQueue]; // +1 @@ -124,7 +136,6 @@ device->event_listener = [[MTLSharedEventListener alloc] initWithDispatchQueue:device->semaphore_notification_queue]; // +1 device->capture_manager = NULL; - *out_device = (iree_hal_device_t*)device; } return status; @@ -158,6 +169,7 @@ [device->queue release]; // -1 [device->device release]; // -1 + iree_hal_metal_staging_buffer_deinitialize(&device->staging_buffer); iree_arena_block_pool_deinitialize(&device->block_pool); iree_allocator_free(host_allocator, device); @@ -221,14 +233,21 @@ iree_hal_command_category_t command_categories, iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device); + if (iree_any_bit_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED)) return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "nested command buffer not yet supported"); if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented multi-shot command buffer"); - return iree_hal_metal_direct_command_buffer_create( + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "multi-shot command buffer not yet supported"); + + iree_status_t status = iree_hal_metal_direct_command_buffer_create( base_device, mode, command_categories, binding_capacity, device->command_buffer_resource_reference_mode, device->queue, device->host_allocator, - &device->block_pool, device->builtin_executable, out_command_buffer); + &device->block_pool, &device->staging_buffer, device->builtin_executable, out_command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_metal_staging_buffer_increase_refcount(&device->staging_buffer); + } + return status; } static iree_status_t iree_hal_metal_device_create_descriptor_set_layout( @@ -354,6 +373,9 @@ id<MTLCommandBuffer> handle = iree_hal_metal_direct_command_buffer_handle(command_buffer); [handle addCompletedHandler:^(id<MTLCommandBuffer> cb) { iree_hal_command_buffer_release(command_buffer); // -1 + // Decrease command buffer refcount in the shared staging buffer, and potentially reclaim + // resources. This is fine right now given we only support one-shot command buffers. + iree_hal_metal_staging_buffer_decrease_refcount(&device->staging_buffer); }]; [handle commit]; }
diff --git a/experimental/metal/staging_buffer.h b/experimental/metal/staging_buffer.h new file mode 100644 index 0000000..c46df94 --- /dev/null +++ b/experimental/metal/staging_buffer.h
@@ -0,0 +1,102 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_METAL_STAGING_BUFFER_H_ +#define IREE_EXPERIMENTAL_METAL_STAGING_BUFFER_H_ + +#import <Metal/Metal.h> + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Size, in bytes, of the shared storage mode staging buffer. +// The given amount of system memory will be allocated and is accessible to both +// the CPU and the GPU. +// +// Larger values here will use more memory but allow more concurrent/complex +// command buffers. As most models that run in these environments are only a few +// hundred dispatches per command buffer we can approximate an average +// consumption of 500 dispatches x worst-case 256B per dispatch of parameters +// and get 128KB. +#define IREE_HAL_METAL_STAGING_BUFFER_DEFAULT_CAPACITY (128 * 1024) + +// A staging uniform buffer used for uploading parameters to the device. +// This allows for high-frequency writes of parameters at appropriate alignment. +// +// Intended usage is to retain one of these per device queue and use them during +// command buffer recording targeting that particular queue. This avoids +// allocating a lot of small buffers. The underlying buffer has shared storage +// mode; so it resides in system memory and is accessible to both the CPU and +// the GPU. +// +// Parameters handled by this buffer include: +// * Argument buffers for descriptor sets +// * Source buffer for buffer update commands +// +// TODO(#14049): Use proper atomics/mutexes for concurrent command buffer +// recording and execution. +typedef struct iree_hal_metal_staging_buffer_t { + // Maximum number of bytes in the buffer. + uint32_t capacity; + + // Device handle to the buffer. + id<MTLBuffer> metal_buffer; + // Host pointer to the buffer. + uint8_t* host_buffer; + + // Current write offset of the device buffer. + uint32_t offset; + + // The number of command buffers that are being recorded or executed on + // device. If this reaches zero, we know that there are no users of the + // staging buffer so we can discard the contents and reset the offset to + // zero. + uint32_t pending_command_buffers; +} iree_hal_metal_staging_buffer_t; + +// Initializes |out_staging_buffer| with the given |buffer_capacity|. +iree_status_t iree_hal_metal_staging_buffer_initialize( + id<MTLDevice> device, iree_host_size_t buffer_capacity, + iree_hal_metal_staging_buffer_t* out_staging_buffer); + +void iree_hal_metal_staging_buffer_deinitialize( + iree_hal_metal_staging_buffer_t* staging_buffer); + +// Reserves |length| bytes from the staging buffer and returns a pointer to it +// in |out_reservation|. +iree_status_t iree_hal_metal_staging_buffer_reserve( + iree_hal_metal_staging_buffer_t* staging_buffer, iree_host_size_t length, + iree_host_size_t alignment, iree_byte_span_t* out_reservation, + uint32_t* out_offset); + +// Appends |data| of |length| bytes to the staging buffer. +iree_status_t iree_hal_metal_staging_buffer_append( + iree_hal_metal_staging_buffer_t* staging_buffer, + iree_const_byte_span_t source, iree_host_size_t alignment, + uint32_t* out_offset); + +// Resets the staging buffer to discard all its contents. +void iree_hal_metal_staging_buffer_reset( + iree_hal_metal_staging_buffer_t* staging_buffer); + +// Increases the command buffer using this staging buffer by one. +void iree_hal_metal_staging_buffer_increase_refcount( + iree_hal_metal_staging_buffer_t* staging_buffer); + +// Decreases the command buffer using this staging buffer by one, which may +// trigger reclaiming of resources. +void iree_hal_metal_staging_buffer_decrease_refcount( + iree_hal_metal_staging_buffer_t* staging_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_METAL_STAGING_BUFFER_H_
diff --git a/experimental/metal/staging_buffer.m b/experimental/metal/staging_buffer.m new file mode 100644 index 0000000..08791b4 --- /dev/null +++ b/experimental/metal/staging_buffer.m
@@ -0,0 +1,98 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/metal/staging_buffer.h" + +#include <stdint.h> + +#include "iree/base/api.h" +#include "iree/base/tracing.h" + +iree_status_t iree_hal_metal_staging_buffer_initialize( + id<MTLDevice> device, iree_host_size_t buffer_capacity, + iree_hal_metal_staging_buffer_t* out_staging_buffer) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(out_staging_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + memset(out_staging_buffer, 0, sizeof(*out_staging_buffer)); + + // From Metal Best Practices Guide: + // "For small-sized data that changes frequently, choose the Shared mode. The overhead of copying + // data to video memory may be more expensive than the overhead of the GPU accessing system memory + // directly." + MTLResourceOptions options = MTLResourceStorageModeShared | MTLResourceCPUCacheModeWriteCombined; + id<MTLBuffer> metal_buffer = [device newBufferWithLength:buffer_capacity options:options]; // +1 + if (!metal_buffer) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "failed to allocate staging buffer with size = %ld bytes", + buffer_capacity); + } + + out_staging_buffer->capacity = (uint32_t)buffer_capacity; + out_staging_buffer->metal_buffer = metal_buffer; + out_staging_buffer->host_buffer = metal_buffer.contents; + out_staging_buffer->offset = 0; + out_staging_buffer->pending_command_buffers = 0; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +void iree_hal_metal_staging_buffer_deinitialize(iree_hal_metal_staging_buffer_t* staging_buffer) { + [staging_buffer->metal_buffer release]; // -1 +} + +iree_status_t iree_hal_metal_staging_buffer_reserve(iree_hal_metal_staging_buffer_t* staging_buffer, + iree_host_size_t length, + iree_host_size_t alignment, + iree_byte_span_t* out_reservation, + uint32_t* out_offset) { + iree_host_size_t aligned_length = iree_host_align(length, alignment); + if (aligned_length > staging_buffer->capacity) { + // This will never fit in the staging buffer. + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "reservation (%" PRIhsz " bytes) exceeds the maximum capacity of " + "the staging buffer (%" PRIu32 " bytes)", + length, staging_buffer->capacity); + } else if (staging_buffer->offset + aligned_length > staging_buffer->capacity) { + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "failed to reserve %" PRIhsz " bytes in staging buffer", length); + } + *out_reservation = + iree_make_byte_span(staging_buffer->host_buffer + staging_buffer->offset, aligned_length); + *out_offset = staging_buffer->offset; + staging_buffer->offset += aligned_length; + return iree_ok_status(); +} + +iree_status_t iree_hal_metal_staging_buffer_append(iree_hal_metal_staging_buffer_t* staging_buffer, + iree_const_byte_span_t source, + iree_host_size_t alignment, + uint32_t* out_offset) { + iree_byte_span_t reservation; + IREE_RETURN_IF_ERROR(iree_hal_metal_staging_buffer_reserve(staging_buffer, source.data_length, + alignment, &reservation, out_offset)); + memcpy(reservation.data, source.data, source.data_length); + return iree_ok_status(); +} + +void iree_hal_metal_staging_buffer_reset(iree_hal_metal_staging_buffer_t* staging_buffer) { + staging_buffer->offset = 0; +} + +void iree_hal_metal_staging_buffer_increase_refcount( + iree_hal_metal_staging_buffer_t* staging_buffer) { + ++staging_buffer->pending_command_buffers; +} + +void iree_hal_metal_staging_buffer_decrease_refcount( + iree_hal_metal_staging_buffer_t* staging_buffer) { + IREE_ASSERT(staging_buffer->pending_command_buffers > 0); + if (--staging_buffer->pending_command_buffers == 0) { + iree_hal_metal_staging_buffer_reset(staging_buffer); + } +}