[metal] Add device parameters to driver/device creation APIs
This commit defines `iree_hal_metal_device_params_t` for
controlling major Metal device behavior. Right now we expose
arena block size and command dispatch type.
diff --git a/experimental/metal/api.h b/experimental/metal/api.h
index a156b5f..a27eb88 100644
--- a/experimental/metal/api.h
+++ b/experimental/metal/api.h
@@ -17,6 +17,38 @@
#endif // __cplusplus
//===----------------------------------------------------------------------===//
+// iree_hal_metal_device_params_t
+//===----------------------------------------------------------------------===//
+
+typedef enum iree_hal_metal_command_dispatch_type_e {
+ // Dispatch commands in command buffer in parallel.
+ IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT = 0,
+ // Dispatch commands in command buffer sequentially.
+ IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_SERIAL = 1,
+} iree_hal_metal_command_dispatch_type_t;
+
+// Parameters configuring an iree_hal_metal_device_t.
+// Must be initialized with iree_hal_metal_device_params_initialize prior to
+// use.
+typedef struct iree_hal_metal_device_params_t {
+ // Total size of each block in the device shared block pool.
+ // Larger sizes will lower overhead and ensure the heap isn't hit for
+ // transient allocations while also increasing memory consumption.
+ iree_host_size_t arena_block_size;
+
+ // Command dispatch type in command buffers.
+ // Normally we want to dispatch commands in command buffers in parallel, given
+ // that IREE performs explicit dependency tracking and synchronization by
+ // itself. Though being able to specify serial command dispatching helps
+ // debugging in certain cases.
+ iree_hal_metal_command_dispatch_type_t command_dispatch_type;
+} iree_hal_metal_device_params_t;
+
+// Initializes |out_params| to default values.
+void iree_hal_metal_device_params_initialize(
+ iree_hal_metal_device_params_t* out_params);
+
+//===----------------------------------------------------------------------===//
// iree_hal_metal_driver_t
//===----------------------------------------------------------------------===//
@@ -25,8 +57,9 @@
//
// |out_driver| must be released by the caller (see iree_hal_driver_release).
IREE_API_EXPORT iree_status_t iree_hal_metal_driver_create(
- iree_string_view_t identifier, iree_allocator_t host_allocator,
- iree_hal_driver_t** out_driver);
+ iree_string_view_t identifier,
+ const iree_hal_metal_device_params_t* device_params,
+ iree_allocator_t host_allocator, iree_hal_driver_t** out_driver);
#ifdef __cplusplus
} // extern "C"
diff --git a/experimental/metal/metal_device.h b/experimental/metal/metal_device.h
index 996d5d6..3348913 100644
--- a/experimental/metal/metal_device.h
+++ b/experimental/metal/metal_device.h
@@ -9,6 +9,7 @@
#import <Metal/Metal.h>
+#include "experimental/metal/api.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
@@ -16,12 +17,18 @@
extern "C" {
#endif // __cplusplus
-// Creates a Metal device.
-iree_status_t iree_hal_metal_device_create(iree_hal_driver_t* driver,
- iree_string_view_t identifier,
- id<MTLDevice> device,
- iree_allocator_t host_allocator,
- iree_hal_device_t** out_device);
+// Creates a Metal device by wrapping |device| from the given |driver| with the
+// specific |params|.
+//
+// |out_device| must be released by the caller (see iree_hal_device_release).
+iree_status_t iree_hal_metal_device_create(
+ iree_string_view_t identifier, const iree_hal_metal_device_params_t* params,
+ id<MTLDevice> device, iree_allocator_t host_allocator,
+ iree_hal_device_t** out_device);
+
+// Returns the parameters used for creating the device.
+const iree_hal_metal_device_params_t* iree_hal_metal_device_params(
+ const iree_hal_device_t* device);
#ifdef __cplusplus
} // extern "C"
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m
index a706448..8aca8a2 100644
--- a/experimental/metal/metal_device.m
+++ b/experimental/metal/metal_device.m
@@ -6,6 +6,7 @@
#include "experimental/metal/metal_device.h"
+#include "experimental/metal/api.h"
#include "experimental/metal/direct_allocator.h"
#include "experimental/metal/metal_shared_event.h"
#include "experimental/metal/nop_executable_cache.h"
@@ -42,11 +43,27 @@
return (iree_hal_metal_device_t*)base_value;
}
-static iree_status_t iree_hal_metal_device_create_internal(iree_hal_driver_t* driver,
- iree_string_view_t identifier,
- id<MTLDevice> metal_device,
- iree_allocator_t host_allocator,
- iree_hal_device_t** out_device) {
+static const iree_hal_metal_device_t* iree_hal_metal_device_const_cast(
+ const iree_hal_device_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_device_vtable);
+ return (const iree_hal_metal_device_t*)base_value;
+}
+
+void iree_hal_metal_device_params_initialize(iree_hal_metal_device_params_t* out_params) {
+ memset(out_params, 0, sizeof(*out_params));
+ out_params->arena_block_size = 32 * 1024;
+ out_params->command_dispatch_type = IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT;
+}
+
+const iree_hal_metal_device_params_t* iree_hal_metal_device_params(
+ const iree_hal_device_t* base_device) {
+ const iree_hal_metal_device_t* device = iree_hal_metal_device_const_cast(base_device);
+ return &device->params;
+}
+
+static iree_status_t iree_hal_metal_device_create_internal(
+ iree_string_view_t identifier, const iree_hal_metal_device_params_t* params,
+ id<MTLDevice> metal_device, iree_allocator_t host_allocator, iree_hal_device_t** out_device) {
iree_hal_metal_device_t* device = NULL;
iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size;
@@ -73,14 +90,15 @@
return status;
}
-iree_status_t iree_hal_metal_device_create(iree_hal_driver_t* driver, iree_string_view_t identifier,
+iree_status_t iree_hal_metal_device_create(iree_string_view_t identifier,
+ const iree_hal_metal_device_params_t* params,
id<MTLDevice> device, iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
IREE_ASSERT_ARGUMENT(out_device);
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status =
- iree_hal_metal_device_create_internal(driver, identifier, device, host_allocator, out_device);
+ iree_hal_metal_device_create_internal(identifier, params, device, host_allocator, out_device);
IREE_TRACE_ZONE_END(z0);
return status;
diff --git a/experimental/metal/metal_driver.m b/experimental/metal/metal_driver.m
index 9519f80..9f7be10 100644
--- a/experimental/metal/metal_driver.m
+++ b/experimental/metal/metal_driver.m
@@ -32,6 +32,9 @@
// multiple Metal versions can be exposed in the same process.
iree_string_view_t identifier;
+ // Parameters used to control device behavior.
+ iree_hal_metal_device_params_t device_params;
+
// The list of GPUs available when creating the driver. We retain them here to make sure
// id<MTLDevice>, which is used for creating devices and such, remains valid.
NSArray<id<MTLDevice>>* devices;
@@ -44,6 +47,12 @@
return (iree_hal_metal_driver_t*)base_value;
}
+static const iree_hal_metal_driver_t* iree_hal_metal_driver_const_cast(
+ const iree_hal_driver_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_driver_vtable);
+ return (const iree_hal_metal_driver_t*)base_value;
+}
+
// Returns an retained array of available Metal GPU devices; the caller should release later.
static NSArray<id<MTLDevice>>* iree_hal_metal_device_copy() {
#if defined(IREE_PLATFORM_MACOS)
@@ -57,9 +66,18 @@
#endif // IREE_PLATFORM_MACOS
}
-static iree_status_t iree_hal_metal_driver_create_internal(iree_string_view_t identifier,
- iree_allocator_t host_allocator,
- iree_hal_driver_t** out_driver) {
+static iree_status_t iree_hal_metal_device_check_params(
+ const iree_hal_metal_device_params_t* params) {
+ if (params->arena_block_size < 4096) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "arena block size too small (< 4096 bytes)");
+ }
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_metal_driver_create_internal(
+ iree_string_view_t identifier, const iree_hal_metal_device_params_t* device_params,
+ iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
iree_hal_metal_driver_t* driver = NULL;
iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(host_allocator, total_size, (void**)&driver));
@@ -68,6 +86,7 @@
driver->host_allocator = host_allocator;
iree_string_view_append_to_buffer(identifier, &driver->identifier,
(char*)driver + iree_sizeof_struct(*driver));
+ driver->device_params = *device_params;
// Get all available Metal devices.
driver->devices = iree_hal_metal_device_copy();
@@ -76,14 +95,15 @@
return iree_ok_status();
}
-IREE_API_EXPORT iree_status_t iree_hal_metal_driver_create(iree_string_view_t identifier,
- iree_allocator_t host_allocator,
- iree_hal_driver_t** out_driver) {
+IREE_API_EXPORT iree_status_t iree_hal_metal_driver_create(
+ iree_string_view_t identifier, const iree_hal_metal_device_params_t* device_params,
+ iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);
IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_device_check_params(device_params));
iree_status_t status =
- iree_hal_metal_driver_create_internal(identifier, host_allocator, out_driver);
+ iree_hal_metal_driver_create_internal(identifier, device_params, host_allocator, out_driver);
IREE_TRACE_ZONE_END(z0);
return status;
@@ -331,6 +351,7 @@
const iree_string_pair_t* params,
iree_allocator_t host_allocator,
iree_hal_device_t** out_device) {
+ iree_hal_metal_driver_t* driver = iree_hal_metal_driver_cast(base_driver);
IREE_TRACE_ZONE_BEGIN(z0);
id<MTLDevice> device = nil;
@@ -344,8 +365,8 @@
iree_string_view_t device_name = iree_make_cstring_view("metal");
- iree_status_t status =
- iree_hal_metal_device_create(base_driver, device_name, device, host_allocator, out_device);
+ iree_status_t status = iree_hal_metal_device_create(device_name, &driver->device_params, device,
+ host_allocator, out_device);
IREE_TRACE_ZONE_END(z0);
return status;
diff --git a/experimental/metal/registration/driver_module.c b/experimental/metal/registration/driver_module.c
index fb14b3e..3126cfd 100644
--- a/experimental/metal/registration/driver_module.c
+++ b/experimental/metal/registration/driver_module.c
@@ -45,8 +45,11 @@
IREE_TRACE_ZONE_BEGIN(z0);
- iree_status_t status =
- iree_hal_metal_driver_create(driver_name, host_allocator, out_driver);
+ iree_hal_metal_device_params_t device_params;
+ iree_hal_metal_device_params_initialize(&device_params);
+
+ iree_status_t status = iree_hal_metal_driver_create(
+ driver_name, &device_params, host_allocator, out_driver);
IREE_TRACE_ZONE_END(z0);