blob: 34687e4381956c65e9cdc90e84f51e3c4d082022 [file]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include <algorithm>
namespace mlir::iree_compiler::IREE::HAL {
//===----------------------------------------------------------------------===//
// TargetRegistry
//===----------------------------------------------------------------------===//
// static
TargetRegistry &TargetRegistry::getMutableTargetRegistry() {
static TargetRegistry global;
return global;
}
// static
const TargetRegistry &TargetRegistry::getGlobal() {
return getMutableTargetRegistry();
}
void TargetRegistry::mergeFrom(const TargetDeviceList &targetDevices) {
for (auto &it : targetDevices.entries) {
if (deviceRegistrations.contains(it.first)) {
llvm::report_fatal_error("Attempting to overwrite an existing device");
}
auto registration = std::make_unique<TargetDeviceRegistration>(it.second);
deviceRegistrations[it.first] = registration.get();
ownedDeviceRegistrations.push_back(std::move(registration));
}
}
void TargetRegistry::mergeFrom(const TargetBackendList &targetBackends) {
for (auto &it : targetBackends.entries) {
if (backendRegistrations.contains(it.first)) {
llvm::report_fatal_error(
"Attempting to overwrite an existing translation backend");
}
auto registration = std::make_unique<TargetBackendRegistration>(it.second);
backendRegistrations[it.first] = registration.get();
ownedBackendRegistrations.push_back(std::move(registration));
}
}
void TargetRegistry::mergeFrom(const TargetRegistry &registry) {
for (auto &it : registry.deviceRegistrations) {
if (deviceRegistrations.contains(it.first())) {
llvm::report_fatal_error("Attempting to overwrite an existing device");
}
deviceRegistrations[it.first()] = it.second;
}
for (auto &it : registry.backendRegistrations) {
if (backendRegistrations.contains(it.first())) {
llvm::report_fatal_error(
"Attempting to overwrite an existing translation backend");
}
backendRegistrations[it.first()] = it.second;
}
}
std::vector<std::string> TargetRegistry::getRegisteredTargetDevices() const {
std::vector<std::string> result;
for (auto &entry : deviceRegistrations) {
result.push_back(entry.getKey().str());
}
std::sort(result.begin(), result.end(),
[](const auto &a, const auto &b) { return a < b; });
return result;
}
std::vector<std::string> TargetRegistry::getRegisteredTargetBackends() const {
std::vector<std::string> result;
for (auto &entry : backendRegistrations) {
result.push_back(entry.getKey().str());
}
std::sort(result.begin(), result.end(),
[](const auto &a, const auto &b) { return a < b; });
return result;
}
std::shared_ptr<TargetDevice>
TargetRegistry::getTargetDevice(StringRef targetName) const {
for (auto &entry : deviceRegistrations) {
if (entry.getKey() == targetName) {
return entry.getValue()->acquire();
}
}
return {};
}
std::shared_ptr<TargetBackend>
TargetRegistry::getTargetBackend(StringRef targetName) const {
for (auto &entry : backendRegistrations) {
if (entry.getKey() == targetName) {
return entry.getValue()->acquire();
}
}
return {};
}
SmallVector<std::shared_ptr<TargetDevice>>
TargetRegistry::getTargetDevices(ArrayRef<std::string> targetNames) const {
SmallVector<std::pair<std::string, std::shared_ptr<TargetDevice>>> matches;
for (auto &targetName : targetNames) {
auto targetDevice = getTargetDevice(targetName);
if (targetDevice) {
matches.push_back(std::make_pair(targetName, std::move(targetDevice)));
}
}
// To ensure deterministic builds we sort matches by name.
std::sort(matches.begin(), matches.end(),
[](const auto &a, const auto &b) { return a.first < b.first; });
return llvm::to_vector(llvm::map_range(
matches, [](auto match) { return std::move(match.second); }));
}
SmallVector<std::shared_ptr<TargetBackend>>
TargetRegistry::getTargetBackends(ArrayRef<std::string> targetNames) const {
SmallVector<std::pair<std::string, std::shared_ptr<TargetBackend>>> matches;
for (auto &targetName : targetNames) {
auto targetBackend = getTargetBackend(targetName);
if (targetBackend) {
matches.push_back(std::make_pair(targetName, std::move(targetBackend)));
}
}
// To ensure deterministic builds we sort matches by name.
std::sort(matches.begin(), matches.end(),
[](const auto &a, const auto &b) { return a.first < b.first; });
return llvm::to_vector(llvm::map_range(
matches, [](auto match) { return std::move(match.second); }));
}
} // namespace mlir::iree_compiler::IREE::HAL
//===----------------------------------------------------------------------===//
// TargetRegistryRef
//===----------------------------------------------------------------------===//
namespace llvm::cl {
template class basic_parser<TargetRegistryRef>;
} // namespace llvm::cl
using TargetRegistryRef = llvm::cl::TargetRegistryRef;
// Return true on error.
bool llvm::cl::parser<TargetRegistryRef>::parse(Option &O, StringRef ArgName,
StringRef Arg,
TargetRegistryRef &Val) {
// We ignore Arg here and just use the global registry. We could parse a list
// of target backends and create a new registry with just that subset but
// ownership gets tricky.
if (Arg != "global")
return true;
Val.value = &mlir::iree_compiler::IREE::HAL::TargetRegistry::getGlobal();
return false;
}
void llvm::cl::parser<TargetRegistryRef>::printOptionDiff(
const Option &O, TargetRegistryRef V, const OptVal &Default,
size_t GlobalWidth) const {
printOptionName(O, GlobalWidth);
std::string Str = "global";
outs() << "= " << Str;
outs().indent(2) << " (default: global)\n";
}
void llvm::cl::parser<TargetRegistryRef>::anchor() {}