Implement executable linking in LLVM (IR and AOT). (#3551)
Fixes https://github.com/google/iree/issues/3137, follows https://github.com/google/iree/pull/3283.
This also refactors the LLVM targets to share more code, since they have the same behavior until serialization.
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index fc9c64a..8c337a0 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -1500,6 +1500,17 @@
})));
}
+bool InterfaceOp::isEquivalentTo(InterfaceOp other) {
+ auto bindings = llvm::to_vector<4>(getBlock().getOps<InterfaceBindingOp>());
+ auto otherBindings =
+ llvm::to_vector<4>(other.getBlock().getOps<InterfaceBindingOp>());
+ return bindings.size() == otherBindings.size() &&
+ llvm::all_of(llvm::zip(bindings, otherBindings), [](auto bindings) {
+ return OperationEquivalence::isEquivalentTo(std::get<0>(bindings),
+ std::get<1>(bindings));
+ });
+}
+
//===----------------------------------------------------------------------===//
// hal.interface.binding
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 9ad3f54..80ad645 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -2124,6 +2124,10 @@
// TODO(benvanik): replace with a nested typed attr that works.
// Array of HAL_DescriptorSetLayoutBindingArrayAttr.
ArrayAttr getExecutableSetLayoutsAttr();
+
+ // Returns true if the all bindings in the interface match exactly those
+ // in |other| (including order).
+ bool isEquivalentTo(IREE::HAL::InterfaceOp other);
}];
}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
index 91e3255..594a0be 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
@@ -39,8 +39,8 @@
],
deps = [
":LLVMAOTTargetLinker",
- "//iree/compiler/Conversion/LinalgToLLVM",
"//iree/compiler/Dialect/HAL/Target",
+ "//iree/compiler/Dialect/HAL/Target/LLVM:LLVMBaseTarget",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMIRPasses",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
"//iree/schemas:dylib_executable_def_cc_fbs",
@@ -52,12 +52,7 @@
"@llvm-project//llvm:Support",
"@llvm-project//llvm:X86AsmParser",
"@llvm-project//llvm:X86CodeGen",
- "@llvm-project//mlir:Affine",
- "@llvm-project//mlir:LLVMDialect",
- "@llvm-project//mlir:LinalgOps",
- "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TargetLLVMIR",
- "@llvm-project//mlir:VectorOps",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
index 82fb0e4..ff46a62 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
@@ -35,14 +35,9 @@
LLVMSupport
LLVMX86AsmParser
LLVMX86CodeGen
- MLIRAffine
- MLIRLLVMIR
- MLIRLinalg
- MLIRSCF
MLIRTargetLLVMIR
- MLIRVector
- iree::compiler::Conversion::LinalgToLLVM
iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::HAL::Target::LLVM::LLVMBaseTarget
iree::compiler::Dialect::HAL::Target::LLVM::LLVMIRPasses
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
iree::schemas::dylib_executable_def_cc_fbs
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
index 6fa091b..199daf3 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
@@ -16,20 +16,14 @@
#include <cstdlib>
-#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h"
+#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/schemas/dylib_executable_def_generated.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
-#include "llvm/Support/Mutex.h"
#include "llvm/Support/TargetSelect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Target/LLVMIR.h"
namespace mlir {
@@ -37,32 +31,17 @@
namespace IREE {
namespace HAL {
-class LLVMAOTTargetBackend final : public TargetBackend {
+class LLVMAOTTargetBackend final : public LLVMBaseTargetBackend {
public:
- LLVMAOTTargetBackend(LLVMTargetOptions options)
- : options_(std::move(options)) {}
+ explicit LLVMAOTTargetBackend(LLVMTargetOptions options)
+ : LLVMBaseTargetBackend(options) {}
// NOTE: we could vary these based on the options, such as by arch/etc.
std::string name() const override { return "llvm_aot"; }
std::string filter_pattern() const override { return "dylib*"; }
- void getDependentDialects(DialectRegistry& registry) const override {
- // clang-format off
- registry.insert<AffineDialect,
- linalg::LinalgDialect,
- LLVM::LLVMDialect,
- scf::SCFDialect,
- vector::VectorDialect>();
- // clang-format on
- }
-
- void buildTranslationPassPipeline(ExecutableTargetOp targetOp,
- OpPassManager& passManager) override {
- buildLLVMTransformPassPipeline(passManager);
- }
-
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
- OpBuilder& executableBuilder) override {
+ OpBuilder &executableBuilder) override {
// Perform the translation in a separate context to avoid any
// multi-threading issues.
llvm::LLVMContext context;
@@ -109,7 +88,7 @@
}
std::string sharedLibData;
- const char* linkerToolPath = std::getenv("IREE_LLVMAOT_LINKER_PATH");
+ const char *linkerToolPath = std::getenv("IREE_LLVMAOT_LINKER_PATH");
if (linkerToolPath != nullptr) {
auto sharedLibDataStatus = linkLLVMAOTObjects(linkerToolPath, objData);
if (!sharedLibDataStatus.ok()) {
@@ -147,18 +126,6 @@
return success();
}
-
- std::array<Value, 3> calculateDispatchWorkgroupCount(
- Location loc, IREE::HAL::ExecutableOp executableOp,
- IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
- OpBuilder& builder) override {
- // For now we are not tiling and just dispatch everything as 1,1,1.
- auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
- return {constantOne, constantOne, constantOne};
- }
-
- private:
- LLVMTargetOptions options_;
};
void registerLLVMAOTTargetBackends(
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
index 7c0a843..b644afe 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -30,6 +30,28 @@
)
cc_library(
+ name = "LLVMBaseTarget",
+ srcs = [
+ "LLVMBaseTarget.cpp",
+ ],
+ hdrs = [
+ "LLVMBaseTarget.h",
+ ],
+ deps = [
+ ":LLVMIRPasses",
+ ":LLVMTargetOptions",
+ "//iree/compiler/Conversion/LinalgToLLVM",
+ "//iree/compiler/Dialect/HAL/Target",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:VectorOps",
+ ],
+)
+
+cc_library(
name = "LLVMIRPasses",
srcs = [
"LLVMIRPasses.cpp",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
index 43b2e50..ed01729 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -20,6 +20,27 @@
iree_cc_library(
NAME
+ LLVMBaseTarget
+ HDRS
+ "LLVMBaseTarget.h"
+ SRCS
+ "LLVMBaseTarget.cpp"
+ DEPS
+ ::LLVMIRPasses
+ ::LLVMTargetOptions
+ LLVMSupport
+ MLIRAffine
+ MLIRLLVMIR
+ MLIRLinalg
+ MLIRSCF
+ MLIRVector
+ iree::compiler::Conversion::LinalgToLLVM
+ iree::compiler::Dialect::HAL::Target
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
LLVMIRPasses
HDRS
"LLVMIRPasses.h"
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
index 3728d92..2194e66 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
@@ -37,18 +37,13 @@
"LLVMIRTarget.h",
],
deps = [
- "//iree/compiler/Conversion/LinalgToLLVM",
"//iree/compiler/Dialect/HAL/Target",
+ "//iree/compiler/Dialect/HAL/Target/LLVM:LLVMBaseTarget",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMIRPasses",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
"//iree/schemas:llvmir_executable_def_cc_fbs",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:Affine",
- "@llvm-project//mlir:LLVMDialect",
- "@llvm-project//mlir:LinalgOps",
- "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TargetLLVMIR",
- "@llvm-project//mlir:VectorOps",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
index a9d447f..345fe5a 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
@@ -28,14 +28,9 @@
DEPS
LLVMCore
LLVMSupport
- MLIRAffine
- MLIRLLVMIR
- MLIRLinalg
- MLIRSCF
MLIRTargetLLVMIR
- MLIRVector
- iree::compiler::Conversion::LinalgToLLVM
iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::HAL::Target::LLVM::LLVMBaseTarget
iree::compiler::Dialect::HAL::Target::LLVM::LLVMIRPasses
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
iree::schemas::llvmir_executable_def_cc_fbs
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
index fac9612..3a03471 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
@@ -14,19 +14,13 @@
#include "iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.h"
-#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
+#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/schemas/llvmir_executable_def_generated.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
-#include "llvm/Support/Mutex.h"
#include "llvm/Support/TargetSelect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Target/LLVMIR.h"
namespace mlir {
@@ -34,32 +28,17 @@
namespace IREE {
namespace HAL {
-class LLVMIRTargetBackend final : public TargetBackend {
+class LLVMIRTargetBackend final : public LLVMBaseTargetBackend {
public:
- LLVMIRTargetBackend(LLVMTargetOptions options)
- : options_(std::move(options)) {}
+ explicit LLVMIRTargetBackend(LLVMTargetOptions options)
+ : LLVMBaseTargetBackend(options) {}
// NOTE: we could vary these based on the options, such as by arch/etc.
std::string name() const override { return "llvm_ir"; }
std::string filter_pattern() const override { return "llvm-ir*"; }
- void getDependentDialects(DialectRegistry& registry) const override {
- // clang-format off
- registry.insert<AffineDialect,
- linalg::LinalgDialect,
- LLVM::LLVMDialect,
- scf::SCFDialect,
- vector::VectorDialect>();
- // clang-format on
- }
-
- void buildTranslationPassPipeline(ExecutableTargetOp targetOp,
- OpPassManager& passManager) override {
- buildLLVMTransformPassPipeline(passManager);
- }
-
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
- OpBuilder& executableBuilder) override {
+ OpBuilder &executableBuilder) override {
// Perform the translation to LLVM in a separate context to avoid
// multi-threading issues.
llvm::LLVMContext context;
@@ -121,18 +100,6 @@
return success();
}
-
- std::array<Value, 3> calculateDispatchWorkgroupCount(
- Location loc, IREE::HAL::ExecutableOp executableOp,
- IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
- OpBuilder& builder) override {
- // For now we are not tiling and just dispatch everything as 1,1,1.
- auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
- return {constantOne, constantOne, constantOne};
- }
-
- private:
- LLVMTargetOptions options_;
};
void registerLLVMIRTargetBackends(
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
new file mode 100644
index 0000000..a0c82c2
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
@@ -0,0 +1,198 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
+
+#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
+#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+namespace {
+
+// Destructively merges |sourceModuleOp| into |targetModuleOp|.
+// |targetSymbolTable| is updated with the new symbols.
+void mergeModuleInto(mlir::ModuleOp sourceModuleOp,
+ mlir::ModuleOp targetModuleOp,
+ DenseMap<StringRef, Operation *> &targetSymbolMap) {
+ auto allOps = llvm::to_vector<8>(llvm::map_range(
+ *sourceModuleOp.getBody(), [&](Operation &op) { return &op; }));
+ for (auto &op : allOps) {
+ if (op->isKnownTerminator()) continue;
+ if (auto symbolInterface = dyn_cast<SymbolOpInterface>(op)) {
+ if (targetSymbolMap.count(symbolInterface.getName())) {
+ // TODO(scotttodd): compare ops to ensure we aren't copying different
+ // things with the same name.
+ continue;
+ }
+ targetSymbolMap[symbolInterface.getName()] = op;
+ }
+ op->moveBefore(&targetModuleOp.getBody()->back());
+ }
+
+ // Now that we're done cloning its ops, delete the original target op.
+ sourceModuleOp.erase();
+}
+
+// Replaces each usage of an entry point with its original symbol name with a
+// new symbol name.
+void replaceEntryPointUses(mlir::ModuleOp moduleOp,
+ const DenseMap<Attribute, Attribute> &replacements) {
+ for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
+ funcOp.walk([&](IREE::HAL::CommandBufferDispatchSymbolOp dispatchOp) {
+ auto it = replacements.find(dispatchOp.entry_point());
+ if (it != replacements.end()) {
+ dispatchOp.entry_pointAttr(it->second.cast<SymbolRefAttr>());
+ }
+ });
+ }
+}
+
+} // namespace
+
+LLVMBaseTargetBackend::LLVMBaseTargetBackend(LLVMTargetOptions options)
+ : options_(std::move(options)) {}
+
+void LLVMBaseTargetBackend::getDependentDialects(
+ DialectRegistry ®istry) const {
+ // clang-format off
+ registry.insert<AffineDialect,
+ linalg::LinalgDialect,
+ LLVM::LLVMDialect,
+ scf::SCFDialect,
+ vector::VectorDialect>();
+ // clang-format on
+}
+
+void LLVMBaseTargetBackend::buildTranslationPassPipeline(
+ ExecutableTargetOp targetOp, OpPassManager &passManager) {
+ buildLLVMTransformPassPipeline(passManager);
+}
+
+LogicalResult LLVMBaseTargetBackend::linkExecutables(mlir::ModuleOp moduleOp) {
+ OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
+ auto executableOps =
+ llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
+
+ // Create our new "linked" hal.executable.
+ std::string linkedExecutableName = llvm::formatv("linked_{0}", name());
+ auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
+ moduleOp.getLoc(), linkedExecutableName);
+ SymbolTable::setSymbolVisibility(linkedExecutableOp,
+ SymbolTable::Visibility::Private);
+ // Add our hal.executable.target with an empty module.
+ builder.setInsertionPointToStart(linkedExecutableOp.getBody());
+ auto linkedTargetOp = builder.create<IREE::HAL::ExecutableTargetOp>(
+ moduleOp.getLoc(), name(), filter_pattern());
+ builder.setInsertionPoint(&linkedTargetOp.getBlock().back());
+ auto linkedModuleOp = builder.create<ModuleOp>(moduleOp.getLoc());
+
+ llvm::SmallVector<IREE::HAL::InterfaceOp, 4> interfaceOps;
+ int nextEntryPointOrdinal = 0;
+ DenseMap<StringRef, Operation *> symbolMap;
+ DenseMap<Attribute, Attribute> entryPointRefReplacements;
+ auto linkedExecutableBuilder =
+ OpBuilder::atBlockBegin(linkedExecutableOp.getBody());
+ auto linkedTargetBuilder = OpBuilder::atBlockBegin(linkedTargetOp.getBody());
+ for (auto executableOp : executableOps) {
+ auto targetOps = llvm::to_vector<4>(
+ executableOp.getOps<IREE::HAL::ExecutableTargetOp>());
+ for (auto targetOp : targetOps) {
+ // Only process targets matching our pattern.
+ if (!matchPattern(targetOp.target_backend_filter(), filter_pattern())) {
+ continue;
+ }
+
+ IREE::HAL::InterfaceOp interfaceOpForExecutable;
+ for (auto interfaceOp : interfaceOps) {
+ if (interfaceOp.isEquivalentTo(executableOp.getFirstInterfaceOp())) {
+ interfaceOpForExecutable = interfaceOp;
+ break;
+ }
+ }
+ if (!interfaceOpForExecutable) {
+ interfaceOpForExecutable = dyn_cast<IREE::HAL::InterfaceOp>(
+ linkedExecutableBuilder.clone(*executableOp.getFirstInterfaceOp()));
+ interfaceOpForExecutable.setName(
+ llvm::formatv("legacy_io_{0}", interfaceOps.size()).str());
+ interfaceOps.push_back(interfaceOpForExecutable);
+ }
+
+ // Clone entry point ops and queue remapping ordinals and updating
+ // symbol refs.
+ for (auto entryPointOp :
+ targetOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
+ auto newEntryPointOp =
+ linkedTargetBuilder.create<IREE::HAL::ExecutableEntryPointOp>(
+ entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
+ builder.getI32IntegerAttr(nextEntryPointOrdinal++),
+ builder.getSymbolRefAttr(interfaceOpForExecutable.getName()),
+ entryPointOp.signatureAttr());
+
+ // Add to replacement table for fixing up dispatch calls referencing
+ // this entry point.
+ auto oldSymbolRefAttr = builder.getSymbolRefAttr(
+ executableOp.getName(), {builder.getSymbolRefAttr(targetOp),
+ builder.getSymbolRefAttr(entryPointOp)});
+ auto newSymbolRefAttr = builder.getSymbolRefAttr(
+ linkedExecutableOp.getName(),
+ {builder.getSymbolRefAttr(linkedTargetOp),
+ builder.getSymbolRefAttr(newEntryPointOp)});
+ entryPointRefReplacements[oldSymbolRefAttr] = newSymbolRefAttr;
+ }
+
+ mergeModuleInto(targetOp.getInnerModule(), linkedModuleOp, symbolMap);
+
+ targetOp.erase();
+ }
+
+ if (executableOp.getOps<IREE::HAL::ExecutableTargetOp>().empty()) {
+ executableOp.erase();
+ }
+ }
+
+ // Update references to @executable::@target::@entry symbols.
+ replaceEntryPointUses(moduleOp, entryPointRefReplacements);
+
+ // Remove if we didn't add anything.
+ if (linkedTargetOp.getOps<IREE::HAL::ExecutableEntryPointOp>().empty()) {
+ linkedTargetOp.erase();
+ linkedExecutableOp.erase();
+ }
+
+ return success();
+}
+
+std::array<Value, 3> LLVMBaseTargetBackend::calculateDispatchWorkgroupCount(
+ Location loc, IREE::HAL::ExecutableOp executableOp,
+ IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+ OpBuilder &builder) {
+ // For now we are not tiling and just dispatch everything as 1,1,1.
+ auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
+ return {constantOne, constantOne, constantOne};
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h
new file mode 100644
index 0000000..52d0dbd
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h
@@ -0,0 +1,52 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_LLVMBASETARGET_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_LLVMBASETARGET_H_
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Base target for LLVM ahead-of-time (AOT) and just-in-time (JIT) backends.
+class LLVMBaseTargetBackend : public TargetBackend {
+ public:
+ explicit LLVMBaseTargetBackend(LLVMTargetOptions options);
+
+ void getDependentDialects(DialectRegistry ®istry) const override;
+
+ void buildTranslationPassPipeline(ExecutableTargetOp targetOp,
+ OpPassManager &passManager) override;
+
+ LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override;
+
+ std::array<Value, 3> calculateDispatchWorkgroupCount(
+ Location loc, IREE::HAL::ExecutableOp executableOp,
+ IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
+ OpBuilder &builder) override;
+
+ protected:
+ LLVMTargetOptions options_;
+};
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_LLVMBASETARGET_H_
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
index e91441d..0725b17 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
@@ -32,7 +32,7 @@
namespace HAL {
std::unique_ptr<llvm::TargetMachine> createTargetMachine(
- const LLVMTargetOptions& targetOptions) {
+ const LLVMTargetOptions &targetOptions) {
std::string errorMessage;
auto target = llvm::TargetRegistry::lookupTarget(targetOptions.targetTriple,
errorMessage);
@@ -44,9 +44,9 @@
return machine;
}
-LogicalResult runLLVMIRPasses(const LLVMTargetOptions& options,
- llvm::TargetMachine* machine,
- llvm::Module* module) {
+LogicalResult runLLVMIRPasses(const LLVMTargetOptions &options,
+ llvm::TargetMachine *machine,
+ llvm::Module *module) {
llvm::LoopAnalysisManager loopAnalysisManager;
llvm::FunctionAnalysisManager functionAnalysisManager;
llvm::CGSCCAnalysisManager cGSCCAnalysisManager;
@@ -78,8 +78,8 @@
return success();
}
-LogicalResult runEmitObjFilePasses(llvm::TargetMachine* machine,
- llvm::Module* module, std::string* objData) {
+LogicalResult runEmitObjFilePasses(llvm::TargetMachine *machine,
+ llvm::Module *module, std::string *objData) {
llvm::SmallVector<char, 0> stream_buffer;
{
// TODO(ataei): Use non legacy pass mamanger for this.
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h
index 37ee1ba..af109c7 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h
@@ -29,16 +29,16 @@
// Creates target machine form target options.
std::unique_ptr<llvm::TargetMachine> createTargetMachine(
- const LLVMTargetOptions& options);
+ const LLVMTargetOptions &options);
// Creates and runs LLVMIR optimization passes defined in LLVMTargetOptions.
-LogicalResult runLLVMIRPasses(const LLVMTargetOptions& options,
- llvm::TargetMachine* machine,
- llvm::Module* module);
+LogicalResult runLLVMIRPasses(const LLVMTargetOptions &options,
+ llvm::TargetMachine *machine,
+ llvm::Module *module);
// Emits compiled module obj for the target machine.
-LogicalResult runEmitObjFilePasses(llvm::TargetMachine* machine,
- llvm::Module* module, std::string* objData);
+LogicalResult runEmitObjFilePasses(llvm::TargetMachine *machine,
+ llvm::Module *module, std::string *objData);
} // namespace HAL
} // namespace IREE
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir b/iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir
index 388d7ec..0a66d94 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir
@@ -11,7 +11,7 @@
}
}
-// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0
+// CHECK-LABEL: hal.executable @linked_llvm_ir
// CHECK-DAG: hal.executable.binary attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = 1280071245 : i32} {
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir b/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir
index 271be8b..6744a70 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir
@@ -11,7 +11,7 @@
}
}
-// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0
+// CHECK-LABEL: hal.executable @linked_llvm_ir
// CHECK-DAG: hal.executable.binary attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = 1280071245 : i32} {
diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 4f03b80..e7b6f6e 100644
--- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -264,7 +264,7 @@
// hal.interface.binding @arg0, set=0, binding=0, ...
// hal.interface.binding @arg1, set=0, binding=1, ...
// }
- // hal.executable.target "target-backend" {
+ // hal.executable.target @target, filter="target-backend" {
// hal.executable.entry_point @main attributes {
// interface = @main_io,
// ordinal = 0 : i32,
@@ -277,7 +277,7 @@
// As output:
// hal.executable @some_executable {
// hal.interface @main_io ...
- // hal.executable.target "target-backend" {
+ // hal.executable.target @target, filter="target-backend" {
// hal.executable.entry_point @main ...
// module { spv.module { ... } }
// }
@@ -289,10 +289,44 @@
virtual void buildTranslationPassPipeline(
IREE::HAL::ExecutableTargetOp targetOp, OpPassManager &passManager) = 0;
- // TODO(benvanik): define linkage rules.
- // Major thing to figure out here is how to rewrite the executable references.
- // We may want to move executable selection into the hal.device.switch of the
- // dispatches so that they can be more easily replaced per-target.
+ // Links compatible executables within the provided |moduleOp| together into
+ // zero or more new linked executables. Implementations should move
+ // executable contents (including interfaces, entry points, and functions)
+ // into new executables and update any relevant references as they do so.
+ //
+ // Which executables to link together and how many new executables to produce
+ // are left to implementations to determine. For example, an implementation
+ // may choose to link all executables (even with different interfaces) into
+ // a single combined executable, or it could choose to limit the number linked
+ // together in order to shard binary size across multiple executables.
+ //
+ // The input |moduleOp| may contain executables containing multiple targets,
+ // so implementations should check target backend filters against their own
+ // `filter_pattern()` prior to modifying them.
+ //
+ // Sample output structure:
+ // hal.executable @linked_executable {
+ // hal.interface @legacy_io_0 { ... }
+ // hal.interface @legacy_io_1 { ... }
+ // hal.executable.target @target, filter="target-backend" {
+ // hal.executable.entry_point @main_dispatch_0 attributes { ... }
+ // hal.executable.entry_point @main_dispatch_1 attributes { ... }
+ // hal.executable.entry_point @main_dispatch_2 attributes { ... }
+ // module {
+ // func @main_0(...) { ... }
+ // func @main_1(...) { ... }
+ // func @main_2(...) { ... }
+ // }
+ // }
+ // }
+ // // Other targets within executables are not modified
+ // hal.executable @main_dispatch_0 {
+ // hal.interface @legacy_io { ... }
+ // hal.executable.target @other, filter="other" {
+ // hal.executable.entry_point @main_dispatch_0 attributes { ... }
+ // module { ... }
+ // }
+ // }
virtual LogicalResult linkExecutables(mlir::ModuleOp moduleOp) {
return success();
}
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index ab3c3a4..dbd63b6 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -41,23 +41,42 @@
namespace {
-bool areInterfacesEquivalent(IREE::HAL::InterfaceOp lhs,
- IREE::HAL::InterfaceOp rhs) {
- auto lhsBindings = lhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
- auto rhsBindings = rhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
- auto lhsIt = lhsBindings.begin(), lhsEnd = lhsBindings.end();
- auto rhsIt = rhsBindings.begin(), rhsEnd = rhsBindings.end();
- for (; lhsIt != lhsEnd && rhsIt != rhsEnd; ++lhsIt, ++rhsIt) {
- // Assume bindings are in order, check equivalence of each pairing.
- if (!OperationEquivalence::isEquivalentTo(*lhsIt, *rhsIt)) return false;
+// Destructively merges |sourceModuleOp| into |targetModuleOp|.
+// |targetSymbolTable| is updated with the new symbols.
+void mergeModuleInto(IREE::VM::ModuleOp sourceModuleOp,
+ IREE::VM::ModuleOp targetModuleOp,
+ DenseMap<StringRef, Operation *> &targetSymbolMap) {
+ auto allOps = llvm::to_vector<8>(llvm::map_range(
+ sourceModuleOp.getBlock(), [&](Operation &op) { return &op; }));
+ for (auto &op : allOps) {
+ if (op->isKnownTerminator()) continue;
+ if (auto symbolInterface = dyn_cast<SymbolOpInterface>(op)) {
+ if (targetSymbolMap.count(symbolInterface.getName())) {
+ // TODO(scotttodd): compare ops to ensure we aren't copying different
+ // things with the same name.
+ continue;
+ }
+ targetSymbolMap[symbolInterface.getName()] = op;
+ }
+ op->moveBefore(&targetModuleOp.getBlock().back());
}
- if (lhsIt != lhsEnd || rhsIt != rhsEnd) {
- // Not finished iterating through one, number of interface bindings differ.
- return false;
- }
+ // Now that we're done cloning its ops, delete the original target op.
+ sourceModuleOp.erase();
+}
- return true;
+// Replaces each usage of an entry point with its original symbol name with a
+// new symbol name.
+void replaceEntryPointUses(mlir::ModuleOp moduleOp,
+ const DenseMap<Attribute, Attribute> &replacements) {
+ for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
+ funcOp.walk([&](IREE::HAL::CommandBufferDispatchSymbolOp dispatchOp) {
+ auto it = replacements.find(dispatchOp.entry_point());
+ if (it != replacements.end()) {
+ dispatchOp.entry_pointAttr(it->second.cast<SymbolRefAttr>());
+ }
+ });
+ }
}
} // namespace
@@ -92,38 +111,6 @@
}
LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override {
- // --- Linking overview ---
- //
- // We start with a `module` containing multiple `hal.executable`s, each with
- // potentially multiple `hal.executable.target`s. We want to move all
- // compatible VMLA functions into a new "linked" executable, de-duping
- // symbols, and updating references as we go.
- //
- // Sample IR after:
- // hal.executable @linked_vmla {
- // hal.interface @legacy_io_0 { ... }
- // hal.interface @legacy_io_1 { ... }
- // hal.executable.target @vmla, filter="vmla" {
- // hal.executable.entry_point @main_dispatch_0 attributes { ... }
- // hal.executable.entry_point @main_dispatch_1 attributes { ... }
- // hal.executable.entry_point @main_dispatch_2 attributes { ... }
- // module {
- // vm.module @module {
- // vm.func @main_0(...) { ... }
- // vm.func @main_1(...) { ... }
- // vm.func @main_2(...) { ... }
- // }
- // }
- // }
- // }
- // hal.executable @main_dispatch_0 {
- // hal.interface @legacy_io { ... }
- // hal.executable.target @other, filter="other" {
- // hal.executable.entry_point @main_dispatch_0 attributes { ... }
- // module { ... }
- // }
- // }
-
OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
auto executableOps =
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
@@ -163,8 +150,7 @@
IREE::HAL::InterfaceOp interfaceOpForExecutable;
for (auto interfaceOp : interfaceOps) {
- if (areInterfacesEquivalent(interfaceOp,
- executableOp.getFirstInterfaceOp())) {
+ if (interfaceOp.isEquivalentTo(executableOp.getFirstInterfaceOp())) {
interfaceOpForExecutable = interfaceOp;
break;
}
@@ -230,45 +216,6 @@
return success();
}
- // Destructively merges |sourceModuleOp| into |targetModuleOp|.
- // |targetSymbolTable| is updated with the new symbols.
- void mergeModuleInto(IREE::VM::ModuleOp sourceModuleOp,
- IREE::VM::ModuleOp targetModuleOp,
- DenseMap<StringRef, Operation *> &targetSymbolMap) {
- auto allOps = llvm::to_vector<8>(llvm::map_range(
- sourceModuleOp.getBlock(), [&](Operation &op) { return &op; }));
- for (auto &op : allOps) {
- if (op->isKnownTerminator()) continue;
- if (auto symbolInterface = dyn_cast<SymbolOpInterface>(op)) {
- if (targetSymbolMap.count(symbolInterface.getName())) {
- // TODO(scotttodd): compare ops to ensure we aren't copying different
- // things with the same name.
- continue;
- }
- targetSymbolMap[symbolInterface.getName()] = op;
- }
- op->moveBefore(&targetModuleOp.getBlock().back());
- }
-
- // Now that we're done cloning its ops, delete the original target op.
- sourceModuleOp.erase();
- }
-
- // Replaces each usage of an entry point with its original symbol name with a
- // new symbol name.
- void replaceEntryPointUses(
- mlir::ModuleOp moduleOp,
- const DenseMap<Attribute, Attribute> &replacements) {
- for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
- funcOp.walk([&](IREE::HAL::CommandBufferDispatchSymbolOp dispatchOp) {
- auto it = replacements.find(dispatchOp.entry_point());
- if (it != replacements.end()) {
- dispatchOp.entry_pointAttr(it->second.cast<SymbolRefAttr>());
- }
- });
- }
- }
-
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
// Serialize the VM module to bytes.