| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #import <Metal/Metal.h> |
| |
| #include "experimental/metal/api.h" |
| #include "experimental/metal/metal_device.h" |
| #include "iree/base/api.h" |
| #include "iree/base/target_platform.h" |
| #include "iree/base/tracing.h" |
| #include "iree/hal/api.h" |
| |
| // Maximum device path length we support. The path is always a 16 character hex string. |
| #define IREE_HAL_METAL_MAX_DEVICE_PATH_LENGTH 32 |
| // Maximum device name length we support. Example names: "Apple M1 Pro". |
| #define IREE_HAL_METAL_MAX_DEVICE_NAME_LENGTH 64 |
| |
| // Cast utilities between Metal id<MTLDevice> and IREE opaque iree_hal_device_id_t. |
| #define METAL_DEVICE_TO_DEVICE_ID(device) (iree_hal_device_id_t)((__bridge void*)device) |
| #define DEVICE_ID_TO_METAL_DEVICE(device_id) (__bridge id<MTLDevice>)(device_id) |
| |
| typedef struct iree_hal_metal_driver_t { |
| // Abstract resource used for injecting reference counting and vtable; must be at offset 0. |
| iree_hal_resource_t resource; |
| |
| iree_allocator_t host_allocator; |
| |
| // Identifier used for the driver in the IREE driver registry. We allow overriding so that |
| // multiple Metal versions can be exposed in the same process. |
| iree_string_view_t identifier; |
| |
| // Parameters used to control device behavior. |
| iree_hal_metal_device_params_t device_params; |
| |
| // The list of GPUs available when creating the driver. We retain them here to make sure |
| // id<MTLDevice>, which is used for creating devices and such, remains valid. |
| NSArray<id<MTLDevice>>* devices; |
| } iree_hal_metal_driver_t; |
| |
| static const iree_hal_driver_vtable_t iree_hal_metal_driver_vtable; |
| |
| static iree_hal_metal_driver_t* iree_hal_metal_driver_cast(iree_hal_driver_t* base_value) { |
| IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_driver_vtable); |
| return (iree_hal_metal_driver_t*)base_value; |
| } |
| |
| static const iree_hal_metal_driver_t* iree_hal_metal_driver_const_cast( |
| const iree_hal_driver_t* base_value) { |
| IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_driver_vtable); |
| return (const iree_hal_metal_driver_t*)base_value; |
| } |
| |
| // Returns an retained array of available Metal GPU devices; the caller should release later. |
| static NSArray<id<MTLDevice>>* iree_hal_metal_device_copy() { |
| #if defined(IREE_PLATFORM_MACOS) |
| // For macOS, we might have more then one GPU devices. |
| return MTLCopyAllDevices(); // +1 |
| #else |
| // For other Apple platforms, we only have one GPU device. |
| @autoreleasepool { // Use @autorelasepool to trigger the autorelease carried in NSArray literal. |
| return [@[ MTLCreateSystemDefaultDevice() ] retain]; // +1 |
| } |
| #endif // IREE_PLATFORM_MACOS |
| } |
| |
| static iree_status_t iree_hal_metal_device_check_params( |
| const iree_hal_metal_device_params_t* params) { |
| if (params->arena_block_size < 4096) { |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "arena block size too small (< 4096 bytes)"); |
| } |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_hal_metal_driver_create_internal( |
| iree_string_view_t identifier, const iree_hal_metal_device_params_t* device_params, |
| iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { |
| iree_hal_metal_driver_t* driver = NULL; |
| iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size; |
| IREE_RETURN_IF_ERROR(iree_allocator_malloc(host_allocator, total_size, (void**)&driver)); |
| |
| iree_hal_resource_initialize(&iree_hal_metal_driver_vtable, &driver->resource); |
| driver->host_allocator = host_allocator; |
| iree_string_view_append_to_buffer(identifier, &driver->identifier, |
| (char*)driver + iree_sizeof_struct(*driver)); |
| driver->device_params = *device_params; |
| |
| // Get all available Metal devices. |
| driver->devices = iree_hal_metal_device_copy(); |
| |
| *out_driver = (iree_hal_driver_t*)driver; |
| return iree_ok_status(); |
| } |
| |
| IREE_API_EXPORT iree_status_t iree_hal_metal_driver_create( |
| iree_string_view_t identifier, const iree_hal_metal_device_params_t* device_params, |
| iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { |
| IREE_ASSERT_ARGUMENT(out_driver); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_device_check_params(device_params)); |
| iree_status_t status = |
| iree_hal_metal_driver_create_internal(identifier, device_params, host_allocator, out_driver); |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static void iree_hal_metal_driver_destroy(iree_hal_driver_t* base_driver) { |
| iree_hal_metal_driver_t* driver = iree_hal_metal_driver_cast(base_driver); |
| iree_allocator_t host_allocator = driver->host_allocator; |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| [driver->devices release]; // -1 |
| iree_allocator_free(host_allocator, driver); |
| |
| IREE_TRACE_ZONE_END(z0); |
| } |
| |
| // Populates device information from the given Metal physical device handle. |out_device_info| must |
| // point to valid memory and additional data will be appended to |buffer_ptr| and the new pointer is |
| // returned. |
| static iree_status_t iree_hal_metal_populate_device_info(id<MTLDevice> device, uint8_t* buffer_ptr, |
| uint8_t** out_buffer_ptr, |
| iree_hal_device_info_t* out_device_info) { |
| *out_buffer_ptr = buffer_ptr; |
| |
| memset(out_device_info, 0, sizeof(*out_device_info)); |
| |
| out_device_info->device_id = METAL_DEVICE_TO_DEVICE_ID(device); |
| |
| // For Metal devices, we don't have a 128-bit UUID; so just use the 64-bit registry ID here. |
| char device_path[16 + 1] = {0}; |
| snprintf(device_path, sizeof(device_path), "%016" PRIx64, device.registryID); |
| buffer_ptr += iree_string_view_append_to_buffer( |
| iree_make_string_view(device_path, IREE_ARRAYSIZE(device_path) - 1), &out_device_info->path, |
| (char*)buffer_ptr); |
| |
| const char* device_name = [device.name UTF8String]; |
| const size_t name_len = strlen(device_name); |
| if (name_len >= IREE_HAL_METAL_MAX_DEVICE_NAME_LENGTH) { |
| return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "device name out of range"); |
| } |
| buffer_ptr += iree_string_view_append_to_buffer(iree_make_string_view(device_name, name_len), |
| &out_device_info->name, (char*)buffer_ptr); |
| |
| *out_buffer_ptr = buffer_ptr; |
| |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_hal_metal_driver_query_available_devices( |
| iree_hal_driver_t* base_driver, iree_allocator_t host_allocator, |
| iree_host_size_t* out_device_info_count, iree_hal_device_info_t** out_device_infos) { |
| iree_hal_metal_driver_t* driver = iree_hal_metal_driver_cast(base_driver); |
| NSArray<id<MTLDevice>>* devices = driver->devices; |
| unsigned device_count = devices.count; |
| |
| // Allocate the return infos and populate with the devices. |
| iree_hal_device_info_t* device_infos = NULL; |
| iree_host_size_t single_info_size = |
| sizeof(iree_hal_device_info_t) + |
| (IREE_HAL_METAL_MAX_DEVICE_PATH_LENGTH + IREE_HAL_METAL_MAX_DEVICE_NAME_LENGTH) * |
| sizeof(char); |
| iree_host_size_t total_size = device_count * single_info_size; |
| iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos); |
| |
| if (iree_status_is_ok(status)) { |
| // Append all path and name strings at the end of the struct. |
| uint8_t* buffer_ptr = (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t); |
| for (iree_host_size_t i = 0; i < device_count; ++i) { |
| status = iree_hal_metal_populate_device_info(devices[i], buffer_ptr, &buffer_ptr, |
| &device_infos[i]); |
| if (!iree_status_is_ok(status)) break; |
| } |
| } |
| if (iree_status_is_ok(status)) { |
| *out_device_info_count = device_count; |
| *out_device_infos = device_infos; |
| } else { |
| iree_allocator_free(host_allocator, device_infos); |
| } |
| return status; |
| } |
| |
| // Returns the GPU family the given |device| supports. Returns 0 if the given |device| does not |
| // belong to a GPU family considered by IREE right now. |
| static MTLGPUFamily iree_hal_metal_apple_gpu_family_query(id<MTLDevice> device) { |
| // Inspect whether the given device is a specific Apple GPU. |
| if ([device supportsFamily:MTLGPUFamilyApple8]) return MTLGPUFamilyApple8; |
| if ([device supportsFamily:MTLGPUFamilyApple7]) return MTLGPUFamilyApple7; |
| if ([device supportsFamily:MTLGPUFamilyApple6]) return MTLGPUFamilyApple6; |
| if ([device supportsFamily:MTLGPUFamilyApple5]) return MTLGPUFamilyApple5; |
| if ([device supportsFamily:MTLGPUFamilyApple4]) return MTLGPUFamilyApple4; |
| if ([device supportsFamily:MTLGPUFamilyApple3]) return MTLGPUFamilyApple3; |
| if ([device supportsFamily:MTLGPUFamilyApple2]) return MTLGPUFamilyApple2; |
| if ([device supportsFamily:MTLGPUFamilyApple1]) return MTLGPUFamilyApple1; |
| |
| // Inspect whether whether the given GPU falls into some common family. |
| if ([device supportsFamily:MTLGPUFamilyCommon3]) return MTLGPUFamilyCommon3; |
| if ([device supportsFamily:MTLGPUFamilyCommon2]) return MTLGPUFamilyCommon2; |
| if ([device supportsFamily:MTLGPUFamilyCommon1]) return MTLGPUFamilyCommon1; |
| |
| return 0; |
| } |
| |
| static const char* iree_hal_metal_get_gpu_family_name(MTLGPUFamily family) { |
| switch (family) { |
| case MTLGPUFamilyApple8: |
| return "apple8(a15/m2)"; |
| case MTLGPUFamilyApple7: |
| return "apple7(a14/m1)"; |
| case MTLGPUFamilyApple6: |
| return "apple6(a13)"; |
| case MTLGPUFamilyApple5: |
| return "apple5(a12)"; |
| case MTLGPUFamilyApple4: |
| return "apple4(a11)"; |
| case MTLGPUFamilyApple3: |
| return "apple3(a9/a10)"; |
| case MTLGPUFamilyApple2: |
| return "apple2(a8)"; |
| case MTLGPUFamilyApple1: |
| return "apple1(a7)"; |
| |
| case MTLGPUFamilyCommon3: |
| return "common3"; |
| case MTLGPUFamilyCommon2: |
| return "common2"; |
| case MTLGPUFamilyCommon1: |
| return "common1"; |
| |
| default: |
| return ""; |
| } |
| } |
| |
| static inline const char* iree_hal_metal_get_argument_buffer_tier_str(MTLGPUFamily family) { |
| if (family >= MTLGPUFamilyApple6 && family <= MTLGPUFamilyApple8) return "2"; |
| if (family >= MTLGPUFamilyApple2 && family <= MTLGPUFamilyApple5) return "1"; |
| return "unknown"; |
| } |
| |
| static inline bool iree_hal_metal_support_simd_matrix_multiply(MTLGPUFamily family) { |
| return family >= MTLGPUFamilyApple7 && family <= MTLGPUFamilyApple8; |
| } |
| |
| static inline bool iree_hal_metal_support_simd_reduce(MTLGPUFamily family) { |
| return family >= MTLGPUFamilyApple7 && family <= MTLGPUFamilyApple8; |
| } |
| |
| static inline bool iree_hal_metal_support_simd_permute(MTLGPUFamily family) { |
| return family >= MTLGPUFamilyApple6 && family <= MTLGPUFamilyApple8; |
| } |
| |
| static inline bool iree_hal_metal_support_simd_shift_and_fill(MTLGPUFamily family) { |
| return family >= MTLGPUFamilyApple8 && family <= MTLGPUFamilyApple8; |
| } |
| |
| static iree_status_t iree_hal_metal_driver_dump_device_info(iree_hal_driver_t* base_driver, |
| iree_hal_device_id_t device_id, |
| iree_string_builder_t* builder) { |
| id<MTLDevice> device = DEVICE_ID_TO_METAL_DEVICE(device_id); |
| MTLGPUFamily apple_gpu_family = iree_hal_metal_apple_gpu_family_query(device); |
| |
| // Dump GPU family information. |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "- gpu-family:")); |
| const char* apple_family_str = iree_hal_metal_get_gpu_family_name(apple_gpu_family); |
| if (apple_family_str) { |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, " ")); |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, apple_family_str)); |
| } |
| if ([device supportsFamily:MTLGPUFamilyMetal3]) { |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, " metal3")); |
| } |
| |
| // Dump memory information. |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_format(builder, "\n- unified-memory: %d", |
| device.hasUnifiedMemory)); |
| |
| // Dump argument buffer tier. |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n- argument-buffer-tier: ")); |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring( |
| builder, iree_hal_metal_get_argument_buffer_tier_str(apple_gpu_family))); |
| |
| // Dump resource limits. |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n- max-buffer-size: ")); |
| { |
| uint32_t max_buffer_mb = device.maxBufferLength / 1024u / 1024u; |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_format(builder, "%uMB", max_buffer_mb)); |
| } |
| IREE_RETURN_IF_ERROR( |
| iree_string_builder_append_cstring(builder, "\n- max-threadgroup-memory-size: ")); |
| { |
| uint32_t max_memory_kb = device.maxThreadgroupMemoryLength / 1024u; |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_format(builder, "%uKB", max_memory_kb)); |
| } |
| IREE_RETURN_IF_ERROR( |
| iree_string_builder_append_cstring(builder, "\n- max-threads-per-threadgroup: ")); |
| { |
| MTLSize threads = device.maxThreadsPerThreadgroup; |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_format( |
| builder, "(%lu, %lu, %lu)", threads.width, threads.height, threads.depth)); |
| } |
| |
| // Dump SIMD-scoped operation features. |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n- simd-scoped-operation:")); |
| if (iree_hal_metal_support_simd_matrix_multiply(apple_gpu_family)) { |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, " matmul")); |
| } |
| if (iree_hal_metal_support_simd_reduce(apple_gpu_family)) { |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, " reduce")); |
| } |
| if (iree_hal_metal_support_simd_permute(apple_gpu_family)) { |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, " permute")); |
| } |
| if (iree_hal_metal_support_simd_shift_and_fill(apple_gpu_family)) { |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, " shift-and-fill")); |
| } |
| IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(builder, "\n")); |
| |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_hal_metal_driver_find_device_by_index(iree_hal_driver_t* base_driver, |
| uint32_t device_index, |
| iree_allocator_t host_allocator, |
| id<MTLDevice>* found_device) { |
| iree_hal_metal_driver_t* driver = iree_hal_metal_driver_cast(base_driver); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (uint64_t)device_index); |
| |
| NSArray<id<MTLDevice>>* devices = driver->devices; |
| if (device_index >= devices.count) { |
| IREE_TRACE_ZONE_END(z0); |
| return iree_make_status(IREE_STATUS_NOT_FOUND, "%d devices enumerated; device #%d not found", |
| (int)devices.count, device_index); |
| } |
| *found_device = devices[device_index]; |
| |
| IREE_TRACE_ZONE_END(z0); |
| return iree_ok_status(); |
| } |
| |
| static iree_status_t iree_hal_metal_driver_create_device_by_id(iree_hal_driver_t* base_driver, |
| iree_hal_device_id_t device_id, |
| iree_host_size_t param_count, |
| const iree_string_pair_t* params, |
| iree_allocator_t host_allocator, |
| iree_hal_device_t** out_device) { |
| iree_hal_metal_driver_t* driver = iree_hal_metal_driver_cast(base_driver); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| id<MTLDevice> device = nil; |
| if (device_id == IREE_HAL_DEVICE_ID_DEFAULT) { |
| // Default to the first Metal device in the list. |
| IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_driver_find_device_by_index( |
| base_driver, device_id, host_allocator, &device)); |
| } else { |
| device = DEVICE_ID_TO_METAL_DEVICE(device_id); |
| } |
| |
| iree_string_view_t device_name = iree_make_cstring_view("metal"); |
| |
| iree_status_t status = iree_hal_metal_device_create(device_name, &driver->device_params, device, |
| host_allocator, out_device); |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static iree_status_t iree_hal_metal_driver_create_device_by_registry_id( |
| iree_hal_driver_t* base_driver, iree_string_view_t driver_name, uint64_t device_registry_id, |
| iree_host_size_t param_count, const iree_string_pair_t* params, iree_allocator_t host_allocator, |
| iree_hal_device_t** out_device) { |
| iree_hal_metal_driver_t* driver = iree_hal_metal_driver_cast(base_driver); |
| IREE_TRACE_ZONE_BEGIN(z0); |
| |
| // Scan the devices and find the one with the matching registry ID. |
| NSArray<id<MTLDevice>>* devices = driver->devices; |
| id<MTLDevice> found_device = nil; |
| for (iree_host_size_t i = 0, e = devices.count; i < e; ++i) { |
| if (device_registry_id == devices[i].registryID) { |
| found_device = devices[i]; |
| break; |
| } |
| } |
| |
| if (!found_device) { |
| IREE_TRACE_ZONE_END(z0); |
| return iree_make_status(IREE_STATUS_NOT_FOUND, |
| "Metal device with device registry ID %016" PRIx64 " not found", |
| device_registry_id); |
| } |
| |
| iree_status_t status = iree_hal_metal_driver_create_device_by_id( |
| base_driver, METAL_DEVICE_TO_DEVICE_ID(found_device), param_count, params, host_allocator, |
| out_device); |
| |
| IREE_TRACE_ZONE_END(z0); |
| return status; |
| } |
| |
| static iree_status_t iree_hal_metal_driver_create_device_by_path( |
| iree_hal_driver_t* base_driver, iree_string_view_t driver_name, iree_string_view_t device_path, |
| iree_host_size_t param_count, const iree_string_pair_t* params, iree_allocator_t host_allocator, |
| iree_hal_device_t** out_device) { |
| if (iree_string_view_is_empty(device_path)) { |
| return iree_hal_metal_driver_create_device_by_id( |
| base_driver, IREE_HAL_DEVICE_ID_DEFAULT, param_count, params, host_allocator, out_device); |
| } |
| |
| // Try parsing as a device ID. |
| uint8_t device_registry_id[8] = {0}; |
| if (iree_string_view_parse_hex_bytes(device_path, IREE_ARRAYSIZE(device_registry_id), |
| device_registry_id)) { |
| return iree_hal_metal_driver_create_device_by_registry_id( |
| base_driver, driver_name, *(uint64_t*)device_registry_id, param_count, params, |
| host_allocator, out_device); |
| } |
| |
| // Fallback and try to parse as a device index. |
| uint32_t device_index = 0; |
| if (iree_string_view_atoi_uint32(device_path, &device_index)) { |
| id<MTLDevice> found_device; |
| IREE_RETURN_IF_ERROR(iree_hal_metal_driver_find_device_by_index(base_driver, device_index, |
| host_allocator, &found_device)); |
| return iree_hal_metal_driver_create_device_by_id( |
| base_driver, METAL_DEVICE_TO_DEVICE_ID(found_device), param_count, params, host_allocator, |
| out_device); |
| } |
| |
| return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unsupported device path"); |
| } |
| |
| static const iree_hal_driver_vtable_t iree_hal_metal_driver_vtable = { |
| .destroy = iree_hal_metal_driver_destroy, |
| .query_available_devices = iree_hal_metal_driver_query_available_devices, |
| .dump_device_info = iree_hal_metal_driver_dump_device_info, |
| .create_device_by_id = iree_hal_metal_driver_create_device_by_id, |
| .create_device_by_path = iree_hal_metal_driver_create_device_by_path, |
| }; |