blob: c5042190e7095c2d922e54b7bf8ddb1796fac27a [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/metal/direct_command_buffer.h"
#import <Metal/Metal.h>
#include "experimental/metal/builtin_executables.h"
#include "experimental/metal/metal_buffer.h"
#include "experimental/metal/metal_device.h"
#include "experimental/metal/metal_kernel_library.h"
#include "experimental/metal/pipeline_layout.h"
#include "experimental/metal/staging_buffer.h"
#include "iree/base/api.h"
#include "iree/base/target_platform.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
#include "iree/hal/utils/resource_set.h"
//===------------------------------------------------------------------------------------------===//
// Segmented submission management
//===------------------------------------------------------------------------------------------===//
// Unlike Vulkan, Metal adopts a multi-level command recording model--memory/dispatch commands are
// not directly recorded into a command buffer; rather, they must go through the additional level of
// blit/compute encoders. IREE's HAL follows the flat Vulkan command buffer recording model, so we
// have a mismatch here. Implementing IREE's HAL using Metal would require switching encoders for
// interleaved memory and dispatch commands. Additionally, certain IREE HAL API features do not have
// direct mapping in Metal APIs, e.g., various forms of IREE HAL execution/memory barriers.
// Translating them would require looking at both previous and next commands to decide the proper
// mapping.
//
// Due to these reasons, it's beneficial to have a complete view of the full command buffer and
// extra flexibility during recording, in order to fixup past commands, or inspect future commands.
//
// Therefore, to implement IREE HAL command buffers using Metal, we perform two steps using a linked
// list of command segments. First we create segments (iree_hal_metal_command_buffer_prepare_* and
// iree_hal_metal_command_segment_create_*) to keep track of all IREE HAL commands and the
// associated data, and then, when finalizing the command buffer, we iterate through all the
// segments and record their contents (iree_hal_metal_command_segment_record_*) into a proper Metal
// command buffer . A linked list gives us the flexibility to organize command sequence in low
// overhead; and a deferred recording gives us the complete picture of the command buffer when
// really started recording.
// Command action kind of a command segment.
typedef enum iree_hal_metal_command_segment_action_e {
IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER, // Execution/memory barrier command
IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH, // Dispatch command
IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER, // Fill buffer command
IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER, // Copy buffer command
} iree_hal_metal_command_segment_action_t;
// API data for execution/memory barrier command segments.
typedef struct iree_hal_metal_barrier_segment_t {
iree_host_size_t memory_barrier_count; // Total number of memory barriers
iree_host_size_t buffer_barrier_count; // Total number of buffer barriers
// The list of buffer barriers, pointing to the end of the segment allocation.
const iree_hal_buffer_barrier_t* buffer_barriers;
} iree_hal_metal_barrier_segment_t;
// + Additional inline allocation for holding all buffer barriers.
typedef struct iree_hal_metal_descriptor_t {
uint32_t set;
uint32_t binding;
iree_hal_buffer_t* buffer;
iree_device_size_t offset;
MTLResourceUsage usage;
} iree_hal_metal_descriptor_t;
// API data for dispatch command segments.
typedef struct iree_hal_metal_dispatch_segment_t {
// Compute kernel information--kernel object, pipeline layout, threadgroup size, etc.
iree_hal_metal_kernel_params_t kernel_params;
// Workgroup count information--if |workgroups_buffer| is not nil, then indirect dispatch;
// otherwise uses |workgroup_count| for direct dispatch.
id<MTLBuffer> workgroups_buffer;
iree_device_size_t workgroups_offset;
uint32_t workgroup_count[3];
// The number of descriptors bound for this dispatch.
iree_host_size_t descriptor_count;
// The list of bound descriptors, pointing to the end of the segment allocation.
iree_hal_metal_descriptor_t* descriptors;
// The number of push constant values.
iree_host_size_t push_constant_count;
// The list of push constants, pointing to the end of the segment allocation.
int32_t* push_constants;
} iree_hal_metal_dispatch_segment_t;
// + Additional inline allocation for holding all bound descriptors.
// + Additional inline allocation for holding all push constants.
// API data for fill buffer command segments.
typedef struct iree_hal_metal_fill_buffer_segment_t {
id<MTLBuffer> target_buffer;
iree_device_size_t target_offset;
iree_device_size_t length;
// The fill pattern, pointing to the end of the segment allocation.
const void* pattern;
iree_host_size_t pattern_length;
} iree_hal_metal_fill_buffer_segment_t;
// + Additional inline allocation for holding the fill pattern.
// API data for copy buffer command segments.
typedef struct iree_hal_metal_copy_buffer_segment_t {
id<MTLBuffer> source_buffer;
iree_device_size_t source_offset;
id<MTLBuffer> target_buffer;
iree_device_size_t target_offset;
iree_device_size_t length;
} iree_hal_metal_copy_buffer_segment_t;
struct iree_hal_metal_command_segment_t;
typedef struct iree_hal_metal_command_segment_t {
struct iree_hal_metal_command_segment_t* next_segment;
iree_hal_metal_command_segment_action_t action;
union {
iree_hal_metal_barrier_segment_t barrier;
iree_hal_metal_dispatch_segment_t dispatch;
iree_hal_metal_fill_buffer_segment_t fill_buffer;
iree_hal_metal_copy_buffer_segment_t copy_buffer;
};
} iree_hal_metal_command_segment_t;
typedef struct iree_hal_metal_command_segment_list_t {
iree_hal_metal_command_segment_t* head;
iree_hal_metal_command_segment_t* tail;
} iree_hal_metal_command_segment_list_t;
static void iree_hal_metal_command_segment_list_reset(iree_hal_metal_command_segment_list_t* list) {
memset(list, 0, sizeof(*list));
}
static void iree_hal_metal_command_segment_list_push_front(
iree_hal_metal_command_segment_list_t* list, iree_hal_metal_command_segment_t* segment) {
segment->next_segment = list->head;
list->head = segment;
if (!list->tail) list->tail = segment;
}
static void iree_hal_metal_command_segment_list_push_back(
iree_hal_metal_command_segment_list_t* list, iree_hal_metal_command_segment_t* segment) {
segment->next_segment = NULL;
if (list->tail) {
list->tail->next_segment = segment;
list->tail = segment;
} else {
list->head = list->tail = segment;
}
}
//===------------------------------------------------------------------------------------------===//
// iree_hal_metal_command_buffer_t
//===------------------------------------------------------------------------------------------===//
typedef struct iree_hal_metal_command_buffer_t {
iree_hal_command_buffer_t base;
// The Metal command queue owning this command buffer.
id<MTLCommandQueue> queue;
// For polyfilling fill/copy/update buffers that are not directly supported by Metal APIs.
iree_hal_metal_builtin_executable_t* builtin_executable;
// Arena used for all allocations; references the shared device block pool.
iree_arena_allocator_t arena;
// Per-queue shared uniform staging buffer for uploading parameters to the GPU, including argument
// buffers and buffer update source buffers.
iree_hal_metal_staging_buffer_t* staging_buffer;
iree_allocator_t host_allocator;
// Maintains a reference to all resources used within the command buffer. Resets on each begin.
iree_hal_resource_set_t* resource_set;
// Linked list of command segments to be recorded into a command buffer.
iree_hal_metal_command_segment_list_t segments;
id<MTLCommandBuffer> command_buffer;
MTLDispatchType dispatch_type;
struct {
// The current active compute/blit encoders for encoding compute for memory operations.
// Metal commands are encoded into the command buffer with such encoders, and each encoder can
// only encode the specific type of operations it supports.
id<MTLComputeCommandEncoder> compute_encoder;
id<MTLBlitCommandEncoder> blit_encoder;
// MTLEven used for synchronization when we switch between blit and compute encoders.
// Normally we would use MTLFence objects, but the difference between IREE HAL and Metal API
// means we may see many encoder switches. It would require creating a lot GPU objects. In order
// to avoid the cost, we just use one MTLEvent with different values for different switches.
id<MTLEvent> encoder_event;
// The next available encoder event value to signal/wait to/on.
uint64_t next_encoder_event_value;
// Metal APIs mandate we create argument bufffers (for descriptor sets) from compiled kernel
// function. That means we need to bind the compute kernel first before setting descriptors and
// binding buffers. However in IREE HAL API we see push descriptors before the dispatch command.
// So we need to cache the descriptor information by ourselves and record them at dispatch time.
struct {
iree_hal_metal_descriptor_t bindings[IREE_HAL_METAL_MAX_DESCRIPTOR_SET_BINDING_COUNT];
} descriptor_sets[IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX];
// All available push constants updated each time push_constants is called. Reset only with the
// command buffer and otherwise will maintain its values during recording to allow for partial
// push_constants updates.
int32_t push_constants[IREE_HAL_METAL_MAX_PUSH_CONSTANT_COUNT];
} state;
} iree_hal_metal_command_buffer_t;
//===------------------------------------------------------------------------------------------===//
// iree_hal_metal_command_buffer_vtable APIs
//===------------------------------------------------------------------------------------------===//
static const iree_hal_command_buffer_vtable_t iree_hal_metal_command_buffer_vtable;
static iree_hal_metal_command_buffer_t* iree_hal_metal_command_buffer_cast(
iree_hal_command_buffer_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_command_buffer_vtable);
return (iree_hal_metal_command_buffer_t*)base_value;
}
static const iree_hal_metal_command_buffer_t* iree_hal_metal_command_buffer_const_cast(
const iree_hal_command_buffer_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_command_buffer_vtable);
return (const iree_hal_metal_command_buffer_t*)base_value;
}
id<MTLCommandBuffer> iree_hal_metal_direct_command_buffer_handle(
const iree_hal_command_buffer_t* base_command_buffer) {
const iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_const_cast(base_command_buffer);
return command_buffer->command_buffer;
}
static void iree_hal_metal_end_compute_encoder(iree_hal_metal_command_buffer_t* command_buffer) {
if (command_buffer->state.compute_encoder) {
[command_buffer->state.compute_encoder endEncoding];
[command_buffer->state.compute_encoder release]; // -1
command_buffer->state.compute_encoder = nil;
}
}
static void iree_hal_metal_end_blit_encoder(iree_hal_metal_command_buffer_t* command_buffer) {
if (command_buffer->state.blit_encoder) {
[command_buffer->state.blit_encoder endEncoding];
[command_buffer->state.blit_encoder release]; // -1
command_buffer->state.blit_encoder = nil;
}
}
static void iree_hal_metal_command_buffer_reset(iree_hal_metal_command_buffer_t* command_buffer) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_end_blit_encoder(command_buffer);
iree_hal_metal_end_compute_encoder(command_buffer);
iree_hal_metal_command_segment_list_reset(&command_buffer->segments);
iree_arena_reset(&command_buffer->arena);
IREE_TRACE_ZONE_END(z0);
}
static id<MTLComputeCommandEncoder> iree_hal_metal_get_or_begin_compute_encoder(
iree_hal_metal_command_buffer_t* command_buffer) {
id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer;
// If we are switching encoders, we would need to use a fence to synchronize "one or more
// resources across different passes within a command buffer."
// https://developer.apple.com/documentation/metal/resource_synchronization
uint64_t encoder_event_value = 0;
if (command_buffer->state.blit_encoder) {
iree_hal_metal_end_blit_encoder(command_buffer);
encoder_event_value = command_buffer->state.next_encoder_event_value++;
[metal_handle encodeSignalEvent:command_buffer->state.encoder_event value:encoder_event_value];
}
if (!command_buffer->state.compute_encoder) {
if (encoder_event_value != 0) {
[metal_handle encodeWaitForEvent:command_buffer->state.encoder_event
value:encoder_event_value];
}
@autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation.
// We manage commands dependencies and insert barriers explicitly in IREE; so use the
// concurrent dispatch type for compute encoders.
command_buffer->state.compute_encoder = [[metal_handle
computeCommandEncoderWithDispatchType:command_buffer->dispatch_type] retain]; // +1
}
}
return command_buffer->state.compute_encoder;
}
static id<MTLBlitCommandEncoder> iree_hal_metal_get_or_begin_blit_encoder(
iree_hal_metal_command_buffer_t* command_buffer) {
id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer;
// If we are switching encoders, we would need to use a fence to synchronize "one or more
// resources across different passes within a command buffer."
// https://developer.apple.com/documentation/metal/resource_synchronization
uint64_t encoder_event_value = 0;
if (command_buffer->state.compute_encoder) {
iree_hal_metal_end_compute_encoder(command_buffer);
encoder_event_value = command_buffer->state.next_encoder_event_value++;
[metal_handle encodeSignalEvent:command_buffer->state.encoder_event value:encoder_event_value];
}
if (!command_buffer->state.blit_encoder) {
if (encoder_event_value != 0) {
[metal_handle encodeWaitForEvent:command_buffer->state.encoder_event
value:encoder_event_value];
}
@autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation.
command_buffer->state.blit_encoder = [[metal_handle blitCommandEncoder] retain]; // +1
}
}
return command_buffer->state.blit_encoder;
}
// Destroys the given |base_command_buffer| itself, without decreasing refcount in the shared
// staging buffer yet.
static void iree_hal_metal_command_buffer_destroy_internal(
iree_hal_command_buffer_t* base_command_buffer);
iree_status_t iree_hal_metal_direct_command_buffer_create(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity,
iree_hal_metal_command_buffer_resource_reference_mode_t resource_reference_mode,
id<MTLCommandQueue> queue, iree_arena_block_pool_t* block_pool,
iree_hal_metal_staging_buffer_t* staging_buffer,
iree_hal_metal_builtin_executable_t* builtin_executable, iree_allocator_t host_allocator,
iree_hal_command_buffer_t** out_command_buffer) {
IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(out_command_buffer);
IREE_ASSERT_TRUE(iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT));
IREE_ASSERT_TRUE(!iree_any_bit_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED));
*out_command_buffer = NULL;
if (binding_capacity > 0) {
// TODO(#10144): support indirect command buffers with binding tables.
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "indirect command buffer not yet supported");
}
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*command_buffer), (void**)&command_buffer));
iree_hal_command_buffer_initialize(device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY,
binding_capacity, &iree_hal_metal_command_buffer_vtable,
&command_buffer->base);
command_buffer->queue = [queue retain]; // +1
command_buffer->builtin_executable = builtin_executable;
iree_arena_initialize(block_pool, &command_buffer->arena);
command_buffer->staging_buffer = staging_buffer;
command_buffer->host_allocator = host_allocator;
iree_status_t status = iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set);
if (iree_status_is_ok(status)) {
iree_hal_metal_command_segment_list_reset(&command_buffer->segments);
@autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation.
// We track resource lifetime by ourselves in IREE; so just do unretained references to
// resources in Metal command buffer, which avoids overhead and gives better performance.
MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new]; // +1
descriptor.retainedReferences =
resource_reference_mode == IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED;
descriptor.errorOptions = MTLCommandBufferErrorOptionNone;
command_buffer->command_buffer =
[[queue commandBufferWithDescriptor:descriptor] retain]; // +1
[descriptor release]; // -1
}
const iree_hal_metal_device_params_t* params = iree_hal_metal_device_params(device);
command_buffer->dispatch_type =
params->command_dispatch_type == IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT
? MTLDispatchTypeConcurrent
: MTLDispatchTypeSerial;
command_buffer->state.compute_encoder = nil;
command_buffer->state.blit_encoder = nil;
command_buffer->state.encoder_event = [queue.device newEvent]; // +1
command_buffer->state.next_encoder_event_value = 1;
}
if (iree_status_is_ok(status)) {
*out_command_buffer = &command_buffer->base;
// Increase command buffer refcount in the shared staging buffer. We tie this to the command
// buffer's lifetime to avoid resource leak.
iree_hal_metal_staging_buffer_increase_refcount(staging_buffer);
} else {
iree_hal_metal_command_buffer_destroy_internal(&command_buffer->base);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
static void iree_hal_metal_command_buffer_destroy_internal(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
iree_hal_metal_command_buffer_reset(command_buffer);
[command_buffer->state.encoder_event release]; // -1
IREE_ASSERT_EQ(command_buffer->state.compute_encoder, nil);
IREE_ASSERT_EQ(command_buffer->state.blit_encoder, nil);
[command_buffer->command_buffer release]; // -1
[command_buffer->queue release]; // -1
iree_hal_resource_set_free(command_buffer->resource_set);
iree_arena_deinitialize(&command_buffer->arena);
iree_allocator_free(command_buffer->host_allocator, command_buffer);
}
static void iree_hal_metal_command_buffer_destroy(iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
// Decrease command buffer refcount in the shared staging buffer, and potentially reclaim
// resources. We tie this to the command buffer's lifetime to avoid resource leak.
if (command_buffer->staging_buffer) {
iree_hal_metal_staging_buffer_decrease_refcount(command_buffer->staging_buffer);
}
iree_hal_metal_command_buffer_destroy_internal(base_command_buffer);
IREE_TRACE_ZONE_END(z0);
}
bool iree_hal_metal_command_buffer_isa(iree_hal_command_buffer_t* command_buffer) {
return iree_hal_resource_is(&command_buffer->resource, &iree_hal_metal_command_buffer_vtable);
}
static void iree_hal_metal_command_buffer_begin_debug_group(
iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label,
iree_hal_label_color_t label_color, const iree_hal_label_location_t* location) {
// TODO(antiagainst): implement support for debug group
}
static void iree_hal_metal_command_buffer_end_debug_group(
iree_hal_command_buffer_t* base_command_buffer) {
// TODO(antiagainst): implement support for debug group
}
static iree_status_t iree_hal_metal_command_buffer_prepare_barrier(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask, iree_hal_execution_barrier_flags_t flags,
iree_host_size_t memory_barrier_count, const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count, const iree_hal_buffer_barrier_t* buffer_barriers) {
if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) ||
iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "barrier involving host not yet supported");
}
if (flags != IREE_HAL_EXECUTION_BARRIER_FLAG_NONE) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "non-zero barrier flag not yet supported");
}
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
// Allocate the command segment and keep track of all necessary API data.
uint8_t* storage_base = NULL;
iree_hal_metal_command_segment_t* segment = NULL;
iree_host_size_t buffer_barrier_length = buffer_barrier_count * sizeof(iree_hal_buffer_barrier_t);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_arena_allocate(&command_buffer->arena, sizeof(*segment) + buffer_barrier_length,
(void**)&storage_base));
// Copy the buffer barriers to the end of the current segments for later access. We don't copy
// memory barriers because in Metal there is only coarse-grained full memory barrier affecting
// all buffers, regardless of the fine-grained details from IREE HAL barriers.
uint8_t* barrier_ptr = storage_base + sizeof(*segment);
memcpy(barrier_ptr, (const uint8_t*)buffer_barriers, buffer_barrier_length);
// Compose and push the barrier segment.
segment = (iree_hal_metal_command_segment_t*)storage_base;
segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER;
iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment);
segment->barrier.memory_barrier_count = memory_barrier_count;
segment->barrier.buffer_barrier_count = buffer_barrier_count;
segment->barrier.buffer_barriers = (const iree_hal_buffer_barrier_t*)barrier_ptr;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static iree_status_t iree_hal_metal_command_segment_record_barrier(
iree_hal_metal_command_buffer_t* command_buffer, iree_hal_metal_barrier_segment_t* segment) {
// TODO(antiagainst): Analyze segments before and after to optimize barriers, e.g., switching
// encoders would require its own synchronization; so we don't need extract barriers in the
// middle.
if (segment->memory_barrier_count == 0 && segment->buffer_barrier_count == 0) {
// There is no direct corresponding APIs for execution only barrier in Metal. We just signal and
// wait on the same value of a MTLEvent here.
iree_hal_metal_end_blit_encoder(command_buffer);
iree_hal_metal_end_compute_encoder(command_buffer);
id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer;
uint64_t event_value = command_buffer->state.next_encoder_event_value++;
[metal_handle encodeSignalEvent:command_buffer->state.encoder_event value:event_value];
[metal_handle encodeWaitForEvent:command_buffer->state.encoder_event value:event_value];
return iree_ok_status();
}
id<MTLComputeCommandEncoder> encoder =
iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
if (segment->memory_barrier_count != 0) {
// If there is a memory barrier specified, we have to place a catch-all barrier for all buffers.
// Metal does not provide a more fine-grained control here.
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
return iree_ok_status();
}
if (segment->buffer_barrier_count != 0) {
// But we do have the option to specify a list of buffers to synchronize if only buffer barriers
// are specified.
id<MTLResource>* resources =
(id<MTLResource>*)iree_alloca(sizeof(id<MTLResource>) * segment->buffer_barrier_count);
for (iree_host_size_t i = 0; i < segment->buffer_barrier_count; ++i) {
resources[i] = iree_hal_metal_buffer_handle(
iree_hal_buffer_allocated_buffer(segment->buffer_barriers[i].buffer));
}
[encoder memoryBarrierWithResources:resources count:segment->buffer_barrier_count];
}
return iree_ok_status();
}
static iree_status_t iree_hal_metal_command_buffer_signal_event(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
}
static iree_status_t iree_hal_metal_command_buffer_reset_event(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
}
static iree_status_t iree_hal_metal_command_buffer_wait_events(
iree_hal_command_buffer_t* base_command_buffer, iree_host_size_t event_count,
const iree_hal_event_t** events, iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask, iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers, iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
}
static iree_status_t iree_hal_metal_command_buffer_discard_buffer(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) {
// This is a hint to the device and we have nothing to do for Metal.
return iree_ok_status();
}
// Fills |value| with the duplicated single byte value and return true if the given |pattern| has
// duplicated values for each of its |pattern_length| bytes.
static bool iree_hal_metal_get_duplicated_single_byte_value(const void* pattern,
size_t pattern_length, uint8_t* value) {
switch (pattern_length) {
case 1: {
*value = *(uint8_t*)pattern;
return true;
}
case 2: {
uint16_t two_bytes = *(uint16_t*)pattern;
uint16_t byte0 = two_bytes & 0xffu;
uint16_t byte1 = two_bytes >> 8u;
if (byte0 == byte1) {
*value = (int8_t)byte0;
return true;
}
break;
}
case 4: {
uint32_t four_bytes = *(uint32_t*)pattern;
uint32_t byte0 = four_bytes & 0xffu;
uint32_t byte1 = (four_bytes >> 8u) & 0xffu;
uint32_t byte2 = (four_bytes >> 16u) & 0xffu;
uint32_t byte3 = four_bytes >> 24u;
if (byte0 == byte1 && byte0 == byte2 && byte0 == byte3) {
*value = (int8_t)byte0;
return true;
}
break;
}
default:
break;
}
return false;
}
// Duplicates the given |pattern| into 4-bytes and returns the value.
static uint32_t iree_hal_metal_duplicate_to_four_byte_value(const void* pattern,
size_t pattern_length) {
if (pattern_length == 1) {
uint8_t single_byte = *(uint8_t*)pattern;
uint32_t value = (uint32_t)single_byte;
value |= (value << 8u);
value |= (value << 16u);
return value;
}
if (pattern_length == 2) {
uint16_t two_bytes = *(uint16_t*)pattern;
uint32_t value = (uint32_t)two_bytes;
value |= (value << 16u);
return value;
}
IREE_ASSERT(pattern_length == 4);
return *(uint32_t*)pattern;
}
static iree_status_t iree_hal_metal_command_buffer_prepare_fill_buffer(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length, const void* pattern,
iree_host_size_t pattern_length) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
id<MTLBuffer> target_device_buffer =
iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer));
target_offset += iree_hal_buffer_byte_offset(target_buffer);
// Allocate the command segment and keep track of all necessary API data.
uint8_t* storage_base = NULL;
iree_hal_metal_command_segment_t* segment = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_arena_allocate(&command_buffer->arena, sizeof(*segment) + pattern_length,
(void**)&storage_base));
// Copy the patttern to the end of the segment for later access.
uint8_t* pattern_ptr = storage_base + sizeof(*segment);
memcpy(pattern_ptr, (const uint8_t*)pattern, pattern_length);
// Compose and push the fill buffer segment.
segment = (iree_hal_metal_command_segment_t*)storage_base;
segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER;
iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment);
segment->fill_buffer.target_buffer = target_device_buffer;
segment->fill_buffer.target_offset = target_offset;
segment->fill_buffer.length = length;
segment->fill_buffer.pattern = (const void*)pattern_ptr;
segment->fill_buffer.pattern_length = pattern_length;
iree_status_t status =
iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer);
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_hal_metal_command_segment_record_fill_buffer(
iree_hal_metal_command_buffer_t* command_buffer,
iree_hal_metal_fill_buffer_segment_t* segment) {
IREE_TRACE_ZONE_BEGIN(z0);
// Note that fillBuffer:range:value: only accepts a single byte as the pattern but FillBuffer
// can accept 1/2/4 bytes. If the pattern itself contains repeated bytes, we can call into
// fillBuffer:range:value:. Otherwise we need to emulate the support.
uint8_t pattern_1byte = 0u;
// Per the spec for fillBuffer:range:value: "The alignment and length of the range must both be a
// multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS."
#if defined(IREE_PLATFORM_MACOS)
const bool can_use_metal_api = segment->target_offset % 4 == 0 && segment->length % 4 == 0 &&
iree_hal_metal_get_duplicated_single_byte_value(
segment->pattern, segment->pattern_length, &pattern_1byte);
#else
const bool can_use_metal_api = iree_hal_metal_get_duplicated_single_byte_value(
segment->pattern, segment->pattern_length, &pattern_1byte);
#endif
if (can_use_metal_api) {
id<MTLBlitCommandEncoder> encoder = iree_hal_metal_get_or_begin_blit_encoder(command_buffer);
[encoder fillBuffer:segment->target_buffer
range:NSMakeRange(segment->target_offset, segment->length)
value:pattern_1byte];
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
id<MTLComputeCommandEncoder> compute_encoder =
iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
uint32_t pattern_4byte =
iree_hal_metal_duplicate_to_four_byte_value(segment->pattern, segment->pattern_length);
iree_status_t status = iree_hal_metal_builtin_executable_fill_buffer(
command_buffer->builtin_executable, compute_encoder, segment->target_buffer,
segment->target_offset, segment->length, pattern_4byte);
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_hal_metal_command_segment_create_copy_buffer(
iree_hal_metal_command_buffer_t* command_buffer, id<MTLBuffer> source_device_buffer,
iree_device_size_t source_offset, id<MTLBuffer> target_device_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
IREE_TRACE_ZONE_BEGIN(z0);
// Allocate the command segment and keep track of all necessary API data.
uint8_t* storage_base = NULL;
iree_hal_metal_command_segment_t* segment = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_arena_allocate(&command_buffer->arena, sizeof(*segment), (void**)&storage_base));
// Compose and push the barrier segment.
segment = (iree_hal_metal_command_segment_t*)storage_base;
segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER;
iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment);
segment->copy_buffer.source_buffer = source_device_buffer;
segment->copy_buffer.source_offset = source_offset;
segment->copy_buffer.target_buffer = target_device_buffer;
segment->copy_buffer.target_offset = target_offset;
segment->copy_buffer.length = length;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static iree_status_t iree_hal_metal_command_segment_record_copy_buffer(
iree_hal_metal_command_buffer_t* command_buffer,
iree_hal_metal_copy_buffer_segment_t* segment) {
IREE_TRACE_ZONE_BEGIN(z0);
// Per the spec for copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size, the source/target
// offset and length must be a multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS.
#if defined(IREE_PLATFORM_MACOS)
bool can_use_metal_api = segment->source_offset % 4 == 0 && segment->target_offset % 4 == 0 &&
segment->length % 4 == 0;
#else
bool can_use_metal_api = true;
#endif
iree_status_t status = iree_ok_status();
if (can_use_metal_api) {
id<MTLBlitCommandEncoder> encoder = iree_hal_metal_get_or_begin_blit_encoder(command_buffer);
[encoder copyFromBuffer:segment->source_buffer
sourceOffset:segment->source_offset
toBuffer:segment->target_buffer
destinationOffset:segment->target_offset
size:segment->length];
} else {
id<MTLComputeCommandEncoder> encoder =
iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
status = iree_hal_metal_builtin_executable_copy_buffer(
command_buffer->builtin_executable, encoder, segment->source_buffer, segment->source_offset,
segment->target_buffer, segment->target_offset, segment->length);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_hal_metal_command_buffer_prepare_update_buffer(
iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
// There are no direct corresponding APIs in Metal. We update the source buffer data to the
// staging buffer and then copy over.
iree_const_byte_span_t source_data_span =
iree_make_const_byte_span((uint8_t*)source_buffer + source_offset, length);
uint32_t offset = 0;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_metal_staging_buffer_append(command_buffer->staging_buffer, source_data_span,
/*alignment=*/4, &offset));
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer));
id<MTLBuffer> target_device_buffer =
iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer));
target_offset += iree_hal_buffer_byte_offset(target_buffer);
iree_status_t status = iree_hal_metal_command_segment_create_copy_buffer(
command_buffer, command_buffer->staging_buffer->metal_buffer, offset, target_device_buffer,
target_offset, length);
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_hal_metal_command_buffer_prepare_copy_buffer(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* source_buffer,
iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
const iree_hal_buffer_t* buffers[2] = {source_buffer, target_buffer};
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 2, buffers));
id<MTLBuffer> source_device_buffer =
iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(source_buffer));
id<MTLBuffer> target_device_buffer =
iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer));
source_offset += iree_hal_buffer_byte_offset(source_buffer);
target_offset += iree_hal_buffer_byte_offset(target_buffer);
iree_status_t status = iree_hal_metal_command_segment_create_copy_buffer(
command_buffer, source_device_buffer, source_offset, target_device_buffer, target_offset,
length);
IREE_TRACE_ZONE_END(z0);
return status;
}
static iree_status_t iree_hal_metal_command_buffer_collective(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_binding_t send_binding,
iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "collectives not yet supported");
}
static iree_status_t iree_hal_metal_command_buffer_push_constants(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout,
iree_host_size_t offset, const void* values, iree_host_size_t values_length) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
// "Binding a pipeline with a layout that is not compatible with the push constant layout does not
// disturb the push constant values." So we don't need to check whether the pipeline layout
// compatibility and invalidate existing values.
if (IREE_UNLIKELY(offset + values_length >= sizeof(command_buffer->state.push_constants))) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"push constant range [%zu, %zu) out of range", offset,
offset + values_length);
}
memcpy((uint8_t*)&command_buffer->state.push_constants + offset, values, values_length);
return iree_ok_status();
}
static inline MTLResourceUsage iree_hal_metal_get_metal_resource_usage(
const iree_hal_descriptor_set_layout_binding_t* binding) {
MTLResourceUsage usage = MTLResourceUsageRead;
if (binding->flags != IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY) usage |= MTLResourceUsageWrite;
return usage;
}
static iree_status_t iree_hal_metal_command_buffer_push_descriptor_set(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout,
uint32_t set, iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
if (binding_count > IREE_HAL_METAL_MAX_DESCRIPTOR_SET_BINDING_COUNT) {
return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
"exceeded available binding slots for push descriptor set #%u; "
"requested %lu vs. maximal %d",
set, binding_count, IREE_HAL_METAL_MAX_DESCRIPTOR_SET_BINDING_COUNT);
}
IREE_TRACE_ZONE_BEGIN(z0);
IREE_ASSERT(set < IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX);
const iree_hal_descriptor_set_layout_t* set_layout =
iree_hal_metal_pipeline_layout_descriptor_set_layout(pipeline_layout, set);
iree_hal_metal_descriptor_t* descriptors = command_buffer->state.descriptor_sets[set].bindings;
// Update descriptors in the current set.
for (iree_host_size_t i = 0; i < binding_count; ++i) {
iree_hal_metal_descriptor_t* descriptor = &descriptors[i];
descriptor->set = set;
descriptor->binding = bindings[i].binding;
descriptor->buffer = bindings[i].buffer;
descriptor->offset = bindings[i].offset;
const iree_hal_descriptor_set_layout_binding_t* binding_params =
iree_hal_metal_descriptor_set_layout_binding(set_layout, descriptor->binding);
descriptor->usage = iree_hal_metal_get_metal_resource_usage(binding_params);
}
// Retain all buffers bound in this descriptor set.
for (iree_host_size_t i = 0; i < binding_count; ++i) {
if (bindings[i].buffer) {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &bindings[i].buffer));
}
}
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &pipeline_layout));
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
// Prepares kernels and argument buffers needed for kernel dispatches.
static iree_status_t iree_hal_metal_command_segment_create_dispatch(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
int32_t entry_point, iree_hal_metal_dispatch_segment_t** out_segment) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &executable));
iree_hal_metal_kernel_params_t kernel_params;
IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_kernel_library_entry_point_kernel_params(
executable, entry_point, &kernel_params));
// Allocate the command segment and keep track of all necessary API data.
uint8_t* storage_base = NULL;
iree_hal_metal_command_segment_t* segment = NULL;
const iree_host_size_t set_count =
iree_hal_metal_pipeline_layout_descriptor_set_count(kernel_params.layout);
iree_host_size_t descriptor_count = 0;
// Calculate the total number of bindings across all descriptor sets.
for (iree_host_size_t i = 0; i < set_count; ++i) {
const iree_hal_descriptor_set_layout_t* set_layout =
iree_hal_metal_pipeline_layout_descriptor_set_layout(kernel_params.layout, i);
descriptor_count += iree_hal_metal_descriptor_set_layout_binding_count(set_layout);
}
iree_host_size_t descriptor_length = descriptor_count * sizeof(iree_hal_metal_descriptor_t);
iree_host_size_t push_constant_count =
iree_hal_metal_pipeline_layout_push_constant_count(kernel_params.layout);
iree_host_size_t push_constant_length = push_constant_count * sizeof(int32_t);
iree_host_size_t total_size = sizeof(*segment) + descriptor_length + push_constant_length;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_arena_allocate(&command_buffer->arena, total_size, (void**)&storage_base));
// Compose and push the dispatch segment.
segment = (iree_hal_metal_command_segment_t*)storage_base;
memset(segment, 0, sizeof(*segment));
segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH;
iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment);
segment->dispatch.kernel_params = kernel_params;
// Copy descriptors from all sets to the end of the current segment for later access.
segment->dispatch.descriptor_count = descriptor_count;
uint8_t* descriptor_ptr = storage_base + sizeof(*segment);
segment->dispatch.descriptors = (iree_hal_metal_descriptor_t*)descriptor_ptr;
for (iree_host_size_t i = 0; i < set_count; ++i) {
const iree_hal_descriptor_set_layout_t* set_layout =
iree_hal_metal_pipeline_layout_descriptor_set_layout(kernel_params.layout, i);
iree_host_size_t binding_count = iree_hal_metal_descriptor_set_layout_binding_count(set_layout);
iree_host_size_t current_size = binding_count * sizeof(iree_hal_metal_descriptor_t);
memcpy(descriptor_ptr, command_buffer->state.descriptor_sets[i].bindings, current_size);
descriptor_ptr += current_size;
}
// Copy push constants to the end of the current segment for later access.
segment->dispatch.push_constant_count = push_constant_count;
uint8_t* push_constant_ptr = storage_base + sizeof(*segment) + descriptor_length;
segment->dispatch.push_constants = (int32_t*)push_constant_ptr;
memcpy(push_constant_ptr, (const uint8_t*)command_buffer->state.push_constants,
push_constant_length);
*out_segment = &segment->dispatch;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static iree_status_t iree_hal_metal_command_segment_record_dispatch(
iree_hal_metal_command_buffer_t* command_buffer, iree_hal_metal_dispatch_segment_t* segment) {
IREE_TRACE_ZONE_BEGIN(z0);
// Set the compute kernel to dispatch.
id<MTLComputeCommandEncoder> compute_encoder =
iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
[compute_encoder setComputePipelineState:segment->kernel_params.pso];
// Record push constants.
if (segment->push_constant_count != 0) {
[compute_encoder setBytes:(void*)segment->push_constants
length:segment->push_constant_count * sizeof(int32_t)
atIndex:IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX];
}
// Record argument buffers for all descriptors and record buffer usages.
iree_hal_metal_descriptor_t* descriptors = segment->descriptors;
for (iree_host_size_t i = 0; i < segment->descriptor_count;) {
uint32_t current_set = descriptors[i].set;
// Build argument encoder and argument buffer for the current descriptor set.
// TODO(antiagainst): Use a cache layer to cache and reuse argument buffers with the same
// content, to avoid duplicating overhead.
id<MTLBuffer> argument_buffer = command_buffer->staging_buffer->metal_buffer;
id<MTLArgumentEncoder> argument_encoder =
[segment->kernel_params.function newArgumentEncoderWithBufferIndex:current_set]; // +1
IREE_ASSERT(argument_encoder != nil);
// Reserve space for the argument buffer from shared staging buffer.
iree_byte_span_t reservation;
uint32_t argument_buffer_offset;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_metal_staging_buffer_reserve(
command_buffer->staging_buffer, argument_encoder.encodedLength,
argument_encoder.alignment, &reservation, &argument_buffer_offset));
[argument_encoder setArgumentBuffer:argument_buffer offset:argument_buffer_offset];
// Now record all bound buffers belonging to the current set into the argument buffer.
for (; i < segment->descriptor_count && descriptors[i].set == current_set; ++i) {
uint32_t current_binding = descriptors[i].binding;
id<MTLBuffer> current_buffer =
iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer));
iree_host_size_t offset =
iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset;
[argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding];
// Also record buffer usages.
[compute_encoder useResource:current_buffer usage:descriptors[i].usage];
}
// Record the argument buffer.
[compute_encoder setBuffer:argument_buffer offset:argument_buffer_offset atIndex:current_set];
[argument_encoder release]; // -1
}
// Record the dispatch, either direct or indirect.
uint32_t* workgroup_size = segment->kernel_params.threadgroup_size;
if (segment->workgroups_buffer == nil) {
// Direct dispatch of a fixed workgroup count.
uint32_t* workgroup_count = segment->workgroup_count;
[compute_encoder
dispatchThreadgroups:MTLSizeMake(workgroup_count[0], workgroup_count[1],
workgroup_count[2])
threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1], workgroup_size[2])];
} else {
// Indirect dispatch using a workgroup count from buffers.
[compute_encoder
dispatchThreadgroupsWithIndirectBuffer:segment->workgroups_buffer
indirectBufferOffset:segment->workgroups_offset
threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1],
workgroup_size[2])];
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
int32_t entry_point, uint32_t workgroup_count_x, uint32_t workgroup_count_y,
uint32_t workgroup_count_z) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_dispatch_segment_t* segment = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_metal_command_segment_create_dispatch(base_command_buffer, executable,
entry_point, &segment));
segment->workgroup_count[0] = workgroup_count_x;
segment->workgroup_count[1] = workgroup_count_y;
segment->workgroup_count[2] = workgroup_count_z;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
int32_t entry_point, iree_hal_buffer_t* workgroups_buffer,
iree_device_size_t workgroups_offset) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_dispatch_segment_t* segment = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_metal_command_segment_create_dispatch(base_command_buffer, executable,
entry_point, &segment));
segment->workgroups_buffer =
iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(workgroups_buffer));
segment->workgroups_offset = workgroups_offset;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static iree_status_t iree_hal_metal_command_buffer_execute_commands(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_command_buffer_t* base_commands,
iree_hal_buffer_binding_table_t binding_table) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "secondary command buffer not yet supported");
}
static iree_status_t iree_hal_metal_command_segment_record(
iree_hal_metal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
for (iree_hal_metal_command_segment_t* segment = command_buffer->segments.head; segment;
segment = segment->next_segment) {
switch (segment->action) {
case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER: {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_metal_command_segment_record_barrier(command_buffer, &segment->barrier));
} break;
case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH: {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_metal_command_segment_record_dispatch(command_buffer, &segment->dispatch));
} break;
case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER: {
IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_fill_buffer(
command_buffer, &segment->fill_buffer));
} break;
case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER: {
IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_copy_buffer(
command_buffer, &segment->copy_buffer));
} break;
default:
IREE_ASSERT(false, "unhandled command segment kind");
break;
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static iree_status_t iree_hal_metal_command_buffer_begin(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
iree_hal_metal_command_buffer_reset(command_buffer);
return iree_ok_status();
}
static iree_status_t iree_hal_metal_command_buffer_end(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_metal_command_buffer_t* command_buffer =
iree_hal_metal_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record(command_buffer));
iree_hal_metal_end_blit_encoder(command_buffer);
iree_hal_metal_end_compute_encoder(command_buffer);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static const iree_hal_command_buffer_vtable_t iree_hal_metal_command_buffer_vtable = {
.destroy = iree_hal_metal_command_buffer_destroy,
.begin = iree_hal_metal_command_buffer_begin,
.end = iree_hal_metal_command_buffer_end,
.begin_debug_group = iree_hal_metal_command_buffer_begin_debug_group,
.end_debug_group = iree_hal_metal_command_buffer_end_debug_group,
.execution_barrier = iree_hal_metal_command_buffer_prepare_barrier,
.signal_event = iree_hal_metal_command_buffer_signal_event,
.reset_event = iree_hal_metal_command_buffer_reset_event,
.wait_events = iree_hal_metal_command_buffer_wait_events,
.discard_buffer = iree_hal_metal_command_buffer_discard_buffer,
.fill_buffer = iree_hal_metal_command_buffer_prepare_fill_buffer,
.update_buffer = iree_hal_metal_command_buffer_prepare_update_buffer,
.copy_buffer = iree_hal_metal_command_buffer_prepare_copy_buffer,
.collective = iree_hal_metal_command_buffer_collective,
.push_constants = iree_hal_metal_command_buffer_push_constants,
.push_descriptor_set = iree_hal_metal_command_buffer_push_descriptor_set,
.dispatch = iree_hal_metal_command_buffer_prepare_dispatch,
.dispatch_indirect = iree_hal_metal_command_buffer_prepare_dispatch_indirect,
.execute_commands = iree_hal_metal_command_buffer_execute_commands,
};