Cleanup of TargetRegistry after #15468.
diff --git a/compiler/plugins/target/MetalSPIRV/BUILD.bazel b/compiler/plugins/target/MetalSPIRV/BUILD.bazel
index f0c729f..9773eff 100644
--- a/compiler/plugins/target/MetalSPIRV/BUILD.bazel
+++ b/compiler/plugins/target/MetalSPIRV/BUILD.bazel
@@ -28,6 +28,7 @@
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/SPIRV",
"//compiler/src/iree/compiler/Codegen/Utils",
+ "//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
"//compiler/src/iree/compiler/PluginAPI",
"//compiler/src/iree/compiler/Utils",
diff --git a/compiler/plugins/target/MetalSPIRV/CMakeLists.txt b/compiler/plugins/target/MetalSPIRV/CMakeLists.txt
index b18dde3..4dd1b06 100644
--- a/compiler/plugins/target/MetalSPIRV/CMakeLists.txt
+++ b/compiler/plugins/target/MetalSPIRV/CMakeLists.txt
@@ -38,6 +38,7 @@
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::SPIRV
iree::compiler::Codegen::Utils
+ iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
iree::compiler::PluginAPI
iree::compiler::Utils
diff --git a/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp b/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
index 190f79b..6622edf 100644
--- a/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
+++ b/compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -9,6 +9,7 @@
#include "compiler/plugins/target/MetalSPIRV/SPIRVToMSL.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
diff --git a/compiler/plugins/target/WebGPUSPIRV/CMakeLists.txt b/compiler/plugins/target/WebGPUSPIRV/CMakeLists.txt
index 9831fb6..d98dcf2 100644
--- a/compiler/plugins/target/WebGPUSPIRV/CMakeLists.txt
+++ b/compiler/plugins/target/WebGPUSPIRV/CMakeLists.txt
@@ -49,6 +49,7 @@
SPIRV-Tools
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::SPIRV
+ iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
iree::compiler::PluginAPI
iree::compiler::Utils
diff --git a/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp b/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp
index c8d6901..c2dd7c0 100644
--- a/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp
+++ b/compiler/plugins/target/WebGPUSPIRV/WebGPUSPIRVTarget.cpp
@@ -8,6 +8,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/WGSL/Passes.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/PluginAPI/Client.h"
diff --git a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
index b84201b..e648f09 100644
--- a/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
+++ b/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp
@@ -345,15 +345,13 @@
pluginActivationStatus = pluginSession.activatePlugins(&context);
// Initialize target registry, bootstrapping with the static globals.
- // TODO(15468): remove the static registration mechanism so the merge
- // from global is not required.
targetRegistry.mergeFrom(IREE::HAL::TargetRegistry::getGlobal());
- IREE::HAL::TargetBackendList pluginTargetBackendList;
- pluginSession.populateHALTargetBackends(pluginTargetBackendList);
- targetRegistry.mergeFrom(pluginTargetBackendList);
IREE::HAL::TargetDeviceList pluginTargetDeviceList;
pluginSession.populateHALTargetDevices(pluginTargetDeviceList);
targetRegistry.mergeFrom(pluginTargetDeviceList);
+ IREE::HAL::TargetBackendList pluginTargetBackendList;
+ pluginSession.populateHALTargetBackends(pluginTargetBackendList);
+ targetRegistry.mergeFrom(pluginTargetBackendList);
}
}
return pluginActivationStatus;
diff --git a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp
index 9307b84..b621724 100644
--- a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp
@@ -99,14 +99,14 @@
// of target backends. However, no such layering exists for the opt tool.
// Since it tests passes that are default initialized, we just configure the
// global registry that such constructors depend on.
- TargetBackendList pluginTargetBackendList;
- pluginSession.populateHALTargetBackends(pluginTargetBackendList);
- const_cast<TargetRegistry &>(TargetRegistry::getGlobal())
- .mergeFrom(pluginTargetBackendList);
TargetDeviceList pluginTargetDeviceList;
pluginSession.populateHALTargetDevices(pluginTargetDeviceList);
const_cast<TargetRegistry &>(TargetRegistry::getGlobal())
.mergeFrom(pluginTargetDeviceList);
+ TargetBackendList pluginTargetBackendList;
+ pluginSession.populateHALTargetBackends(pluginTargetBackendList);
+ const_cast<TargetRegistry &>(TargetRegistry::getGlobal())
+ .mergeFrom(pluginTargetBackendList);
// When reading from stdin and the input is a tty, it is often a user mistake
// and the process "appears to be stuck". Print a message to let the user know
diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
index 9df3295..ae9eaf4 100644
--- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
+++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/ConstEval/PassDetail.h"
#include "iree/compiler/ConstEval/Passes.h"
#include "iree/compiler/ConstEval/Runtime.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetOptions.h"
#include "iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h"
#include "iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
index 91224fa..f04a329 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
@@ -17,11 +17,13 @@
srcs = [
"TargetBackend.cpp",
"TargetDevice.cpp",
+ "TargetOptions.cpp",
"TargetRegistry.cpp",
],
hdrs = [
"TargetBackend.h",
"TargetDevice.h",
+ "TargetOptions.h",
"TargetRegistry.h",
],
deps = [
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
index 1db6d8f..b8b5e13 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
@@ -16,10 +16,12 @@
HDRS
"TargetBackend.h"
"TargetDevice.h"
+ "TargetOptions.h"
"TargetRegistry.h"
SRCS
"TargetBackend.cpp"
"TargetDevice.cpp"
+ "TargetOptions.cpp"
"TargetRegistry.cpp"
DEPS
LLVMSupport
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
index 581ac6f..a668bd5 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
@@ -9,88 +9,13 @@
#include <algorithm>
#include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Support/FileUtilities.h"
-IREE_DEFINE_COMPILER_OPTION_FLAGS(
- mlir::iree_compiler::IREE::HAL::TargetOptions);
-
namespace mlir::iree_compiler::IREE::HAL {
-void TargetOptions::bindOptions(OptionsBinder &binder) {
- static llvm::cl::OptionCategory halTargetOptionsCategory(
- "IREE HAL executable target options");
-
- // This function is called as part of registering the pass
- // TranslateExecutablesPass. Pass registry is also staticly
- // initialized, so targetBackendsFlags needs to be here to be initialized
- // first.
- binder.list<std::string>(
- "iree-hal-target-backends", targets,
- llvm::cl::desc("Target backends for executable compilation."),
- llvm::cl::ZeroOrMore, llvm::cl::cat(halTargetOptionsCategory));
-
- binder.opt<int>(
- "iree-hal-executable-debug-level", debugLevel,
- llvm::cl::desc("Debug level for executable translation (0-3)."),
- llvm::cl::cat(halTargetOptionsCategory));
-
- binder.opt<std::string>(
- "iree-hal-dump-executable-files-to", executableFilesPath,
- llvm::cl::desc(
- "Meta flag for all iree-hal-dump-executable-* options. Path to write "
- "executable files (sources, benchmarks, intermediates, binaries) "
- "to."),
- llvm::cl::callback([&](const std::string &path) {
- if (executableSourcesPath.empty())
- executableSourcesPath = path;
- if (executableConfigurationsPath.empty())
- executableConfigurationsPath = path;
- if (executableBenchmarksPath.empty())
- executableBenchmarksPath = path;
- if (executableIntermediatesPath.empty())
- executableIntermediatesPath = path;
- if (executableBinariesPath.empty())
- executableBinariesPath = path;
- }),
- llvm::cl::cat(halTargetOptionsCategory));
-
- binder.opt<std::string>(
- "iree-hal-dump-executable-sources-to", executableSourcesPath,
- llvm::cl::desc("Path to write individual hal.executable input "
- "source listings into (- for stdout)."),
- llvm::cl::cat(halTargetOptionsCategory));
-
- binder.opt<std::string>(
- "iree-hal-dump-executable-configurations-to",
- executableConfigurationsPath,
- llvm::cl::desc("Path to write individual hal.executable input source "
- "listings into, after translation strategy selection and "
- "before starting translation (- for stdout)."),
- llvm::cl::cat(halTargetOptionsCategory));
-
- binder.opt<std::string>(
- "iree-hal-dump-executable-benchmarks-to", executableBenchmarksPath,
- llvm::cl::desc("Path to write standalone hal.executable benchmarks into "
- "(- for stdout)."),
- llvm::cl::cat(halTargetOptionsCategory));
-
- binder.opt<std::string>("iree-hal-dump-executable-intermediates-to",
- executableIntermediatesPath,
- llvm::cl::desc("Path to write translated executable "
- "intermediates (.bc, .o, etc) into."),
- llvm::cl::cat(halTargetOptionsCategory));
-
- binder.opt<std::string>(
- "iree-hal-dump-executable-binaries-to", executableBinariesPath,
- llvm::cl::desc(
- "Path to write translated and serialized executable binaries into."),
- llvm::cl::cat(halTargetOptionsCategory));
-}
-
SmallVector<std::string>
gatherExecutableTargetNames(IREE::HAL::ExecutableOp executableOp) {
SmallVector<std::string> targetNames;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index bff153c..388fafd 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -12,9 +12,7 @@
#include <string>
#include <vector>
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Utils/OptionUtils.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Dialect.h"
@@ -22,45 +20,6 @@
namespace mlir::iree_compiler::IREE::HAL {
-// TODO(benvanik): remove this and replace with the pass pipeline options.
-// Controls executable translation targets.
-struct TargetOptions {
- // TODO(benvanik): multiple targets of the same type, etc.
- std::vector<std::string> targets;
-
- // Coarse debug level for executable translation across all targets.
- // Each target backend can use this to control its own flags, with values
- // generally corresponding to the gcc-style levels 0-3:
- // 0: no debug information
- // 1: minimal debug information
- // 2: default debug information
- // 3: maximal debug information
- int debugLevel = 2;
-
- // Default path to write executable files into.
- std::string executableFilesPath;
-
- // A path to write individual executable source listings into (before
- // configuration).
- std::string executableSourcesPath;
-
- // A path to write individual executable source listings into (after
- // configuration).
- std::string executableConfigurationsPath;
-
- // A path to write standalone executable benchmarks into.
- std::string executableBenchmarksPath;
-
- // A path to write executable intermediates into.
- std::string executableIntermediatesPath;
-
- // A path to write translated and serialized executable binaries into.
- std::string executableBinariesPath;
-
- void bindOptions(OptionsBinder &binder);
- using FromFlags = OptionsFromFlags<TargetOptions>;
-};
-
// HAL executable target backend interface.
// Multiple backends can be registered and targeted during a single compilation.
// The flow->hal conversion process will use registered TargetBackend interfaces
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
index 8641a33..bace88a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
@@ -6,8 +6,6 @@
#include "iree/compiler/Dialect/HAL/Target/TargetDevice.h"
-#include "mlir/IR/Dialect.h"
-
namespace mlir::iree_compiler::IREE::HAL {
// TODO(benvanik): add device options.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp
new file mode 100644
index 0000000..c00cb63
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.cpp
@@ -0,0 +1,87 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Target/TargetOptions.h"
+
+#include "llvm/Support/CommandLine.h"
+
+IREE_DEFINE_COMPILER_OPTION_FLAGS(
+ mlir::iree_compiler::IREE::HAL::TargetOptions);
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+void TargetOptions::bindOptions(OptionsBinder &binder) {
+ static llvm::cl::OptionCategory halTargetOptionsCategory(
+ "IREE HAL executable target options");
+
+ // This function is called as part of registering the pass
+ // TranslateExecutablesPass. Pass registry is also staticly
+ // initialized, so targetBackendsFlags needs to be here to be initialized
+ // first.
+ binder.list<std::string>(
+ "iree-hal-target-backends", targets,
+ llvm::cl::desc("Target backends for executable compilation."),
+ llvm::cl::ZeroOrMore, llvm::cl::cat(halTargetOptionsCategory));
+
+ binder.opt<int>(
+ "iree-hal-executable-debug-level", debugLevel,
+ llvm::cl::desc("Debug level for executable translation (0-3)"),
+ llvm::cl::init(2), llvm::cl::cat(halTargetOptionsCategory));
+
+ binder.opt<std::string>(
+ "iree-hal-dump-executable-files-to", executableFilesPath,
+ llvm::cl::desc(
+ "Meta flag for all iree-hal-dump-executable-* options. Path to write "
+ "executable files (sources, benchmarks, intermediates, binaries) "
+ "to."),
+ llvm::cl::callback([&](const std::string &path) {
+ if (executableSourcesPath.empty())
+ executableSourcesPath = path;
+ if (executableConfigurationsPath.empty())
+ executableConfigurationsPath = path;
+ if (executableBenchmarksPath.empty())
+ executableBenchmarksPath = path;
+ if (executableIntermediatesPath.empty())
+ executableIntermediatesPath = path;
+ if (executableBinariesPath.empty())
+ executableBinariesPath = path;
+ }),
+ llvm::cl::cat(halTargetOptionsCategory));
+
+ binder.opt<std::string>(
+ "iree-hal-dump-executable-sources-to", executableSourcesPath,
+ llvm::cl::desc("Path to write individual hal.executable input "
+ "source listings into (- for stdout)."),
+ llvm::cl::cat(halTargetOptionsCategory));
+
+ binder.opt<std::string>(
+ "iree-hal-dump-executable-configurations-to",
+ executableConfigurationsPath,
+ llvm::cl::desc("Path to write individual hal.executable input source "
+ "listings into, after translation strategy selection and "
+ "before starting translation (- for stdout)."),
+ llvm::cl::cat(halTargetOptionsCategory));
+
+ binder.opt<std::string>(
+ "iree-hal-dump-executable-benchmarks-to", executableBenchmarksPath,
+ llvm::cl::desc("Path to write standalone hal.executable benchmarks into "
+ "(- for stdout)."),
+ llvm::cl::cat(halTargetOptionsCategory));
+
+ binder.opt<std::string>("iree-hal-dump-executable-intermediates-to",
+ executableIntermediatesPath,
+ llvm::cl::desc("Path to write translated executable "
+ "intermediates (.bc, .o, etc) into."),
+ llvm::cl::cat(halTargetOptionsCategory));
+
+ binder.opt<std::string>(
+ "iree-hal-dump-executable-binaries-to", executableBinariesPath,
+ llvm::cl::desc(
+ "Path to write translated and serialized executable binaries into."),
+ llvm::cl::cat(halTargetOptionsCategory));
+}
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h
new file mode 100644
index 0000000..711e0e1
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetOptions.h
@@ -0,0 +1,58 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_TARGETOPTIONS_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_TARGETOPTIONS_H_
+
+#include <string>
+#include <vector>
+
+#include "iree/compiler/Utils/OptionUtils.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+// TODO(benvanik): remove this and replace with the pass pipeline options.
+// Controls executable translation targets.
+struct TargetOptions {
+ // TODO(benvanik): multiple targets of the same type, etc.
+ std::vector<std::string> targets;
+
+ // Coarse debug level for executable translation across all targets.
+ // Each target backend can use this to control its own flags, with values
+ // generally corresponding to the gcc-style levels 0-3:
+ // 0: no debug information
+ // 1: minimal debug information
+ // 2: default debug information
+ // 3: maximal debug information
+ int debugLevel;
+
+ // Default path to write executable files into.
+ std::string executableFilesPath;
+
+ // A path to write individual executable source listings into (before
+ // configuration).
+ std::string executableSourcesPath;
+
+ // A path to write individual executable source listings into (after
+ // configuration).
+ std::string executableConfigurationsPath;
+
+ // A path to write standalone executable benchmarks into.
+ std::string executableBenchmarksPath;
+
+ // A path to write executable intermediates into.
+ std::string executableIntermediatesPath;
+
+ // A path to write translated and serialized executable binaries into.
+ std::string executableBinariesPath;
+
+ void bindOptions(OptionsBinder &binder);
+ using FromFlags = OptionsFromFlags<TargetOptions>;
+};
+
+} // namespace mlir::iree_compiler::IREE::HAL
+
+#endif // IREE_COMPILER_DIALECT_HAL_TARGET_TARGETOPTIONS_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp
index f24a3c4..34687e4 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.cpp
@@ -11,79 +11,43 @@
namespace mlir::iree_compiler::IREE::HAL {
//===----------------------------------------------------------------------===//
-// TargetRegistration
+// TargetRegistry
//===----------------------------------------------------------------------===//
-// Returns the static registry of translator names to translation functions.
-static TargetRegistry &getMutableTargetRegistry() {
+// static
+TargetRegistry &TargetRegistry::getMutableTargetRegistry() {
static TargetRegistry global;
return global;
}
-TargetBackendRegistration::TargetBackendRegistration(
- llvm::StringRef name, TargetFactoryFn<TargetBackend> fn,
- bool registerStaticGlobal)
- : TargetRegistration<TargetBackend>(std::move(fn)) {
- if (registerStaticGlobal) {
- auto ®istry = getMutableTargetRegistry();
- if (registry.backendRegistrations.contains(name)) {
- llvm::report_fatal_error(
- "Attempting to overwrite an existing translation backend");
- }
- assert(initFn &&
- "Attempting to register an empty backend factory function");
- registry.backendRegistrations[name] = this;
- }
-}
-
-TargetDeviceRegistration::TargetDeviceRegistration(
- llvm::StringRef name, TargetFactoryFn<TargetDevice> fn,
- bool registerStaticGlobal)
- : TargetRegistration<TargetDevice>(std::move(fn)) {
- if (registerStaticGlobal) {
- auto ®istry = getMutableTargetRegistry();
- if (registry.deviceRegistrations.contains(name)) {
- llvm::report_fatal_error(
- "Attempting to overwrite an existing target device");
- }
- assert(initFn && "Attempting to register an empty device factory function");
- registry.deviceRegistrations[name] = this;
- }
-}
-
-//===----------------------------------------------------------------------===//
-// TargetRegistry
-//===----------------------------------------------------------------------===//
-
+// static
const TargetRegistry &TargetRegistry::getGlobal() {
return getMutableTargetRegistry();
}
+void TargetRegistry::mergeFrom(const TargetDeviceList &targetDevices) {
+ for (auto &it : targetDevices.entries) {
+ if (deviceRegistrations.contains(it.first)) {
+ llvm::report_fatal_error("Attempting to overwrite an existing device");
+ }
+ auto registration = std::make_unique<TargetDeviceRegistration>(it.second);
+ deviceRegistrations[it.first] = registration.get();
+ ownedDeviceRegistrations.push_back(std::move(registration));
+ }
+}
+
void TargetRegistry::mergeFrom(const TargetBackendList &targetBackends) {
for (auto &it : targetBackends.entries) {
if (backendRegistrations.contains(it.first)) {
llvm::report_fatal_error(
"Attempting to overwrite an existing translation backend");
}
- auto registration = std::make_unique<TargetBackendRegistration>(
- it.first, it.second, /*registerStaticGlobal=*/false);
+ auto registration = std::make_unique<TargetBackendRegistration>(it.second);
backendRegistrations[it.first] = registration.get();
ownedBackendRegistrations.push_back(std::move(registration));
}
}
-void TargetRegistry::mergeFrom(const TargetDeviceList &targetDevices) {
- for (auto &it : targetDevices.entries) {
- if (deviceRegistrations.contains(it.first)) {
- llvm::report_fatal_error("Attempting to overwrite an existing device");
- }
- auto registration = std::make_unique<TargetDeviceRegistration>(
- it.first, it.second, /*registerStaticGlobal=*/false);
- deviceRegistrations[it.first] = registration.get();
- ownedDeviceRegistrations.push_back(std::move(registration));
- }
-}
-
void TargetRegistry::mergeFrom(const TargetRegistry ®istry) {
for (auto &it : registry.deviceRegistrations) {
if (deviceRegistrations.contains(it.first())) {
@@ -100,16 +64,6 @@
}
}
-std::vector<std::string> TargetRegistry::getRegisteredTargetBackends() const {
- std::vector<std::string> result;
- for (auto &entry : backendRegistrations) {
- result.push_back(entry.getKey().str());
- }
- std::sort(result.begin(), result.end(),
- [](const auto &a, const auto &b) { return a < b; });
- return result;
-}
-
std::vector<std::string> TargetRegistry::getRegisteredTargetDevices() const {
std::vector<std::string> result;
for (auto &entry : deviceRegistrations) {
@@ -120,14 +74,14 @@
return result;
}
-std::shared_ptr<TargetBackend>
-TargetRegistry::getTargetBackend(StringRef targetName) const {
+std::vector<std::string> TargetRegistry::getRegisteredTargetBackends() const {
+ std::vector<std::string> result;
for (auto &entry : backendRegistrations) {
- if (entry.getKey() == targetName) {
- return entry.getValue()->acquire();
- }
+ result.push_back(entry.getKey().str());
}
- return {};
+ std::sort(result.begin(), result.end(),
+ [](const auto &a, const auto &b) { return a < b; });
+ return result;
}
std::shared_ptr<TargetDevice>
@@ -140,20 +94,14 @@
return {};
}
-SmallVector<std::shared_ptr<TargetBackend>>
-TargetRegistry::getTargetBackends(ArrayRef<std::string> targetNames) const {
- SmallVector<std::pair<std::string, std::shared_ptr<TargetBackend>>> matches;
- for (auto &targetName : targetNames) {
- auto targetBackend = getTargetBackend(targetName);
- if (targetBackend) {
- matches.push_back(std::make_pair(targetName, std::move(targetBackend)));
+std::shared_ptr<TargetBackend>
+TargetRegistry::getTargetBackend(StringRef targetName) const {
+ for (auto &entry : backendRegistrations) {
+ if (entry.getKey() == targetName) {
+ return entry.getValue()->acquire();
}
}
- // To ensure deterministic builds we sort matches by name.
- std::sort(matches.begin(), matches.end(),
- [](const auto &a, const auto &b) { return a.first < b.first; });
- return llvm::to_vector(llvm::map_range(
- matches, [](auto match) { return std::move(match.second); }));
+ return {};
}
SmallVector<std::shared_ptr<TargetDevice>>
@@ -172,6 +120,22 @@
matches, [](auto match) { return std::move(match.second); }));
}
+SmallVector<std::shared_ptr<TargetBackend>>
+TargetRegistry::getTargetBackends(ArrayRef<std::string> targetNames) const {
+ SmallVector<std::pair<std::string, std::shared_ptr<TargetBackend>>> matches;
+ for (auto &targetName : targetNames) {
+ auto targetBackend = getTargetBackend(targetName);
+ if (targetBackend) {
+ matches.push_back(std::make_pair(targetName, std::move(targetBackend)));
+ }
+ }
+ // To ensure deterministic builds we sort matches by name.
+ std::sort(matches.begin(), matches.end(),
+ [](const auto &a, const auto &b) { return a.first < b.first; });
+ return llvm::to_vector(llvm::map_range(
+ matches, [](auto match) { return std::move(match.second); }));
+}
+
} // namespace mlir::iree_compiler::IREE::HAL
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.h
index 21d9f65..a435588 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetRegistry.h
@@ -41,20 +41,8 @@
std::once_flag initFlag;
std::shared_ptr<T> cachedValue;
};
-class TargetBackendRegistration : public TargetRegistration<TargetBackend> {
-public:
- // TODO(#15468): remove the registerStaticGlobal mode once callers are
- // migrated and move the constructor to the template type.
- TargetBackendRegistration(StringRef name, TargetFactoryFn<TargetBackend> fn,
- bool registerStaticGlobal = true);
-};
-class TargetDeviceRegistration : public TargetRegistration<TargetDevice> {
-public:
- // TODO(#15468): remove the registerStaticGlobal mode once callers are
- // migrated and move the constructor to the template type.
- TargetDeviceRegistration(StringRef name, TargetFactoryFn<TargetDevice> fn,
- bool registerStaticGlobal = true);
-};
+using TargetBackendRegistration = TargetRegistration<TargetBackend>;
+using TargetDeviceRegistration = TargetRegistration<TargetDevice>;
template <typename T>
class TargetFactoryList {
@@ -77,50 +65,50 @@
// A concrete target registry.
class TargetRegistry {
public:
+ // Returns the global registry.
+ // This should only be used at initialization-time for built-in devices.
+ // All other usage should use scoped registries.
+ static TargetRegistry &getMutableTargetRegistry();
// Returns the read-only global registry.
// This is used by passes which depend on it from their default constructor.
static const TargetRegistry &getGlobal();
- // Merge from a list of of target backends.
- // The receiving registry will own the registration entries.
- void mergeFrom(const TargetBackendList &targetBackends);
// Merge from a list of of target devices.
// The receiving registry will own the registration entries.
void mergeFrom(const TargetDeviceList &targetDevices);
+ // Merge from a list of of target backends.
+ // The receiving registry will own the registration entries.
+ void mergeFrom(const TargetBackendList &targetBackends);
// Initialize from an existing registry. This registry will not own the
// backing registration entries. The source registry must remain live for the
// life of this.
void mergeFrom(const TargetRegistry ®istry);
- // Returns a list of registered target backends.
- std::vector<std::string> getRegisteredTargetBackends() const;
// Returns a list of registered target devices.
std::vector<std::string> getRegisteredTargetDevices() const;
+ // Returns a list of registered target backends.
+ std::vector<std::string> getRegisteredTargetBackends() const;
- // Returns the target backend with the given name.
- std::shared_ptr<TargetBackend> getTargetBackend(StringRef targetName) const;
// Returns the target device with the given name.
std::shared_ptr<TargetDevice> getTargetDevice(StringRef targetName) const;
+ // Returns the target backend with the given name.
+ std::shared_ptr<TargetBackend> getTargetBackend(StringRef targetName) const;
- // Returns one backend per entry in |targetNames|.
- SmallVector<std::shared_ptr<TargetBackend>>
- getTargetBackends(ArrayRef<std::string> targetNames) const;
// Returns one device per entry in |targetNames|.
SmallVector<std::shared_ptr<TargetDevice>>
getTargetDevices(ArrayRef<std::string> targetNames) const;
+ // Returns one backend per entry in |targetNames|.
+ SmallVector<std::shared_ptr<TargetBackend>>
+ getTargetBackends(ArrayRef<std::string> targetNames) const;
private:
- llvm::StringMap<TargetBackendRegistration *> backendRegistrations;
- llvm::SmallVector<std::unique_ptr<TargetBackendRegistration>>
- ownedBackendRegistrations;
llvm::StringMap<TargetDeviceRegistration *> deviceRegistrations;
llvm::SmallVector<std::unique_ptr<TargetDeviceRegistration>>
ownedDeviceRegistrations;
-
- // TODO(#15468): remove this when not used by LLVMCPU/VulkanSPIRV.
- friend class TargetBackendRegistration;
- friend class TargetDeviceRegistration;
+ llvm::StringMap<TargetBackendRegistration *> backendRegistrations;
+ llvm::SmallVector<std::unique_ptr<TargetBackendRegistration>>
+ ownedBackendRegistrations;
};
} // namespace mlir::iree_compiler::IREE::HAL
@@ -140,9 +128,12 @@
TargetRegistryRef(const mlir::iree_compiler::IREE::HAL::TargetRegistry *value)
: value(value) {}
operator bool() const noexcept {
- return value->getRegisteredTargetBackends() !=
- mlir::iree_compiler::IREE::HAL::TargetRegistry::getGlobal()
- .getRegisteredTargetBackends();
+ return value->getRegisteredTargetDevices() !=
+ mlir::iree_compiler::IREE::HAL::TargetRegistry::getGlobal()
+ .getRegisteredTargetDevices() &&
+ value->getRegisteredTargetBackends() !=
+ mlir::iree_compiler::IREE::HAL::TargetRegistry::getGlobal()
+ .getRegisteredTargetBackends();
}
const mlir::iree_compiler::IREE::HAL::TargetRegistry *operator->() const {
return value;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
index 1305d20..ec5c9c5 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -9,6 +9,8 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetDevice.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetOptions.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.h b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.h
index ed0c8d9..111af4d 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.h
@@ -8,6 +8,7 @@
#define IREE_COMPILER_MODULES_HAL_INLINE_TRANSFORMS_PASSES_H_
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetOptions.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.h"
#include "llvm/ADT/StringMap.h"
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.h b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.h
index 6f65082..d956f24 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.h
@@ -8,6 +8,7 @@
#define IREE_COMPILER_MODULES_HAL_LOADER_TRANSFORMS_PASSES_H_
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetOptions.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.h"
#include "llvm/ADT/StringMap.h"
diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.h b/compiler/src/iree/compiler/Pipelines/Pipelines.h
index 1b06390..1fba747 100644
--- a/compiler/src/iree/compiler/Pipelines/Pipelines.h
+++ b/compiler/src/iree/compiler/Pipelines/Pipelines.h
@@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_PIPELINES_PIPELINES_H_
#define IREE_COMPILER_PIPELINES_PIPELINES_H_
+#include "iree/compiler/Dialect/HAL/Target/TargetOptions.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/VM/Conversion/TargetOptions.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
diff --git a/compiler/src/iree/compiler/PluginAPI/Client.h b/compiler/src/iree/compiler/PluginAPI/Client.h
index 3a9652d..0272e25 100644
--- a/compiler/src/iree/compiler/PluginAPI/Client.h
+++ b/compiler/src/iree/compiler/PluginAPI/Client.h
@@ -179,15 +179,15 @@
// it should emit an appropriate diagnostic.
LogicalResult activate(MLIRContext *context);
+ // Populates new HAL target devices, if any, into the given list.
+ // Targets will be merged into the plugin session-owned registry.
+ virtual void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {}
+
// Populates new HAL target backends, if any, into the given list.
// Targets will be merged into the plugin session-owned registry.
virtual void
populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {}
- // Populates new HAL target devices, if any, into the given list.
- // Targets will be merged into the plugin session-owned registry.
- virtual void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {}
-
protected:
// Called from registerDialects() prior to initializing the context and
// prior to onActivate().
diff --git a/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp b/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp
index f2b5628..3f37b9f 100644
--- a/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp
+++ b/compiler/src/iree/compiler/PluginAPI/PluginManager.cpp
@@ -193,13 +193,6 @@
return success();
}
-void PluginManagerSession::populateHALTargetBackends(
- IREE::HAL::TargetBackendList &list) {
- for (auto *s : initializedSessions) {
- s->populateHALTargetBackends(list);
- }
-}
-
void PluginManagerSession::populateHALTargetDevices(
IREE::HAL::TargetDeviceList &list) {
for (auto *s : initializedSessions) {
@@ -207,4 +200,11 @@
}
}
+void PluginManagerSession::populateHALTargetBackends(
+ IREE::HAL::TargetBackendList &list) {
+ for (auto *s : initializedSessions) {
+ s->populateHALTargetBackends(list);
+ }
+}
+
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/PluginAPI/PluginManager.h b/compiler/src/iree/compiler/PluginAPI/PluginManager.h
index 7d5c4ff..7f820dc 100644
--- a/compiler/src/iree/compiler/PluginAPI/PluginManager.h
+++ b/compiler/src/iree/compiler/PluginAPI/PluginManager.h
@@ -133,14 +133,14 @@
}
}
- // Populates the given list of HAL target backends for all initialized
- // plugins.
- void populateHALTargetBackends(IREE::HAL::TargetBackendList &list);
-
// Populates the given list of HAL target devices for all initialized
// plugins.
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &list);
+ // Populates the given list of HAL target backends for all initialized
+ // plugins.
+ void populateHALTargetBackends(IREE::HAL::TargetBackendList &list);
+
private:
PluginManagerOptions &options;
// At construction, uninitialized plugin sessions are created for all