[metal] Add device parameters to driver/device creation APIs This commit defines `iree_hal_metal_device_params_t` for controlling major Metal device behavior. Right now we expose arena block size and command dispatch type.
diff --git a/experimental/metal/api.h b/experimental/metal/api.h index a156b5f..a27eb88 100644 --- a/experimental/metal/api.h +++ b/experimental/metal/api.h
@@ -17,6 +17,38 @@ #endif // __cplusplus //===----------------------------------------------------------------------===// +// iree_hal_metal_device_params_t +//===----------------------------------------------------------------------===// + +typedef enum iree_hal_metal_command_dispatch_type_e { + // Dispatch commands in command buffer in parallel. + IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT = 0, + // Dispatch commands in command buffer sequentially. + IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_SERIAL = 1, +} iree_hal_metal_command_dispatch_type_t; + +// Parameters configuring an iree_hal_metal_device_t. +// Must be initialized with iree_hal_metal_device_params_initialize prior to +// use. +typedef struct iree_hal_metal_device_params_t { + // Total size of each block in the device shared block pool. + // Larger sizes will lower overhead and ensure the heap isn't hit for + // transient allocations while also increasing memory consumption. + iree_host_size_t arena_block_size; + + // Command dispatch type in command buffers. + // Normally we want to dispatch commands in command buffers in parallel, given + // that IREE performs explicit dependency tracking and synchronization by + // itself. Though being able to specify serial command dispatching helps + // debugging in certain cases. + iree_hal_metal_command_dispatch_type_t command_dispatch_type; +} iree_hal_metal_device_params_t; + +// Initializes |out_params| to default values. +void iree_hal_metal_device_params_initialize( + iree_hal_metal_device_params_t* out_params); + +//===----------------------------------------------------------------------===// // iree_hal_metal_driver_t //===----------------------------------------------------------------------===// @@ -25,8 +57,9 @@ // // |out_driver| must be released by the caller (see iree_hal_driver_release). IREE_API_EXPORT iree_status_t iree_hal_metal_driver_create( - iree_string_view_t identifier, iree_allocator_t host_allocator, - iree_hal_driver_t** out_driver); + iree_string_view_t identifier, + const iree_hal_metal_device_params_t* device_params, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/metal/metal_device.h b/experimental/metal/metal_device.h index 996d5d6..3348913 100644 --- a/experimental/metal/metal_device.h +++ b/experimental/metal/metal_device.h
@@ -9,6 +9,7 @@ #import <Metal/Metal.h> +#include "experimental/metal/api.h" #include "iree/base/api.h" #include "iree/hal/api.h" @@ -16,12 +17,18 @@ extern "C" { #endif // __cplusplus -// Creates a Metal device. -iree_status_t iree_hal_metal_device_create(iree_hal_driver_t* driver, - iree_string_view_t identifier, - id<MTLDevice> device, - iree_allocator_t host_allocator, - iree_hal_device_t** out_device); +// Creates a Metal device by wrapping |device| from the given |driver| with the +// specific |params|. +// +// |out_device| must be released by the caller (see iree_hal_device_release). +iree_status_t iree_hal_metal_device_create( + iree_string_view_t identifier, const iree_hal_metal_device_params_t* params, + id<MTLDevice> device, iree_allocator_t host_allocator, + iree_hal_device_t** out_device); + +// Returns the parameters used for creating the device. +const iree_hal_metal_device_params_t* iree_hal_metal_device_params( + const iree_hal_device_t* device); #ifdef __cplusplus } // extern "C"
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m index a706448..8aca8a2 100644 --- a/experimental/metal/metal_device.m +++ b/experimental/metal/metal_device.m
@@ -6,6 +6,7 @@ #include "experimental/metal/metal_device.h" +#include "experimental/metal/api.h" #include "experimental/metal/direct_allocator.h" #include "experimental/metal/metal_shared_event.h" #include "experimental/metal/nop_executable_cache.h" @@ -42,11 +43,27 @@ return (iree_hal_metal_device_t*)base_value; } -static iree_status_t iree_hal_metal_device_create_internal(iree_hal_driver_t* driver, - iree_string_view_t identifier, - id<MTLDevice> metal_device, - iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { +static const iree_hal_metal_device_t* iree_hal_metal_device_const_cast( + const iree_hal_device_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_device_vtable); + return (const iree_hal_metal_device_t*)base_value; +} + +void iree_hal_metal_device_params_initialize(iree_hal_metal_device_params_t* out_params) { + memset(out_params, 0, sizeof(*out_params)); + out_params->arena_block_size = 32 * 1024; + out_params->command_dispatch_type = IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT; +} + +const iree_hal_metal_device_params_t* iree_hal_metal_device_params( + const iree_hal_device_t* base_device) { + const iree_hal_metal_device_t* device = iree_hal_metal_device_const_cast(base_device); + return &device->params; +} + +static iree_status_t iree_hal_metal_device_create_internal( + iree_string_view_t identifier, const iree_hal_metal_device_params_t* params, + id<MTLDevice> metal_device, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { iree_hal_metal_device_t* device = NULL; iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size; @@ -73,14 +90,15 @@ return status; } -iree_status_t iree_hal_metal_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier, +iree_status_t iree_hal_metal_device_create(iree_string_view_t identifier, + const iree_hal_metal_device_params_t* params, id<MTLDevice> device, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { IREE_ASSERT_ARGUMENT(out_device); IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = - iree_hal_metal_device_create_internal(driver, identifier, device, host_allocator, out_device); + iree_hal_metal_device_create_internal(identifier, params, device, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status;
diff --git a/experimental/metal/metal_driver.m b/experimental/metal/metal_driver.m index 9519f80..9f7be10 100644 --- a/experimental/metal/metal_driver.m +++ b/experimental/metal/metal_driver.m
@@ -32,6 +32,9 @@ // multiple Metal versions can be exposed in the same process. iree_string_view_t identifier; + // Parameters used to control device behavior. + iree_hal_metal_device_params_t device_params; + // The list of GPUs available when creating the driver. We retain them here to make sure // id<MTLDevice>, which is used for creating devices and such, remains valid. NSArray<id<MTLDevice>>* devices; @@ -44,6 +47,12 @@ return (iree_hal_metal_driver_t*)base_value; } +static const iree_hal_metal_driver_t* iree_hal_metal_driver_const_cast( + const iree_hal_driver_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_driver_vtable); + return (const iree_hal_metal_driver_t*)base_value; +} + // Returns an retained array of available Metal GPU devices; the caller should release later. static NSArray<id<MTLDevice>>* iree_hal_metal_device_copy() { #if defined(IREE_PLATFORM_MACOS) @@ -57,9 +66,18 @@ #endif // IREE_PLATFORM_MACOS } -static iree_status_t iree_hal_metal_driver_create_internal(iree_string_view_t identifier, - iree_allocator_t host_allocator, - iree_hal_driver_t** out_driver) { +static iree_status_t iree_hal_metal_device_check_params( + const iree_hal_metal_device_params_t* params) { + if (params->arena_block_size < 4096) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "arena block size too small (< 4096 bytes)"); + } + return iree_ok_status(); +} + +static iree_status_t iree_hal_metal_driver_create_internal( + iree_string_view_t identifier, const iree_hal_metal_device_params_t* device_params, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { iree_hal_metal_driver_t* driver = NULL; iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size; IREE_RETURN_IF_ERROR(iree_allocator_malloc(host_allocator, total_size, (void**)&driver)); @@ -68,6 +86,7 @@ driver->host_allocator = host_allocator; iree_string_view_append_to_buffer(identifier, &driver->identifier, (char*)driver + iree_sizeof_struct(*driver)); + driver->device_params = *device_params; // Get all available Metal devices. driver->devices = iree_hal_metal_device_copy(); @@ -76,14 +95,15 @@ return iree_ok_status(); } -IREE_API_EXPORT iree_status_t iree_hal_metal_driver_create(iree_string_view_t identifier, - iree_allocator_t host_allocator, - iree_hal_driver_t** out_driver) { +IREE_API_EXPORT iree_status_t iree_hal_metal_driver_create( + iree_string_view_t identifier, const iree_hal_metal_device_params_t* device_params, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { IREE_ASSERT_ARGUMENT(out_driver); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_device_check_params(device_params)); iree_status_t status = - iree_hal_metal_driver_create_internal(identifier, host_allocator, out_driver); + iree_hal_metal_driver_create_internal(identifier, device_params, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0); return status; @@ -331,6 +351,7 @@ const iree_string_pair_t* params, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + iree_hal_metal_driver_t* driver = iree_hal_metal_driver_cast(base_driver); IREE_TRACE_ZONE_BEGIN(z0); id<MTLDevice> device = nil; @@ -344,8 +365,8 @@ iree_string_view_t device_name = iree_make_cstring_view("metal"); - iree_status_t status = - iree_hal_metal_device_create(base_driver, device_name, device, host_allocator, out_device); + iree_status_t status = iree_hal_metal_device_create(device_name, &driver->device_params, device, + host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status;
diff --git a/experimental/metal/registration/driver_module.c b/experimental/metal/registration/driver_module.c index fb14b3e..3126cfd 100644 --- a/experimental/metal/registration/driver_module.c +++ b/experimental/metal/registration/driver_module.c
@@ -45,8 +45,11 @@ IREE_TRACE_ZONE_BEGIN(z0); - iree_status_t status = - iree_hal_metal_driver_create(driver_name, host_allocator, out_driver); + iree_hal_metal_device_params_t device_params; + iree_hal_metal_device_params_initialize(&device_params); + + iree_status_t status = iree_hal_metal_driver_create( + driver_name, &device_params, host_allocator, out_driver); IREE_TRACE_ZONE_END(z0);