Moving the LocalDevice impl out of LLVM-CPU/VMVX. It's still registered there for backwards compatibility with the iree-hal-target-backends flag.
diff --git a/compiler/plugins/target/LLVMCPU/BUILD.bazel b/compiler/plugins/target/LLVMCPU/BUILD.bazel index f8fe708..1cbc486 100644 --- a/compiler/plugins/target/LLVMCPU/BUILD.bazel +++ b/compiler/plugins/target/LLVMCPU/BUILD.bazel
@@ -37,6 +37,7 @@ "//compiler/src/iree/compiler/Codegen/LLVMCPU", "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Dialect/HAL/Target/Devices", "//compiler/src/iree/compiler/Dialect/HAL/Utils:LLVMLinkerUtils", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/Util/IR",
diff --git a/compiler/plugins/target/LLVMCPU/CMakeLists.txt b/compiler/plugins/target/LLVMCPU/CMakeLists.txt index cd4d3fb..cf27486 100644 --- a/compiler/plugins/target/LLVMCPU/CMakeLists.txt +++ b/compiler/plugins/target/LLVMCPU/CMakeLists.txt
@@ -56,6 +56,7 @@ iree::compiler::Codegen::LLVMCPU iree::compiler::Codegen::Utils iree::compiler::Dialect::HAL::Target + iree::compiler::Dialect::HAL::Target::Devices iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::Util::IR
diff --git a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp index fbcb44c..864e6f6 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
@@ -19,8 +19,9 @@ #include "iree/compiler/Codegen/LLVMCPU/Passes.h" #include "iree/compiler/Codegen/LLVMCPU/Utils.h" #include "iree/compiler/Codegen/Utils/Utils.h" -#include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h" +#include "iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h" #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" +#include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/PluginAPI/Client.h" #include "iree/compiler/Utils/ModuleUtils.h" @@ -122,51 +123,6 @@ return success(); } -class LLVMCPUTargetDevice final : public TargetDevice { -public: - LLVMCPUTargetDevice() = default; - - IREE::HAL::DeviceTargetAttr - getDefaultDeviceTarget(MLIRContext *context, - const TargetRegistry &targetRegistry) const override { - Builder b(context); - SmallVector<NamedAttribute> configItems; - - auto configAttr = b.getDictionaryAttr(configItems); - - // If we had multiple target environments we would generate one target attr - // per environment, with each setting its own environment attribute. - SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs; - targetRegistry.getTargetBackend("llvm-cpu") - ->getDefaultExecutableTargets(context, "llvm-cpu", configAttr, - executableTargetAttrs); - - return IREE::HAL::DeviceTargetAttr::get(context, - b.getStringAttr("llvm-cpu"), - configAttr, executableTargetAttrs); - } - - std::optional<IREE::HAL::DeviceTargetAttr> - getHostDeviceTarget(MLIRContext *context, - const TargetRegistry &targetRegistry) const override { - Builder b(context); - SmallVector<NamedAttribute> configItems; - - auto configAttr = b.getDictionaryAttr(configItems); - - // If we had multiple target environments we would generate one target attr - // per environment, with each setting its own environment attribute. - SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs; - targetRegistry.getTargetBackend("llvm-cpu") - ->getHostExecutableTargets(context, "llvm-cpu", configAttr, - executableTargetAttrs); - - return IREE::HAL::DeviceTargetAttr::get(context, - b.getStringAttr("llvm-cpu"), - configAttr, executableTargetAttrs); - } -}; - class LLVMCPUTargetBackend final : public TargetBackend { public: explicit LLVMCPUTargetBackend(LLVMTargetOptions options) @@ -858,9 +814,17 @@ : public PluginSession<LLVMCPUSession, LLVMCPUTargetCLOptions, PluginActivationPolicy::DefaultActivated> { void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) { + // TODO(multi-device): move local device registration out. + // This exists here for backwards compat with the old + // iree-hal-target-backends flag that needs to look up the device by backend + // name. // #hal.device.target<"llvm-cpu", ... - targets.add("llvm-cpu", - [=]() { return std::make_shared<LLVMCPUTargetDevice>(); }); + targets.add("llvm-cpu", [=]() { + LocalDevice::Options localDeviceOptions; + localDeviceOptions.defaultTargetBackends.push_back("llvm-cpu"); + localDeviceOptions.defaultHostBackends.push_back("llvm-cpu"); + return std::make_shared<LocalDevice>(localDeviceOptions); + }); } void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) { // #hal.executable.target<"llvm-cpu", ...
diff --git a/compiler/plugins/target/VMVX/BUILD.bazel b/compiler/plugins/target/VMVX/BUILD.bazel index 6d6c7d7..31d8b87 100644 --- a/compiler/plugins/target/VMVX/BUILD.bazel +++ b/compiler/plugins/target/VMVX/BUILD.bazel
@@ -26,6 +26,7 @@ "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/VMVX", "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Dialect/HAL/Target/Devices", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Dialect/VM/IR",
diff --git a/compiler/plugins/target/VMVX/CMakeLists.txt b/compiler/plugins/target/VMVX/CMakeLists.txt index a3124f3..7af0a88 100644 --- a/compiler/plugins/target/VMVX/CMakeLists.txt +++ b/compiler/plugins/target/VMVX/CMakeLists.txt
@@ -30,6 +30,7 @@ iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::VMVX iree::compiler::Dialect::HAL::Target + iree::compiler::Dialect::HAL::Target::Devices iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::VM::Conversion iree::compiler::Dialect::VM::IR
diff --git a/compiler/plugins/target/VMVX/VMVXTarget.cpp b/compiler/plugins/target/VMVX/VMVXTarget.cpp index f7e3a8d..7620b5a 100644 --- a/compiler/plugins/target/VMVX/VMVXTarget.cpp +++ b/compiler/plugins/target/VMVX/VMVXTarget.cpp
@@ -6,6 +6,7 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/VMVX/Passes.h" +#include "iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h" #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h" @@ -53,38 +54,6 @@ b.getDictionaryAttr(configItems)); } -// TODO(benvanik): move to a CPU device registration outside of VMVX. -class VMVXTargetDevice final : public TargetDevice { -public: - VMVXTargetDevice() = default; - - IREE::HAL::DeviceTargetAttr - getDefaultDeviceTarget(MLIRContext *context, - const TargetRegistry &targetRegistry) const override { - Builder b(context); - SmallVector<NamedAttribute> configItems; - - auto configAttr = b.getDictionaryAttr(configItems); - - // If we had multiple target environments we would generate one target attr - // per environment, with each setting its own environment attribute. - // If we had multiple target environments we would generate one target attr - // per environment, with each setting its own environment attribute. - SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs; - targetRegistry.getTargetBackend("vmvx")->getDefaultExecutableTargets( - context, "vmvx", configAttr, executableTargetAttrs); - - return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("vmvx"), - configAttr, executableTargetAttrs); - } - - std::optional<IREE::HAL::DeviceTargetAttr> - getHostDeviceTarget(MLIRContext *context, - const TargetRegistry &targetRegistry) const override { - return getDefaultDeviceTarget(context, targetRegistry); - } -}; - class VMVXTargetBackend final : public TargetBackend { public: VMVXTargetBackend(const VMVXOptions &options) : options(options) {} @@ -265,20 +234,28 @@ : public PluginSession<VMVXSession, VMVXOptions, PluginActivationPolicy::DefaultActivated> { void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) { - // TODO(benvanik): move to a CPU device registration outside of VMVX. Note - // that the inline device does need to be special. + // TODO(multi-device): move local device registration out. + // This exists here for backwards compat with the old + // iree-hal-target-backends flag that needs to look up the device by backend + // name. + // Note that the inline device does need to be special. // #hal.device.target<"vmvx", ... - targets.add("vmvx", [&]() { return std::make_shared<VMVXTargetDevice>(); }); + targets.add("vmvx", [=]() { + LocalDevice::Options localDeviceOptions; + localDeviceOptions.defaultTargetBackends.push_back("vmvx"); + localDeviceOptions.defaultHostBackends.push_back("vmvx"); + return std::make_shared<LocalDevice>(localDeviceOptions); + }); // #hal.device.target<"vmvx-inline", ... targets.add("vmvx-inline", - [&]() { return std::make_shared<VMVXInlineTargetDevice>(); }); + [=]() { return std::make_shared<VMVXInlineTargetDevice>(); }); } void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) { // #hal.executable.target<"vmvx", ... targets.add("vmvx", - [&]() { return std::make_shared<VMVXTargetBackend>(options); }); + [=]() { return std::make_shared<VMVXTargetBackend>(options); }); // #hal.executable.target<"vmvx-inline", ... - targets.add("vmvx-inline", [&]() { + targets.add("vmvx-inline", [=]() { return std::make_shared<VMVXInlineTargetBackend>(options); }); }
diff --git a/compiler/plugins/target/VMVX/test/smoketest.mlir b/compiler/plugins/target/VMVX/test/smoketest.mlir index ef56028..b640d12 100644 --- a/compiler/plugins/target/VMVX/test/smoketest.mlir +++ b/compiler/plugins/target/VMVX/test/smoketest.mlir
@@ -2,7 +2,7 @@ module attributes { hal.device.targets = [ - #hal.device.target<"vmvx", [ + #hal.device.target<"local", [ #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> ]> ]
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp index 82d38ae..1ff5f24 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
@@ -206,6 +206,13 @@ return configAttrs; } +void DeviceTargetAttr::getExecutableTargets( + SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs) { + for (auto attr : getExecutableTargets()) { + resultAttrs.insert(attr); + } +} + // Returns a set of all configuration attributes from all device targets. // Returns nullopt if any target is missing a configuration attribute. static std::optional<SmallVector<DictionaryAttr>>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 3221fe4..37ab489 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -547,6 +547,10 @@ static std::optional<StaticRange<APInt>> lookupConfigAttrRange(Operation *op, StringRef name); + // Returns zero or more executable targets that this device supports. + void getExecutableTargets( + SetVector<IREE::HAL::ExecutableTargetAttr> &resultAttrs); + // Returns a list of all target executable configurations that may be // required for the given operation. static SmallVector<IREE::HAL::ExecutableTargetAttr, 4>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel index f04a329..63207aa 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/BUILD.bazel
@@ -27,13 +27,14 @@ "TargetRegistry.h", ], deps = [ - "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/Util/IR", "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ],
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt index b8b5e13..5c6a84b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CMakeLists.txt
@@ -25,11 +25,12 @@ "TargetRegistry.cpp" DEPS LLVMSupport + MLIRArithDialect MLIRIR MLIRPass + MLIRSCFDialect MLIRSupport MLIRTransforms - iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::Util::IR iree::compiler::Utils
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/BUILD.bazel new file mode 100644 index 0000000..bdfe673 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/BUILD.bazel
@@ -0,0 +1,33 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_compiler_cc_library( + name = "Devices", + srcs = [ + "LocalDevice.cpp", + ], + hdrs = [ + "LocalDevice.h", + ], + deps = [ + "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/CMakeLists.txt new file mode 100644 index 0000000..3a996ae --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/CMakeLists.txt
@@ -0,0 +1,32 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Dialect/HAL/Target/Devices/BUILD.bazel # +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_cc_library( + NAME + Devices + HDRS + "LocalDevice.h" + SRCS + "LocalDevice.cpp" + DEPS + LLVMSupport + MLIRIR + MLIRPass + MLIRSupport + MLIRTransforms + iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::HAL::Target + iree::compiler::Utils + PUBLIC +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp new file mode 100644 index 0000000..302ad62 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.cpp
@@ -0,0 +1,94 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h" + +#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" +#include "llvm/Support/CommandLine.h" + +IREE_DEFINE_COMPILER_OPTION_FLAGS( + mlir::iree_compiler::IREE::HAL::LocalDevice::Options); + +namespace mlir::iree_compiler::IREE::HAL { + +void LocalDevice::Options::bindOptions(OptionsBinder &binder) { + static llvm::cl::OptionCategory optionsCategory( + "IREE HAL local device options"); + + binder.list<std::string>( + "iree-hal-local-target-device-backends", defaultTargetBackends, + llvm::cl::desc( + "Default target backends for local device executable compilation."), + llvm::cl::ZeroOrMore, llvm::cl::cat(optionsCategory)); + binder.list<std::string>( + "iree-hal-local-host-device-backends", defaultHostBackends, + llvm::cl::desc( + "Default host backends for local device executable compilation."), + llvm::cl::ZeroOrMore, llvm::cl::cat(optionsCategory)); +} + +LocalDevice::LocalDevice(const LocalDevice::Options options) + : options(options) {} + +IREE::HAL::DeviceTargetAttr LocalDevice::getDefaultDeviceTarget( + MLIRContext *context, const TargetRegistry &targetRegistry) const { + Builder b(context); + SmallVector<NamedAttribute> configItems; + + // TODO(benvanik): flags for common queries. + + auto configAttr = b.getDictionaryAttr(configItems); + + SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs; + for (auto backendName : options.defaultTargetBackends) { + auto targetBackend = targetRegistry.getTargetBackend(backendName); + if (!targetBackend) { + llvm::errs() << "Default target backend not registered: " << backendName + << "\n"; + return {}; + } + targetBackend->getDefaultExecutableTargets(context, "local", configAttr, + executableTargetAttrs); + } + + return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("local"), + configAttr, executableTargetAttrs); +} + +std::optional<IREE::HAL::DeviceTargetAttr> +LocalDevice::getHostDeviceTarget(MLIRContext *context, + const TargetRegistry &targetRegistry) const { + Builder b(context); + SmallVector<NamedAttribute> configItems; + + // TODO(benvanik): flags for overrides or ask LLVM for info about the host. + + auto configAttr = b.getDictionaryAttr(configItems); + + SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs; + for (auto backendName : options.defaultHostBackends) { + auto targetBackend = targetRegistry.getTargetBackend(backendName); + if (!targetBackend) { + llvm::errs() << "Default host backend not registered: " << backendName + << "\n"; + return std::nullopt; + } + targetBackend->getHostExecutableTargets(context, "local", configAttr, + executableTargetAttrs); + } + + return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("local"), + configAttr, executableTargetAttrs); +} + +Value LocalDevice::buildDeviceTargetMatch( + Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr, + OpBuilder &builder) const { + return buildDeviceIDAndExecutableFormatsMatch( + loc, device, "local*", targetAttr.getExecutableTargets(), builder); +} + +} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h new file mode 100644 index 0000000..8e18374 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h
@@ -0,0 +1,50 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_DEVICE_LOCALDEVICE_H_ +#define IREE_COMPILER_DIALECT_HAL_TARGET_DEVICE_LOCALDEVICE_H_ + +#include <string> +#include <vector> + +#include "iree/compiler/Dialect/HAL/Target/TargetDevice.h" +#include "iree/compiler/Utils/OptionUtils.h" + +namespace mlir::iree_compiler::IREE::HAL { + +class LocalDevice final : public TargetDevice { +public: + struct Options { + // A list of default target backends for local devices. + std::vector<std::string> defaultTargetBackends; + // A list of default host backends for local devices. + std::vector<std::string> defaultHostBackends; + + void bindOptions(OptionsBinder &binder); + using FromFlags = OptionsFromFlags<Options>; + }; + + explicit LocalDevice(const Options options); + + IREE::HAL::DeviceTargetAttr + getDefaultDeviceTarget(MLIRContext *context, + const TargetRegistry &targetRegistry) const override; + + std::optional<IREE::HAL::DeviceTargetAttr> + getHostDeviceTarget(MLIRContext *context, + const TargetRegistry &targetRegistry) const override; + + Value buildDeviceTargetMatch(Location loc, Value device, + IREE::HAL::DeviceTargetAttr targetAttr, + OpBuilder &builder) const override; + +private: + const Options options; +}; + +} // namespace mlir::iree_compiler::IREE::HAL + +#endif // IREE_COMPILER_DIALECT_HAL_TARGET_DEVICE_LOCALDEVICE_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp index bace88a..1695c50 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.cpp
@@ -6,8 +6,66 @@ #include "iree/compiler/Dialect/HAL/Target/TargetDevice.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + namespace mlir::iree_compiler::IREE::HAL { -// TODO(benvanik): add device options. +// virtual +Value TargetDevice::buildDeviceTargetMatch( + Location loc, Value device, IREE::HAL::DeviceTargetAttr targetAttr, + OpBuilder &builder) const { + return buildDeviceIDAndExecutableFormatsMatch( + loc, device, targetAttr.getDeviceID(), targetAttr.getExecutableTargets(), + builder); +} + +Value buildDeviceIDAndExecutableFormatsMatch( + Location loc, Value device, StringRef deviceIDPattern, + ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs, + OpBuilder &builder) { + // Match first on the device ID, as that's the top-level filter. + Value idMatch = IREE::HAL::DeviceQueryOp::createI1( + loc, device, "hal.device.id", deviceIDPattern, builder); + + // If there are executable formats defined we should check at least one of + // them is supported. + if (executableTargetAttrs.empty()) { + return idMatch; // just device ID + } else { + auto ifOp = builder.create<scf::IfOp>(loc, builder.getI1Type(), idMatch, + true, true); + auto thenBuilder = ifOp.getThenBodyBuilder(); + Value anyFormatMatch = buildExecutableFormatMatch( + loc, device, executableTargetAttrs, thenBuilder); + thenBuilder.create<scf::YieldOp>(loc, anyFormatMatch); + auto elseBuilder = ifOp.getElseBodyBuilder(); + Value falseValue = elseBuilder.create<arith::ConstantIntOp>(loc, 0, 1); + elseBuilder.create<scf::YieldOp>(loc, falseValue); + return ifOp.getResult(0); + } +} + +Value buildExecutableFormatMatch( + Location loc, Value device, + ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs, + OpBuilder &builder) { + if (executableTargetAttrs.empty()) + return builder.create<arith::ConstantIntOp>(loc, 1, 1); + Value anyFormatMatch; + for (auto executableTargetAttr : executableTargetAttrs) { + Value formatMatch = IREE::HAL::DeviceQueryOp::createI1( + loc, device, "hal.executable.format", + executableTargetAttr.getFormat().getValue(), builder); + if (!anyFormatMatch) { + anyFormatMatch = formatMatch; + } else { + anyFormatMatch = + builder.create<arith::OrIOp>(loc, anyFormatMatch, formatMatch); + } + } + return anyFormatMatch; +} } // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h index 357e51f..2ce53f4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetDevice.h
@@ -38,10 +38,31 @@ return {}; } + // Builds an expression that returns an i1 indicating whether the given + // |device| matches the |targetAttr| requirements. + virtual Value buildDeviceTargetMatch(Location loc, Value device, + IREE::HAL::DeviceTargetAttr targetAttr, + OpBuilder &builder) const; + // TODO(benvanik): pipeline registration for specialization of host code at // various stages. }; +// Builds an expression that returns an i1 indicating whether the given +// |device| matches the device ID string pattern and executable target +// requirements. +Value buildDeviceIDAndExecutableFormatsMatch( + Location loc, Value device, StringRef deviceIDPattern, + ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs, + OpBuilder &builder); + +// Builds a match expression that returns an i1 indicating whether the given +// |device| supports any one of the |executableTargetAttrs|. +Value buildExecutableFormatMatch( + Location loc, Value device, + ArrayRef<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs, + OpBuilder &builder); + } // namespace mlir::iree_compiler::IREE::HAL #endif // IREE_COMPILER_DIALECT_HAL_TARGET_TARGETDEVICE_H_
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel index 2847a83..33e9561 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
@@ -57,6 +57,7 @@ "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/HAL/Target", + "//compiler/src/iree/compiler/Dialect/HAL/Target/Devices", "//compiler/src/iree/compiler/Dialect/Stream/IR", "//compiler/src/iree/compiler/Dialect/Stream/Transforms", "//compiler/src/iree/compiler/Dialect/Util/Conversion",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index 8c05849..6cce442 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -71,6 +71,7 @@ iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::HAL::Target + iree::compiler::Dialect::HAL::Target::Devices iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::Stream::Transforms iree::compiler::Dialect::Util::Conversion
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index baf469d..5fe1fd6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -11,6 +11,7 @@ #include "iree/compiler/Codegen/Common/CPU/Passes.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/OptionUtils.h" #include "iree/compiler/Utils/PassUtils.h" @@ -542,6 +543,16 @@ // Force the flags to be bound. // TODO(benvanik): remove the global flags and only rely on pipeline flags. (void)IREE::HAL::TargetOptions::FromFlags::get(); + // TODO(multi-device): move the local device registration somewhere more + // centralized. For now we piggy-back on the pass registration as that's where + // the local device is used. + (void)IREE::HAL::LocalDevice::Options::FromFlags::get(); + IREE::HAL::TargetDeviceList deviceList; + deviceList.add("local", [=]() { + return std::make_shared<LocalDevice>( + IREE::HAL::LocalDevice::Options::FromFlags::get()); + }); + IREE::HAL::TargetRegistry::getMutableTargetRegistry().mergeFrom(deviceList); // Generated. registerPasses();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir index a889c7d..46c0a5e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/assign_target_devices.mlir
@@ -3,21 +3,21 @@ // RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx,vmvx-inline})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-2 // RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targetBackends=vmvx,vmvx})' %s | FileCheck %s --check-prefix=CHECK --check-prefix=TARGET-EQ -// TARGET-1: #device_target_vmvx = #hal.device.target<"vmvx" +// TARGET-1: #device_target_local = #hal.device.target<"local" -// TARGET-2: #device_target_vmvx = #hal.device.target<"vmvx" +// TARGET-2: #device_target_local = #hal.device.target<"local" // TARGET-2: #device_target_vmvx_inline = #hal.device.target<"vmvx-inline" -// TARGET-EQ: #device_target_vmvx = #hal.device.target<"vmvx" +// TARGET-EQ: #device_target_local = #hal.device.target<"local" // CHECK: module // TARGET-0: @module { // TARGET-1: @module attributes { -// TARGET-1-SAME: hal.device.targets = [#device_target_vmvx] +// TARGET-1-SAME: hal.device.targets = [#device_target_local] // TARGET-2: @module attributes { -// TARGET-2-SAME: hal.device.targets = [#device_target_vmvx, #device_target_vmvx_inline]} +// TARGET-2-SAME: hal.device.targets = [#device_target_local, #device_target_vmvx_inline]} // TARGET-EQ: @module attributes { -// TARGET-EQ-SAME: hal.device.targets = [#device_target_vmvx]} +// TARGET-EQ-SAME: hal.device.targets = [#device_target_local]} module @module {} // -----
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir index 1a913ed..f47bcd2 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
@@ -30,7 +30,7 @@ // Tests for a no-op if there are no devices requiring legacy mode. module attributes {hal.device.targets = [ - #hal.device.target<"vmvx", {}>, + #hal.device.target<"local", {}>, #hal.device.target<"vulkan", {}> ]} { // CHECK-LABEL: @legacy_mode_not_required @@ -46,7 +46,7 @@ // Tests that any device requiring legacy_sync will trigger the pass. module attributes {hal.device.targets = [ - #hal.device.target<"vmvx", {}>, + #hal.device.target<"local", {}>, #hal.device.target<"vulkan", {legacy_sync}> ]} { // CHECK-LABEL: @mixed_legacy_mode_required
diff --git a/tests/compiler_driver/precompile.mlir b/tests/compiler_driver/precompile.mlir index ecc3d0b..5cdd117 100644 --- a/tests/compiler_driver/precompile.mlir +++ b/tests/compiler_driver/precompile.mlir
@@ -7,4 +7,4 @@ } // Just check that we have the right target and executable targets. -// CHECK: module attributes {hal.device.targets = [#hal.device.target<"vmvx", [#hal.executable.target<"vmvx" +// CHECK: module attributes {hal.device.targets = [#hal.device.target<"local", [#hal.executable.target<"vmvx"