blob: 069ea10bfee1967d3ce9a9b2a84df3d341bde73d [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
#include <algorithm>
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Dialect.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
TargetOptions getTargetOptionsFromFlags() {
static llvm::cl::OptionCategory halTargetOptionsCategory(
"IREE HAL executable target options");
// This function is called as part of registering the pass
// TranslateExecutablesPass. Pass registery is also staticly initialized,
// so targetBackendsFlags needs to be here to be initialized first.
static llvm::cl::list<std::string> *targetBackendsFlag =
new llvm::cl::list<std::string>{
"iree-hal-target-backends",
llvm::cl::desc("Target backends for executable compilation"),
llvm::cl::ZeroOrMore, llvm::cl::cat(halTargetOptionsCategory)};
TargetOptions targetOptions;
targetOptions.targets = *targetBackendsFlag;
return targetOptions;
}
// static
bool TargetBackend::matchPattern(StringRef value, StringRef pattern) {
size_t nextCharIndex = pattern.find_first_of("*?");
if (nextCharIndex == std::string::npos) {
return value == pattern;
} else if (nextCharIndex > 0) {
if (value.substr(0, nextCharIndex) != pattern.substr(0, nextCharIndex)) {
return false;
}
value = value.substr(nextCharIndex);
pattern = pattern.substr(nextCharIndex);
}
if (value.empty() && pattern.empty()) {
return true;
}
char patternChar = pattern[0];
if (patternChar == '*' && pattern.size() > 1 && value.empty()) {
return false;
} else if (patternChar == '*' && pattern.size() == 1) {
return true;
} else if (patternChar == '?' || value[0] == patternChar) {
return matchPattern(value.substr(1), pattern.substr(1));
} else if (patternChar == '*') {
return matchPattern(value, pattern.substr(1)) ||
matchPattern(value.substr(1), pattern);
}
return false;
}
// static
BufferConstraintsAttr TargetBackend::makeDefaultBufferConstraints(
MLIRContext *context) {
// Picked to represent what we kind of want on CPU today.
uint64_t maxAllocationSize = 1 * 1024 * 1024 * 1024ull;
uint64_t minBufferOffsetAlignment = 16ull;
uint64_t maxBufferRange = 1 * 1024 * 1024 * 1024ull;
uint64_t minBufferRangeAlignment = 16ull;
Builder b(context);
return BufferConstraintsAttr::get(b.getIndexAttr(maxAllocationSize),
b.getIndexAttr(minBufferOffsetAlignment),
b.getIndexAttr(maxBufferRange),
b.getIndexAttr(minBufferRangeAlignment));
}
BufferConstraintsAttr TargetBackend::queryBufferConstraints(
MLIRContext *context) {
return makeDefaultBufferConstraints(context);
}
// Renames |op| within |moduleOp| with a new name that is unique within both
// |moduleOp| and |optionalSymbolTable| (if one is provided).
static void renameWithDisambiguatedName(
Operation *op, Operation *moduleOp,
DenseMap<StringRef, Operation *> &targetSymbolMap,
SymbolTable *optionalSymbolTable) {
StringRef originalName = SymbolTable::getSymbolName(op);
// Iteratively try suffixes until we find one that isn't used.
std::string disambiguatedName;
int uniqueingCounter = 0;
do {
disambiguatedName =
llvm::formatv("{0}_{1}", originalName, uniqueingCounter++).str();
} while (
targetSymbolMap.lookup(disambiguatedName) ||
(optionalSymbolTable && optionalSymbolTable->lookup(disambiguatedName)));
SymbolTableCollection symbolTable;
SymbolUserMap symbolUsers(symbolTable, moduleOp);
symbolUsers.replaceAllUsesWith(op, disambiguatedName);
SymbolTable::setSymbolName(op, disambiguatedName);
}
void TargetBackend::declareVariantOps(IREE::Flow::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp) {
OpBuilder targetBuilder(&executableOp.getBlock().back());
auto targetContainerOp = targetBuilder.create<IREE::HAL::ExecutableVariantOp>(
sourceOp.getLoc(), name(), filter_pattern());
OpBuilder containerBuilder(&targetContainerOp.getBlock().back());
containerBuilder.create<ModuleOp>(sourceOp.getLoc());
}
// Destructively merges |sourceModuleOp| into |targetModuleOp|.
// |targetSymbolMap| is updated with the new symbols.
//
// If a private symbol in |sourceModuleOp| conflicts with another symbol
// (public or private) tracked in |targetSymbolMap|, it will be renamed.
//
// Fails if a public symbol in |sourceModuleOp| conflicts with another public
// symbol tracked in |targetSymbolMap|.
static LogicalResult mergeModuleInto(
Operation *sourceModuleOp, Operation *targetModuleOp,
DenseMap<StringRef, Operation *> &targetSymbolMap) {
auto &sourceBlock = sourceModuleOp->getRegion(0).front();
auto &targetBlock = targetModuleOp->getRegion(0).front();
SymbolTable sourceSymbolTable(sourceModuleOp);
auto allOps = llvm::to_vector<8>(
llvm::map_range(sourceBlock, [&](Operation &op) { return &op; }));
for (auto &op : allOps) {
if (op->hasTrait<OpTrait::IsTerminator>()) continue;
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
auto symbolName = symbolOp.getName();
// Resolve symbol name conflicts.
if (auto targetOp = targetSymbolMap[symbolName]) {
if (symbolOp.getVisibility() == SymbolTable::Visibility::Private) {
// Private symbols can be safely folded into duplicates or renamed.
if (OperationEquivalence::isEquivalentTo(targetOp, op)) {
// Optimization: skip over duplicate private symbols.
// We could let CSE do this later, but we may as well check here.
continue;
} else {
// Preserve the op but give it a unique name.
renameWithDisambiguatedName(op, sourceModuleOp, targetSymbolMap,
&sourceSymbolTable);
}
} else {
// The source symbol has 'nested' or 'public' visibility.
if (SymbolTable::getSymbolVisibility(targetOp) !=
SymbolTable::Visibility::Private) {
// Oops! Both symbols are public and we can't safely rename either.
// If you hit this with ops that you think are safe to rename, mark
// them private.
//
// Note: we could also skip linking between executables with
// conflicting symbol names. We think such conflicts will be better
// fixed in other ways, so we'll emit an error until we find a case
// where that isn't true.
return op->emitError()
<< "multiple public symbols with the name: " << symbolName;
} else {
// Keep the original name for our new op, rename the target op.
renameWithDisambiguatedName(targetOp, targetModuleOp,
targetSymbolMap,
/*optionalSymbolTable=*/nullptr);
}
}
}
targetSymbolMap[SymbolTable::getSymbolName(op)] = op;
}
if (!targetBlock.empty() &&
targetBlock.back().hasTrait<OpTrait::IsTerminator>()) {
op->moveBefore(&targetBlock.back());
} else {
op->moveBefore(&targetBlock, targetBlock.end());
}
}
// Now that we're done cloning its ops, delete the original target op.
sourceModuleOp->erase();
return success();
}
// Replaces each usage of an entry point with its original symbol name with a
// new symbol name.
static void replaceEntryPointUses(
mlir::ModuleOp moduleOp,
const DenseMap<Attribute, Attribute> &replacements) {
for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
funcOp.walk([&](IREE::HAL::CommandBufferDispatchSymbolOp dispatchOp) {
auto it = replacements.find(dispatchOp.entry_point());
if (it != replacements.end()) {
dispatchOp.entry_pointAttr(it->second.cast<SymbolRefAttr>());
}
});
}
}
LogicalResult TargetBackend::linkExecutablesInto(
mlir::ModuleOp moduleOp,
ArrayRef<IREE::HAL::ExecutableOp> sourceExecutableOps,
IREE::HAL::ExecutableOp linkedExecutableOp,
IREE::HAL::ExecutableVariantOp linkedTargetOp,
std::function<Operation *(mlir::ModuleOp moduleOp)> getInnerModuleFn,
OpBuilder &builder) {
llvm::SmallVector<IREE::HAL::InterfaceOp, 4> linkedInterfaceOps;
int nextEntryPointOrdinal = 0;
DenseMap<StringRef, Operation *> targetSymbolMap;
DenseMap<Attribute, Attribute> entryPointRefReplacements;
auto linkedExecutableBuilder =
OpBuilder::atBlockBegin(linkedExecutableOp.getBody());
auto linkedTargetBuilder = OpBuilder::atBlockBegin(linkedTargetOp.getBody());
auto linkedModuleOp = getInnerModuleFn(linkedTargetOp.getInnerModule());
// Iterate over all source executable ops, linking as many as we can.
for (auto sourceExecutableOp : sourceExecutableOps) {
auto variantOps = llvm::to_vector<4>(
sourceExecutableOp.getOps<IREE::HAL::ExecutableVariantOp>());
for (auto variantOp : variantOps) {
// Only process targets matching our pattern.
if (!matchPattern(variantOp.target_backend_filter(), filter_pattern())) {
continue;
}
// Clone entry point ops and queue remapping ordinals and updating
// symbol refs.
for (auto entryPointOp :
variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
// Lookup the interface used by this entry point.
auto sourceInterfaceOp =
SymbolTable::lookupNearestSymbolFrom<IREE::HAL::InterfaceOp>(
sourceExecutableOp, entryPointOp.interfaceAttr());
assert(sourceInterfaceOp && "cannot find source interface");
IREE::HAL::InterfaceOp linkedInterfaceOp;
for (auto interfaceOp : linkedInterfaceOps) {
if (interfaceOp.isEquivalentTo(sourceInterfaceOp)) {
linkedInterfaceOp = interfaceOp;
break;
}
}
if (!linkedInterfaceOp) {
linkedInterfaceOp = dyn_cast<IREE::HAL::InterfaceOp>(
linkedExecutableBuilder.clone(*sourceInterfaceOp));
linkedInterfaceOp.setName(
llvm::formatv("io_{0}", linkedInterfaceOps.size()).str());
linkedInterfaceOps.push_back(linkedInterfaceOp);
}
auto newEntryPointOp =
linkedTargetBuilder.create<IREE::HAL::ExecutableEntryPointOp>(
entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
builder.getIndexAttr(nextEntryPointOrdinal++),
builder.getSymbolRefAttr(linkedInterfaceOp.getName()),
ArrayAttr{}, IntegerAttr{});
// Add to replacement table for fixing up dispatch calls referencing
// this entry point.
auto oldSymbolRefAttr =
builder.getSymbolRefAttr(sourceExecutableOp.getName(),
{builder.getSymbolRefAttr(variantOp),
builder.getSymbolRefAttr(entryPointOp)});
auto newSymbolRefAttr = builder.getSymbolRefAttr(
linkedExecutableOp.getName(),
{builder.getSymbolRefAttr(linkedTargetOp),
builder.getSymbolRefAttr(newEntryPointOp)});
entryPointRefReplacements[oldSymbolRefAttr] = newSymbolRefAttr;
}
// Merge the existing module into the new linked module op.
auto sourceModuleOp = getInnerModuleFn(variantOp.getInnerModule());
if (failed(mergeModuleInto(sourceModuleOp, linkedModuleOp,
targetSymbolMap))) {
return failure();
}
variantOp.erase();
}
if (sourceExecutableOp.getOps<IREE::HAL::ExecutableVariantOp>().empty()) {
sourceExecutableOp.erase();
}
}
// Update references to @executable::@target::@entry symbols.
replaceEntryPointUses(moduleOp, entryPointRefReplacements);
// Remove if we didn't add anything.
if (linkedTargetOp.getOps<IREE::HAL::ExecutableEntryPointOp>().empty()) {
linkedTargetOp.erase();
linkedExecutableOp.erase();
}
return success();
}
std::array<Value, 3> TargetBackend::calculateDispatchWorkgroupSize(
Location loc, IREE::HAL::ExecutableOp executableOp,
IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
OpBuilder &builder) {
// When no workgroup size is specified we just assume [1,1,1].
// This yields a workgroup count that models the extents of the workload.
return {
builder.createOrFold<mlir::ConstantIndexOp>(loc, 1),
builder.createOrFold<mlir::ConstantIndexOp>(loc, 1),
builder.createOrFold<mlir::ConstantIndexOp>(loc, 1),
};
}
static std::array<Value, 3> calculateDispatchWorkgroupCountFromRegion(
Location loc, IREE::HAL::ExecutableEntryPointOp entryPointOp,
ValueRange workload, OpBuilder &builder) {
Block *body = entryPointOp.getBlock();
BlockAndValueMapping bvm;
for (auto args : llvm::enumerate(workload)) {
bvm.map(body->getArgument(args.index()), args.value());
}
for (Operation &op : body->without_terminator()) {
builder.clone(op, bvm);
}
auto returnOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
// Verifier of EntryPointOp checks that the return has 3 values.
SmallVector<Value, 4> count = llvm::to_vector<4>(llvm::map_range(
returnOp.operands(), [&bvm](Value v) { return bvm.lookup(v); }));
return {count[0], count[1], count[2]};
}
std::array<Value, 3> TargetBackend::calculateDispatchWorkgroupCount(
Location loc, IREE::HAL::ExecutableOp executableOp,
IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
OpBuilder &builder) {
Region *region = entryPointOp.getBody();
if (region) {
return calculateDispatchWorkgroupCountFromRegion(loc, entryPointOp,
workload, builder);
}
auto workgroupSize = calculateDispatchWorkgroupSize(
loc, executableOp, entryPointOp, workload, builder);
return calculateDispatchWorkgroupCount(loc, workload, workgroupSize, builder);
}
std::array<Value, 3> TargetBackend::calculateDispatchWorkgroupCount(
Location loc, ValueRange workload,
const std::array<Value, 3> &workgroupSize, OpBuilder &builder) {
std::array<Value, 3> result;
auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
if (workload.size() <= 3) {
// 1-D to 3-D are easy (pad 2 to 0 dimensions) and divide by workgroup size.
for (int i = 0; i < 3; ++i) {
// Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
Value workloadI = i < workload.size() ? workload[i] : constantOne;
workloadI = builder.createOrFold<mlir::SubIOp>(
loc,
builder.createOrFold<mlir::AddIOp>(loc, workloadI, workgroupSize[i]),
constantOne);
result[i] = builder.createOrFold<UnsignedDivIOp>(loc, workloadI,
workgroupSize[i]);
}
} else {
// TODO(#4140): remapping of N-D to 3-D: this is not how you do this!
Value flatWorkload = constantOne;
for (auto workloadI : workload) {
flatWorkload = builder.createOrFold<MulIOp>(loc, flatWorkload, workloadI);
}
for (int i = 0; i < 3; ++i) {
// Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
auto rounded = builder.createOrFold<mlir::SubIOp>(
loc,
builder.createOrFold<mlir::AddIOp>(loc, flatWorkload,
workgroupSize[i]),
constantOne);
auto workgroupCountI = builder.createOrFold<mlir::UnsignedDivIOp>(
loc, rounded, workgroupSize[i]);
result[i] = workgroupCountI;
// Multiply back out and subtract from invocations.
flatWorkload = builder.createOrFold<SubIOp>(
loc, flatWorkload,
builder.createOrFold<MulIOp>(loc, workgroupCountI, rounded));
}
}
return result;
}
LogicalResult TargetBackend::recordDispatch(
Location loc, DispatchState dispatchState,
DeviceSwitchRewriter &switchRewriter) {
SmallVector<Value, 4> regionArgs;
regionArgs.push_back(dispatchState.commandBuffer);
for (auto dim : dispatchState.workgroupCount) {
regionArgs.push_back(dim);
}
auto *region = switchRewriter.addConditionRegion(
IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
regionArgs);
auto &entryBlock = region->front();
auto commandBuffer = entryBlock.getArgument(0);
SmallVector<Value, 3> originalWorkgroupCount;
for (int i = 0; i < dispatchState.workgroupCount.size(); ++i) {
originalWorkgroupCount.push_back(entryBlock.getArgument(1 + i));
}
auto builder = OpBuilder::atBlockBegin(&entryBlock);
auto entryPointSymRef = builder.getSymbolRefAttr(
dispatchState.executableOp.getName(),
{builder.getSymbolRefAttr(dispatchState.entryPointOp->getParentOp()),
builder.getSymbolRefAttr(dispatchState.entryPointOp)});
auto remappedWorkgroupCount = calculateDispatchWorkgroupCount(
loc, dispatchState.executableOp, dispatchState.entryPointOp,
originalWorkgroupCount, builder);
builder.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
loc, commandBuffer, entryPointSymRef, remappedWorkgroupCount[0],
remappedWorkgroupCount[1], remappedWorkgroupCount[2]);
builder.create<IREE::HAL::ReturnOp>(loc);
return success();
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir