Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1 | // Copyright 2023 The IREE Authors |
| 2 | // |
| 3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | |
| 7 | #include "experimental/metal/direct_command_buffer.h" |
| 8 | |
| 9 | #import <Metal/Metal.h> |
| 10 | |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 11 | #include "experimental/metal/builtin_executables.h" |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 12 | #include "experimental/metal/metal_buffer.h" |
| 13 | #include "experimental/metal/metal_device.h" |
| 14 | #include "experimental/metal/metal_kernel_library.h" |
| 15 | #include "experimental/metal/pipeline_layout.h" |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 16 | #include "experimental/metal/staging_buffer.h" |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 17 | #include "iree/base/api.h" |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 18 | #include "iree/base/target_platform.h" |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 19 | #include "iree/base/tracing.h" |
| 20 | #include "iree/hal/api.h" |
| 21 | #include "iree/hal/utils/resource_set.h" |
| 22 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 23 | //===------------------------------------------------------------------------------------------===// |
| 24 | // Segmented submission management |
| 25 | //===------------------------------------------------------------------------------------------===// |
| 26 | |
| 27 | // Unlike Vulkan, Metal adopts a multi-level command recording model--memory/dispatch commands are |
| 28 | // not directly recorded into a command buffer; rather, they must go through the additional level of |
| 29 | // blit/compute encoders. IREE's HAL follows the flat Vulkan command buffer recording model, so we |
| 30 | // have a mismatch here. Implementing IREE's HAL using Metal would require switching encoders for |
| 31 | // interleaved memory and dispatch commands. Additionally, certain IREE HAL API features do not have |
| 32 | // direct mapping in Metal APIs, e.g., various forms of IREE HAL execution/memory barriers. |
| 33 | // Translating them would require looking at both previous and next commands to decide the proper |
| 34 | // mapping. |
| 35 | // |
| 36 | // Due to these reasons, it's beneficial to have a complete view of the full command buffer and |
| 37 | // extra flexibility during recording, in order to fixup past commands, or inspect future commands. |
| 38 | // |
| 39 | // Therefore, to implement IREE HAL command buffers using Metal, we perform two steps using a linked |
| 40 | // list of command segments. First we create segments (iree_hal_metal_command_buffer_prepare_* and |
| 41 | // iree_hal_metal_command_segment_create_*) to keep track of all IREE HAL commands and the |
| 42 | // associated data, and then, when finalizing the command buffer, we iterate through all the |
| 43 | // segments and record their contents (iree_hal_metal_command_segment_record_*) into a proper Metal |
| 44 | // command buffer . A linked list gives us the flexibility to organize command sequence in low |
| 45 | // overhead; and a deferred recording gives us the complete picture of the command buffer when |
| 46 | // really started recording. |
| 47 | |
| 48 | // Command action kind of a command segment. |
| 49 | typedef enum iree_hal_metal_command_segment_action_e { |
| 50 | IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER, // Execution/memory barrier command |
| 51 | IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH, // Dispatch command |
| 52 | IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER, // Fill buffer command |
| 53 | IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER, // Copy buffer command |
| 54 | } iree_hal_metal_command_segment_action_t; |
| 55 | |
| 56 | // API data for execution/memory barrier command segments. |
| 57 | typedef struct iree_hal_metal_barrier_segment_t { |
| 58 | iree_host_size_t memory_barrier_count; // Total number of memory barriers |
| 59 | iree_host_size_t buffer_barrier_count; // Total number of buffer barriers |
| 60 | // The list of buffer barriers, pointing to the end of the segment allocation. |
| 61 | const iree_hal_buffer_barrier_t* buffer_barriers; |
| 62 | } iree_hal_metal_barrier_segment_t; |
| 63 | // + Additional inline allocation for holding all buffer barriers. |
| 64 | |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 65 | typedef struct iree_hal_metal_descriptor_t { |
| 66 | uint32_t set; |
| 67 | uint32_t binding; |
| 68 | iree_hal_buffer_t* buffer; |
| 69 | iree_device_size_t offset; |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 70 | MTLResourceUsage usage; |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 71 | } iree_hal_metal_descriptor_t; |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 72 | |
| 73 | // API data for dispatch command segments. |
| 74 | typedef struct iree_hal_metal_dispatch_segment_t { |
| 75 | // Compute kernel information--kernel object, pipeline layout, threadgroup size, etc. |
| 76 | iree_hal_metal_kernel_params_t kernel_params; |
| 77 | |
| 78 | // Workgroup count information--if |workgroups_buffer| is not nil, then indirect dispatch; |
| 79 | // otherwise uses |workgroup_count| for direct dispatch. |
| 80 | id<MTLBuffer> workgroups_buffer; |
| 81 | iree_device_size_t workgroups_offset; |
| 82 | uint32_t workgroup_count[3]; |
| 83 | |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 84 | // The number of descriptors bound for this dispatch. |
| 85 | iree_host_size_t descriptor_count; |
| 86 | // The list of bound descriptors, pointing to the end of the segment allocation. |
| 87 | iree_hal_metal_descriptor_t* descriptors; |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 88 | |
| 89 | // The number of push constant values. |
| 90 | iree_host_size_t push_constant_count; |
| 91 | // The list of push constants, pointing to the end of the segment allocation. |
| 92 | int32_t* push_constants; |
| 93 | } iree_hal_metal_dispatch_segment_t; |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 94 | // + Additional inline allocation for holding all bound descriptors. |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 95 | // + Additional inline allocation for holding all push constants. |
| 96 | |
| 97 | // API data for fill buffer command segments. |
| 98 | typedef struct iree_hal_metal_fill_buffer_segment_t { |
| 99 | id<MTLBuffer> target_buffer; |
| 100 | iree_device_size_t target_offset; |
| 101 | iree_device_size_t length; |
| 102 | // The fill pattern, pointing to the end of the segment allocation. |
| 103 | const void* pattern; |
| 104 | iree_host_size_t pattern_length; |
| 105 | } iree_hal_metal_fill_buffer_segment_t; |
| 106 | // + Additional inline allocation for holding the fill pattern. |
| 107 | |
| 108 | // API data for copy buffer command segments. |
| 109 | typedef struct iree_hal_metal_copy_buffer_segment_t { |
| 110 | id<MTLBuffer> source_buffer; |
| 111 | iree_device_size_t source_offset; |
| 112 | id<MTLBuffer> target_buffer; |
| 113 | iree_device_size_t target_offset; |
| 114 | iree_device_size_t length; |
| 115 | } iree_hal_metal_copy_buffer_segment_t; |
| 116 | |
| 117 | struct iree_hal_metal_command_segment_t; |
| 118 | typedef struct iree_hal_metal_command_segment_t { |
| 119 | struct iree_hal_metal_command_segment_t* next_segment; |
| 120 | iree_hal_metal_command_segment_action_t action; |
| 121 | union { |
| 122 | iree_hal_metal_barrier_segment_t barrier; |
| 123 | iree_hal_metal_dispatch_segment_t dispatch; |
| 124 | iree_hal_metal_fill_buffer_segment_t fill_buffer; |
| 125 | iree_hal_metal_copy_buffer_segment_t copy_buffer; |
| 126 | }; |
| 127 | } iree_hal_metal_command_segment_t; |
| 128 | |
| 129 | typedef struct iree_hal_metal_command_segment_list_t { |
| 130 | iree_hal_metal_command_segment_t* head; |
| 131 | iree_hal_metal_command_segment_t* tail; |
| 132 | } iree_hal_metal_command_segment_list_t; |
| 133 | |
| 134 | static void iree_hal_metal_command_segment_list_reset(iree_hal_metal_command_segment_list_t* list) { |
| 135 | memset(list, 0, sizeof(*list)); |
| 136 | } |
| 137 | |
| 138 | static void iree_hal_metal_command_segment_list_push_front( |
| 139 | iree_hal_metal_command_segment_list_t* list, iree_hal_metal_command_segment_t* segment) { |
| 140 | segment->next_segment = list->head; |
| 141 | list->head = segment; |
| 142 | if (!list->tail) list->tail = segment; |
| 143 | } |
| 144 | |
| 145 | static void iree_hal_metal_command_segment_list_push_back( |
| 146 | iree_hal_metal_command_segment_list_t* list, iree_hal_metal_command_segment_t* segment) { |
| 147 | segment->next_segment = NULL; |
| 148 | if (list->tail) { |
| 149 | list->tail->next_segment = segment; |
| 150 | list->tail = segment; |
| 151 | } else { |
| 152 | list->head = list->tail = segment; |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | //===------------------------------------------------------------------------------------------===// |
| 157 | // iree_hal_metal_command_buffer_t |
| 158 | //===------------------------------------------------------------------------------------------===// |
| 159 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 160 | typedef struct iree_hal_metal_command_buffer_t { |
| 161 | iree_hal_command_buffer_t base; |
| 162 | |
| 163 | // The Metal command queue owning this command buffer. |
| 164 | id<MTLCommandQueue> queue; |
| 165 | |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 166 | // For polyfilling fill/copy/update buffers that are not directly supported by Metal APIs. |
| 167 | iree_hal_metal_builtin_executable_t* builtin_executable; |
| 168 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 169 | // Arena used for all allocations; references the shared device block pool. |
| 170 | iree_arena_allocator_t arena; |
| 171 | |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 172 | // Per-queue shared uniform staging buffer for uploading parameters to the GPU, including argument |
| 173 | // buffers and buffer update source buffers. |
| 174 | iree_hal_metal_staging_buffer_t* staging_buffer; |
| 175 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 176 | iree_allocator_t host_allocator; |
| 177 | |
| 178 | // Maintains a reference to all resources used within the command buffer. Resets on each begin. |
| 179 | iree_hal_resource_set_t* resource_set; |
| 180 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 181 | // Linked list of command segments to be recorded into a command buffer. |
| 182 | iree_hal_metal_command_segment_list_t segments; |
| 183 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 184 | id<MTLCommandBuffer> command_buffer; |
| 185 | |
| 186 | MTLDispatchType dispatch_type; |
| 187 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 188 | struct { |
| 189 | // The current active compute/blit encoders for encoding compute for memory operations. |
| 190 | // Metal commands are encoded into the command buffer with such encoders, and each encoder can |
| 191 | // only encode the specific type of operations it supports. |
| 192 | id<MTLComputeCommandEncoder> compute_encoder; |
| 193 | id<MTLBlitCommandEncoder> blit_encoder; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 194 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 195 | // MTLEven used for synchronization when we switch between blit and compute encoders. |
| 196 | // Normally we would use MTLFence objects, but the difference between IREE HAL and Metal API |
| 197 | // means we may see many encoder switches. It would require creating a lot GPU objects. In order |
| 198 | // to avoid the cost, we just use one MTLEvent with different values for different switches. |
| 199 | id<MTLEvent> encoder_event; |
| 200 | // The next available encoder event value to signal/wait to/on. |
| 201 | uint64_t next_encoder_event_value; |
Lei Zhang | 30d52cd | 2023-04-26 17:14:11 -0700 | [diff] [blame] | 202 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 203 | // Metal APIs mandate we create argument bufffers (for descriptor sets) from compiled kernel |
| 204 | // function. That means we need to bind the compute kernel first before setting descriptors and |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 205 | // binding buffers. However in IREE HAL API we see push descriptors before the dispatch command. |
| 206 | // So we need to cache the descriptor information by ourselves and record them at dispatch time. |
| 207 | struct { |
| 208 | iree_hal_metal_descriptor_t bindings[IREE_HAL_METAL_MAX_DESCRIPTOR_SET_BINDING_COUNT]; |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 209 | } descriptor_sets[IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX]; |
Lei Zhang | 30f6a39 | 2023-02-12 16:42:27 -0800 | [diff] [blame] | 210 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 211 | // All available push constants updated each time push_constants is called. Reset only with the |
| 212 | // command buffer and otherwise will maintain its values during recording to allow for partial |
| 213 | // push_constants updates. |
| 214 | int32_t push_constants[IREE_HAL_METAL_MAX_PUSH_CONSTANT_COUNT]; |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 215 | } state; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 216 | } iree_hal_metal_command_buffer_t; |
| 217 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 218 | //===------------------------------------------------------------------------------------------===// |
| 219 | // iree_hal_metal_command_buffer_vtable APIs |
| 220 | //===------------------------------------------------------------------------------------------===// |
| 221 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 222 | static const iree_hal_command_buffer_vtable_t iree_hal_metal_command_buffer_vtable; |
| 223 | |
| 224 | static iree_hal_metal_command_buffer_t* iree_hal_metal_command_buffer_cast( |
| 225 | iree_hal_command_buffer_t* base_value) { |
| 226 | IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_command_buffer_vtable); |
| 227 | return (iree_hal_metal_command_buffer_t*)base_value; |
| 228 | } |
| 229 | |
Lei Zhang | 54818a3 | 2023-06-10 16:01:14 -0700 | [diff] [blame] | 230 | static const iree_hal_metal_command_buffer_t* iree_hal_metal_command_buffer_const_cast( |
| 231 | const iree_hal_command_buffer_t* base_value) { |
| 232 | IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_command_buffer_vtable); |
| 233 | return (const iree_hal_metal_command_buffer_t*)base_value; |
| 234 | } |
| 235 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 236 | id<MTLCommandBuffer> iree_hal_metal_direct_command_buffer_handle( |
Lei Zhang | 54818a3 | 2023-06-10 16:01:14 -0700 | [diff] [blame] | 237 | const iree_hal_command_buffer_t* base_command_buffer) { |
| 238 | const iree_hal_metal_command_buffer_t* command_buffer = |
| 239 | iree_hal_metal_command_buffer_const_cast(base_command_buffer); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 240 | return command_buffer->command_buffer; |
| 241 | } |
| 242 | |
| 243 | static void iree_hal_metal_end_compute_encoder(iree_hal_metal_command_buffer_t* command_buffer) { |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 244 | if (command_buffer->state.compute_encoder) { |
| 245 | [command_buffer->state.compute_encoder endEncoding]; |
| 246 | [command_buffer->state.compute_encoder release]; // -1 |
| 247 | command_buffer->state.compute_encoder = nil; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 248 | } |
| 249 | } |
| 250 | |
| 251 | static void iree_hal_metal_end_blit_encoder(iree_hal_metal_command_buffer_t* command_buffer) { |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 252 | if (command_buffer->state.blit_encoder) { |
| 253 | [command_buffer->state.blit_encoder endEncoding]; |
| 254 | [command_buffer->state.blit_encoder release]; // -1 |
| 255 | command_buffer->state.blit_encoder = nil; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 256 | } |
| 257 | } |
| 258 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 259 | static void iree_hal_metal_command_buffer_reset(iree_hal_metal_command_buffer_t* command_buffer) { |
| 260 | IREE_TRACE_ZONE_BEGIN(z0); |
| 261 | iree_hal_metal_end_blit_encoder(command_buffer); |
| 262 | iree_hal_metal_end_compute_encoder(command_buffer); |
| 263 | iree_hal_metal_command_segment_list_reset(&command_buffer->segments); |
| 264 | iree_arena_reset(&command_buffer->arena); |
| 265 | IREE_TRACE_ZONE_END(z0); |
| 266 | } |
| 267 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 268 | static id<MTLComputeCommandEncoder> iree_hal_metal_get_or_begin_compute_encoder( |
| 269 | iree_hal_metal_command_buffer_t* command_buffer) { |
Lei Zhang | 30d52cd | 2023-04-26 17:14:11 -0700 | [diff] [blame] | 270 | id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer; |
| 271 | |
| 272 | // If we are switching encoders, we would need to use a fence to synchronize "one or more |
| 273 | // resources across different passes within a command buffer." |
| 274 | // https://developer.apple.com/documentation/metal/resource_synchronization |
| 275 | uint64_t encoder_event_value = 0; |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 276 | if (command_buffer->state.blit_encoder) { |
Lei Zhang | 0adc461 | 2023-03-11 21:09:54 -0800 | [diff] [blame] | 277 | iree_hal_metal_end_blit_encoder(command_buffer); |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 278 | encoder_event_value = command_buffer->state.next_encoder_event_value++; |
| 279 | [metal_handle encodeSignalEvent:command_buffer->state.encoder_event value:encoder_event_value]; |
Lei Zhang | 0adc461 | 2023-03-11 21:09:54 -0800 | [diff] [blame] | 280 | } |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 281 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 282 | if (!command_buffer->state.compute_encoder) { |
Lei Zhang | 30d52cd | 2023-04-26 17:14:11 -0700 | [diff] [blame] | 283 | if (encoder_event_value != 0) { |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 284 | [metal_handle encodeWaitForEvent:command_buffer->state.encoder_event |
| 285 | value:encoder_event_value]; |
Lei Zhang | 30d52cd | 2023-04-26 17:14:11 -0700 | [diff] [blame] | 286 | } |
| 287 | @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation. |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 288 | // We manage commands dependencies and insert barriers explicitly in IREE; so use the |
| 289 | // concurrent dispatch type for compute encoders. |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 290 | command_buffer->state.compute_encoder = [[metal_handle |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 291 | computeCommandEncoderWithDispatchType:command_buffer->dispatch_type] retain]; // +1 |
| 292 | } |
| 293 | } |
Lei Zhang | 0adc461 | 2023-03-11 21:09:54 -0800 | [diff] [blame] | 294 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 295 | return command_buffer->state.compute_encoder; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 296 | } |
| 297 | |
| 298 | static id<MTLBlitCommandEncoder> iree_hal_metal_get_or_begin_blit_encoder( |
| 299 | iree_hal_metal_command_buffer_t* command_buffer) { |
Lei Zhang | 30d52cd | 2023-04-26 17:14:11 -0700 | [diff] [blame] | 300 | id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer; |
| 301 | |
| 302 | // If we are switching encoders, we would need to use a fence to synchronize "one or more |
| 303 | // resources across different passes within a command buffer." |
| 304 | // https://developer.apple.com/documentation/metal/resource_synchronization |
| 305 | uint64_t encoder_event_value = 0; |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 306 | if (command_buffer->state.compute_encoder) { |
Lei Zhang | 0adc461 | 2023-03-11 21:09:54 -0800 | [diff] [blame] | 307 | iree_hal_metal_end_compute_encoder(command_buffer); |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 308 | encoder_event_value = command_buffer->state.next_encoder_event_value++; |
| 309 | [metal_handle encodeSignalEvent:command_buffer->state.encoder_event value:encoder_event_value]; |
Lei Zhang | 0adc461 | 2023-03-11 21:09:54 -0800 | [diff] [blame] | 310 | } |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 311 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 312 | if (!command_buffer->state.blit_encoder) { |
Lei Zhang | 30d52cd | 2023-04-26 17:14:11 -0700 | [diff] [blame] | 313 | if (encoder_event_value != 0) { |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 314 | [metal_handle encodeWaitForEvent:command_buffer->state.encoder_event |
| 315 | value:encoder_event_value]; |
Lei Zhang | 30d52cd | 2023-04-26 17:14:11 -0700 | [diff] [blame] | 316 | } |
| 317 | @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation. |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 318 | command_buffer->state.blit_encoder = [[metal_handle blitCommandEncoder] retain]; // +1 |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 319 | } |
| 320 | } |
Lei Zhang | 0adc461 | 2023-03-11 21:09:54 -0800 | [diff] [blame] | 321 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 322 | return command_buffer->state.blit_encoder; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 323 | } |
| 324 | |
Lei Zhang | 028acfb | 2023-06-13 17:37:32 -0700 | [diff] [blame] | 325 | // Destroys the given |base_command_buffer| itself, without decreasing refcount in the shared |
| 326 | // staging buffer yet. |
| 327 | static void iree_hal_metal_command_buffer_destroy_internal( |
| 328 | iree_hal_command_buffer_t* base_command_buffer); |
| 329 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 330 | iree_status_t iree_hal_metal_direct_command_buffer_create( |
| 331 | iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode, |
| 332 | iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, |
Lei Zhang | 3f64a11 | 2023-03-10 09:09:28 -0800 | [diff] [blame] | 333 | iree_hal_metal_command_buffer_resource_reference_mode_t resource_reference_mode, |
Lei Zhang | 54818a3 | 2023-06-10 16:01:14 -0700 | [diff] [blame] | 334 | id<MTLCommandQueue> queue, iree_arena_block_pool_t* block_pool, |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 335 | iree_hal_metal_staging_buffer_t* staging_buffer, |
Lei Zhang | 54818a3 | 2023-06-10 16:01:14 -0700 | [diff] [blame] | 336 | iree_hal_metal_builtin_executable_t* builtin_executable, iree_allocator_t host_allocator, |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 337 | iree_hal_command_buffer_t** out_command_buffer) { |
| 338 | IREE_ASSERT_ARGUMENT(device); |
| 339 | IREE_ASSERT_ARGUMENT(out_command_buffer); |
| 340 | IREE_ASSERT_TRUE(iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)); |
| 341 | IREE_ASSERT_TRUE(!iree_any_bit_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED)); |
| 342 | *out_command_buffer = NULL; |
| 343 | |
| 344 | if (binding_capacity > 0) { |
| 345 | // TODO(#10144): support indirect command buffers with binding tables. |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 346 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "indirect command buffer not yet supported"); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 347 | } |
| 348 | |
| 349 | IREE_TRACE_ZONE_BEGIN(z0); |
| 350 | |
| 351 | iree_hal_metal_command_buffer_t* command_buffer = NULL; |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 352 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 353 | z0, iree_allocator_malloc(host_allocator, sizeof(*command_buffer), (void**)&command_buffer)); |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 354 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 355 | iree_hal_command_buffer_initialize(device, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, |
| 356 | binding_capacity, &iree_hal_metal_command_buffer_vtable, |
| 357 | &command_buffer->base); |
| 358 | command_buffer->queue = [queue retain]; // +1 |
| 359 | command_buffer->builtin_executable = builtin_executable; |
| 360 | iree_arena_initialize(block_pool, &command_buffer->arena); |
| 361 | command_buffer->staging_buffer = staging_buffer; |
| 362 | command_buffer->host_allocator = host_allocator; |
| 363 | iree_status_t status = iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set); |
Lei Zhang | 028acfb | 2023-06-13 17:37:32 -0700 | [diff] [blame] | 364 | if (iree_status_is_ok(status)) { |
| 365 | iree_hal_metal_command_segment_list_reset(&command_buffer->segments); |
| 366 | @autoreleasepool { // Use @autoreleasepool to trigger the autorelease within encoder creation. |
| 367 | // We track resource lifetime by ourselves in IREE; so just do unretained references to |
| 368 | // resources in Metal command buffer, which avoids overhead and gives better performance. |
| 369 | MTLCommandBufferDescriptor* descriptor = [MTLCommandBufferDescriptor new]; // +1 |
| 370 | descriptor.retainedReferences = |
| 371 | resource_reference_mode == IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED; |
| 372 | descriptor.errorOptions = MTLCommandBufferErrorOptionNone; |
| 373 | command_buffer->command_buffer = |
| 374 | [[queue commandBufferWithDescriptor:descriptor] retain]; // +1 |
| 375 | [descriptor release]; // -1 |
| 376 | } |
| 377 | const iree_hal_metal_device_params_t* params = iree_hal_metal_device_params(device); |
| 378 | command_buffer->dispatch_type = |
| 379 | params->command_dispatch_type == IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT |
| 380 | ? MTLDispatchTypeConcurrent |
| 381 | : MTLDispatchTypeSerial; |
| 382 | command_buffer->state.compute_encoder = nil; |
| 383 | command_buffer->state.blit_encoder = nil; |
| 384 | command_buffer->state.encoder_event = [queue.device newEvent]; // +1 |
| 385 | command_buffer->state.next_encoder_event_value = 1; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 386 | } |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 387 | |
Lei Zhang | 028acfb | 2023-06-13 17:37:32 -0700 | [diff] [blame] | 388 | if (iree_status_is_ok(status)) { |
| 389 | *out_command_buffer = &command_buffer->base; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 390 | |
Lei Zhang | 028acfb | 2023-06-13 17:37:32 -0700 | [diff] [blame] | 391 | // Increase command buffer refcount in the shared staging buffer. We tie this to the command |
| 392 | // buffer's lifetime to avoid resource leak. |
| 393 | iree_hal_metal_staging_buffer_increase_refcount(staging_buffer); |
| 394 | } else { |
| 395 | iree_hal_metal_command_buffer_destroy_internal(&command_buffer->base); |
| 396 | } |
Lei Zhang | eba9f5a | 2023-06-11 09:38:20 -0700 | [diff] [blame] | 397 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 398 | IREE_TRACE_ZONE_END(z0); |
| 399 | return status; |
| 400 | } |
| 401 | |
Lei Zhang | 028acfb | 2023-06-13 17:37:32 -0700 | [diff] [blame] | 402 | static void iree_hal_metal_command_buffer_destroy_internal( |
| 403 | iree_hal_command_buffer_t* base_command_buffer) { |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 404 | iree_hal_metal_command_buffer_t* command_buffer = |
| 405 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 406 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 407 | iree_hal_metal_command_buffer_reset(command_buffer); |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 408 | [command_buffer->state.encoder_event release]; // -1 |
| 409 | IREE_ASSERT_EQ(command_buffer->state.compute_encoder, nil); |
| 410 | IREE_ASSERT_EQ(command_buffer->state.blit_encoder, nil); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 411 | [command_buffer->command_buffer release]; // -1 |
| 412 | [command_buffer->queue release]; // -1 |
| 413 | iree_hal_resource_set_free(command_buffer->resource_set); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 414 | iree_arena_deinitialize(&command_buffer->arena); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 415 | iree_allocator_free(command_buffer->host_allocator, command_buffer); |
Lei Zhang | 028acfb | 2023-06-13 17:37:32 -0700 | [diff] [blame] | 416 | } |
| 417 | |
| 418 | static void iree_hal_metal_command_buffer_destroy(iree_hal_command_buffer_t* base_command_buffer) { |
| 419 | iree_hal_metal_command_buffer_t* command_buffer = |
| 420 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
| 421 | IREE_TRACE_ZONE_BEGIN(z0); |
| 422 | |
| 423 | // Decrease command buffer refcount in the shared staging buffer, and potentially reclaim |
| 424 | // resources. We tie this to the command buffer's lifetime to avoid resource leak. |
| 425 | if (command_buffer->staging_buffer) { |
| 426 | iree_hal_metal_staging_buffer_decrease_refcount(command_buffer->staging_buffer); |
| 427 | } |
| 428 | |
| 429 | iree_hal_metal_command_buffer_destroy_internal(base_command_buffer); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 430 | |
| 431 | IREE_TRACE_ZONE_END(z0); |
| 432 | } |
| 433 | |
| 434 | bool iree_hal_metal_command_buffer_isa(iree_hal_command_buffer_t* command_buffer) { |
| 435 | return iree_hal_resource_is(&command_buffer->resource, &iree_hal_metal_command_buffer_vtable); |
| 436 | } |
| 437 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 438 | static void iree_hal_metal_command_buffer_begin_debug_group( |
| 439 | iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, |
| 440 | iree_hal_label_color_t label_color, const iree_hal_label_location_t* location) { |
| 441 | // TODO(antiagainst): implement support for debug group |
| 442 | } |
| 443 | |
| 444 | static void iree_hal_metal_command_buffer_end_debug_group( |
| 445 | iree_hal_command_buffer_t* base_command_buffer) { |
| 446 | // TODO(antiagainst): implement support for debug group |
| 447 | } |
| 448 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 449 | static iree_status_t iree_hal_metal_command_buffer_prepare_barrier( |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 450 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_execution_stage_t source_stage_mask, |
| 451 | iree_hal_execution_stage_t target_stage_mask, iree_hal_execution_barrier_flags_t flags, |
| 452 | iree_host_size_t memory_barrier_count, const iree_hal_memory_barrier_t* memory_barriers, |
| 453 | iree_host_size_t buffer_barrier_count, const iree_hal_buffer_barrier_t* buffer_barriers) { |
| 454 | if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || |
| 455 | iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { |
| 456 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "barrier involving host not yet supported"); |
| 457 | } |
| 458 | |
| 459 | if (flags != IREE_HAL_EXECUTION_BARRIER_FLAG_NONE) { |
| 460 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "non-zero barrier flag not yet supported"); |
| 461 | } |
| 462 | |
| 463 | iree_hal_metal_command_buffer_t* command_buffer = |
| 464 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 465 | IREE_TRACE_ZONE_BEGIN(z0); |
Lei Zhang | ffb40b1 | 2023-04-27 08:18:51 -0700 | [diff] [blame] | 466 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 467 | // Allocate the command segment and keep track of all necessary API data. |
| 468 | uint8_t* storage_base = NULL; |
| 469 | iree_hal_metal_command_segment_t* segment = NULL; |
| 470 | iree_host_size_t buffer_barrier_length = buffer_barrier_count * sizeof(iree_hal_buffer_barrier_t); |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 471 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 472 | z0, iree_arena_allocate(&command_buffer->arena, sizeof(*segment) + buffer_barrier_length, |
| 473 | (void**)&storage_base)); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 474 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 475 | // Copy the buffer barriers to the end of the current segments for later access. We don't copy |
| 476 | // memory barriers because in Metal there is only coarse-grained full memory barrier affecting |
| 477 | // all buffers, regardless of the fine-grained details from IREE HAL barriers. |
| 478 | uint8_t* barrier_ptr = storage_base + sizeof(*segment); |
| 479 | memcpy(barrier_ptr, (const uint8_t*)buffer_barriers, buffer_barrier_length); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 480 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 481 | // Compose and push the barrier segment. |
| 482 | segment = (iree_hal_metal_command_segment_t*)storage_base; |
| 483 | segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER; |
| 484 | iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 485 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 486 | segment->barrier.memory_barrier_count = memory_barrier_count; |
| 487 | segment->barrier.buffer_barrier_count = buffer_barrier_count; |
| 488 | segment->barrier.buffer_barriers = (const iree_hal_buffer_barrier_t*)barrier_ptr; |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 489 | |
| 490 | IREE_TRACE_ZONE_END(z0); |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 491 | return iree_ok_status(); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 492 | } |
| 493 | |
| 494 | static iree_status_t iree_hal_metal_command_segment_record_barrier( |
| 495 | iree_hal_metal_command_buffer_t* command_buffer, iree_hal_metal_barrier_segment_t* segment) { |
Lei Zhang | e32cfa6 | 2023-05-08 13:40:06 -0700 | [diff] [blame] | 496 | // TODO(antiagainst): Analyze segments before and after to optimize barriers, e.g., switching |
| 497 | // encoders would require its own synchronization; so we don't need extract barriers in the |
| 498 | // middle. |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 499 | if (segment->memory_barrier_count == 0 && segment->buffer_barrier_count == 0) { |
Lei Zhang | ffb40b1 | 2023-04-27 08:18:51 -0700 | [diff] [blame] | 500 | // There is no direct corresponding APIs for execution only barrier in Metal. We just signal and |
| 501 | // wait on the same value of a MTLEvent here. |
| 502 | iree_hal_metal_end_blit_encoder(command_buffer); |
| 503 | iree_hal_metal_end_compute_encoder(command_buffer); |
| 504 | id<MTLCommandBuffer> metal_handle = command_buffer->command_buffer; |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 505 | uint64_t event_value = command_buffer->state.next_encoder_event_value++; |
| 506 | [metal_handle encodeSignalEvent:command_buffer->state.encoder_event value:event_value]; |
| 507 | [metal_handle encodeWaitForEvent:command_buffer->state.encoder_event value:event_value]; |
Lei Zhang | ffb40b1 | 2023-04-27 08:18:51 -0700 | [diff] [blame] | 508 | return iree_ok_status(); |
| 509 | } |
| 510 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 511 | id<MTLComputeCommandEncoder> encoder = |
| 512 | iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
| 513 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 514 | if (segment->memory_barrier_count != 0) { |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 515 | // If there is a memory barrier specified, we have to place a catch-all barrier for all buffers. |
| 516 | // Metal does not provide a more fine-grained control here. |
| 517 | [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; |
| 518 | return iree_ok_status(); |
| 519 | } |
| 520 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 521 | if (segment->buffer_barrier_count != 0) { |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 522 | // But we do have the option to specify a list of buffers to synchronize if only buffer barriers |
| 523 | // are specified. |
| 524 | id<MTLResource>* resources = |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 525 | (id<MTLResource>*)iree_alloca(sizeof(id<MTLResource>) * segment->buffer_barrier_count); |
| 526 | for (iree_host_size_t i = 0; i < segment->buffer_barrier_count; ++i) { |
| 527 | resources[i] = iree_hal_metal_buffer_handle( |
| 528 | iree_hal_buffer_allocated_buffer(segment->buffer_barriers[i].buffer)); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 529 | } |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 530 | [encoder memoryBarrierWithResources:resources count:segment->buffer_barrier_count]; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 531 | } |
| 532 | return iree_ok_status(); |
| 533 | } |
| 534 | |
| 535 | static iree_status_t iree_hal_metal_command_buffer_signal_event( |
| 536 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, |
| 537 | iree_hal_execution_stage_t source_stage_mask) { |
Lei Zhang | f231e81 | 2023-04-17 14:18:23 -0700 | [diff] [blame] | 538 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 539 | } |
| 540 | |
| 541 | static iree_status_t iree_hal_metal_command_buffer_reset_event( |
| 542 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, |
| 543 | iree_hal_execution_stage_t source_stage_mask) { |
Lei Zhang | f231e81 | 2023-04-17 14:18:23 -0700 | [diff] [blame] | 544 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 545 | } |
| 546 | |
| 547 | static iree_status_t iree_hal_metal_command_buffer_wait_events( |
| 548 | iree_hal_command_buffer_t* base_command_buffer, iree_host_size_t event_count, |
| 549 | const iree_hal_event_t** events, iree_hal_execution_stage_t source_stage_mask, |
| 550 | iree_hal_execution_stage_t target_stage_mask, iree_host_size_t memory_barrier_count, |
| 551 | const iree_hal_memory_barrier_t* memory_barriers, iree_host_size_t buffer_barrier_count, |
| 552 | const iree_hal_buffer_barrier_t* buffer_barriers) { |
Lei Zhang | f231e81 | 2023-04-17 14:18:23 -0700 | [diff] [blame] | 553 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 554 | } |
| 555 | |
| 556 | static iree_status_t iree_hal_metal_command_buffer_discard_buffer( |
| 557 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { |
| 558 | // This is a hint to the device and we have nothing to do for Metal. |
| 559 | return iree_ok_status(); |
| 560 | } |
| 561 | |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 562 | // Fills |value| with the duplicated single byte value and return true if the given |pattern| has |
| 563 | // duplicated values for each of its |pattern_length| bytes. |
| 564 | static bool iree_hal_metal_get_duplicated_single_byte_value(const void* pattern, |
| 565 | size_t pattern_length, uint8_t* value) { |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 566 | switch (pattern_length) { |
| 567 | case 1: { |
| 568 | *value = *(uint8_t*)pattern; |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 569 | return true; |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 570 | } |
| 571 | case 2: { |
| 572 | uint16_t two_bytes = *(uint16_t*)pattern; |
| 573 | uint16_t byte0 = two_bytes & 0xffu; |
| 574 | uint16_t byte1 = two_bytes >> 8u; |
| 575 | if (byte0 == byte1) { |
| 576 | *value = (int8_t)byte0; |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 577 | return true; |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 578 | } |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 579 | break; |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 580 | } |
| 581 | case 4: { |
| 582 | uint32_t four_bytes = *(uint32_t*)pattern; |
| 583 | uint32_t byte0 = four_bytes & 0xffu; |
| 584 | uint32_t byte1 = (four_bytes >> 8u) & 0xffu; |
| 585 | uint32_t byte2 = (four_bytes >> 16u) & 0xffu; |
| 586 | uint32_t byte3 = four_bytes >> 24u; |
| 587 | if (byte0 == byte1 && byte0 == byte2 && byte0 == byte3) { |
| 588 | *value = (int8_t)byte0; |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 589 | return true; |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 590 | } |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 591 | break; |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 592 | } |
| 593 | default: |
| 594 | break; |
| 595 | } |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 596 | return false; |
| 597 | } |
| 598 | |
Lei Zhang | e8679ad | 2023-06-11 08:45:31 -0700 | [diff] [blame] | 599 | // Duplicates the given |pattern| into 4-bytes and returns the value. |
| 600 | static uint32_t iree_hal_metal_duplicate_to_four_byte_value(const void* pattern, |
| 601 | size_t pattern_length) { |
| 602 | if (pattern_length == 1) { |
| 603 | uint8_t single_byte = *(uint8_t*)pattern; |
| 604 | uint32_t value = (uint32_t)single_byte; |
| 605 | value |= (value << 8u); |
| 606 | value |= (value << 16u); |
| 607 | return value; |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 608 | } |
Lei Zhang | e8679ad | 2023-06-11 08:45:31 -0700 | [diff] [blame] | 609 | |
| 610 | if (pattern_length == 2) { |
| 611 | uint16_t two_bytes = *(uint16_t*)pattern; |
| 612 | uint32_t value = (uint32_t)two_bytes; |
| 613 | value |= (value << 16u); |
| 614 | return value; |
| 615 | } |
| 616 | |
| 617 | IREE_ASSERT(pattern_length == 4); |
| 618 | return *(uint32_t*)pattern; |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 619 | } |
| 620 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 621 | static iree_status_t iree_hal_metal_command_buffer_prepare_fill_buffer( |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 622 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* target_buffer, |
| 623 | iree_device_size_t target_offset, iree_device_size_t length, const void* pattern, |
| 624 | iree_host_size_t pattern_length) { |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 625 | iree_hal_metal_command_buffer_t* command_buffer = |
| 626 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
| 627 | IREE_TRACE_ZONE_BEGIN(z0); |
| 628 | |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 629 | id<MTLBuffer> target_device_buffer = |
| 630 | iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer)); |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 631 | target_offset += iree_hal_buffer_byte_offset(target_buffer); |
| 632 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 633 | // Allocate the command segment and keep track of all necessary API data. |
| 634 | uint8_t* storage_base = NULL; |
| 635 | iree_hal_metal_command_segment_t* segment = NULL; |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 636 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 637 | z0, iree_arena_allocate(&command_buffer->arena, sizeof(*segment) + pattern_length, |
| 638 | (void**)&storage_base)); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 639 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 640 | // Copy the patttern to the end of the segment for later access. |
| 641 | uint8_t* pattern_ptr = storage_base + sizeof(*segment); |
| 642 | memcpy(pattern_ptr, (const uint8_t*)pattern, pattern_length); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 643 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 644 | // Compose and push the fill buffer segment. |
| 645 | segment = (iree_hal_metal_command_segment_t*)storage_base; |
| 646 | segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER; |
| 647 | iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 648 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 649 | segment->fill_buffer.target_buffer = target_device_buffer; |
| 650 | segment->fill_buffer.target_offset = target_offset; |
| 651 | segment->fill_buffer.length = length; |
| 652 | segment->fill_buffer.pattern = (const void*)pattern_ptr; |
| 653 | segment->fill_buffer.pattern_length = pattern_length; |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 654 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 655 | iree_status_t status = |
| 656 | iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 657 | |
| 658 | IREE_TRACE_ZONE_END(z0); |
| 659 | return status; |
| 660 | } |
| 661 | |
| 662 | static iree_status_t iree_hal_metal_command_segment_record_fill_buffer( |
| 663 | iree_hal_metal_command_buffer_t* command_buffer, |
| 664 | iree_hal_metal_fill_buffer_segment_t* segment) { |
| 665 | IREE_TRACE_ZONE_BEGIN(z0); |
| 666 | |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 667 | // Note that fillBuffer:range:value: only accepts a single byte as the pattern but FillBuffer |
| 668 | // can accept 1/2/4 bytes. If the pattern itself contains repeated bytes, we can call into |
| 669 | // fillBuffer:range:value:. Otherwise we need to emulate the support. |
Lei Zhang | e8679ad | 2023-06-11 08:45:31 -0700 | [diff] [blame] | 670 | uint8_t pattern_1byte = 0u; |
| 671 | |
| 672 | // Per the spec for fillBuffer:range:value: "The alignment and length of the range must both be a |
| 673 | // multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS." |
| 674 | #if defined(IREE_PLATFORM_MACOS) |
| 675 | const bool can_use_metal_api = segment->target_offset % 4 == 0 && segment->length % 4 == 0 && |
| 676 | iree_hal_metal_get_duplicated_single_byte_value( |
| 677 | segment->pattern, segment->pattern_length, &pattern_1byte); |
| 678 | #else |
| 679 | const bool can_use_metal_api = iree_hal_metal_get_duplicated_single_byte_value( |
| 680 | segment->pattern, segment->pattern_length, &pattern_1byte); |
| 681 | #endif |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 682 | |
Lei Zhang | 182db9d | 2023-02-20 13:38:12 -0800 | [diff] [blame] | 683 | if (can_use_metal_api) { |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 684 | id<MTLBlitCommandEncoder> encoder = iree_hal_metal_get_or_begin_blit_encoder(command_buffer); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 685 | [encoder fillBuffer:segment->target_buffer |
| 686 | range:NSMakeRange(segment->target_offset, segment->length) |
Lei Zhang | e8679ad | 2023-06-11 08:45:31 -0700 | [diff] [blame] | 687 | value:pattern_1byte]; |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 688 | IREE_TRACE_ZONE_END(z0); |
| 689 | return iree_ok_status(); |
| 690 | } |
| 691 | |
| 692 | id<MTLComputeCommandEncoder> compute_encoder = |
| 693 | iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
Lei Zhang | e8679ad | 2023-06-11 08:45:31 -0700 | [diff] [blame] | 694 | uint32_t pattern_4byte = |
| 695 | iree_hal_metal_duplicate_to_four_byte_value(segment->pattern, segment->pattern_length); |
| 696 | iree_status_t status = iree_hal_metal_builtin_executable_fill_buffer( |
| 697 | command_buffer->builtin_executable, compute_encoder, segment->target_buffer, |
| 698 | segment->target_offset, segment->length, pattern_4byte); |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 699 | |
| 700 | IREE_TRACE_ZONE_END(z0); |
| 701 | return status; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 702 | } |
| 703 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 704 | static iree_status_t iree_hal_metal_command_segment_create_copy_buffer( |
Lei Zhang | 59c4699 | 2023-02-25 17:05:29 -0800 | [diff] [blame] | 705 | iree_hal_metal_command_buffer_t* command_buffer, id<MTLBuffer> source_device_buffer, |
| 706 | iree_device_size_t source_offset, id<MTLBuffer> target_device_buffer, |
| 707 | iree_device_size_t target_offset, iree_device_size_t length) { |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 708 | IREE_TRACE_ZONE_BEGIN(z0); |
| 709 | |
| 710 | // Allocate the command segment and keep track of all necessary API data. |
| 711 | uint8_t* storage_base = NULL; |
| 712 | iree_hal_metal_command_segment_t* segment = NULL; |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 713 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 714 | z0, iree_arena_allocate(&command_buffer->arena, sizeof(*segment), (void**)&storage_base)); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 715 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 716 | // Compose and push the barrier segment. |
| 717 | segment = (iree_hal_metal_command_segment_t*)storage_base; |
| 718 | segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER; |
| 719 | iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 720 | |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 721 | segment->copy_buffer.source_buffer = source_device_buffer; |
| 722 | segment->copy_buffer.source_offset = source_offset; |
| 723 | segment->copy_buffer.target_buffer = target_device_buffer; |
| 724 | segment->copy_buffer.target_offset = target_offset; |
| 725 | segment->copy_buffer.length = length; |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 726 | |
| 727 | IREE_TRACE_ZONE_END(z0); |
Lei Zhang | 52a8d0c | 2023-06-10 20:17:07 -0700 | [diff] [blame] | 728 | return iree_ok_status(); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 729 | } |
| 730 | |
| 731 | static iree_status_t iree_hal_metal_command_segment_record_copy_buffer( |
| 732 | iree_hal_metal_command_buffer_t* command_buffer, |
| 733 | iree_hal_metal_copy_buffer_segment_t* segment) { |
| 734 | IREE_TRACE_ZONE_BEGIN(z0); |
| 735 | |
Lei Zhang | 59c4699 | 2023-02-25 17:05:29 -0800 | [diff] [blame] | 736 | // Per the spec for copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size, the source/target |
| 737 | // offset and length must be a multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS. |
| 738 | #if defined(IREE_PLATFORM_MACOS) |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 739 | bool can_use_metal_api = segment->source_offset % 4 == 0 && segment->target_offset % 4 == 0 && |
| 740 | segment->length % 4 == 0; |
Lei Zhang | 59c4699 | 2023-02-25 17:05:29 -0800 | [diff] [blame] | 741 | #else |
| 742 | bool can_use_metal_api = true; |
| 743 | #endif |
| 744 | |
| 745 | iree_status_t status = iree_ok_status(); |
| 746 | if (can_use_metal_api) { |
| 747 | id<MTLBlitCommandEncoder> encoder = iree_hal_metal_get_or_begin_blit_encoder(command_buffer); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 748 | [encoder copyFromBuffer:segment->source_buffer |
| 749 | sourceOffset:segment->source_offset |
| 750 | toBuffer:segment->target_buffer |
| 751 | destinationOffset:segment->target_offset |
| 752 | size:segment->length]; |
Lei Zhang | 59c4699 | 2023-02-25 17:05:29 -0800 | [diff] [blame] | 753 | } else { |
| 754 | id<MTLComputeCommandEncoder> encoder = |
| 755 | iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
| 756 | status = iree_hal_metal_builtin_executable_copy_buffer( |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 757 | command_buffer->builtin_executable, encoder, segment->source_buffer, segment->source_offset, |
| 758 | segment->target_buffer, segment->target_offset, segment->length); |
Lei Zhang | 59c4699 | 2023-02-25 17:05:29 -0800 | [diff] [blame] | 759 | } |
| 760 | |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 761 | IREE_TRACE_ZONE_END(z0); |
Lei Zhang | 59c4699 | 2023-02-25 17:05:29 -0800 | [diff] [blame] | 762 | return status; |
| 763 | } |
| 764 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 765 | static iree_status_t iree_hal_metal_command_buffer_prepare_update_buffer( |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 766 | iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, |
| 767 | iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, |
| 768 | iree_device_size_t target_offset, iree_device_size_t length) { |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 769 | iree_hal_metal_command_buffer_t* command_buffer = |
| 770 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
| 771 | IREE_TRACE_ZONE_BEGIN(z0); |
| 772 | |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 773 | // There are no direct corresponding APIs in Metal. We update the source buffer data to the |
| 774 | // staging buffer and then copy over. |
| 775 | |
| 776 | iree_const_byte_span_t source_data_span = |
| 777 | iree_make_const_byte_span((uint8_t*)source_buffer + source_offset, length); |
| 778 | uint32_t offset = 0; |
| 779 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 780 | z0, iree_hal_metal_staging_buffer_append(command_buffer->staging_buffer, source_data_span, |
| 781 | /*alignment=*/4, &offset)); |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 782 | |
| 783 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 784 | z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &target_buffer)); |
| 785 | |
| 786 | id<MTLBuffer> target_device_buffer = |
| 787 | iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer)); |
| 788 | target_offset += iree_hal_buffer_byte_offset(target_buffer); |
| 789 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 790 | iree_status_t status = iree_hal_metal_command_segment_create_copy_buffer( |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 791 | command_buffer, command_buffer->staging_buffer->metal_buffer, offset, target_device_buffer, |
| 792 | target_offset, length); |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 793 | |
| 794 | IREE_TRACE_ZONE_END(z0); |
Lei Zhang | 59c4699 | 2023-02-25 17:05:29 -0800 | [diff] [blame] | 795 | return status; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 796 | } |
| 797 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 798 | static iree_status_t iree_hal_metal_command_buffer_prepare_copy_buffer( |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 799 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* source_buffer, |
| 800 | iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer, |
| 801 | iree_device_size_t target_offset, iree_device_size_t length) { |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 802 | iree_hal_metal_command_buffer_t* command_buffer = |
| 803 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
| 804 | IREE_TRACE_ZONE_BEGIN(z0); |
| 805 | |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 806 | const iree_hal_buffer_t* buffers[2] = {source_buffer, target_buffer}; |
| 807 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 808 | z0, iree_hal_resource_set_insert(command_buffer->resource_set, 2, buffers)); |
| 809 | |
| 810 | id<MTLBuffer> source_device_buffer = |
| 811 | iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(source_buffer)); |
| 812 | id<MTLBuffer> target_device_buffer = |
| 813 | iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer)); |
| 814 | |
| 815 | source_offset += iree_hal_buffer_byte_offset(source_buffer); |
| 816 | target_offset += iree_hal_buffer_byte_offset(target_buffer); |
| 817 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 818 | iree_status_t status = iree_hal_metal_command_segment_create_copy_buffer( |
Lei Zhang | 59c4699 | 2023-02-25 17:05:29 -0800 | [diff] [blame] | 819 | command_buffer, source_device_buffer, source_offset, target_device_buffer, target_offset, |
| 820 | length); |
Lei Zhang | 063e1fa | 2023-02-12 19:27:54 -0800 | [diff] [blame] | 821 | |
Lei Zhang | 59c4699 | 2023-02-25 17:05:29 -0800 | [diff] [blame] | 822 | IREE_TRACE_ZONE_END(z0); |
| 823 | return status; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 824 | } |
| 825 | |
| 826 | static iree_status_t iree_hal_metal_command_buffer_collective( |
| 827 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel, |
| 828 | iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_binding_t send_binding, |
| 829 | iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) { |
| 830 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "collectives not yet supported"); |
| 831 | } |
| 832 | |
| 833 | static iree_status_t iree_hal_metal_command_buffer_push_constants( |
| 834 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout, |
| 835 | iree_host_size_t offset, const void* values, iree_host_size_t values_length) { |
Lei Zhang | 30f6a39 | 2023-02-12 16:42:27 -0800 | [diff] [blame] | 836 | iree_hal_metal_command_buffer_t* command_buffer = |
| 837 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
| 838 | |
| 839 | // "Binding a pipeline with a layout that is not compatible with the push constant layout does not |
| 840 | // disturb the push constant values." So we don't need to check whether the pipeline layout |
| 841 | // compatibility and invalidate existing values. |
Lei Zhang | 30f6a39 | 2023-02-12 16:42:27 -0800 | [diff] [blame] | 842 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 843 | if (IREE_UNLIKELY(offset + values_length >= sizeof(command_buffer->state.push_constants))) { |
Lei Zhang | 30f6a39 | 2023-02-12 16:42:27 -0800 | [diff] [blame] | 844 | return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| 845 | "push constant range [%zu, %zu) out of range", offset, |
| 846 | offset + values_length); |
| 847 | } |
| 848 | |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 849 | memcpy((uint8_t*)&command_buffer->state.push_constants + offset, values, values_length); |
Lei Zhang | 30f6a39 | 2023-02-12 16:42:27 -0800 | [diff] [blame] | 850 | |
| 851 | return iree_ok_status(); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 852 | } |
| 853 | |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 854 | static inline MTLResourceUsage iree_hal_metal_get_metal_resource_usage( |
Lei Zhang | 54818a3 | 2023-06-10 16:01:14 -0700 | [diff] [blame] | 855 | const iree_hal_descriptor_set_layout_binding_t* binding) { |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 856 | MTLResourceUsage usage = MTLResourceUsageRead; |
| 857 | if (binding->flags != IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY) usage |= MTLResourceUsageWrite; |
| 858 | return usage; |
| 859 | } |
| 860 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 861 | static iree_status_t iree_hal_metal_command_buffer_push_descriptor_set( |
| 862 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_pipeline_layout_t* pipeline_layout, |
| 863 | uint32_t set, iree_host_size_t binding_count, |
| 864 | const iree_hal_descriptor_set_binding_t* bindings) { |
| 865 | iree_hal_metal_command_buffer_t* command_buffer = |
| 866 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 867 | |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 868 | if (binding_count > IREE_HAL_METAL_MAX_DESCRIPTOR_SET_BINDING_COUNT) { |
| 869 | return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, |
| 870 | "exceeded available binding slots for push descriptor set #%u; " |
| 871 | "requested %lu vs. maximal %d", |
| 872 | set, binding_count, IREE_HAL_METAL_MAX_DESCRIPTOR_SET_BINDING_COUNT); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 873 | } |
| 874 | |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 875 | IREE_TRACE_ZONE_BEGIN(z0); |
| 876 | |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 877 | IREE_ASSERT(set < IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX); |
Lei Zhang | 54818a3 | 2023-06-10 16:01:14 -0700 | [diff] [blame] | 878 | const iree_hal_descriptor_set_layout_t* set_layout = |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 879 | iree_hal_metal_pipeline_layout_descriptor_set_layout(pipeline_layout, set); |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 880 | iree_hal_metal_descriptor_t* descriptors = command_buffer->state.descriptor_sets[set].bindings; |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 881 | |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 882 | // Update descriptors in the current set. |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 883 | for (iree_host_size_t i = 0; i < binding_count; ++i) { |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 884 | iree_hal_metal_descriptor_t* descriptor = &descriptors[i]; |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 885 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 886 | descriptor->set = set; |
| 887 | descriptor->binding = bindings[i].binding; |
| 888 | descriptor->buffer = bindings[i].buffer; |
| 889 | descriptor->offset = bindings[i].offset; |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 890 | |
Lei Zhang | 54818a3 | 2023-06-10 16:01:14 -0700 | [diff] [blame] | 891 | const iree_hal_descriptor_set_layout_binding_t* binding_params = |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 892 | iree_hal_metal_descriptor_set_layout_binding(set_layout, descriptor->binding); |
| 893 | descriptor->usage = iree_hal_metal_get_metal_resource_usage(binding_params); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 894 | } |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 895 | |
| 896 | // Retain all buffers bound in this descriptor set. |
| 897 | for (iree_host_size_t i = 0; i < binding_count; ++i) { |
| 898 | if (bindings[i].buffer) { |
| 899 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 900 | z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &bindings[i].buffer)); |
| 901 | } |
| 902 | } |
| 903 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 904 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 905 | z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &pipeline_layout)); |
| 906 | |
| 907 | IREE_TRACE_ZONE_END(z0); |
| 908 | return iree_ok_status(); |
| 909 | } |
| 910 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 911 | // Prepares kernels and argument buffers needed for kernel dispatches. |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 912 | static iree_status_t iree_hal_metal_command_segment_create_dispatch( |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 913 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 914 | int32_t entry_point, iree_hal_metal_dispatch_segment_t** out_segment) { |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 915 | iree_hal_metal_command_buffer_t* command_buffer = |
| 916 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
| 917 | IREE_TRACE_ZONE_BEGIN(z0); |
| 918 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 919 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 920 | z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &executable)); |
| 921 | |
| 922 | iree_hal_metal_kernel_params_t kernel_params; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 923 | IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_kernel_library_entry_point_kernel_params( |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 924 | executable, entry_point, &kernel_params)); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 925 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 926 | // Allocate the command segment and keep track of all necessary API data. |
| 927 | uint8_t* storage_base = NULL; |
| 928 | iree_hal_metal_command_segment_t* segment = NULL; |
Lei Zhang | 3097b3a | 2023-06-11 13:32:44 -0700 | [diff] [blame] | 929 | const iree_host_size_t set_count = |
| 930 | iree_hal_metal_pipeline_layout_descriptor_set_count(kernel_params.layout); |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 931 | iree_host_size_t descriptor_count = 0; |
| 932 | // Calculate the total number of bindings across all descriptor sets. |
Lei Zhang | 3097b3a | 2023-06-11 13:32:44 -0700 | [diff] [blame] | 933 | for (iree_host_size_t i = 0; i < set_count; ++i) { |
| 934 | const iree_hal_descriptor_set_layout_t* set_layout = |
| 935 | iree_hal_metal_pipeline_layout_descriptor_set_layout(kernel_params.layout, i); |
| 936 | descriptor_count += iree_hal_metal_descriptor_set_layout_binding_count(set_layout); |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 937 | } |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 938 | iree_host_size_t descriptor_length = descriptor_count * sizeof(iree_hal_metal_descriptor_t); |
Lei Zhang | d7fb981 | 2023-06-11 09:28:07 -0700 | [diff] [blame] | 939 | iree_host_size_t push_constant_count = |
| 940 | iree_hal_metal_pipeline_layout_push_constant_count(kernel_params.layout); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 941 | iree_host_size_t push_constant_length = push_constant_count * sizeof(int32_t); |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 942 | iree_host_size_t total_size = sizeof(*segment) + descriptor_length + push_constant_length; |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 943 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 944 | z0, iree_arena_allocate(&command_buffer->arena, total_size, (void**)&storage_base)); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 945 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 946 | // Compose and push the dispatch segment. |
| 947 | segment = (iree_hal_metal_command_segment_t*)storage_base; |
| 948 | memset(segment, 0, sizeof(*segment)); |
| 949 | segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH; |
| 950 | iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment); |
| 951 | |
| 952 | segment->dispatch.kernel_params = kernel_params; |
| 953 | |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 954 | // Copy descriptors from all sets to the end of the current segment for later access. |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 955 | segment->dispatch.descriptor_count = descriptor_count; |
| 956 | uint8_t* descriptor_ptr = storage_base + sizeof(*segment); |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 957 | segment->dispatch.descriptors = (iree_hal_metal_descriptor_t*)descriptor_ptr; |
Lei Zhang | 3097b3a | 2023-06-11 13:32:44 -0700 | [diff] [blame] | 958 | for (iree_host_size_t i = 0; i < set_count; ++i) { |
| 959 | const iree_hal_descriptor_set_layout_t* set_layout = |
| 960 | iree_hal_metal_pipeline_layout_descriptor_set_layout(kernel_params.layout, i); |
| 961 | iree_host_size_t binding_count = iree_hal_metal_descriptor_set_layout_binding_count(set_layout); |
| 962 | iree_host_size_t current_size = binding_count * sizeof(iree_hal_metal_descriptor_t); |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 963 | memcpy(descriptor_ptr, command_buffer->state.descriptor_sets[i].bindings, current_size); |
| 964 | descriptor_ptr += current_size; |
| 965 | } |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 966 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 967 | // Copy push constants to the end of the current segment for later access. |
| 968 | segment->dispatch.push_constant_count = push_constant_count; |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 969 | uint8_t* push_constant_ptr = storage_base + sizeof(*segment) + descriptor_length; |
Lei Zhang | d90e80f | 2023-05-13 21:32:29 -0700 | [diff] [blame] | 970 | segment->dispatch.push_constants = (int32_t*)push_constant_ptr; |
Lei Zhang | d4aef98 | 2023-05-08 13:48:42 -0700 | [diff] [blame] | 971 | memcpy(push_constant_ptr, (const uint8_t*)command_buffer->state.push_constants, |
| 972 | push_constant_length); |
Lei Zhang | 30f6a39 | 2023-02-12 16:42:27 -0800 | [diff] [blame] | 973 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 974 | *out_segment = &segment->dispatch; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 975 | IREE_TRACE_ZONE_END(z0); |
Lei Zhang | 4307ba2 | 2023-05-07 10:40:23 -0700 | [diff] [blame] | 976 | return iree_ok_status(); |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 977 | } |
| 978 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 979 | static iree_status_t iree_hal_metal_command_segment_record_dispatch( |
| 980 | iree_hal_metal_command_buffer_t* command_buffer, iree_hal_metal_dispatch_segment_t* segment) { |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 981 | IREE_TRACE_ZONE_BEGIN(z0); |
| 982 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 983 | // Set the compute kernel to dispatch. |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 984 | id<MTLComputeCommandEncoder> compute_encoder = |
| 985 | iree_hal_metal_get_or_begin_compute_encoder(command_buffer); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 986 | [compute_encoder setComputePipelineState:segment->kernel_params.pso]; |
| 987 | |
| 988 | // Record push constants. |
| 989 | if (segment->push_constant_count != 0) { |
| 990 | [compute_encoder setBytes:(void*)segment->push_constants |
| 991 | length:segment->push_constant_count * sizeof(int32_t) |
| 992 | atIndex:IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX]; |
| 993 | } |
| 994 | |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 995 | // Record argument buffers for all descriptors and record buffer usages. |
| 996 | iree_hal_metal_descriptor_t* descriptors = segment->descriptors; |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 997 | for (iree_host_size_t i = 0; i < segment->descriptor_count;) { |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 998 | uint32_t current_set = descriptors[i].set; |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 999 | |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 1000 | // Build argument encoder and argument buffer for the current descriptor set. |
Lei Zhang | e32cfa6 | 2023-05-08 13:40:06 -0700 | [diff] [blame] | 1001 | // TODO(antiagainst): Use a cache layer to cache and reuse argument buffers with the same |
| 1002 | // content, to avoid duplicating overhead. |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 1003 | id<MTLBuffer> argument_buffer = command_buffer->staging_buffer->metal_buffer; |
| 1004 | id<MTLArgumentEncoder> argument_encoder = |
| 1005 | [segment->kernel_params.function newArgumentEncoderWithBufferIndex:current_set]; // +1 |
| 1006 | IREE_ASSERT(argument_encoder != nil); |
| 1007 | |
| 1008 | // Reserve space for the argument buffer from shared staging buffer. |
| 1009 | iree_byte_span_t reservation; |
| 1010 | uint32_t argument_buffer_offset; |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 1011 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 1012 | z0, iree_hal_metal_staging_buffer_reserve( |
| 1013 | command_buffer->staging_buffer, argument_encoder.encodedLength, |
| 1014 | argument_encoder.alignment, &reservation, &argument_buffer_offset)); |
| 1015 | [argument_encoder setArgumentBuffer:argument_buffer offset:argument_buffer_offset]; |
Lei Zhang | 0ec791e | 2023-05-07 22:15:34 -0700 | [diff] [blame] | 1016 | |
| 1017 | // Now record all bound buffers belonging to the current set into the argument buffer. |
| 1018 | for (; i < segment->descriptor_count && descriptors[i].set == current_set; ++i) { |
| 1019 | uint32_t current_binding = descriptors[i].binding; |
| 1020 | id<MTLBuffer> current_buffer = |
| 1021 | iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer)); |
| 1022 | iree_host_size_t offset = |
| 1023 | iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset; |
| 1024 | [argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding]; |
| 1025 | |
| 1026 | // Also record buffer usages. |
| 1027 | [compute_encoder useResource:current_buffer usage:descriptors[i].usage]; |
| 1028 | } |
| 1029 | // Record the argument buffer. |
Lei Zhang | f598fd2 | 2023-05-08 07:48:57 -0700 | [diff] [blame] | 1030 | [compute_encoder setBuffer:argument_buffer offset:argument_buffer_offset atIndex:current_set]; |
| 1031 | |
| 1032 | [argument_encoder release]; // -1 |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 1033 | } |
| 1034 | |
| 1035 | // Record the dispatch, either direct or indirect. |
| 1036 | uint32_t* workgroup_size = segment->kernel_params.threadgroup_size; |
| 1037 | if (segment->workgroups_buffer == nil) { |
| 1038 | // Direct dispatch of a fixed workgroup count. |
| 1039 | uint32_t* workgroup_count = segment->workgroup_count; |
| 1040 | [compute_encoder |
| 1041 | dispatchThreadgroups:MTLSizeMake(workgroup_count[0], workgroup_count[1], |
| 1042 | workgroup_count[2]) |
| 1043 | threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1], workgroup_size[2])]; |
| 1044 | } else { |
| 1045 | // Indirect dispatch using a workgroup count from buffers. |
| 1046 | [compute_encoder |
| 1047 | dispatchThreadgroupsWithIndirectBuffer:segment->workgroups_buffer |
| 1048 | indirectBufferOffset:segment->workgroups_offset |
| 1049 | threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1], |
| 1050 | workgroup_size[2])]; |
| 1051 | } |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1052 | |
| 1053 | IREE_TRACE_ZONE_END(z0); |
| 1054 | return iree_ok_status(); |
| 1055 | } |
| 1056 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 1057 | static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch( |
| 1058 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, |
| 1059 | int32_t entry_point, uint32_t workgroup_count_x, uint32_t workgroup_count_y, |
| 1060 | uint32_t workgroup_count_z) { |
| 1061 | IREE_TRACE_ZONE_BEGIN(z0); |
| 1062 | |
| 1063 | iree_hal_metal_dispatch_segment_t* segment = NULL; |
| 1064 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 1065 | z0, iree_hal_metal_command_segment_create_dispatch(base_command_buffer, executable, |
| 1066 | entry_point, &segment)); |
| 1067 | segment->workgroup_count[0] = workgroup_count_x; |
| 1068 | segment->workgroup_count[1] = workgroup_count_y; |
| 1069 | segment->workgroup_count[2] = workgroup_count_z; |
| 1070 | |
| 1071 | IREE_TRACE_ZONE_END(z0); |
| 1072 | return iree_ok_status(); |
| 1073 | } |
| 1074 | |
| 1075 | static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch_indirect( |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1076 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, |
| 1077 | int32_t entry_point, iree_hal_buffer_t* workgroups_buffer, |
| 1078 | iree_device_size_t workgroups_offset) { |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1079 | IREE_TRACE_ZONE_BEGIN(z0); |
| 1080 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 1081 | iree_hal_metal_dispatch_segment_t* segment = NULL; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1082 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 1083 | z0, iree_hal_metal_command_segment_create_dispatch(base_command_buffer, executable, |
| 1084 | entry_point, &segment)); |
| 1085 | segment->workgroups_buffer = |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1086 | iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(workgroups_buffer)); |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 1087 | segment->workgroups_offset = workgroups_offset; |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1088 | |
| 1089 | IREE_TRACE_ZONE_END(z0); |
| 1090 | return iree_ok_status(); |
| 1091 | } |
| 1092 | |
| 1093 | static iree_status_t iree_hal_metal_command_buffer_execute_commands( |
| 1094 | iree_hal_command_buffer_t* base_command_buffer, iree_hal_command_buffer_t* base_commands, |
| 1095 | iree_hal_buffer_binding_table_t binding_table) { |
| 1096 | return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "secondary command buffer not yet supported"); |
| 1097 | } |
| 1098 | |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 1099 | static iree_status_t iree_hal_metal_command_segment_record( |
| 1100 | iree_hal_metal_command_buffer_t* command_buffer) { |
| 1101 | IREE_ASSERT_ARGUMENT(command_buffer); |
| 1102 | IREE_TRACE_ZONE_BEGIN(z0); |
| 1103 | |
| 1104 | for (iree_hal_metal_command_segment_t* segment = command_buffer->segments.head; segment; |
| 1105 | segment = segment->next_segment) { |
| 1106 | switch (segment->action) { |
| 1107 | case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER: { |
| 1108 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 1109 | z0, iree_hal_metal_command_segment_record_barrier(command_buffer, &segment->barrier)); |
| 1110 | } break; |
| 1111 | case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH: { |
| 1112 | IREE_RETURN_AND_END_ZONE_IF_ERROR( |
| 1113 | z0, iree_hal_metal_command_segment_record_dispatch(command_buffer, &segment->dispatch)); |
| 1114 | } break; |
| 1115 | case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER: { |
| 1116 | IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_fill_buffer( |
| 1117 | command_buffer, &segment->fill_buffer)); |
| 1118 | } break; |
| 1119 | case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER: { |
| 1120 | IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_copy_buffer( |
| 1121 | command_buffer, &segment->copy_buffer)); |
| 1122 | } break; |
| 1123 | default: |
| 1124 | IREE_ASSERT(false, "unhandled command segment kind"); |
| 1125 | break; |
| 1126 | } |
| 1127 | } |
| 1128 | |
| 1129 | IREE_TRACE_ZONE_END(z0); |
| 1130 | return iree_ok_status(); |
| 1131 | } |
| 1132 | |
| 1133 | static iree_status_t iree_hal_metal_command_buffer_begin( |
| 1134 | iree_hal_command_buffer_t* base_command_buffer) { |
| 1135 | iree_hal_metal_command_buffer_t* command_buffer = |
| 1136 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
| 1137 | iree_hal_metal_command_buffer_reset(command_buffer); |
| 1138 | return iree_ok_status(); |
| 1139 | } |
| 1140 | |
| 1141 | static iree_status_t iree_hal_metal_command_buffer_end( |
| 1142 | iree_hal_command_buffer_t* base_command_buffer) { |
| 1143 | iree_hal_metal_command_buffer_t* command_buffer = |
| 1144 | iree_hal_metal_command_buffer_cast(base_command_buffer); |
| 1145 | IREE_TRACE_ZONE_BEGIN(z0); |
| 1146 | |
| 1147 | IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record(command_buffer)); |
| 1148 | iree_hal_metal_end_blit_encoder(command_buffer); |
| 1149 | iree_hal_metal_end_compute_encoder(command_buffer); |
| 1150 | |
| 1151 | IREE_TRACE_ZONE_END(z0); |
| 1152 | return iree_ok_status(); |
| 1153 | } |
| 1154 | |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1155 | static const iree_hal_command_buffer_vtable_t iree_hal_metal_command_buffer_vtable = { |
| 1156 | .destroy = iree_hal_metal_command_buffer_destroy, |
| 1157 | .begin = iree_hal_metal_command_buffer_begin, |
| 1158 | .end = iree_hal_metal_command_buffer_end, |
| 1159 | .begin_debug_group = iree_hal_metal_command_buffer_begin_debug_group, |
| 1160 | .end_debug_group = iree_hal_metal_command_buffer_end_debug_group, |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 1161 | .execution_barrier = iree_hal_metal_command_buffer_prepare_barrier, |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1162 | .signal_event = iree_hal_metal_command_buffer_signal_event, |
| 1163 | .reset_event = iree_hal_metal_command_buffer_reset_event, |
| 1164 | .wait_events = iree_hal_metal_command_buffer_wait_events, |
| 1165 | .discard_buffer = iree_hal_metal_command_buffer_discard_buffer, |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 1166 | .fill_buffer = iree_hal_metal_command_buffer_prepare_fill_buffer, |
| 1167 | .update_buffer = iree_hal_metal_command_buffer_prepare_update_buffer, |
| 1168 | .copy_buffer = iree_hal_metal_command_buffer_prepare_copy_buffer, |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1169 | .collective = iree_hal_metal_command_buffer_collective, |
| 1170 | .push_constants = iree_hal_metal_command_buffer_push_constants, |
| 1171 | .push_descriptor_set = iree_hal_metal_command_buffer_push_descriptor_set, |
Lei Zhang | c0ad0ea | 2023-05-06 18:02:02 -0700 | [diff] [blame] | 1172 | .dispatch = iree_hal_metal_command_buffer_prepare_dispatch, |
| 1173 | .dispatch_indirect = iree_hal_metal_command_buffer_prepare_dispatch_indirect, |
Lei Zhang | df1e9a2 | 2023-02-12 12:08:00 -0800 | [diff] [blame] | 1174 | .execute_commands = iree_hal_metal_command_buffer_execute_commands, |
| 1175 | }; |