blob: 161b6886ecf6b2d285c54a5c521a151c2f4ebbd4 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
#include <algorithm>
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Dialect.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
void TargetOptions::bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory halTargetOptionsCategory(
"IREE HAL executable target options");
// This function is called as part of registering the pass
// TranslateExecutablesPass. Pass registery is also staticly
// initialized, so targetBackendsFlags needs to be here to be initialized
// first.
binder.list<std::string>(
"iree-hal-target-backends", targets,
llvm::cl::desc("Target backends for executable compilation."),
llvm::cl::ZeroOrMore, llvm::cl::cat(halTargetOptionsCategory));
binder.opt<std::string>(
"iree-hal-dump-executable-sources-to", sourceListingPath,
llvm::cl::desc("Path to write individual hal.executable input "
"source listings into (- for stdout)."),
llvm::cl::cat(halTargetOptionsCategory));
}
// Renames |op| within |moduleOp| with a new name that is unique within both
// |moduleOp| and |optionalSymbolTable| (if one is provided).
static void renameWithDisambiguatedName(
Operation *op, Operation *moduleOp,
DenseMap<StringRef, Operation *> &targetSymbolMap,
SymbolTable *optionalSymbolTable) {
StringRef originalName = SymbolTable::getSymbolName(op).getValue();
// Iteratively try suffixes until we find one that isn't used.
std::string disambiguatedName;
int uniqueingCounter = 0;
do {
disambiguatedName =
llvm::formatv("{0}_{1}", originalName, uniqueingCounter++).str();
} while (
targetSymbolMap.lookup(disambiguatedName) ||
(optionalSymbolTable && optionalSymbolTable->lookup(disambiguatedName)));
SymbolTableCollection symbolTable;
SymbolUserMap symbolUsers(symbolTable, moduleOp);
mlir::StringAttr nameAttr =
mlir::StringAttr::get(op->getContext(), disambiguatedName);
symbolUsers.replaceAllUsesWith(op, nameAttr);
SymbolTable::setSymbolName(op, disambiguatedName);
}
// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version.
// Only difference is one has the symbol map that we don't even need.
// Destructively merges |sourceModuleOp| into |targetModuleOp|.
// |targetSymbolMap| is updated with the new symbols.
//
// If a private symbol in |sourceModuleOp| conflicts with another symbol
// (public or private) tracked in |targetSymbolMap|, it will be renamed.
//
// Fails if a public symbol in |sourceModuleOp| conflicts with another public
// symbol tracked in |targetSymbolMap|.
static LogicalResult mergeModuleInto(
Operation *sourceModuleOp, Operation *targetModuleOp,
DenseMap<StringRef, Operation *> &targetSymbolMap) {
auto &sourceBlock = sourceModuleOp->getRegion(0).front();
auto &targetBlock = targetModuleOp->getRegion(0).front();
SymbolTable sourceSymbolTable(sourceModuleOp);
auto allOps = llvm::to_vector<8>(
llvm::map_range(sourceBlock, [&](Operation &op) { return &op; }));
for (auto &op : allOps) {
if (op->hasTrait<OpTrait::IsTerminator>()) continue;
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
auto symbolName = symbolOp.getName();
// Resolve symbol name conflicts.
if (auto targetOp = targetSymbolMap[symbolName]) {
if (symbolOp.getVisibility() == SymbolTable::Visibility::Private) {
// Private symbols can be safely folded into duplicates or renamed.
if (OperationEquivalence::isEquivalentTo(
targetOp, op, OperationEquivalence::exactValueMatch,
OperationEquivalence::exactValueMatch,
OperationEquivalence::Flags::IgnoreLocations)) {
// Optimization: skip over duplicate private symbols.
// We could let CSE do this later, but we may as well check here.
continue;
} else {
// Preserve the op but give it a unique name.
renameWithDisambiguatedName(op, sourceModuleOp, targetSymbolMap,
&sourceSymbolTable);
}
} else {
// The source symbol has 'nested' or 'public' visibility.
if (SymbolTable::getSymbolVisibility(targetOp) !=
SymbolTable::Visibility::Private) {
// Oops! Both symbols are public and we can't safely rename either.
// If you hit this with ops that you think are safe to rename, mark
// them private.
//
// Note: we could also skip linking between executables with
// conflicting symbol names. We think such conflicts will be better
// fixed in other ways, so we'll emit an error until we find a case
// where that isn't true.
return op->emitError()
<< "multiple public symbols with the name: " << symbolName;
} else {
// Keep the original name for our new op, rename the target op.
renameWithDisambiguatedName(targetOp, targetModuleOp,
targetSymbolMap,
/*optionalSymbolTable=*/nullptr);
}
}
}
targetSymbolMap[SymbolTable::getSymbolName(op).getValue()] = op;
}
if (!targetBlock.empty() &&
targetBlock.back().hasTrait<OpTrait::IsTerminator>()) {
op->moveBefore(&targetBlock.back());
} else {
op->moveBefore(&targetBlock, targetBlock.end());
}
}
// Now that we're done cloning its ops, delete the original target op.
sourceModuleOp->erase();
return success();
}
// Replaces each usage of an entry point with its original symbol name with a
// new symbol name.
static void replaceEntryPointUses(
mlir::ModuleOp moduleOp,
const DenseMap<Attribute, Attribute> &replacements) {
for (auto funcLikeOp : moduleOp.getOps<FunctionOpInterface>()) {
funcLikeOp.walk([&](IREE::HAL::CommandBufferDispatchSymbolOp dispatchOp) {
auto it = replacements.find(dispatchOp.entry_point());
if (it != replacements.end()) {
dispatchOp.entry_pointAttr(it->second.cast<SymbolRefAttr>());
}
});
}
}
LogicalResult TargetBackend::linkExecutablesInto(
mlir::ModuleOp moduleOp,
ArrayRef<IREE::HAL::ExecutableOp> sourceExecutableOps,
IREE::HAL::ExecutableOp linkedExecutableOp,
IREE::HAL::ExecutableVariantOp linkedTargetOp,
std::function<Operation *(mlir::ModuleOp moduleOp)> getInnerModuleFn,
OpBuilder &builder) {
int nextEntryPointOrdinal = 0;
DenseMap<StringRef, Operation *> targetSymbolMap;
DenseMap<Attribute, Attribute> entryPointRefReplacements;
auto linkedTargetBuilder = OpBuilder::atBlockBegin(linkedTargetOp.getBody());
auto linkedModuleOp = getInnerModuleFn(linkedTargetOp.getInnerModule());
// Iterate over all source executable ops, linking as many as we can.
for (auto sourceExecutableOp : sourceExecutableOps) {
auto variantOps = llvm::to_vector<4>(
sourceExecutableOp.getOps<IREE::HAL::ExecutableVariantOp>());
for (auto variantOp : variantOps) {
// Only process targets matching our pattern.
if (variantOp.target().getBackend().getValue() != name()) continue;
// Clone entry point ops and queue remapping ordinals and updating
// symbol refs.
for (auto entryPointOp :
variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
auto newEntryPointOp =
linkedTargetBuilder.create<IREE::HAL::ExecutableEntryPointOp>(
entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
builder.getIndexAttr(nextEntryPointOrdinal++),
entryPointOp.layout(), ArrayAttr{}, IntegerAttr{});
// Add to replacement table for fixing up dispatch calls referencing
// this entry point.
auto oldSymbolRefAttr = SymbolRefAttr::get(
builder.getContext(), sourceExecutableOp.getName(),
{SymbolRefAttr::get(variantOp), SymbolRefAttr::get(entryPointOp)});
auto newSymbolRefAttr = SymbolRefAttr::get(
builder.getContext(), linkedExecutableOp.getName(),
{SymbolRefAttr::get(linkedTargetOp),
SymbolRefAttr::get(newEntryPointOp)});
entryPointRefReplacements[oldSymbolRefAttr] = newSymbolRefAttr;
}
// Merge the existing module into the new linked module op.
auto sourceModuleOp = getInnerModuleFn(variantOp.getInnerModule());
if (failed(mergeModuleInto(sourceModuleOp, linkedModuleOp,
targetSymbolMap))) {
return failure();
}
variantOp.erase();
}
if (sourceExecutableOp.getOps<IREE::HAL::ExecutableVariantOp>().empty()) {
sourceExecutableOp.erase();
}
}
// Update references to @executable::@target::@entry symbols.
replaceEntryPointUses(moduleOp, entryPointRefReplacements);
// Remove if we didn't add anything.
if (linkedTargetOp.getOps<IREE::HAL::ExecutableEntryPointOp>().empty()) {
linkedTargetOp.erase();
linkedExecutableOp.erase();
}
return success();
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir