blob: da3569804afda63d814438d79ad95c3018d1425b [file]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/metal/registration/driver_module.h"
#include <inttypes.h>
#include <stddef.h>
#include "experimental/metal/api.h"
#include "iree/base/api.h"
#include "iree/base/internal/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
IREE_FLAG(bool, metal_serial_command_dispatch, false,
"Serializes all commands within command buffers as if there were "
"barriers between each");
IREE_FLAG(bool, metal_command_buffer_retain_resources, false,
"Enables automatic Metal resource reference counting for diagnosing "
"resource lifetime issues");
IREE_FLAG(bool, metal_resource_hazard_tracking, false,
"Enables automatic Metal hazard tracking for diagnosing concurrency "
"issues");
static iree_status_t iree_hal_metal_driver_factory_enumerate(
void* self, iree_host_size_t* out_driver_info_count,
const iree_hal_driver_info_t** out_driver_infos) {
IREE_ASSERT_ARGUMENT(out_driver_info_count);
IREE_ASSERT_ARGUMENT(out_driver_infos);
static const iree_hal_driver_info_t driver_infos[1] = {
{
.driver_name = IREE_SVL("metal"),
.full_name = IREE_SVL("Apple Metal"),
},
};
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
*out_driver_infos = driver_infos;
return iree_ok_status();
}
static iree_status_t iree_hal_metal_driver_factory_try_create(
void* self, iree_string_view_t driver_name, iree_allocator_t host_allocator,
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);
if (!iree_string_view_equal(driver_name, IREE_SV("metal"))) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver '%.*s' is provided by this factory",
(int)driver_name.size, driver_name.data);
}
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_device_params_t device_params;
iree_hal_metal_device_params_initialize(&device_params);
device_params.command_dispatch_type =
FLAG_metal_serial_command_dispatch
? IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_SERIAL
: IREE_HAL_METAL_COMMAND_DISPATCH_TYPE_CONCURRENT;
device_params.command_buffer_resource_reference_mode =
FLAG_metal_command_buffer_retain_resources
? IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_RETAINED
: IREE_HAL_METAL_COMMAND_BUFFER_RESOURCE_REFERENCE_MODE_UNRETAINED;
device_params.resource_hazard_tracking_mode =
FLAG_metal_resource_hazard_tracking
? IREE_HAL_METAL_RESOURCE_HAZARD_TRACKING_MODE_TRACKED
: IREE_HAL_METAL_RESOURCE_HAZARD_TRACKING_MODE_UNTRACKED;
iree_status_t status = iree_hal_metal_driver_create(
driver_name, &device_params, host_allocator, out_driver);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t
iree_hal_metal_driver_module_register(iree_hal_driver_registry_t* registry) {
static const iree_hal_driver_factory_t factory = {
.self = NULL,
.enumerate = iree_hal_metal_driver_factory_enumerate,
.try_create = iree_hal_metal_driver_factory_try_create,
};
return iree_hal_driver_registry_register_factory(registry, &factory);
}