blob: 8b4b9cc26f1d6b62f75c7a1270c9855abb2c563b [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/rocm/registration/driver_module.h"
#include <inttypes.h>
#include <stddef.h>
#include "experimental/rocm/api.h"
#include "iree/base/api.h"
#include "iree/base/tracing.h"
#define IREE_HAL_ROCM_DRIVER_ID 0x524f434d0au // ROCM
static iree_status_t iree_hal_rocm_driver_factory_enumerate(
void *self, const iree_hal_driver_info_t **out_driver_infos,
iree_host_size_t *out_driver_info_count) {
// NOTE: we could query supported ROCM versions or featuresets here.
static const iree_hal_driver_info_t driver_infos[1] = {{
.driver_id = IREE_HAL_ROCM_DRIVER_ID,
.driver_name = iree_string_view_literal("rocm"),
.full_name = iree_string_view_literal("ROCM (dynamic)"),
}};
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
*out_driver_infos = driver_infos;
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_driver_factory_try_create(
void *self, iree_hal_driver_id_t driver_id, iree_allocator_t host_allocator,
iree_hal_driver_t **out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);
*out_driver = NULL;
if (driver_id != IREE_HAL_ROCM_DRIVER_ID) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver with ID %016" PRIu64
" is provided by this factory",
driver_id);
}
IREE_TRACE_ZONE_BEGIN(z0);
// When we expose more than one driver (different rocm versions, etc) we
// can name them here:
iree_string_view_t identifier = iree_make_cstring_view("rocm");
iree_hal_rocm_driver_options_t driver_options;
iree_hal_rocm_driver_options_initialize(&driver_options);
iree_status_t status = iree_hal_rocm_driver_create(
identifier, &driver_options, host_allocator, out_driver);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t
iree_hal_rocm_driver_module_register(iree_hal_driver_registry_t *registry) {
static const iree_hal_driver_factory_t factory = {
.self = NULL,
.enumerate = iree_hal_rocm_driver_factory_enumerate,
.try_create = iree_hal_rocm_driver_factory_try_create,
};
return iree_hal_driver_registry_register_factory(registry, &factory);
}