blob: 5f290d0ddab46f72e9759a6dd620a60eb6dba78e [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree_pjrt/rocm/client.h"
#include "experimental/rocm/registration/driver_module.h"
namespace iree::pjrt::rocm {
ROCMClientInstance::ROCMClientInstance(std::unique_ptr<Platform> platform)
: ClientInstance(std::move(platform)) {
// Seems that it must match how registered. Action at a distance not
// great.
// TODO: Get this when constructing the client so it is guaranteed to
// match.
cached_platform_name_ = "iree_rocm";
IREE_CHECK_OK(iree_hal_rocm_driver_module_register(driver_registry_));
}
ROCMClientInstance::~ROCMClientInstance() {}
iree_status_t ROCMClientInstance::CreateDriver(iree_hal_driver_t** out_driver) {
iree_string_view_t driver_name = iree_make_cstring_view("rocm");
IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create(
driver_registry_, driver_name, host_allocator_, out_driver));
logger().debug("ROCM driver created");
return iree_ok_status();
}
bool ROCMClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
return compiler_job->SetFlag("--iree-hal-target-backends=rocm");
}
} // namespace iree::pjrt::rocm