blob: 417f6503431cfa6b2debe78a3af4a13c36e3a746 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree_pjrt/cpu/client.h"
#include "iree/hal/drivers/local_sync/sync_driver.h"
#include "iree/hal/drivers/local_task/task_driver.h"
#include "iree/hal/local/plugins/registration/init.h"
#include "iree/task/api.h"
namespace iree::pjrt::cpu {
CPUClientInstance::CPUClientInstance(std::unique_ptr<Platform> platform)
: ClientInstance(std::move(platform)) {
// Seems that it must match how registered. Action at a distance not
// great.
// TODO: Get this when constructing the client so it is guaranteed to
// match.
cached_platform_name_ = "iree_cpu";
iree_task_topology_initialize(&task_topology_options_);
}
CPUClientInstance::~CPUClientInstance() {
iree_hal_allocator_release(device_allocator_);
if (executor_) iree_task_executor_release(executor_);
for (iree_host_size_t i = 0; i < loader_count_; ++i) {
iree_hal_executable_loader_release(loaders_[i]);
}
if (plugin_manager_)
iree_hal_executable_plugin_manager_release(plugin_manager_);
iree_task_topology_deinitialize(&task_topology_options_);
}
iree_status_t CPUClientInstance::InitializeDeps() {
// plugin_manager_
IREE_RETURN_IF_ERROR(iree_hal_executable_plugin_manager_create(
/*capacity=*/0, host_allocator_, &plugin_manager_));
// loaders_
IREE_RETURN_IF_ERROR(iree_hal_create_all_available_executable_loaders(
plugin_manager_, IREE_ARRAYSIZE(loaders_), &loader_count_, loaders_,
host_allocator_));
// device_allocator_
IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap(
iree_make_cstring_view("local"), host_allocator_, host_allocator_,
&device_allocator_));
return iree_ok_status();
}
iree_status_t CPUClientInstance::CreateDriver(iree_hal_driver_t** out_driver) {
// TODO: There is substantial configuration available.
// We choose to use explicit instantiation (vs registration) because
// it is assumed that for server-library oriented cases, we are going to
// want non-default control.
IREE_RETURN_IF_ERROR(InitializeDeps());
// driver
if (single_threaded_debug_) {
logger().debug("Creating single threaded CPU driver (debugging)");
iree_hal_sync_device_params_t sync_params;
iree_hal_sync_device_params_initialize(&sync_params);
IREE_RETURN_IF_ERROR(iree_hal_sync_driver_create(
IREE_SV("local-sync"), &sync_params, loader_count_, loaders_,
device_allocator_, host_allocator_, out_driver));
} else {
iree_task_executor_options_t task_executor_options;
iree_hal_task_device_params_t task_params;
iree_task_executor_options_initialize(&task_executor_options);
iree_hal_task_device_params_initialize(&task_params);
// executor options and topology options. Getting these from flags is not
// great for this use since there is no way to set the flags :/
IREE_RETURN_IF_ERROR(iree_task_executor_options_initialize_from_flags(
&task_executor_options));
// TODO: Do something smarter than pinning to NUMA node 0.
IREE_RETURN_IF_ERROR(iree_task_topology_initialize_from_flags(
/*node_id=*/0, &task_topology_options_));
IREE_RETURN_IF_ERROR(iree_task_executor_create(
task_executor_options, &task_topology_options_, host_allocator_,
&executor_));
IREE_RETURN_IF_ERROR(iree_hal_task_driver_create(
IREE_SV("local-task"), &task_params, /*queue_count=*/1, &executor_,
loader_count_, loaders_, device_allocator_, host_allocator_,
out_driver));
}
logger().debug("CPU driver created");
return iree_ok_status();
}
bool CPUClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
return compiler_job->SetFlag("--iree-hal-target-backends=llvm-cpu");
}
} // namespace iree::pjrt::cpu