blob: 9a4d2d040aa2efe33fb86a3e241a615ae78582a4 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree_pjrt/cuda/client.h"
namespace iree::pjrt::cuda {
CUDAClientInstance::CUDAClientInstance(std::unique_ptr<Platform> platform)
: ClientInstance(std::move(platform)) {
// Seems that it must match how registered. Action at a distance not
// great.
// TODO: Get this when constructing the client so it is guaranteed to
// match.
cached_platform_name_ = "iree_cuda";
}
CUDAClientInstance::~CUDAClientInstance() {}
iree_status_t CUDAClientInstance::CreateDriver(iree_hal_driver_t** out_driver) {
iree_string_view_t driver_name = iree_make_cstring_view("cuda");
// Device params.
// TODO: Plumb through some important params:
// nccl_default_id
// nccl_default_rank
// nccl_default_count
// Switch command_buffer_mode to graphs when ready.
iree_hal_cuda_device_params_t default_params;
iree_hal_cuda_device_params_initialize(&default_params);
default_params.command_buffer_mode = IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM;
// Driver params.
iree_hal_cuda_driver_options_t driver_options;
iree_hal_cuda_driver_options_initialize(&driver_options);
driver_options.default_device_index = 0;
IREE_RETURN_IF_ERROR(
iree_hal_cuda_driver_create(driver_name, &driver_options, &default_params,
host_allocator_, out_driver));
logger().debug("CUDA driver created");
return iree_ok_status();
}
bool CUDAClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
return compiler_job->SetFlag("--iree-hal-target-device=cuda");
}
} // namespace iree::pjrt::cuda