blob: c5042190e7095c2d922e54b7bf8ddb1796fac27a [file] [log] [blame]
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001// Copyright 2023 The IREE Authors
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7#include "experimental/metal/direct_command_buffer.h"
8
9#import <Metal/Metal.h>
10
Lei Zhang182db9d2023-02-20 13:38:12 -080011#include "experimental/metal/builtin_executables.h"
Lei Zhangdf1e9a22023-02-12 12:08:00 -080012#include "experimental/metal/metal_buffer.h"
13#include "experimental/metal/metal_device.h"
14#include "experimental/metal/metal_kernel_library.h"
15#include "experimental/metal/pipeline_layout.h"
Lei Zhangf598fd22023-05-08 07:48:57 -070016#include "experimental/metal/staging_buffer.h"
Lei Zhangdf1e9a22023-02-12 12:08:00 -080017#include "iree/base/api.h"
Lei Zhang063e1fa2023-02-12 19:27:54 -080018#include "iree/base/target_platform.h"
Lei Zhangdf1e9a22023-02-12 12:08:00 -080019#include "iree/base/tracing.h"
20#include "iree/hal/api.h"
21#include "iree/hal/utils/resource_set.h"
22
Lei Zhangc0ad0ea2023-05-06 18:02:02 -070023//===------------------------------------------------------------------------------------------===//
24// Segmented submission management
25//===------------------------------------------------------------------------------------------===//
26
27// Unlike Vulkan, Metal adopts a multi-level command recording model--memory/dispatch commands are
28// not directly recorded into a command buffer; rather, they must go through the additional level of
29// blit/compute encoders. IREE's HAL follows the flat Vulkan command buffer recording model, so we
30// have a mismatch here. Implementing IREE's HAL using Metal would require switching encoders for
31// interleaved memory and dispatch commands. Additionally, certain IREE HAL API features do not have
32// direct mapping in Metal APIs, e.g., various forms of IREE HAL execution/memory barriers.
33// Translating them would require looking at both previous and next commands to decide the proper
34// mapping.
35//
36// Due to these reasons, it's beneficial to have a complete view of the full command buffer and
37// extra flexibility during recording, in order to fixup past commands, or inspect future commands.
38//
39// Therefore, to implement IREE HAL command buffers using Metal, we perform two steps using a linked
40// list of command segments. First we create segments (iree_hal_metal_command_buffer_prepare_* and
41// iree_hal_metal_command_segment_create_*) to keep track of all IREE HAL commands and the
42// associated data, and then, when finalizing the command buffer, we iterate through all the
43// segments and record their contents (iree_hal_metal_command_segment_record_*) into a proper Metal
44// command buffer . A linked list gives us the flexibility to organize command sequence in low
45// overhead; and a deferred recording gives us the complete picture of the command buffer when
46// really started recording.
47
48// Command action kind of a command segment.
49typedef enum iree_hal_metal_command_segment_action_e {
50 IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER, // Execution/memory barrier command
51 IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH, // Dispatch command
52 IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER, // Fill buffer command
53 IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER, // Copy buffer command
54} iree_hal_metal_command_segment_action_t;
55
56// API data for execution/memory barrier command segments.
57typedef struct iree_hal_metal_barrier_segment_t {
58 iree_host_size_t memory_barrier_count; // Total number of memory barriers
59 iree_host_size_t buffer_barrier_count; // Total number of buffer barriers
60 // The list of buffer barriers, pointing to the end of the segment allocation.
61 const iree_hal_buffer_barrier_t* buffer_barriers;
62} iree_hal_metal_barrier_segment_t;
63// + Additional inline allocation for holding all buffer barriers.
64
Lei Zhang0ec791e2023-05-07 22:15:34 -070065typedef struct iree_hal_metal_descriptor_t {
66 uint32_t set;
67 uint32_t binding;
68 iree_hal_buffer_t* buffer;
69 iree_device_size_t offset;
Lei Zhangc0ad0ea2023-05-06 18:02:02 -070070 MTLResourceUsage usage;
Lei Zhang0ec791e2023-05-07 22:15:34 -070071} iree_hal_metal_descriptor_t;
Lei Zhangc0ad0ea2023-05-06 18:02:02 -070072
73// API data for dispatch command segments.
74typedef struct iree_hal_metal_dispatch_segment_t {
75 // Compute kernel information--kernel object, pipeline layout, threadgroup size, etc.
76 iree_hal_metal_kernel_params_t kernel_params;
77
78 // Workgroup count information--if |workgroups_buffer| is not nil, then indirect dispatch;
79 // otherwise uses |workgroup_count| for direct dispatch.
80 id<MTLBuffer> workgroups_buffer;
81 iree_device_size_t workgroups_offset;
82 uint32_t workgroup_count[3];
83
Lei Zhang0ec791e2023-05-07 22:15:34 -070084 // The number of descriptors bound for this dispatch.
85 iree_host_size_t descriptor_count;
86 // The list of bound descriptors, pointing to the end of the segment allocation.
87 iree_hal_metal_descriptor_t* descriptors;
Lei Zhangc0ad0ea2023-05-06 18:02:02 -070088
89 // The number of push constant values.
90 iree_host_size_t push_constant_count;
91 // The list of push constants, pointing to the end of the segment allocation.
92 int32_t* push_constants;
93} iree_hal_metal_dispatch_segment_t;
Lei Zhang0ec791e2023-05-07 22:15:34 -070094// + Additional inline allocation for holding all bound descriptors.
Lei Zhangc0ad0ea2023-05-06 18:02:02 -070095// + Additional inline allocation for holding all push constants.
96
97// API data for fill buffer command segments.
98typedef struct iree_hal_metal_fill_buffer_segment_t {
99 id<MTLBuffer> target_buffer;
100 iree_device_size_t target_offset;
101 iree_device_size_t length;
102 // The fill pattern, pointing to the end of the segment allocation.
103 const void* pattern;
104 iree_host_size_t pattern_length;
105} iree_hal_metal_fill_buffer_segment_t;
106// + Additional inline allocation for holding the fill pattern.
107
108// API data for copy buffer command segments.
109typedef struct iree_hal_metal_copy_buffer_segment_t {
110 id<MTLBuffer> source_buffer;
111 iree_device_size_t source_offset;
112 id<MTLBuffer> target_buffer;
113 iree_device_size_t target_offset;
114 iree_device_size_t length;
115} iree_hal_metal_copy_buffer_segment_t;
116
117struct iree_hal_metal_command_segment_t;
118typedef struct iree_hal_metal_command_segment_t {
119 struct iree_hal_metal_command_segment_t* next_segment;
120 iree_hal_metal_command_segment_action_t action;
121 union {
122 iree_hal_metal_barrier_segment_t barrier;
123 iree_hal_metal_dispatch_segment_t dispatch;
124 iree_hal_metal_fill_buffer_segment_t fill_buffer;
125 iree_hal_metal_copy_buffer_segment_t copy_buffer;
126 };
127} iree_hal_metal_command_segment_t;
128
129typedef struct iree_hal_metal_command_segment_list_t {
130 iree_hal_metal_command_segment_t* head;
131 iree_hal_metal_command_segment_t* tail;
132} iree_hal_metal_command_segment_list_t;
133
134static void iree_hal_metal_command_segment_list_reset(iree_hal_metal_command_segment_list_t* list) {
135 memset(list, 0, sizeof(*list));
136}
137
138static void iree_hal_metal_command_segment_list_push_front(
139 iree_hal_metal_command_segment_list_t* list, iree_hal_metal_command_segment_t* segment) {
140 segment->next_segment = list->head;
141 list->head = segment;
142 if (!list->tail) list->tail = segment;
143}
144
145static void iree_hal_metal_command_segment_list_push_back(
146 iree_hal_metal_command_segment_list_t* list, iree_hal_metal_command_segment_t* segment) {
147 segment->next_segment = NULL;
148 if (list->tail) {
149 list->tail->next_segment = segment;
150 list->tail = segment;
151 } else {
152 list->head = list->tail = segment;
153 }
154}
155
156//===------------------------------------------------------------------------------------------===//
157// iree_hal_metal_command_buffer_t
158//===------------------------------------------------------------------------------------------===//
159
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800160typedef struct iree_hal_metal_command_buffer_t {
161 iree_hal_command_buffer_t base;
162
163 // The Metal command queue owning this command buffer.
164 id<MTLCommandQueue> queue;
165
Lei Zhang182db9d2023-02-20 13:38:12 -0800166 // For polyfilling fill/copy/update buffers that are not directly supported by Metal APIs.
167 iree_hal_metal_builtin_executable_t* builtin_executable;
168
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700169 // Arena used for all allocations; references the shared device block pool.
170 iree_arena_allocator_t arena;
171
Lei Zhangf598fd22023-05-08 07:48:57 -0700172 // Per-queue shared uniform staging buffer for uploading parameters to the GPU, including argument
173 // buffers and buffer update source buffers.
174 iree_hal_metal_staging_buffer_t* staging_buffer;
175
Lei Zhangd4aef982023-05-08 13:48:42 -0700176 iree_allocator_t host_allocator;
177
178 // Maintains a reference to all resources used within the command buffer. Resets on each begin.
179 iree_hal_resource_set_t* resource_set;
180
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700181 // Linked list of command segments to be recorded into a command buffer.
182 iree_hal_metal_command_segment_list_t segments;
183
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800184 id<MTLCommandBuffer> command_buffer;
185
186 MTLDispatchType dispatch_type;
187
Lei Zhangd4aef982023-05-08 13:48:42 -0700188 struct {
189 // The current active compute/blit encoders for encoding compute for memory operations.
190 // Metal commands are encoded into the command buffer with such encoders, and each encoder can
191 // only encode the specific type of operations it supports.
192 id<MTLComputeCommandEncoder> compute_encoder;
193 id<MTLBlitCommandEncoder> blit_encoder;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800194
Lei Zhangd4aef982023-05-08 13:48:42 -0700195 // MTLEven used for synchronization when we switch between blit and compute encoders.
196 // Normally we would use MTLFence objects, but the difference between IREE HAL and Metal API
197 // means we may see many encoder switches. It would require creating a lot GPU objects. In order
198 // to avoid the cost, we just use one MTLEvent with different values for different switches.
199 id<MTLEvent> encoder_event;
200 // The next available encoder event value to signal/wait to/on.
201 uint64_t next_encoder_event_value;
Lei Zhang30d52cd2023-04-26 17:14:11 -0700202
Lei Zhangd4aef982023-05-08 13:48:42 -0700203 // Metal APIs mandate we create argument bufffers (for descriptor sets) from compiled kernel
204 // function. That means we need to bind the compute kernel first before setting descriptors and
Lei Zhangd90e80f2023-05-13 21:32:29 -0700205 // binding buffers. However in IREE HAL API we see push descriptors before the dispatch command.
206 // So we need to cache the descriptor information by ourselves and record them at dispatch time.
207 struct {
208 iree_hal_metal_descriptor_t bindings[IREE_HAL_METAL_MAX_DESCRIPTOR_SET_BINDING_COUNT];
Lei Zhangd90e80f2023-05-13 21:32:29 -0700209 } descriptor_sets[IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX];
Lei Zhang30f6a392023-02-12 16:42:27 -0800210
Lei Zhangd4aef982023-05-08 13:48:42 -0700211 // All available push constants updated each time push_constants is called. Reset only with the
212 // command buffer and otherwise will maintain its values during recording to allow for partial
213 // push_constants updates.
214 int32_t push_constants[IREE_HAL_METAL_MAX_PUSH_CONSTANT_COUNT];
Lei Zhangd4aef982023-05-08 13:48:42 -0700215 } state;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800216} iree_hal_metal_command_buffer_t;
217
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700218//===------------------------------------------------------------------------------------------===//
219// iree_hal_metal_command_buffer_vtable APIs
220//===------------------------------------------------------------------------------------------===//
221
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800222static const iree_hal_command_buffer_vtable_t iree_hal_metal_command_buffer_vtable;
223
224static iree_hal_metal_command_buffer_t* iree_hal_metal_command_buffer_cast(
225 iree_hal_command_buffer_t* base_value) {
226 IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_command_buffer_vtable);
227 return (iree_hal_metal_command_buffer_t*)base_value;
228}
229
Lei Zhang54818a32023-06-10 16:01:14 -0700230static const iree_hal_metal_command_buffer_t* iree_hal_metal_command_buffer_const_cast(
231 const iree_hal_command_buffer_t* base_value) {
232 IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_command_buffer_vtable);
233 return (const iree_hal_metal_command_buffer_t*)base_value;
234}
235
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800236id<MTLCommandBuffer> iree_hal_metal_direct_command_buffer_handle(
Lei Zhang54818a32023-06-10 16:01:14 -0700237 const iree_hal_command_buffer_t* base_command_buffer) {
238 const iree_hal_metal_command_buffer_t* command_buffer =
239 iree_hal_metal_command_buffer_const_cast(base_command_buffer);
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800240 return command_buffer->command_buffer;
241}
242
243static void iree_hal_metal_end_compute_encoder(iree_hal_metal_command_buffer_t* command_buffer) {
Lei Zhangd4aef982023-05-08 13:48:42 -0700244 if (command_buffer->state.compute_encoder) {
245 [command_buffer->state.compute_encoder endEncoding];
246 [command_buffer->state.compute_encoder release]; // -1
247 command_buffer->state.compute_encoder = nil;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800248 }
249}
250
251static void iree_hal_metal_end_blit_encoder(iree_hal_metal_command_buffer_t* command_buffer) {
Lei Zhangd4aef982023-05-08 13:48:42 -0700252 if (command_buffer->state.blit_encoder) {
253 [command_buffer->state.blit_encoder endEncoding];
254 [command_buffer->state.blit_encoder release]; // -1
255 command_buffer->state.blit_encoder = nil;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800256 }
257}
258
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700259static void iree_hal_metal_command_buffer_reset(iree_hal_metal_command_buffer_t* command_buffer) {
260 IREE_TRACE_ZONE_BEGIN(z0);
261 iree_hal_metal_end_blit_encoder(command_buffer);
262 iree_hal_metal_end_compute_encoder(command_buffer);
263 iree_hal_metal_command_segment_list_reset(&command_buffer->segments);
264 iree_arena_reset(&command_buffer->arena);
265 IREE_TRACE_ZONE_END(z0);
266}
267
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800268static id<MTLComputeCommandEncoder> iree_hal_metal_get_or_begin_compute_encoder(
269 iree_hal_metal_command_buffer_t* command_buffer) {
Lei Zhang30d52cd2023-04-26 17:14:11 -0700270 id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer;
271
272 // If we are switching encoders, we would need to use a fence to synchronize "one or more
273 // resources across different passes within a command buffer."
274 // https://developer.apple.com/documentation/metal/resource_synchronization
275 uint64_t encoder_event_value = 0;
Lei Zhangd4aef982023-05-08 13:48:42 -0700276 if (command_buffer->state.blit_encoder) {
Lei Zhang0adc4612023-03-11 21:09:54 -0800277 iree_hal_metal_end_blit_encoder(command_buffer);
Lei Zhangd4aef982023-05-08 13:48:42 -0700278 encoder_event_value = command_buffer->state.next_encoder_event_value++;
279 [metal_handle encodeSignalEvent:command_buffer->state.encoder_event value:encoder_event_value];
Lei Zhang0adc4612023-03-11 21:09:54 -0800280 }
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800281
Lei Zhangd4aef982023-05-08 13:48:42 -0700282 if (!command_buffer->state.compute_encoder) {
Lei Zhang30d52cd2023-04-26 17:14:11 -0700283 if (encoder_event_value != 0) {
Lei Zhangd4aef982023-05-08 13:48:42 -0700284 [metal_handle encodeWaitForEvent:command_buffer->state.encoder_event
285 value:encoder_event_value];
Lei Zhang30d52cd2023-04-26 17:14:11 -0700286 }
287 @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation.
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800288 // We manage commands dependencies and insert barriers explicitly in IREE; so use the
289 // concurrent dispatch type for compute encoders.
Lei Zhangd4aef982023-05-08 13:48:42 -0700290 command_buffer->state.compute_encoder = [[metal_handle
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800291 computeCommandEncoderWithDispatchType:command_buffer->dispatch_type] retain]; // +1
292 }
293 }
Lei Zhang0adc4612023-03-11 21:09:54 -0800294
Lei Zhangd4aef982023-05-08 13:48:42 -0700295 return command_buffer->state.compute_encoder;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800296}
297
298static id<MTLBlitCommandEncoder> iree_hal_metal_get_or_begin_blit_encoder(
299 iree_hal_metal_command_buffer_t* command_buffer) {
Lei Zhang30d52cd2023-04-26 17:14:11 -0700300 id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer;
301
302 // If we are switching encoders, we would need to use a fence to synchronize "one or more
303 // resources across different passes within a command buffer."
304 // https://developer.apple.com/documentation/metal/resource_synchronization
305 uint64_t encoder_event_value = 0;
Lei Zhangd4aef982023-05-08 13:48:42 -0700306 if (command_buffer->state.compute_encoder) {
Lei Zhang0adc4612023-03-11 21:09:54 -0800307 iree_hal_metal_end_compute_encoder(command_buffer);
Lei Zhangd4aef982023-05-08 13:48:42 -0700308 encoder_event_value = command_buffer->state.next_encoder_event_value++;
309 [metal_handle encodeSignalEvent:command_buffer->state.encoder_event value:encoder_event_value];
Lei Zhang0adc4612023-03-11 21:09:54 -0800310 }
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800311
Lei Zhangd4aef982023-05-08 13:48:42 -0700312 if (!command_buffer->state.blit_encoder) {
Lei Zhang30d52cd2023-04-26 17:14:11 -0700313 if (encoder_event_value != 0) {
Lei Zhangd4aef982023-05-08 13:48:42 -0700314 [metal_handle encodeWaitForEvent:command_buffer->state.encoder_event
315 value:encoder_event_value];
Lei Zhang30d52cd2023-04-26 17:14:11 -0700316 }
317 @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation.
Lei Zhangd4aef982023-05-08 13:48:42 -0700318 command_buffer->state.blit_encoder = [[metal_handle blitCommandEncoder] retain]; // +1
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800319 }
320 }
Lei Zhang0adc4612023-03-11 21:09:54 -0800321
Lei Zhangd4aef982023-05-08 13:48:42 -0700322 return command_buffer->state.blit_encoder;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800323}
324
Lei Zhang028acfb2023-06-13 17:37:32 -0700325// Destroys the given |base_command_buffer| itself, without decreasing refcount in the shared
326// staging buffer yet.
327static void iree_hal_metal_command_buffer_destroy_internal(
328 iree_hal_command_buffer_t* base_command_buffer);
329
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800330iree_status_t iree_hal_metal_direct_command_buffer_create(
331 iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
332 iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity,
Lei Zhang3f64a112023-03-10 09:09:28 -0800333 iree_hal_metal_command_buffer_resource_reference_mode_t resource_reference_mode,
Lei Zhang54818a32023-06-10 16:01:14 -0700334 id<MTLCommandQueue> queue, iree_arena_block_pool_t* block_pool,
Lei Zhangf598fd22023-05-08 07:48:57 -0700335 iree_hal_metal_staging_buffer_t* staging_buffer,
Lei Zhang54818a32023-06-10 16:01:14 -0700336 iree_hal_metal_builtin_executable_t* builtin_executable, iree_allocator_t host_allocator,
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800337 iree_hal_command_buffer_t** out_command_buffer) {
338 IREE_ASSERT_ARGUMENT(device);
339 IREE_ASSERT_ARGUMENT(out_command_buffer);
340 IREE_ASSERT_TRUE(iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT));
341 IREE_ASSERT_TRUE(!iree_any_bit_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED));
342 *out_command_buffer = NULL;
343
344 if (binding_capacity > 0) {
345 // TODO(#10144): support indirect command buffers with binding tables.
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700346 return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "indirect command buffer not yet supported");
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800347 }
348
349 IREE_TRACE_ZONE_BEGIN(z0);
350
351 iree_hal_metal_command_buffer_t* command_buffer = NULL;
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700352 IREE_RETURN_AND_END_ZONE_IF_ERROR(
353 z0, iree_allocator_malloc(host_allocator, sizeof(*command_buffer), (void**)&command_buffer));
Lei Zhangd4aef982023-05-08 13:48:42 -0700354
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700355 iree_hal_command_buffer_initialize(device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY,
356 binding_capacity, &iree_hal_metal_command_buffer_vtable,
357 &command_buffer->base);
358 command_buffer->queue = [queue retain]; // +1
359 command_buffer->builtin_executable = builtin_executable;
360 iree_arena_initialize(block_pool, &command_buffer->arena);
361 command_buffer->staging_buffer = staging_buffer;
362 command_buffer->host_allocator = host_allocator;
363 iree_status_t status = iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set);
Lei Zhang028acfb2023-06-13 17:37:32 -0700364 if (iree_status_is_ok(status)) {
365 iree_hal_metal_command_segment_list_reset(&command_buffer->segments);
366 @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation.
367 // We track resource lifetime by ourselves in IREE; so just do unretained references to
368 // resources in Metal command buffer, which avoids overhead and gives better performance.
369 MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new]; // +1
370 descriptor.retainedReferences =
371 resource_reference_mode == IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED;
372 descriptor.errorOptions = MTLCommandBufferErrorOptionNone;
373 command_buffer->command_buffer =
374 [[queue commandBufferWithDescriptor:descriptor] retain]; // +1
375 [descriptor release]; // -1
376 }
377 const iree_hal_metal_device_params_t* params = iree_hal_metal_device_params(device);
378 command_buffer->dispatch_type =
379 params->command_dispatch_type == IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT
380 ? MTLDispatchTypeConcurrent
381 : MTLDispatchTypeSerial;
382 command_buffer->state.compute_encoder = nil;
383 command_buffer->state.blit_encoder = nil;
384 command_buffer->state.encoder_event = [queue.device newEvent]; // +1
385 command_buffer->state.next_encoder_event_value = 1;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800386 }
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700387
Lei Zhang028acfb2023-06-13 17:37:32 -0700388 if (iree_status_is_ok(status)) {
389 *out_command_buffer = &command_buffer->base;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800390
Lei Zhang028acfb2023-06-13 17:37:32 -0700391 // Increase command buffer refcount in the shared staging buffer. We tie this to the command
392 // buffer's lifetime to avoid resource leak.
393 iree_hal_metal_staging_buffer_increase_refcount(staging_buffer);
394 } else {
395 iree_hal_metal_command_buffer_destroy_internal(&command_buffer->base);
396 }
Lei Zhangeba9f5a2023-06-11 09:38:20 -0700397
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800398 IREE_TRACE_ZONE_END(z0);
399 return status;
400}
401
Lei Zhang028acfb2023-06-13 17:37:32 -0700402static void iree_hal_metal_command_buffer_destroy_internal(
403 iree_hal_command_buffer_t* base_command_buffer) {
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800404 iree_hal_metal_command_buffer_t* command_buffer =
405 iree_hal_metal_command_buffer_cast(base_command_buffer);
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800406
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700407 iree_hal_metal_command_buffer_reset(command_buffer);
Lei Zhangd4aef982023-05-08 13:48:42 -0700408 [command_buffer->state.encoder_event release]; // -1
409 IREE_ASSERT_EQ(command_buffer->state.compute_encoder, nil);
410 IREE_ASSERT_EQ(command_buffer->state.blit_encoder, nil);
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800411 [command_buffer->command_buffer release]; // -1
412 [command_buffer->queue release]; // -1
413 iree_hal_resource_set_free(command_buffer->resource_set);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700414 iree_arena_deinitialize(&command_buffer->arena);
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800415 iree_allocator_free(command_buffer->host_allocator, command_buffer);
Lei Zhang028acfb2023-06-13 17:37:32 -0700416}
417
418static void iree_hal_metal_command_buffer_destroy(iree_hal_command_buffer_t* base_command_buffer) {
419 iree_hal_metal_command_buffer_t* command_buffer =
420 iree_hal_metal_command_buffer_cast(base_command_buffer);
421 IREE_TRACE_ZONE_BEGIN(z0);
422
423 // Decrease command buffer refcount in the shared staging buffer, and potentially reclaim
424 // resources. We tie this to the command buffer's lifetime to avoid resource leak.
425 if (command_buffer->staging_buffer) {
426 iree_hal_metal_staging_buffer_decrease_refcount(command_buffer->staging_buffer);
427 }
428
429 iree_hal_metal_command_buffer_destroy_internal(base_command_buffer);
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800430
431 IREE_TRACE_ZONE_END(z0);
432}
433
434bool iree_hal_metal_command_buffer_isa(iree_hal_command_buffer_t* command_buffer) {
435 return iree_hal_resource_is(&command_buffer->resource, &iree_hal_metal_command_buffer_vtable);
436}
437
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800438static void iree_hal_metal_command_buffer_begin_debug_group(
439 iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label,
440 iree_hal_label_color_t label_color, const iree_hal_label_location_t* location) {
441 // TODO(antiagainst): implement support for debug group
442}
443
444static void iree_hal_metal_command_buffer_end_debug_group(
445 iree_hal_command_buffer_t* base_command_buffer) {
446 // TODO(antiagainst): implement support for debug group
447}
448
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700449static iree_status_t iree_hal_metal_command_buffer_prepare_barrier(
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800450 iree_hal_command_buffer_t* base_command_buffer, iree_hal_execution_stage_t source_stage_mask,
451 iree_hal_execution_stage_t target_stage_mask, iree_hal_execution_barrier_flags_t flags,
452 iree_host_size_t memory_barrier_count, const iree_hal_memory_barrier_t* memory_barriers,
453 iree_host_size_t buffer_barrier_count, const iree_hal_buffer_barrier_t* buffer_barriers) {
454 if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) ||
455 iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) {
456 return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "barrier involving host not yet supported");
457 }
458
459 if (flags != IREE_HAL_EXECUTION_BARRIER_FLAG_NONE) {
460 return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "non-zero barrier flag not yet supported");
461 }
462
463 iree_hal_metal_command_buffer_t* command_buffer =
464 iree_hal_metal_command_buffer_cast(base_command_buffer);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700465 IREE_TRACE_ZONE_BEGIN(z0);
Lei Zhangffb40b12023-04-27 08:18:51 -0700466
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700467 // Allocate the command segment and keep track of all necessary API data.
468 uint8_t* storage_base = NULL;
469 iree_hal_metal_command_segment_t* segment = NULL;
470 iree_host_size_t buffer_barrier_length = buffer_barrier_count * sizeof(iree_hal_buffer_barrier_t);
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700471 IREE_RETURN_AND_END_ZONE_IF_ERROR(
472 z0, iree_arena_allocate(&command_buffer->arena, sizeof(*segment) + buffer_barrier_length,
473 (void**)&storage_base));
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700474
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700475 // Copy the buffer barriers to the end of the current segments for later access. We don't copy
476 // memory barriers because in Metal there is only coarse-grained full memory barrier affecting
477 // all buffers, regardless of the fine-grained details from IREE HAL barriers.
478 uint8_t* barrier_ptr = storage_base + sizeof(*segment);
479 memcpy(barrier_ptr, (const uint8_t*)buffer_barriers, buffer_barrier_length);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700480
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700481 // Compose and push the barrier segment.
482 segment = (iree_hal_metal_command_segment_t*)storage_base;
483 segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER;
484 iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700485
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700486 segment->barrier.memory_barrier_count = memory_barrier_count;
487 segment->barrier.buffer_barrier_count = buffer_barrier_count;
488 segment->barrier.buffer_barriers = (const iree_hal_buffer_barrier_t*)barrier_ptr;
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700489
490 IREE_TRACE_ZONE_END(z0);
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700491 return iree_ok_status();
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700492}
493
494static iree_status_t iree_hal_metal_command_segment_record_barrier(
495 iree_hal_metal_command_buffer_t* command_buffer, iree_hal_metal_barrier_segment_t* segment) {
Lei Zhange32cfa62023-05-08 13:40:06 -0700496 // TODO(antiagainst): Analyze segments before and after to optimize barriers, e.g., switching
497 // encoders would require its own synchronization; so we don't need extract barriers in the
498 // middle.
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700499 if (segment->memory_barrier_count == 0 && segment->buffer_barrier_count == 0) {
Lei Zhangffb40b12023-04-27 08:18:51 -0700500 // There is no direct corresponding APIs for execution only barrier in Metal. We just signal and
501 // wait on the same value of a MTLEvent here.
502 iree_hal_metal_end_blit_encoder(command_buffer);
503 iree_hal_metal_end_compute_encoder(command_buffer);
504 id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer;
Lei Zhangd4aef982023-05-08 13:48:42 -0700505 uint64_t event_value = command_buffer->state.next_encoder_event_value++;
506 [metal_handle encodeSignalEvent:command_buffer->state.encoder_event value:event_value];
507 [metal_handle encodeWaitForEvent:command_buffer->state.encoder_event value:event_value];
Lei Zhangffb40b12023-04-27 08:18:51 -0700508 return iree_ok_status();
509 }
510
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800511 id<MTLComputeCommandEncoder> encoder =
512 iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
513
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700514 if (segment->memory_barrier_count != 0) {
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800515 // If there is a memory barrier specified, we have to place a catch-all barrier for all buffers.
516 // Metal does not provide a more fine-grained control here.
517 [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
518 return iree_ok_status();
519 }
520
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700521 if (segment->buffer_barrier_count != 0) {
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800522 // But we do have the option to specify a list of buffers to synchronize if only buffer barriers
523 // are specified.
524 id<MTLResource>* resources =
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700525 (id<MTLResource>*)iree_alloca(sizeof(id<MTLResource>) * segment->buffer_barrier_count);
526 for (iree_host_size_t i = 0; i < segment->buffer_barrier_count; ++i) {
527 resources[i] = iree_hal_metal_buffer_handle(
528 iree_hal_buffer_allocated_buffer(segment->buffer_barriers[i].buffer));
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800529 }
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700530 [encoder memoryBarrierWithResources:resources count:segment->buffer_barrier_count];
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800531 }
532 return iree_ok_status();
533}
534
535static iree_status_t iree_hal_metal_command_buffer_signal_event(
536 iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
537 iree_hal_execution_stage_t source_stage_mask) {
Lei Zhangf231e812023-04-17 14:18:23 -0700538 return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800539}
540
541static iree_status_t iree_hal_metal_command_buffer_reset_event(
542 iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
543 iree_hal_execution_stage_t source_stage_mask) {
Lei Zhangf231e812023-04-17 14:18:23 -0700544 return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800545}
546
547static iree_status_t iree_hal_metal_command_buffer_wait_events(
548 iree_hal_command_buffer_t* base_command_buffer, iree_host_size_t event_count,
549 const iree_hal_event_t** events, iree_hal_execution_stage_t source_stage_mask,
550 iree_hal_execution_stage_t target_stage_mask, iree_host_size_t memory_barrier_count,
551 const iree_hal_memory_barrier_t* memory_barriers, iree_host_size_t buffer_barrier_count,
552 const iree_hal_buffer_barrier_t* buffer_barriers) {
Lei Zhangf231e812023-04-17 14:18:23 -0700553 return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported");
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800554}
555
556static iree_status_t iree_hal_metal_command_buffer_discard_buffer(
557 iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) {
558 // This is a hint to the device and we have nothing to do for Metal.
559 return iree_ok_status();
560}
561
Lei Zhang182db9d2023-02-20 13:38:12 -0800562// Fills |value| with the duplicated single byte value and return true if the given |pattern| has
563// duplicated values for each of its |pattern_length| bytes.
564static bool iree_hal_metal_get_duplicated_single_byte_value(const void* pattern,
565 size_t pattern_length, uint8_t* value) {
Lei Zhang063e1fa2023-02-12 19:27:54 -0800566 switch (pattern_length) {
567 case 1: {
568 *value = *(uint8_t*)pattern;
Lei Zhang182db9d2023-02-20 13:38:12 -0800569 return true;
Lei Zhang063e1fa2023-02-12 19:27:54 -0800570 }
571 case 2: {
572 uint16_t two_bytes = *(uint16_t*)pattern;
573 uint16_t byte0 = two_bytes & 0xffu;
574 uint16_t byte1 = two_bytes >> 8u;
575 if (byte0 == byte1) {
576 *value = (int8_t)byte0;
Lei Zhang182db9d2023-02-20 13:38:12 -0800577 return true;
Lei Zhang063e1fa2023-02-12 19:27:54 -0800578 }
Lei Zhang182db9d2023-02-20 13:38:12 -0800579 break;
Lei Zhang063e1fa2023-02-12 19:27:54 -0800580 }
581 case 4: {
582 uint32_t four_bytes = *(uint32_t*)pattern;
583 uint32_t byte0 = four_bytes & 0xffu;
584 uint32_t byte1 = (four_bytes >> 8u) & 0xffu;
585 uint32_t byte2 = (four_bytes >> 16u) & 0xffu;
586 uint32_t byte3 = four_bytes >> 24u;
587 if (byte0 == byte1 && byte0 == byte2 && byte0 == byte3) {
588 *value = (int8_t)byte0;
Lei Zhang182db9d2023-02-20 13:38:12 -0800589 return true;
Lei Zhang063e1fa2023-02-12 19:27:54 -0800590 }
Lei Zhang182db9d2023-02-20 13:38:12 -0800591 break;
Lei Zhang063e1fa2023-02-12 19:27:54 -0800592 }
593 default:
594 break;
595 }
Lei Zhang182db9d2023-02-20 13:38:12 -0800596 return false;
597}
598
Lei Zhange8679ad2023-06-11 08:45:31 -0700599// Duplicates the given |pattern| into 4-bytes and returns the value.
600static uint32_t iree_hal_metal_duplicate_to_four_byte_value(const void* pattern,
601 size_t pattern_length) {
602 if (pattern_length == 1) {
603 uint8_t single_byte = *(uint8_t*)pattern;
604 uint32_t value = (uint32_t)single_byte;
605 value |= (value << 8u);
606 value |= (value << 16u);
607 return value;
Lei Zhang182db9d2023-02-20 13:38:12 -0800608 }
Lei Zhange8679ad2023-06-11 08:45:31 -0700609
610 if (pattern_length == 2) {
611 uint16_t two_bytes = *(uint16_t*)pattern;
612 uint32_t value = (uint32_t)two_bytes;
613 value |= (value << 16u);
614 return value;
615 }
616
617 IREE_ASSERT(pattern_length == 4);
618 return *(uint32_t*)pattern;
Lei Zhang063e1fa2023-02-12 19:27:54 -0800619}
620
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700621static iree_status_t iree_hal_metal_command_buffer_prepare_fill_buffer(
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800622 iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* target_buffer,
623 iree_device_size_t target_offset, iree_device_size_t length, const void* pattern,
624 iree_host_size_t pattern_length) {
Lei Zhang063e1fa2023-02-12 19:27:54 -0800625 iree_hal_metal_command_buffer_t* command_buffer =
626 iree_hal_metal_command_buffer_cast(base_command_buffer);
627 IREE_TRACE_ZONE_BEGIN(z0);
628
Lei Zhang182db9d2023-02-20 13:38:12 -0800629 id<MTLBuffer> target_device_buffer =
630 iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer));
Lei Zhang063e1fa2023-02-12 19:27:54 -0800631 target_offset += iree_hal_buffer_byte_offset(target_buffer);
632
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700633 // Allocate the command segment and keep track of all necessary API data.
634 uint8_t* storage_base = NULL;
635 iree_hal_metal_command_segment_t* segment = NULL;
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700636 IREE_RETURN_AND_END_ZONE_IF_ERROR(
637 z0, iree_arena_allocate(&command_buffer->arena, sizeof(*segment) + pattern_length,
638 (void**)&storage_base));
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700639
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700640 // Copy the patttern to the end of the segment for later access.
641 uint8_t* pattern_ptr = storage_base + sizeof(*segment);
642 memcpy(pattern_ptr, (const uint8_t*)pattern, pattern_length);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700643
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700644 // Compose and push the fill buffer segment.
645 segment = (iree_hal_metal_command_segment_t*)storage_base;
646 segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER;
647 iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700648
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700649 segment->fill_buffer.target_buffer = target_device_buffer;
650 segment->fill_buffer.target_offset = target_offset;
651 segment->fill_buffer.length = length;
652 segment->fill_buffer.pattern = (const void*)pattern_ptr;
653 segment->fill_buffer.pattern_length = pattern_length;
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700654
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700655 iree_status_t status =
656 iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700657
658 IREE_TRACE_ZONE_END(z0);
659 return status;
660}
661
662static iree_status_t iree_hal_metal_command_segment_record_fill_buffer(
663 iree_hal_metal_command_buffer_t* command_buffer,
664 iree_hal_metal_fill_buffer_segment_t* segment) {
665 IREE_TRACE_ZONE_BEGIN(z0);
666
Lei Zhang063e1fa2023-02-12 19:27:54 -0800667 // Note that fillBuffer:range:value: only accepts a single byte as the pattern but FillBuffer
668 // can accept 1/2/4 bytes. If the pattern itself contains repeated bytes, we can call into
669 // fillBuffer:range:value:. Otherwise we need to emulate the support.
Lei Zhange8679ad2023-06-11 08:45:31 -0700670 uint8_t pattern_1byte = 0u;
671
672 // Per the spec for fillBuffer:range:value: "The alignment and length of the range must both be a
673 // multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS."
674#if defined(IREE_PLATFORM_MACOS)
675 const bool can_use_metal_api = segment->target_offset % 4 == 0 && segment->length % 4 == 0 &&
676 iree_hal_metal_get_duplicated_single_byte_value(
677 segment->pattern, segment->pattern_length, &pattern_1byte);
678#else
679 const bool can_use_metal_api = iree_hal_metal_get_duplicated_single_byte_value(
680 segment->pattern, segment->pattern_length, &pattern_1byte);
681#endif
Lei Zhang182db9d2023-02-20 13:38:12 -0800682
Lei Zhang182db9d2023-02-20 13:38:12 -0800683 if (can_use_metal_api) {
Lei Zhang063e1fa2023-02-12 19:27:54 -0800684 id<MTLBlitCommandEncoder> encoder = iree_hal_metal_get_or_begin_blit_encoder(command_buffer);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700685 [encoder fillBuffer:segment->target_buffer
686 range:NSMakeRange(segment->target_offset, segment->length)
Lei Zhange8679ad2023-06-11 08:45:31 -0700687 value:pattern_1byte];
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700688 IREE_TRACE_ZONE_END(z0);
689 return iree_ok_status();
690 }
691
692 id<MTLComputeCommandEncoder> compute_encoder =
693 iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
Lei Zhange8679ad2023-06-11 08:45:31 -0700694 uint32_t pattern_4byte =
695 iree_hal_metal_duplicate_to_four_byte_value(segment->pattern, segment->pattern_length);
696 iree_status_t status = iree_hal_metal_builtin_executable_fill_buffer(
697 command_buffer->builtin_executable, compute_encoder, segment->target_buffer,
698 segment->target_offset, segment->length, pattern_4byte);
Lei Zhang063e1fa2023-02-12 19:27:54 -0800699
700 IREE_TRACE_ZONE_END(z0);
701 return status;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800702}
703
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700704static iree_status_t iree_hal_metal_command_segment_create_copy_buffer(
Lei Zhang59c46992023-02-25 17:05:29 -0800705 iree_hal_metal_command_buffer_t* command_buffer, id<MTLBuffer> source_device_buffer,
706 iree_device_size_t source_offset, id<MTLBuffer> target_device_buffer,
707 iree_device_size_t target_offset, iree_device_size_t length) {
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700708 IREE_TRACE_ZONE_BEGIN(z0);
709
710 // Allocate the command segment and keep track of all necessary API data.
711 uint8_t* storage_base = NULL;
712 iree_hal_metal_command_segment_t* segment = NULL;
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700713 IREE_RETURN_AND_END_ZONE_IF_ERROR(
714 z0, iree_arena_allocate(&command_buffer->arena, sizeof(*segment), (void**)&storage_base));
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700715
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700716 // Compose and push the barrier segment.
717 segment = (iree_hal_metal_command_segment_t*)storage_base;
718 segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER;
719 iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700720
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700721 segment->copy_buffer.source_buffer = source_device_buffer;
722 segment->copy_buffer.source_offset = source_offset;
723 segment->copy_buffer.target_buffer = target_device_buffer;
724 segment->copy_buffer.target_offset = target_offset;
725 segment->copy_buffer.length = length;
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700726
727 IREE_TRACE_ZONE_END(z0);
Lei Zhang52a8d0c2023-06-10 20:17:07 -0700728 return iree_ok_status();
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700729}
730
731static iree_status_t iree_hal_metal_command_segment_record_copy_buffer(
732 iree_hal_metal_command_buffer_t* command_buffer,
733 iree_hal_metal_copy_buffer_segment_t* segment) {
734 IREE_TRACE_ZONE_BEGIN(z0);
735
Lei Zhang59c46992023-02-25 17:05:29 -0800736 // Per the spec for copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size, the source/target
737 // offset and length must be a multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS.
738#if defined(IREE_PLATFORM_MACOS)
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700739 bool can_use_metal_api = segment->source_offset % 4 == 0 && segment->target_offset % 4 == 0 &&
740 segment->length % 4 == 0;
Lei Zhang59c46992023-02-25 17:05:29 -0800741#else
742 bool can_use_metal_api = true;
743#endif
744
745 iree_status_t status = iree_ok_status();
746 if (can_use_metal_api) {
747 id<MTLBlitCommandEncoder> encoder = iree_hal_metal_get_or_begin_blit_encoder(command_buffer);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700748 [encoder copyFromBuffer:segment->source_buffer
749 sourceOffset:segment->source_offset
750 toBuffer:segment->target_buffer
751 destinationOffset:segment->target_offset
752 size:segment->length];
Lei Zhang59c46992023-02-25 17:05:29 -0800753 } else {
754 id<MTLComputeCommandEncoder> encoder =
755 iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
756 status = iree_hal_metal_builtin_executable_copy_buffer(
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700757 command_buffer->builtin_executable, encoder, segment->source_buffer, segment->source_offset,
758 segment->target_buffer, segment->target_offset, segment->length);
Lei Zhang59c46992023-02-25 17:05:29 -0800759 }
760
Lei Zhangf598fd22023-05-08 07:48:57 -0700761 IREE_TRACE_ZONE_END(z0);
Lei Zhang59c46992023-02-25 17:05:29 -0800762 return status;
763}
764
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700765static iree_status_t iree_hal_metal_command_buffer_prepare_update_buffer(
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800766 iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
767 iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
768 iree_device_size_t target_offset, iree_device_size_t length) {
Lei Zhang063e1fa2023-02-12 19:27:54 -0800769 iree_hal_metal_command_buffer_t* command_buffer =
770 iree_hal_metal_command_buffer_cast(base_command_buffer);
771 IREE_TRACE_ZONE_BEGIN(z0);
772
Lei Zhangf598fd22023-05-08 07:48:57 -0700773 // There are no direct corresponding APIs in Metal. We update the source buffer data to the
774 // staging buffer and then copy over.
775
776 iree_const_byte_span_t source_data_span =
777 iree_make_const_byte_span((uint8_t*)source_buffer + source_offset, length);
778 uint32_t offset = 0;
779 IREE_RETURN_AND_END_ZONE_IF_ERROR(
780 z0, iree_hal_metal_staging_buffer_append(command_buffer->staging_buffer, source_data_span,
781 /*alignment=*/4, &offset));
Lei Zhang063e1fa2023-02-12 19:27:54 -0800782
783 IREE_RETURN_AND_END_ZONE_IF_ERROR(
784 z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer));
785
786 id<MTLBuffer> target_device_buffer =
787 iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer));
788 target_offset += iree_hal_buffer_byte_offset(target_buffer);
789
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700790 iree_status_t status = iree_hal_metal_command_segment_create_copy_buffer(
Lei Zhangf598fd22023-05-08 07:48:57 -0700791 command_buffer, command_buffer->staging_buffer->metal_buffer, offset, target_device_buffer,
792 target_offset, length);
Lei Zhang063e1fa2023-02-12 19:27:54 -0800793
794 IREE_TRACE_ZONE_END(z0);
Lei Zhang59c46992023-02-25 17:05:29 -0800795 return status;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800796}
797
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700798static iree_status_t iree_hal_metal_command_buffer_prepare_copy_buffer(
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800799 iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* source_buffer,
800 iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer,
801 iree_device_size_t target_offset, iree_device_size_t length) {
Lei Zhang063e1fa2023-02-12 19:27:54 -0800802 iree_hal_metal_command_buffer_t* command_buffer =
803 iree_hal_metal_command_buffer_cast(base_command_buffer);
804 IREE_TRACE_ZONE_BEGIN(z0);
805
Lei Zhang063e1fa2023-02-12 19:27:54 -0800806 const iree_hal_buffer_t* buffers[2] = {source_buffer, target_buffer};
807 IREE_RETURN_AND_END_ZONE_IF_ERROR(
808 z0, iree_hal_resource_set_insert(command_buffer->resource_set, 2, buffers));
809
810 id<MTLBuffer> source_device_buffer =
811 iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(source_buffer));
812 id<MTLBuffer> target_device_buffer =
813 iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer));
814
815 source_offset += iree_hal_buffer_byte_offset(source_buffer);
816 target_offset += iree_hal_buffer_byte_offset(target_buffer);
817
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700818 iree_status_t status = iree_hal_metal_command_segment_create_copy_buffer(
Lei Zhang59c46992023-02-25 17:05:29 -0800819 command_buffer, source_device_buffer, source_offset, target_device_buffer, target_offset,
820 length);
Lei Zhang063e1fa2023-02-12 19:27:54 -0800821
Lei Zhang59c46992023-02-25 17:05:29 -0800822 IREE_TRACE_ZONE_END(z0);
823 return status;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800824}
825
826static iree_status_t iree_hal_metal_command_buffer_collective(
827 iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
828 iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_binding_t send_binding,
829 iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
830 return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "collectives not yet supported");
831}
832
833static iree_status_t iree_hal_metal_command_buffer_push_constants(
834 iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout,
835 iree_host_size_t offset, const void* values, iree_host_size_t values_length) {
Lei Zhang30f6a392023-02-12 16:42:27 -0800836 iree_hal_metal_command_buffer_t* command_buffer =
837 iree_hal_metal_command_buffer_cast(base_command_buffer);
838
839 // "Binding a pipeline with a layout that is not compatible with the push constant layout does not
840 // disturb the push constant values." So we don't need to check whether the pipeline layout
841 // compatibility and invalidate existing values.
Lei Zhang30f6a392023-02-12 16:42:27 -0800842
Lei Zhangd4aef982023-05-08 13:48:42 -0700843 if (IREE_UNLIKELY(offset + values_length >= sizeof(command_buffer->state.push_constants))) {
Lei Zhang30f6a392023-02-12 16:42:27 -0800844 return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
845 "push constant range [%zu, %zu) out of range", offset,
846 offset + values_length);
847 }
848
Lei Zhangd4aef982023-05-08 13:48:42 -0700849 memcpy((uint8_t*)&command_buffer->state.push_constants + offset, values, values_length);
Lei Zhang30f6a392023-02-12 16:42:27 -0800850
851 return iree_ok_status();
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800852}
853
Lei Zhang0ec791e2023-05-07 22:15:34 -0700854static inline MTLResourceUsage iree_hal_metal_get_metal_resource_usage(
Lei Zhang54818a32023-06-10 16:01:14 -0700855 const iree_hal_descriptor_set_layout_binding_t* binding) {
Lei Zhang0ec791e2023-05-07 22:15:34 -0700856 MTLResourceUsage usage = MTLResourceUsageRead;
857 if (binding->flags != IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY) usage |= MTLResourceUsageWrite;
858 return usage;
859}
860
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800861static iree_status_t iree_hal_metal_command_buffer_push_descriptor_set(
862 iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout,
863 uint32_t set, iree_host_size_t binding_count,
864 const iree_hal_descriptor_set_binding_t* bindings) {
865 iree_hal_metal_command_buffer_t* command_buffer =
866 iree_hal_metal_command_buffer_cast(base_command_buffer);
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800867
Lei Zhangd90e80f2023-05-13 21:32:29 -0700868 if (binding_count > IREE_HAL_METAL_MAX_DESCRIPTOR_SET_BINDING_COUNT) {
869 return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
870 "exceeded available binding slots for push descriptor set #%u; "
871 "requested %lu vs. maximal %d",
872 set, binding_count, IREE_HAL_METAL_MAX_DESCRIPTOR_SET_BINDING_COUNT);
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800873 }
874
Lei Zhangd90e80f2023-05-13 21:32:29 -0700875 IREE_TRACE_ZONE_BEGIN(z0);
876
Lei Zhangd90e80f2023-05-13 21:32:29 -0700877 IREE_ASSERT(set < IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX);
Lei Zhang54818a32023-06-10 16:01:14 -0700878 const iree_hal_descriptor_set_layout_t* set_layout =
Lei Zhang0ec791e2023-05-07 22:15:34 -0700879 iree_hal_metal_pipeline_layout_descriptor_set_layout(pipeline_layout, set);
Lei Zhangd90e80f2023-05-13 21:32:29 -0700880 iree_hal_metal_descriptor_t* descriptors = command_buffer->state.descriptor_sets[set].bindings;
Lei Zhang0ec791e2023-05-07 22:15:34 -0700881
Lei Zhangd90e80f2023-05-13 21:32:29 -0700882 // Update descriptors in the current set.
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800883 for (iree_host_size_t i = 0; i < binding_count; ++i) {
Lei Zhangd90e80f2023-05-13 21:32:29 -0700884 iree_hal_metal_descriptor_t* descriptor = &descriptors[i];
Lei Zhang0ec791e2023-05-07 22:15:34 -0700885
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800886 descriptor->set = set;
887 descriptor->binding = bindings[i].binding;
888 descriptor->buffer = bindings[i].buffer;
889 descriptor->offset = bindings[i].offset;
Lei Zhang0ec791e2023-05-07 22:15:34 -0700890
Lei Zhang54818a32023-06-10 16:01:14 -0700891 const iree_hal_descriptor_set_layout_binding_t* binding_params =
Lei Zhang0ec791e2023-05-07 22:15:34 -0700892 iree_hal_metal_descriptor_set_layout_binding(set_layout, descriptor->binding);
893 descriptor->usage = iree_hal_metal_get_metal_resource_usage(binding_params);
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800894 }
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800895
896 // Retain all buffers bound in this descriptor set.
897 for (iree_host_size_t i = 0; i < binding_count; ++i) {
898 if (bindings[i].buffer) {
899 IREE_RETURN_AND_END_ZONE_IF_ERROR(
900 z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &bindings[i].buffer));
901 }
902 }
903
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800904 IREE_RETURN_AND_END_ZONE_IF_ERROR(
905 z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &pipeline_layout));
906
907 IREE_TRACE_ZONE_END(z0);
908 return iree_ok_status();
909}
910
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800911// Prepares kernels and argument buffers needed for kernel dispatches.
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700912static iree_status_t iree_hal_metal_command_segment_create_dispatch(
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800913 iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700914 int32_t entry_point, iree_hal_metal_dispatch_segment_t** out_segment) {
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800915 iree_hal_metal_command_buffer_t* command_buffer =
916 iree_hal_metal_command_buffer_cast(base_command_buffer);
917 IREE_TRACE_ZONE_BEGIN(z0);
918
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700919 IREE_RETURN_AND_END_ZONE_IF_ERROR(
920 z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &executable));
921
922 iree_hal_metal_kernel_params_t kernel_params;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800923 IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_kernel_library_entry_point_kernel_params(
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700924 executable, entry_point, &kernel_params));
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800925
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700926 // Allocate the command segment and keep track of all necessary API data.
927 uint8_t* storage_base = NULL;
928 iree_hal_metal_command_segment_t* segment = NULL;
Lei Zhang3097b3a2023-06-11 13:32:44 -0700929 const iree_host_size_t set_count =
930 iree_hal_metal_pipeline_layout_descriptor_set_count(kernel_params.layout);
Lei Zhangd90e80f2023-05-13 21:32:29 -0700931 iree_host_size_t descriptor_count = 0;
932 // Calculate the total number of bindings across all descriptor sets.
Lei Zhang3097b3a2023-06-11 13:32:44 -0700933 for (iree_host_size_t i = 0; i < set_count; ++i) {
934 const iree_hal_descriptor_set_layout_t* set_layout =
935 iree_hal_metal_pipeline_layout_descriptor_set_layout(kernel_params.layout, i);
936 descriptor_count += iree_hal_metal_descriptor_set_layout_binding_count(set_layout);
Lei Zhangd90e80f2023-05-13 21:32:29 -0700937 }
Lei Zhang0ec791e2023-05-07 22:15:34 -0700938 iree_host_size_t descriptor_length = descriptor_count * sizeof(iree_hal_metal_descriptor_t);
Lei Zhangd7fb9812023-06-11 09:28:07 -0700939 iree_host_size_t push_constant_count =
940 iree_hal_metal_pipeline_layout_push_constant_count(kernel_params.layout);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700941 iree_host_size_t push_constant_length = push_constant_count * sizeof(int32_t);
Lei Zhang0ec791e2023-05-07 22:15:34 -0700942 iree_host_size_t total_size = sizeof(*segment) + descriptor_length + push_constant_length;
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700943 IREE_RETURN_AND_END_ZONE_IF_ERROR(
944 z0, iree_arena_allocate(&command_buffer->arena, total_size, (void**)&storage_base));
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800945
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700946 // Compose and push the dispatch segment.
947 segment = (iree_hal_metal_command_segment_t*)storage_base;
948 memset(segment, 0, sizeof(*segment));
949 segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH;
950 iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment);
951
952 segment->dispatch.kernel_params = kernel_params;
953
Lei Zhangd90e80f2023-05-13 21:32:29 -0700954 // Copy descriptors from all sets to the end of the current segment for later access.
Lei Zhang0ec791e2023-05-07 22:15:34 -0700955 segment->dispatch.descriptor_count = descriptor_count;
956 uint8_t* descriptor_ptr = storage_base + sizeof(*segment);
Lei Zhang0ec791e2023-05-07 22:15:34 -0700957 segment->dispatch.descriptors = (iree_hal_metal_descriptor_t*)descriptor_ptr;
Lei Zhang3097b3a2023-06-11 13:32:44 -0700958 for (iree_host_size_t i = 0; i < set_count; ++i) {
959 const iree_hal_descriptor_set_layout_t* set_layout =
960 iree_hal_metal_pipeline_layout_descriptor_set_layout(kernel_params.layout, i);
961 iree_host_size_t binding_count = iree_hal_metal_descriptor_set_layout_binding_count(set_layout);
962 iree_host_size_t current_size = binding_count * sizeof(iree_hal_metal_descriptor_t);
Lei Zhangd90e80f2023-05-13 21:32:29 -0700963 memcpy(descriptor_ptr, command_buffer->state.descriptor_sets[i].bindings, current_size);
964 descriptor_ptr += current_size;
965 }
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800966
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700967 // Copy push constants to the end of the current segment for later access.
968 segment->dispatch.push_constant_count = push_constant_count;
Lei Zhang0ec791e2023-05-07 22:15:34 -0700969 uint8_t* push_constant_ptr = storage_base + sizeof(*segment) + descriptor_length;
Lei Zhangd90e80f2023-05-13 21:32:29 -0700970 segment->dispatch.push_constants = (int32_t*)push_constant_ptr;
Lei Zhangd4aef982023-05-08 13:48:42 -0700971 memcpy(push_constant_ptr, (const uint8_t*)command_buffer->state.push_constants,
972 push_constant_length);
Lei Zhang30f6a392023-02-12 16:42:27 -0800973
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700974 *out_segment = &segment->dispatch;
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800975 IREE_TRACE_ZONE_END(z0);
Lei Zhang4307ba22023-05-07 10:40:23 -0700976 return iree_ok_status();
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800977}
978
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700979static iree_status_t iree_hal_metal_command_segment_record_dispatch(
980 iree_hal_metal_command_buffer_t* command_buffer, iree_hal_metal_dispatch_segment_t* segment) {
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800981 IREE_TRACE_ZONE_BEGIN(z0);
982
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700983 // Set the compute kernel to dispatch.
Lei Zhangdf1e9a22023-02-12 12:08:00 -0800984 id<MTLComputeCommandEncoder> compute_encoder =
985 iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700986 [compute_encoder setComputePipelineState:segment->kernel_params.pso];
987
988 // Record push constants.
989 if (segment->push_constant_count != 0) {
990 [compute_encoder setBytes:(void*)segment->push_constants
991 length:segment->push_constant_count * sizeof(int32_t)
992 atIndex:IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX];
993 }
994
Lei Zhang0ec791e2023-05-07 22:15:34 -0700995 // Record argument buffers for all descriptors and record buffer usages.
996 iree_hal_metal_descriptor_t* descriptors = segment->descriptors;
Lei Zhangf598fd22023-05-08 07:48:57 -0700997 for (iree_host_size_t i = 0; i < segment->descriptor_count;) {
Lei Zhang0ec791e2023-05-07 22:15:34 -0700998 uint32_t current_set = descriptors[i].set;
Lei Zhangc0ad0ea2023-05-06 18:02:02 -0700999
Lei Zhang0ec791e2023-05-07 22:15:34 -07001000 // Build argument encoder and argument buffer for the current descriptor set.
Lei Zhange32cfa62023-05-08 13:40:06 -07001001 // TODO(antiagainst): Use a cache layer to cache and reuse argument buffers with the same
1002 // content, to avoid duplicating overhead.
Lei Zhangf598fd22023-05-08 07:48:57 -07001003 id<MTLBuffer> argument_buffer = command_buffer->staging_buffer->metal_buffer;
1004 id<MTLArgumentEncoder> argument_encoder =
1005 [segment->kernel_params.function newArgumentEncoderWithBufferIndex:current_set]; // +1
1006 IREE_ASSERT(argument_encoder != nil);
1007
1008 // Reserve space for the argument buffer from shared staging buffer.
1009 iree_byte_span_t reservation;
1010 uint32_t argument_buffer_offset;
Lei Zhang0ec791e2023-05-07 22:15:34 -07001011 IREE_RETURN_AND_END_ZONE_IF_ERROR(
Lei Zhangf598fd22023-05-08 07:48:57 -07001012 z0, iree_hal_metal_staging_buffer_reserve(
1013 command_buffer->staging_buffer, argument_encoder.encodedLength,
1014 argument_encoder.alignment, &reservation, &argument_buffer_offset));
1015 [argument_encoder setArgumentBuffer:argument_buffer offset:argument_buffer_offset];
Lei Zhang0ec791e2023-05-07 22:15:34 -07001016
1017 // Now record all bound buffers belonging to the current set into the argument buffer.
1018 for (; i < segment->descriptor_count && descriptors[i].set == current_set; ++i) {
1019 uint32_t current_binding = descriptors[i].binding;
1020 id<MTLBuffer> current_buffer =
1021 iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer));
1022 iree_host_size_t offset =
1023 iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset;
1024 [argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding];
1025
1026 // Also record buffer usages.
1027 [compute_encoder useResource:current_buffer usage:descriptors[i].usage];
1028 }
1029 // Record the argument buffer.
Lei Zhangf598fd22023-05-08 07:48:57 -07001030 [compute_encoder setBuffer:argument_buffer offset:argument_buffer_offset atIndex:current_set];
1031
1032 [argument_encoder release]; // -1
Lei Zhangc0ad0ea2023-05-06 18:02:02 -07001033 }
1034
1035 // Record the dispatch, either direct or indirect.
1036 uint32_t* workgroup_size = segment->kernel_params.threadgroup_size;
1037 if (segment->workgroups_buffer == nil) {
1038 // Direct dispatch of a fixed workgroup count.
1039 uint32_t* workgroup_count = segment->workgroup_count;
1040 [compute_encoder
1041 dispatchThreadgroups:MTLSizeMake(workgroup_count[0], workgroup_count[1],
1042 workgroup_count[2])
1043 threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1], workgroup_size[2])];
1044 } else {
1045 // Indirect dispatch using a workgroup count from buffers.
1046 [compute_encoder
1047 dispatchThreadgroupsWithIndirectBuffer:segment->workgroups_buffer
1048 indirectBufferOffset:segment->workgroups_offset
1049 threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1],
1050 workgroup_size[2])];
1051 }
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001052
1053 IREE_TRACE_ZONE_END(z0);
1054 return iree_ok_status();
1055}
1056
Lei Zhangc0ad0ea2023-05-06 18:02:02 -07001057static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch(
1058 iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
1059 int32_t entry_point, uint32_t workgroup_count_x, uint32_t workgroup_count_y,
1060 uint32_t workgroup_count_z) {
1061 IREE_TRACE_ZONE_BEGIN(z0);
1062
1063 iree_hal_metal_dispatch_segment_t* segment = NULL;
1064 IREE_RETURN_AND_END_ZONE_IF_ERROR(
1065 z0, iree_hal_metal_command_segment_create_dispatch(base_command_buffer, executable,
1066 entry_point, &segment));
1067 segment->workgroup_count[0] = workgroup_count_x;
1068 segment->workgroup_count[1] = workgroup_count_y;
1069 segment->workgroup_count[2] = workgroup_count_z;
1070
1071 IREE_TRACE_ZONE_END(z0);
1072 return iree_ok_status();
1073}
1074
1075static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch_indirect(
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001076 iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
1077 int32_t entry_point, iree_hal_buffer_t* workgroups_buffer,
1078 iree_device_size_t workgroups_offset) {
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001079 IREE_TRACE_ZONE_BEGIN(z0);
1080
Lei Zhangc0ad0ea2023-05-06 18:02:02 -07001081 iree_hal_metal_dispatch_segment_t* segment = NULL;
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001082 IREE_RETURN_AND_END_ZONE_IF_ERROR(
Lei Zhangc0ad0ea2023-05-06 18:02:02 -07001083 z0, iree_hal_metal_command_segment_create_dispatch(base_command_buffer, executable,
1084 entry_point, &segment));
1085 segment->workgroups_buffer =
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001086 iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(workgroups_buffer));
Lei Zhangc0ad0ea2023-05-06 18:02:02 -07001087 segment->workgroups_offset = workgroups_offset;
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001088
1089 IREE_TRACE_ZONE_END(z0);
1090 return iree_ok_status();
1091}
1092
1093static iree_status_t iree_hal_metal_command_buffer_execute_commands(
1094 iree_hal_command_buffer_t* base_command_buffer, iree_hal_command_buffer_t* base_commands,
1095 iree_hal_buffer_binding_table_t binding_table) {
1096 return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "secondary command buffer not yet supported");
1097}
1098
Lei Zhangc0ad0ea2023-05-06 18:02:02 -07001099static iree_status_t iree_hal_metal_command_segment_record(
1100 iree_hal_metal_command_buffer_t* command_buffer) {
1101 IREE_ASSERT_ARGUMENT(command_buffer);
1102 IREE_TRACE_ZONE_BEGIN(z0);
1103
1104 for (iree_hal_metal_command_segment_t* segment = command_buffer->segments.head; segment;
1105 segment = segment->next_segment) {
1106 switch (segment->action) {
1107 case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER: {
1108 IREE_RETURN_AND_END_ZONE_IF_ERROR(
1109 z0, iree_hal_metal_command_segment_record_barrier(command_buffer, &segment->barrier));
1110 } break;
1111 case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH: {
1112 IREE_RETURN_AND_END_ZONE_IF_ERROR(
1113 z0, iree_hal_metal_command_segment_record_dispatch(command_buffer, &segment->dispatch));
1114 } break;
1115 case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER: {
1116 IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_fill_buffer(
1117 command_buffer, &segment->fill_buffer));
1118 } break;
1119 case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER: {
1120 IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_copy_buffer(
1121 command_buffer, &segment->copy_buffer));
1122 } break;
1123 default:
1124 IREE_ASSERT(false, "unhandled command segment kind");
1125 break;
1126 }
1127 }
1128
1129 IREE_TRACE_ZONE_END(z0);
1130 return iree_ok_status();
1131}
1132
1133static iree_status_t iree_hal_metal_command_buffer_begin(
1134 iree_hal_command_buffer_t* base_command_buffer) {
1135 iree_hal_metal_command_buffer_t* command_buffer =
1136 iree_hal_metal_command_buffer_cast(base_command_buffer);
1137 iree_hal_metal_command_buffer_reset(command_buffer);
1138 return iree_ok_status();
1139}
1140
1141static iree_status_t iree_hal_metal_command_buffer_end(
1142 iree_hal_command_buffer_t* base_command_buffer) {
1143 iree_hal_metal_command_buffer_t* command_buffer =
1144 iree_hal_metal_command_buffer_cast(base_command_buffer);
1145 IREE_TRACE_ZONE_BEGIN(z0);
1146
1147 IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record(command_buffer));
1148 iree_hal_metal_end_blit_encoder(command_buffer);
1149 iree_hal_metal_end_compute_encoder(command_buffer);
1150
1151 IREE_TRACE_ZONE_END(z0);
1152 return iree_ok_status();
1153}
1154
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001155static const iree_hal_command_buffer_vtable_t iree_hal_metal_command_buffer_vtable = {
1156 .destroy = iree_hal_metal_command_buffer_destroy,
1157 .begin = iree_hal_metal_command_buffer_begin,
1158 .end = iree_hal_metal_command_buffer_end,
1159 .begin_debug_group = iree_hal_metal_command_buffer_begin_debug_group,
1160 .end_debug_group = iree_hal_metal_command_buffer_end_debug_group,
Lei Zhangc0ad0ea2023-05-06 18:02:02 -07001161 .execution_barrier = iree_hal_metal_command_buffer_prepare_barrier,
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001162 .signal_event = iree_hal_metal_command_buffer_signal_event,
1163 .reset_event = iree_hal_metal_command_buffer_reset_event,
1164 .wait_events = iree_hal_metal_command_buffer_wait_events,
1165 .discard_buffer = iree_hal_metal_command_buffer_discard_buffer,
Lei Zhangc0ad0ea2023-05-06 18:02:02 -07001166 .fill_buffer = iree_hal_metal_command_buffer_prepare_fill_buffer,
1167 .update_buffer = iree_hal_metal_command_buffer_prepare_update_buffer,
1168 .copy_buffer = iree_hal_metal_command_buffer_prepare_copy_buffer,
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001169 .collective = iree_hal_metal_command_buffer_collective,
1170 .push_constants = iree_hal_metal_command_buffer_push_constants,
1171 .push_descriptor_set = iree_hal_metal_command_buffer_push_descriptor_set,
Lei Zhangc0ad0ea2023-05-06 18:02:02 -07001172 .dispatch = iree_hal_metal_command_buffer_prepare_dispatch,
1173 .dispatch_indirect = iree_hal_metal_command_buffer_prepare_dispatch_indirect,
Lei Zhangdf1e9a22023-02-12 12:08:00 -08001174 .execute_commands = iree_hal_metal_command_buffer_execute_commands,
1175};