blob: 0aae4bd80772945d6322cffc4cd2e53f42b5eea9 [file] [log] [blame] [edit]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Modules/HAL/Inline/IR/HALInlineDialect.h"
#include "iree/compiler/Modules/HAL/Inline/Transforms/Passes.h"
#include "iree/compiler/Utils/IntegerSet.h"
#include "iree/compiler/Utils/ModuleUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
namespace mlir::iree_compiler::IREE::HAL::Inline {
#define GEN_PASS_DEF_INLINEEXECUTABLESPASS
#include "iree/compiler/Modules/HAL/Inline/Transforms/Passes.h.inc"
namespace {
class InlineExecutablesPass final
: public impl::InlineExecutablesPassBase<InlineExecutablesPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Util::UtilDialect, IREE::HAL::HALDialect,
IREE::HAL::Inline::HALInlineDialect, arith::ArithDialect,
func::FuncDialect, scf::SCFDialect>();
}
void runOnOperation() override {
mlir::ModuleOp moduleOp = getOperation();
// Inline variants and produce a function map.
DenseMap<Attribute, Attribute> exportToFuncMap;
SymbolTableCollection symbolTables;
for (auto executableOp : llvm::make_early_inc_range(
moduleOp.getOps<IREE::HAL::ExecutableOp>())) {
// Inline each variant.
for (auto variantOp :
executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
if (failed(inlineVariant(executableOp, variantOp, moduleOp,
exportToFuncMap, symbolTables))) {
return signalPassFailure();
}
}
// Drop executable after information has been extracted and the workgroup
// code has been inlined.
executableOp.erase();
}
// Annotate all dispatches with the target function.
for (auto funcOp : moduleOp.getOps<mlir::FunctionOpInterface>()) {
auto result = funcOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
// Specify new target function that conversion can use to make the call.
// We only support single variant dispatches when inline.
auto entryPointAttrs = dispatchOp.getEntryPoints().getValue();
if (entryPointAttrs.size() != 1) {
dispatchOp.emitOpError()
<< "multiple variant targets not supported with the inline HAL";
return WalkResult::interrupt();
}
auto targetFuncName =
cast<StringAttr>(exportToFuncMap[entryPointAttrs.front()]);
assert(targetFuncName && "missing mapping");
dispatchOp->setAttr("hal_inline.target",
FlatSymbolRefAttr::get(targetFuncName));
return WalkResult::advance();
});
if (result.wasInterrupted()) {
return signalPassFailure();
}
}
}
LogicalResult inlineVariant(IREE::HAL::ExecutableOp executableOp,
IREE::HAL::ExecutableVariantOp variantOp,
mlir::ModuleOp targetModuleOp,
DenseMap<Attribute, Attribute> &exportToFuncMap,
SymbolTableCollection &symbolTables) {
auto innerModuleOp = variantOp.getInnerModule();
auto innerSymbolTable = symbolTables.getSymbolTable(innerModuleOp);
auto innerModuleBuilder = OpBuilder::atBlockEnd(innerModuleOp.getBody());
// We want to merge the module ahead of the exported functions to ensure
// initializer order is preserved.
OpBuilder targetModuleBuilder(executableOp);
// Build each dispatch function wrapper.
auto indexType = innerModuleBuilder.getIndexType();
auto i32Type = innerModuleBuilder.getI32Type();
auto bufferType = innerModuleBuilder.getType<IREE::Util::BufferType>();
for (auto exportOp : variantOp.getExportOps()) {
// Build dispatch function signature that the stream.cmd.dispatch ops will
// map to.
auto layoutAttr = exportOp.getLayout();
size_t bindingCount = layoutAttr.getBindings().size();
SmallVector<Type> inputTypes;
inputTypes.append(exportOp.getWorkgroupCountBody()->getNumArguments() - 1,
indexType); // workload
inputTypes.append(layoutAttr.getConstants(), i32Type);
inputTypes.append(bindingCount, bufferType); // buffers
inputTypes.append(bindingCount, indexType); // offsets
inputTypes.append(bindingCount, indexType); // lengths
auto dispatchFuncType =
innerModuleBuilder.getFunctionType(inputTypes, {});
// Create the function and insert into the module.
auto dispatchFuncOp = IREE::Util::FuncOp::create(
exportOp.getLoc(),
("__dispatch_" + executableOp.getName() + "_" + exportOp.getName())
.str(),
dispatchFuncType);
dispatchFuncOp.setPrivate();
innerSymbolTable.insert(dispatchFuncOp,
innerModuleBuilder.getInsertionPoint());
innerModuleBuilder.setInsertionPointAfter(dispatchFuncOp);
// Build the dispatch function by calling the target function in a loop.
auto bodyFuncOp =
innerSymbolTable.lookup<FunctionOpInterface>(exportOp.getName());
if (!bodyFuncOp) {
return exportOp.emitOpError("missing body function");
}
if (bodyFuncOp.isPublic()) {
if (failed(rewriteWorkgroupSignature(layoutAttr, bindingCount,
bodyFuncOp))) {
return failure();
}
bodyFuncOp.setPrivate(); // so we only do it once
}
buildDispatchFunc(exportOp, layoutAttr, bindingCount, bodyFuncOp,
dispatchFuncOp);
// Map from what the stream.cmd.dispatch ops is using to the new function.
auto exportTargetAttr = SymbolRefAttr::get(
executableOp.getNameAttr(),
{
FlatSymbolRefAttr::get(variantOp.getNameAttr()),
FlatSymbolRefAttr::get(exportOp.getNameAttr()),
});
exportToFuncMap[exportTargetAttr] = dispatchFuncOp.getNameAttr();
}
// Merge the source executable module into the target host module.
if (failed(mergeModuleInto(innerModuleOp, targetModuleOp,
targetModuleBuilder))) {
return failure();
}
return success();
}
// Rewrites a workgroup body function signature to a flattened list.
//
// Body (as translated):
// (local_memory, [constants], [bindings],
// workgroup_x, workgroup_y, workgroup_z,
// workgroup_size_x, workgroup_size_y, workgroup_size_z,
// workgroup_count_x, workgroup_count_y, workgroup_count_z)
//
// Body after rewrite:
// (local_memory, constants..., bindings...,
// workgroup_x, workgroup_y, workgroup_z,
// workgroup_size_x, workgroup_size_y, workgroup_size_z,
// workgroup_count_x, workgroup_count_y, workgroup_count_z)
//
// To make this process easier and lighten the load on the downstream passes
// we muck with the ABI to pass a flattened list of constants and bindings.
// Whenever better IPO and util.list optimizations are added we could back
// this out to keep things vanilla and have fewer places making assumptions
// about the function signatures.
LogicalResult
rewriteWorkgroupSignature(IREE::HAL::PipelineLayoutAttr layoutAttr,
size_t bindingCount,
FunctionOpInterface bodyFuncOp) {
auto *entryBlock = &bodyFuncOp.front();
auto builder = OpBuilder::atBlockBegin(entryBlock);
auto indexType = builder.getIndexType();
auto i32Type = builder.getI32Type();
auto bufferType = builder.getType<IREE::Util::BufferType>();
// There may be nicer ways of doing this but I can't find them.
// We build a new list of argument types and insert them as we go. This lets
// us map the arguments over and replace usage such that by the end we can
// slice off the original arguments as they'll have no more uses.
unsigned originalArgCount = entryBlock->getNumArguments();
SmallVector<Type> newArgTypes;
unsigned argOffset = 0;
// Local memory is carried across as-is.
auto localMemoryArg = entryBlock->getArgument(argOffset++);
newArgTypes.push_back(bufferType);
localMemoryArg.replaceAllUsesWith(
entryBlock->addArgument(bufferType, localMemoryArg.getLoc()));
// Expand push constants by replacing buffer accesses with the flattened
// args.
newArgTypes.append(layoutAttr.getConstants(), i32Type);
auto constantBuffer = entryBlock->getArgument(argOffset++);
SmallVector<Value> constantArgs;
for (unsigned i = 0; i < layoutAttr.getConstants(); ++i) {
constantArgs.push_back(
entryBlock->addArgument(i32Type, constantBuffer.getLoc()));
}
if (failed(replaceBufferAccesses(constantBuffer, constantArgs))) {
return failure();
}
// Expand buffer list by replacing list accesses with the flattened args.
newArgTypes.append(bindingCount, bufferType);
auto bindingList = entryBlock->getArgument(argOffset++);
SmallVector<Value> bindingArgs;
for (unsigned i = 0; i < bindingCount; ++i) {
bindingArgs.push_back(
entryBlock->addArgument(bufferType, bindingList.getLoc()));
}
if (failed(replaceListAccesses(bindingList, bindingArgs))) {
return failure();
}
// Take care of the workgroup id/size/count tuples.
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
for (unsigned i = 0; i < 3 * /*xyz=*/3; ++i) {
newArgTypes.push_back(indexType);
auto oldArg = entryBlock->getArgument(argOffset++);
auto newArg = entryBlock->addArgument(indexType, oldArg.getLoc());
oldArg.replaceAllUsesWith(arith::IndexCastOp::create(
entryBuilder, oldArg.getLoc(), i32Type, newArg));
}
// Erase the original args.
for (unsigned i = 0; i < originalArgCount; ++i) {
entryBlock->eraseArgument(0);
}
// Update function signature to reflect the entry block args.
bodyFuncOp.setType(
builder.getFunctionType(newArgTypes, bodyFuncOp.getResultTypes()));
return success();
}
// Replaces trivial constant index accesses to a buffer with their values.
// This is an extremely poor optimization that we should remove if buffer
// ever gets store-load forwarding - we could just create the buffer, store
// the elements, and let that take care of the rest. Today it doesn't do that.
LogicalResult replaceBufferAccesses(Value buffer, ValueRange elements) {
for (auto user : llvm::make_early_inc_range(buffer.getUsers())) {
if (auto sizeOp = dyn_cast<IREE::Util::BufferSizeOp>(user)) {
// Ignored but we need to get rid of it.
// TODO(benvanik): see if we can allow this through; today it will pin
// the function argument (constants most likely) and cause us to fail to
// remove it later on.
OpBuilder builder(sizeOp);
Value dummySize = arith::ConstantIndexOp::create(
builder, sizeOp.getLoc(), 0xCAFEF00D);
sizeOp.replaceAllUsesWith(dummySize);
sizeOp.erase();
continue;
} else if (auto loadOp = dyn_cast<IREE::Util::BufferLoadOp>(user)) {
APInt index;
if (matchPattern(loadOp.getSourceOffset(), m_ConstantInt(&index))) {
loadOp.replaceAllUsesWith(
elements[index.getSExtValue() / sizeof(uint32_t)]);
loadOp.erase();
continue;
} else {
return loadOp.emitOpError(
"unhandled dynamic buffer access; must be static");
}
} else if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
if (loadOp.getIndices().size() != 1) {
return loadOp.emitOpError(
"expected memrefs to have been flattened before inlining "
"executables");
}
APInt index;
if (matchPattern(loadOp.getIndices()[0], m_ConstantInt(&index))) {
loadOp.replaceAllUsesWith(elements[index.getSExtValue()]);
loadOp.erase();
continue;
} else {
return loadOp.emitOpError(
"unhandled dynamic buffer access; must be static");
}
} else {
return user->emitOpError(
"unhandled buffer access op; only loads are supported");
}
}
return success();
}
// Replaces trivial constant index accesses to a list with their values.
// util.list store-load forwarding could do this instead.
LogicalResult replaceListAccesses(Value list, ValueRange elements) {
for (auto user : llvm::make_early_inc_range(list.getUsers())) {
if (auto getOp = dyn_cast<IREE::Util::ListGetOp>(user)) {
APInt index;
if (matchPattern(getOp.getIndex(), m_ConstantInt(&index))) {
getOp.replaceAllUsesWith(elements[index.getSExtValue()]);
getOp.erase();
continue;
} else {
return getOp.emitOpError(
"unhandled dynamic list access; must be static");
}
} else {
return user->emitOpError(
"unhandled list access op; only gets are supported");
}
}
return success();
}
// Builds a function that calls a workgroup body and marshals arguments.
//
// Incoming:
// (workload..., constants...,
// binding_buffers..., binding_offsets..., binding_lengths...)
// Body (as translated):
// (local_memory, [constants], [bindings],
// workgroup_x, workgroup_y, workgroup_z,
// workgroup_size_x, workgroup_size_y, workgroup_size_z,
// workgroup_count_x, workgroup_count_y, workgroup_count_z)
void buildDispatchFunc(IREE::HAL::ExecutableExportOp exportOp,
IREE::HAL::PipelineLayoutAttr layoutAttr,
size_t bindingCount, FunctionOpInterface bodyFuncOp,
FunctionOpInterface dispatchFuncOp) {
auto loc = exportOp.getLoc();
auto builder = OpBuilder::atBlockBegin(dispatchFuncOp.addEntryBlock());
IndexSet indexSet(loc, builder);
auto bufferType = builder.getType<IREE::Util::BufferType>();
SmallVector<Value> workgroupArgs;
// Calculate the XYZ workgroup count from the export function.
// There may be multiple exports pointing at the same body with different
// workgroup count functions.
unsigned workloadArgCount =
exportOp.getWorkgroupCountBody()->getNumArguments() - 1;
unsigned argOffset = 0;
SmallVector<Value> workload;
workload.reserve(workloadArgCount);
for (unsigned i = 0; i < workloadArgCount; ++i) {
workload.push_back(dispatchFuncOp.getArgument(argOffset++));
}
Value device = IREE::Util::NullOp::create(
builder, loc, builder.getType<IREE::HAL::DeviceType>());
std::array<Value, 3> workgroupCount =
exportOp.calculateWorkgroupCount(loc, device, workload, builder);
// For now we don't handle local memory.
Value localMemory = IREE::Util::NullOp::create(builder, loc, bufferType);
workgroupArgs.push_back(localMemory);
// Pass all constants through.
for (int64_t i = 0; i < layoutAttr.getConstants(); ++i) {
workgroupArgs.push_back(dispatchFuncOp.getArgument(argOffset++));
}
// Pass all buffers through as subspans with the binding offset and length
// factored in. IPO can propagate the subspans (hopefully).
for (size_t i = 0; i < bindingCount; ++i) {
auto bindingBuffer = dispatchFuncOp.getArgument(argOffset + i);
auto bindingOffset =
dispatchFuncOp.getArgument(argOffset + bindingCount + i);
auto bindingLength = dispatchFuncOp.getArgument(argOffset + bindingCount +
bindingCount + i);
Value bufferSize =
IREE::Util::BufferSizeOp::create(builder, loc, bindingBuffer);
Value bindingView = IREE::Util::BufferSubspanOp::create(
builder, loc, bindingBuffer, bufferSize, bindingOffset,
bindingLength);
workgroupArgs.push_back(bindingView);
}
int workgroupXYZOffset = workgroupArgs.size();
workgroupArgs.append(3, nullptr); // workgroup_xyz, set below
workgroupArgs.append(3, indexSet.get(1)); // workgroup_size_xyz
llvm::append_range(workgroupArgs, workgroupCount); // workgroup_count_xyz
// Z -> Y -> Z loop nest.
scf::ForOp::create(
builder, loc, indexSet.get(0), workgroupCount[2], indexSet.get(1),
ValueRange{},
[&](OpBuilder &forZBuilder, Location loc, Value iz, ValueRange iters) {
workgroupArgs[workgroupXYZOffset + 2] = iz;
scf::ForOp::create(
forZBuilder, loc, indexSet.get(0), workgroupCount[1],
indexSet.get(1), ValueRange{},
[&](OpBuilder &forYBuilder, Location loc, Value iy,
ValueRange iters) {
workgroupArgs[workgroupXYZOffset + 1] = iy;
scf::ForOp::create(
forYBuilder, loc, indexSet.get(0), workgroupCount[0],
indexSet.get(1), ValueRange{},
[&](OpBuilder &forXBuilder, Location loc, Value ix,
ValueRange iters) {
workgroupArgs[workgroupXYZOffset + 0] = ix;
func::CallOp::create(
forXBuilder, loc, bodyFuncOp.getNameAttr(),
bodyFuncOp.getResultTypes(), workgroupArgs);
scf::YieldOp::create(forXBuilder, loc);
});
scf::YieldOp::create(forYBuilder, loc);
});
scf::YieldOp::create(forZBuilder, loc);
});
IREE::Util::ReturnOp::create(builder, loc);
}
};
} // namespace
} // namespace mlir::iree_compiler::IREE::HAL::Inline