Moving the LocalDevice impl out of LLVM-CPU/VMVX.
It's still registered there for backwards compatibility with the
iree-hal-target-backends flag.
diff --git a/compiler/plugins/target/LLVMCPU/BUILD.bazel b/compiler/plugins/target/LLVMCPU/BUILD.bazel
index f8fe708..1cbc486 100644
--- a/compiler/plugins/target/LLVMCPU/BUILD.bazel
+++ b/compiler/plugins/target/LLVMCPU/BUILD.bazel
@@ -37,6 +37,7 @@
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
+ "//compiler/src/iree/compiler/Dialect/HAL/Target/Devices",
"//compiler/src/iree/compiler/Dialect/HAL/Utils:LLVMLinkerUtils",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
diff --git a/compiler/plugins/target/LLVMCPU/CMakeLists.txt b/compiler/plugins/target/LLVMCPU/CMakeLists.txt
index cd4d3fb..cf27486 100644
--- a/compiler/plugins/target/LLVMCPU/CMakeLists.txt
+++ b/compiler/plugins/target/LLVMCPU/CMakeLists.txt
@@ -56,6 +56,7 @@
iree::compiler::Codegen::LLVMCPU
iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::HAL::Target::Devices
iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::Util::IR
diff --git a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
index fbcb44c..864e6f6 100644
--- a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
+++ b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
@@ -19,8 +19,9 @@
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h"
+#include "iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/ModuleUtils.h"
@@ -122,51 +123,6 @@
return success();
}
-class LLVMCPUTargetDevice final : public TargetDevice {
-public:
- LLVMCPUTargetDevice() = default;
-
- IREE::HAL::DeviceTargetAttr
- getDefaultDeviceTarget(MLIRContext *context,
- const TargetRegistry &targetRegistry) const override {
- Builder b(context);
- SmallVector<NamedAttribute> configItems;
-
- auto configAttr = b.getDictionaryAttr(configItems);
-
- // If we had multiple target environments we would generate one target attr
- // per environment, with each setting its own environment attribute.
- SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
- targetRegistry.getTargetBackend("llvm-cpu")
- ->getDefaultExecutableTargets(context, "llvm-cpu", configAttr,
- executableTargetAttrs);
-
- return IREE::HAL::DeviceTargetAttr::get(context,
- b.getStringAttr("llvm-cpu"),
- configAttr, executableTargetAttrs);
- }
-
- std::optional<IREE::HAL::DeviceTargetAttr>
- getHostDeviceTarget(MLIRContext *context,
- const TargetRegistry &targetRegistry) const override {
- Builder b(context);
- SmallVector<NamedAttribute> configItems;
-
- auto configAttr = b.getDictionaryAttr(configItems);
-
- // If we had multiple target environments we would generate one target attr
- // per environment, with each setting its own environment attribute.
- SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
- targetRegistry.getTargetBackend("llvm-cpu")
- ->getHostExecutableTargets(context, "llvm-cpu", configAttr,
- executableTargetAttrs);
-
- return IREE::HAL::DeviceTargetAttr::get(context,
- b.getStringAttr("llvm-cpu"),
- configAttr, executableTargetAttrs);
- }
-};
-
class LLVMCPUTargetBackend final : public TargetBackend {
public:
explicit LLVMCPUTargetBackend(LLVMTargetOptions options)
@@ -858,9 +814,17 @@
: public PluginSession<LLVMCPUSession, LLVMCPUTargetCLOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
+ // TODO(multi-device): move local device registration out.
+ // This exists here for backwards compat with the old
+ // iree-hal-target-backends flag that needs to look up the device by backend
+ // name.
// #hal.device.target<"llvm-cpu", ...
- targets.add("llvm-cpu",
- [=]() { return std::make_shared<LLVMCPUTargetDevice>(); });
+ targets.add("llvm-cpu", [=]() {
+ LocalDevice::Options localDeviceOptions;
+ localDeviceOptions.defaultTargetBackends.push_back("llvm-cpu");
+ localDeviceOptions.defaultHostBackends.push_back("llvm-cpu");
+ return std::make_shared<LocalDevice>(localDeviceOptions);
+ });
}
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
// #hal.executable.target<"llvm-cpu", ...
diff --git a/compiler/plugins/target/VMVX/BUILD.bazel b/compiler/plugins/target/VMVX/BUILD.bazel
index 6d6c7d7..31d8b87 100644
--- a/compiler/plugins/target/VMVX/BUILD.bazel
+++ b/compiler/plugins/target/VMVX/BUILD.bazel
@@ -26,6 +26,7 @@
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/VMVX",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
+ "//compiler/src/iree/compiler/Dialect/HAL/Target/Devices",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"//compiler/src/iree/compiler/Dialect/VM/IR",
diff --git a/compiler/plugins/target/VMVX/CMakeLists.txt b/compiler/plugins/target/VMVX/CMakeLists.txt
index a3124f3..7af0a88 100644
--- a/compiler/plugins/target/VMVX/CMakeLists.txt
+++ b/compiler/plugins/target/VMVX/CMakeLists.txt
@@ -30,6 +30,7 @@
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::VMVX
iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::HAL::Target::Devices
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::VM::Conversion
iree::compiler::Dialect::VM::IR
diff --git a/compiler/plugins/target/VMVX/VMVXTarget.cpp b/compiler/plugins/target/VMVX/VMVXTarget.cpp
index f7e3a8d..7620b5a 100644
--- a/compiler/plugins/target/VMVX/VMVXTarget.cpp
+++ b/compiler/plugins/target/VMVX/VMVXTarget.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
+#include "iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
@@ -53,38 +54,6 @@
b.getDictionaryAttr(configItems));
}
-// TODO(benvanik): move to a CPU device registration outside of VMVX.
-class VMVXTargetDevice final : public TargetDevice {
-public:
- VMVXTargetDevice() = default;
-
- IREE::HAL::DeviceTargetAttr
- getDefaultDeviceTarget(MLIRContext *context,
- const TargetRegistry &targetRegistry) const override {
- Builder b(context);
- SmallVector<NamedAttribute> configItems;
-
- auto configAttr = b.getDictionaryAttr(configItems);
-
- // If we had multiple target environments we would generate one target attr
- // per environment, with each setting its own environment attribute.
- // If we had multiple target environments we would generate one target attr
- // per environment, with each setting its own environment attribute.
- SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
- targetRegistry.getTargetBackend("vmvx")->getDefaultExecutableTargets(
- context, "vmvx", configAttr, executableTargetAttrs);
-
- return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("vmvx"),
- configAttr, executableTargetAttrs);
- }
-
- std::optional<IREE::HAL::DeviceTargetAttr>
- getHostDeviceTarget(MLIRContext *context,
- const TargetRegistry &targetRegistry) const override {
- return getDefaultDeviceTarget(context, targetRegistry);
- }
-};
-
class VMVXTargetBackend final : public TargetBackend {
public:
VMVXTargetBackend(const VMVXOptions &options) : options(options) {}
@@ -265,20 +234,28 @@
: public PluginSession<VMVXSession, VMVXOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
- // TODO(benvanik): move to a CPU device registration outside of VMVX. Note
- // that the inline device does need to be special.
+ // TODO(multi-device): move local device registration out.
+ // This exists here for backwards compat with the old
+ // iree-hal-target-backends flag that needs to look up the device by backend
+ // name.
+ // Note that the inline device does need to be special.
// #hal.device.target<"vmvx", ...
- targets.add("vmvx", [&]() { return std::make_shared<VMVXTargetDevice>(); });
+ targets.add("vmvx", [=]() {
+ LocalDevice::Options localDeviceOptions;
+ localDeviceOptions.defaultTargetBackends.push_back("vmvx");
+ localDeviceOptions.defaultHostBackends.push_back("vmvx");
+ return std::make_shared<LocalDevice>(localDeviceOptions);
+ });
// #hal.device.target<"vmvx-inline", ...
targets.add("vmvx-inline",
- [&]() { return std::make_shared<VMVXInlineTargetDevice>(); });
+ [=]() { return std::make_shared<VMVXInlineTargetDevice>(); });
}
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
// #hal.executable.target<"vmvx", ...
targets.add("vmvx",
- [&]() { return std::make_shared<VMVXTargetBackend>(options); });
+ [=]() { return std::make_shared<VMVXTargetBackend>(options); });
// #hal.executable.target<"vmvx-inline", ...
- targets.add("vmvx-inline", [&]() {
+ targets.add("vmvx-inline", [=]() {
return std::make_shared<VMVXInlineTargetBackend>(options);
});
}
diff --git a/compiler/plugins/target/VMVX/test/smoketest.mlir b/compiler/plugins/target/VMVX/test/smoketest.mlir
index ef56028..b640d12 100644
--- a/compiler/plugins/target/VMVX/test/smoketest.mlir
+++ b/compiler/plugins/target/VMVX/test/smoketest.mlir
@@ -2,7 +2,7 @@
module attributes {
hal.device.targets = [
- #hal.device.target<"vmvx", [
+ #hal.device.target<"local", [
#hal.executable.target<"vmvx", "vmvx-bytecode-fb">
]>
]
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
index 82d38ae..1ff5f24 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
@@ -206,6 +206,13 @@
return configAttrs;
}
+void DeviceTargetAttr::getExecutableTargets(
+ SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs) {
+ for (auto attr : getExecutableTargets()) {
+ resultAttrs.insert(attr);
+ }
+}
+
// Returns a set of all configuration attributes from all device targets.
// Returns nullopt if any target is missing a configuration attribute.
static std::optional<SmallVector<DictionaryAttr>>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index 3221fe4..37ab489 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -547,6 +547,10 @@
static std::optional<StaticRange<APInt>>
lookupConfigAttrRange(Operation *op, StringRef name);
+ // Returns zero or more executable targets that this device supports.
+ void getExecutableTargets(
+ SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs);
+
// Returns a list of all target executable configurations that may be
// required for the given operation.
static SmallVector<IREE::HAL::ExecutableTargetAttr, 4>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
index f04a329..63207aa 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
@@ -27,13 +27,14 @@
"TargetRegistry.h",
],
deps = [
- "//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
index b8b5e13..5c6a84b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
@@ -25,11 +25,12 @@
"TargetRegistry.cpp"
DEPS
LLVMSupport
+ MLIRArithDialect
MLIRIR
MLIRPass
+ MLIRSCFDialect
MLIRSupport
MLIRTransforms
- iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/BUILD.bazel
new file mode 100644
index 0000000..bdfe673
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/BUILD.bazel
@@ -0,0 +1,33 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_compiler_cc_library(
+ name = "Devices",
+ srcs = [
+ "LocalDevice.cpp",
+ ],
+ hdrs = [
+ "LocalDevice.h",
+ ],
+ deps = [
+ "//compiler/src/iree/compiler/Dialect/HAL/IR",
+ "//compiler/src/iree/compiler/Dialect/HAL/Target",
+ "//compiler/src/iree/compiler/Utils",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/CMakeLists.txt
new file mode 100644
index 0000000..3a996ae
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/CMakeLists.txt
@@ -0,0 +1,32 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# compiler/src/iree/compiler/Dialect/HAL/Target/Devices/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ Devices
+ HDRS
+ "LocalDevice.h"
+ SRCS
+ "LocalDevice.cpp"
+ DEPS
+ LLVMSupport
+ MLIRIR
+ MLIRPass
+ MLIRSupport
+ MLIRTransforms
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::HAL::Target
+ iree::compiler::Utils
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp
new file mode 100644
index 0000000..302ad62
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp
@@ -0,0 +1,94 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h"
+
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "llvm/Support/CommandLine.h"
+
+IREE_DEFINE_COMPILER_OPTION_FLAGS(
+ mlir::iree_compiler::IREE::HAL::LocalDevice::Options);
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+void LocalDevice::Options::bindOptions(OptionsBinder &binder) {
+ static llvm::cl::OptionCategory optionsCategory(
+ "IREE HAL local device options");
+
+ binder.list<std::string>(
+ "iree-hal-local-target-device-backends", defaultTargetBackends,
+ llvm::cl::desc(
+ "Default target backends for local device executable compilation."),
+ llvm::cl::ZeroOrMore, llvm::cl::cat(optionsCategory));
+ binder.list<std::string>(
+ "iree-hal-local-host-device-backends", defaultHostBackends,
+ llvm::cl::desc(
+ "Default host backends for local device executable compilation."),
+ llvm::cl::ZeroOrMore, llvm::cl::cat(optionsCategory));
+}
+
+LocalDevice::LocalDevice(const LocalDevice::Options options)
+ : options(options) {}
+
+IREE::HAL::DeviceTargetAttr LocalDevice::getDefaultDeviceTarget(
+ MLIRContext *context, const TargetRegistry &targetRegistry) const {
+ Builder b(context);
+ SmallVector<NamedAttribute> configItems;
+
+ // TODO(benvanik): flags for common queries.
+
+ auto configAttr = b.getDictionaryAttr(configItems);
+
+ SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
+ for (auto backendName : options.defaultTargetBackends) {
+ auto targetBackend = targetRegistry.getTargetBackend(backendName);
+ if (!targetBackend) {
+ llvm::errs() << "Default target backend not registered: " << backendName
+ << "\n";
+ return {};
+ }
+ targetBackend->getDefaultExecutableTargets(context, "local", configAttr,
+ executableTargetAttrs);
+ }
+
+ return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("local"),
+ configAttr, executableTargetAttrs);
+}
+
+std::optional<IREE::HAL::DeviceTargetAttr>
+LocalDevice::getHostDeviceTarget(MLIRContext *context,
+ const TargetRegistry &targetRegistry) const {
+ Builder b(context);
+ SmallVector<NamedAttribute> configItems;
+
+ // TODO(benvanik): flags for overrides or ask LLVM for info about the host.
+
+ auto configAttr = b.getDictionaryAttr(configItems);
+
+ SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
+ for (auto backendName : options.defaultHostBackends) {
+ auto targetBackend = targetRegistry.getTargetBackend(backendName);
+ if (!targetBackend) {
+ llvm::errs() << "Default host backend not registered: " << backendName
+ << "\n";
+ return std::nullopt;
+ }
+ targetBackend->getHostExecutableTargets(context, "local", configAttr,
+ executableTargetAttrs);
+ }
+
+ return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("local"),
+ configAttr, executableTargetAttrs);
+}
+
+Value LocalDevice::buildDeviceTargetMatch(
+ Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr,
+ OpBuilder &builder) const {
+ return buildDeviceIDAndExecutableFormatsMatch(
+ loc, device, "local*", targetAttr.getExecutableTargets(), builder);
+}
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h
new file mode 100644
index 0000000..8e18374
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h
@@ -0,0 +1,50 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_DEVICE_LOCALDEVICE_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_DEVICE_LOCALDEVICE_H_
+
+#include <string>
+#include <vector>
+
+#include "iree/compiler/Dialect/HAL/Target/TargetDevice.h"
+#include "iree/compiler/Utils/OptionUtils.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+class LocalDevice final : public TargetDevice {
+public:
+ struct Options {
+ // A list of default target backends for local devices.
+ std::vector<std::string> defaultTargetBackends;
+ // A list of default host backends for local devices.
+ std::vector<std::string> defaultHostBackends;
+
+ void bindOptions(OptionsBinder &binder);
+ using FromFlags = OptionsFromFlags<Options>;
+ };
+
+ explicit LocalDevice(const Options options);
+
+ IREE::HAL::DeviceTargetAttr
+ getDefaultDeviceTarget(MLIRContext *context,
+ const TargetRegistry &targetRegistry) const override;
+
+ std::optional<IREE::HAL::DeviceTargetAttr>
+ getHostDeviceTarget(MLIRContext *context,
+ const TargetRegistry &targetRegistry) const override;
+
+ Value buildDeviceTargetMatch(Location loc, Value device,
+ IREE::HAL::DeviceTargetAttr targetAttr,
+ OpBuilder &builder) const override;
+
+private:
+ const Options options;
+};
+
+} // namespace mlir::iree_compiler::IREE::HAL
+
+#endif // IREE_COMPILER_DIALECT_HAL_TARGET_DEVICE_LOCALDEVICE_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
index bace88a..1695c50 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
@@ -6,8 +6,66 @@
#include "iree/compiler/Dialect/HAL/Target/TargetDevice.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+
namespace mlir::iree_compiler::IREE::HAL {
-// TODO(benvanik): add device options.
+// virtual
+Value TargetDevice::buildDeviceTargetMatch(
+ Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr,
+ OpBuilder &builder) const {
+ return buildDeviceIDAndExecutableFormatsMatch(
+ loc, device, targetAttr.getDeviceID(), targetAttr.getExecutableTargets(),
+ builder);
+}
+
+Value buildDeviceIDAndExecutableFormatsMatch(
+ Location loc, Value device, StringRef deviceIDPattern,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
+ OpBuilder &builder) {
+ // Match first on the device ID, as that's the top-level filter.
+ Value idMatch = IREE::HAL::DeviceQueryOp::createI1(
+ loc, device, "hal.device.id", deviceIDPattern, builder);
+
+ // If there are executable formats defined we should check at least one of
+ // them is supported.
+ if (executableTargetAttrs.empty()) {
+ return idMatch; // just device ID
+ } else {
+ auto ifOp = builder.create<scf::IfOp>(loc, builder.getI1Type(), idMatch,
+ true, true);
+ auto thenBuilder = ifOp.getThenBodyBuilder();
+ Value anyFormatMatch = buildExecutableFormatMatch(
+ loc, device, executableTargetAttrs, thenBuilder);
+ thenBuilder.create<scf::YieldOp>(loc, anyFormatMatch);
+ auto elseBuilder = ifOp.getElseBodyBuilder();
+ Value falseValue = elseBuilder.create<arith::ConstantIntOp>(loc, 0, 1);
+ elseBuilder.create<scf::YieldOp>(loc, falseValue);
+ return ifOp.getResult(0);
+ }
+}
+
+Value buildExecutableFormatMatch(
+ Location loc, Value device,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
+ OpBuilder &builder) {
+ if (executableTargetAttrs.empty())
+ return builder.create<arith::ConstantIntOp>(loc, 1, 1);
+ Value anyFormatMatch;
+ for (auto executableTargetAttr : executableTargetAttrs) {
+ Value formatMatch = IREE::HAL::DeviceQueryOp::createI1(
+ loc, device, "hal.executable.format",
+ executableTargetAttr.getFormat().getValue(), builder);
+ if (!anyFormatMatch) {
+ anyFormatMatch = formatMatch;
+ } else {
+ anyFormatMatch =
+ builder.create<arith::OrIOp>(loc, anyFormatMatch, formatMatch);
+ }
+ }
+ return anyFormatMatch;
+}
} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h
index 357e51f..2ce53f4 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h
@@ -38,10 +38,31 @@
return {};
}
+ // Builds an expression that returns an i1 indicating whether the given
+ // |device| matches the |targetAttr| requirements.
+ virtual Value buildDeviceTargetMatch(Location loc, Value device,
+ IREE::HAL::DeviceTargetAttr targetAttr,
+ OpBuilder &builder) const;
+
// TODO(benvanik): pipeline registration for specialization of host code at
// various stages.
};
+// Builds an expression that returns an i1 indicating whether the given
+// |device| matches the device ID string pattern and executable target
+// requirements.
+Value buildDeviceIDAndExecutableFormatsMatch(
+ Location loc, Value device, StringRef deviceIDPattern,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
+ OpBuilder &builder);
+
+// Builds a match expression that returns an i1 indicating whether the given
+// |device| supports any one of the |executableTargetAttrs|.
+Value buildExecutableFormatMatch(
+ Location loc, Value device,
+ ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs,
+ OpBuilder &builder);
+
} // namespace mlir::iree_compiler::IREE::HAL
#endif // IREE_COMPILER_DIALECT_HAL_TARGET_TARGETDEVICE_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
index 2847a83..33e9561 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
@@ -57,6 +57,7 @@
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
+ "//compiler/src/iree/compiler/Dialect/HAL/Target/Devices",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Stream/Transforms",
"//compiler/src/iree/compiler/Dialect/Util/Conversion",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 8c05849..6cce442 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -71,6 +71,7 @@
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::HAL::Target::Devices
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Stream::Transforms
iree::compiler::Dialect::Util::Conversion
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index baf469d..5fe1fd6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Utils/OptionUtils.h"
#include "iree/compiler/Utils/PassUtils.h"
@@ -542,6 +543,16 @@
// Force the flags to be bound.
// TODO(benvanik): remove the global flags and only rely on pipeline flags.
(void)IREE::HAL::TargetOptions::FromFlags::get();
+ // TODO(multi-device): move the local device registration somewhere more
+ // centralized. For now we piggy-back on the pass registration as that's where
+ // the local device is used.
+ (void)IREE::HAL::LocalDevice::Options::FromFlags::get();
+ IREE::HAL::TargetDeviceList deviceList;
+ deviceList.add("local", [=]() {
+ return std::make_shared<LocalDevice>(
+ IREE::HAL::LocalDevice::Options::FromFlags::get());
+ });
+ IREE::HAL::TargetRegistry::getMutableTargetRegistry().mergeFrom(deviceList);
// Generated.
registerPasses();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
index a889c7d..46c0a5e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
@@ -3,21 +3,21 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx,vmvx-inline})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-2
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx,vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-EQ
-// TARGET-1: #device_target_vmvx = #hal.device.target<"vmvx"
+// TARGET-1: #device_target_local = #hal.device.target<"local"
-// TARGET-2: #device_target_vmvx = #hal.device.target<"vmvx"
+// TARGET-2: #device_target_local = #hal.device.target<"local"
// TARGET-2: #device_target_vmvx_inline = #hal.device.target<"vmvx-inline"
-// TARGET-EQ: #device_target_vmvx = #hal.device.target<"vmvx"
+// TARGET-EQ: #device_target_local = #hal.device.target<"local"
// CHECK: module
// TARGET-0: @module {
// TARGET-1: @module attributes {
-// TARGET-1-SAME: hal.device.targets = [#device_target_vmvx]
+// TARGET-1-SAME: hal.device.targets = [#device_target_local]
// TARGET-2: @module attributes {
-// TARGET-2-SAME: hal.device.targets = [#device_target_vmvx, #device_target_vmvx_inline]}
+// TARGET-2-SAME: hal.device.targets = [#device_target_local, #device_target_vmvx_inline]}
// TARGET-EQ: @module attributes {
-// TARGET-EQ-SAME: hal.device.targets = [#device_target_vmvx]}
+// TARGET-EQ-SAME: hal.device.targets = [#device_target_local]}
module @module {}
// -----
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
index 1a913ed..f47bcd2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
@@ -30,7 +30,7 @@
// Tests for a no-op if there are no devices requiring legacy mode.
module attributes {hal.device.targets = [
- #hal.device.target<"vmvx", {}>,
+ #hal.device.target<"local", {}>,
#hal.device.target<"vulkan", {}>
]} {
// CHECK-LABEL: @legacy_mode_not_required
@@ -46,7 +46,7 @@
// Tests that any device requiring legacy_sync will trigger the pass.
module attributes {hal.device.targets = [
- #hal.device.target<"vmvx", {}>,
+ #hal.device.target<"local", {}>,
#hal.device.target<"vulkan", {legacy_sync}>
]} {
// CHECK-LABEL: @mixed_legacy_mode_required
diff --git a/tests/compiler_driver/precompile.mlir b/tests/compiler_driver/precompile.mlir
index ecc3d0b..5cdd117 100644
--- a/tests/compiler_driver/precompile.mlir
+++ b/tests/compiler_driver/precompile.mlir
@@ -7,4 +7,4 @@
}
// Just check that we have the right target and executable targets.
-// CHECK: module attributes {hal.device.targets = [#hal.device.target<"vmvx", [#hal.executable.target<"vmvx"
+// CHECK: module attributes {hal.device.targets = [#hal.device.target<"local", [#hal.executable.target<"vmvx"