blob: 2442ca6aa6effa11c31c80857a61f9939133a51b [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <memory>
#include <utility>
#include "iree/compiler/Dialect/HAL/Analysis/BindingLayout.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-hal-materialize-interfaces"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
namespace {
//===----------------------------------------------------------------------===//
// hal.executable.variant creation
//===----------------------------------------------------------------------===//
// Creates zero or more hal.executable.variant ops for each target backend.
// The source op will contain the flow.executable contents and any attributes
// the backend wants to carry along during transformation.
static LogicalResult declareVariantOps(IREE::Stream::ExecutableOp sourceOp,
IREE::HAL::ExecutableOp executableOp) {
// Gather a list of all #hal.executable.targets that we should produce
// variants for.
auto executableTargetAttrs =
IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(sourceOp);
if (executableTargetAttrs.empty()) {
return sourceOp.emitError()
<< "no executable targets specified for translation";
}
// Materialize all of the hal.executable.variant ops for all backends we are
// targeting. Note that each backend may create zero or more target ops.
SymbolTable targetSymbolTable(executableOp);
OpBuilder targetBuilder(&executableOp.getBlock().back());
for (auto &targetAttr : executableTargetAttrs) {
auto targetContainerOp =
targetBuilder.create<IREE::HAL::ExecutableVariantOp>(
sourceOp.getLoc(), targetAttr.getSymbolNameFragment(), targetAttr);
targetSymbolTable.insert(targetContainerOp);
OpBuilder containerBuilder(&targetContainerOp.getBlock().back());
containerBuilder.create<mlir::ModuleOp>(sourceOp.getLoc());
}
// Ensure that at least one target op got created. If it didn't that means
// the executable cannot be translated and it's better to fail now.
if (executableOp.getBlock()
.getOps<IREE::HAL::ExecutableVariantOp>()
.empty()) {
auto diagnostic = sourceOp.emitError();
diagnostic
<< "no target backend was able to handle this executable; tried = [ ";
for (const auto &targetAttr : executableTargetAttrs) {
diagnostic << targetAttr.getFormat() << " ";
}
diagnostic << "]";
return diagnostic;
}
return success();
}
//===----------------------------------------------------------------------===//
// Interface definition
//===----------------------------------------------------------------------===//
// Verifies that all types used with the given entry point are supportable.
static LogicalResult verifyEntryPointTypes(mlir::FuncOp entryFuncOp) {
for (auto inputType : llvm::enumerate(entryFuncOp.getType().getInputs())) {
if (inputType.value().isa<IREE::Stream::BindingType>()) {
// OK - directly translates to a HAL interface binding.
} else if (inputType.value().isa<IndexType>()) {
// Index types are converted to platform bit-width later on.
// TODO(benvanik): pick something here that the target devices support.
} else if (auto integerType = inputType.value().dyn_cast<IntegerType>()) {
if (integerType.getIntOrFloatBitWidth() != 32) {
return entryFuncOp.emitError()
<< "unsupported argument " << inputType.index() << " bit depth "
<< integerType.getIntOrFloatBitWidth() << " (" << integerType
<< "); only 32-bit values are supported right now";
}
} else {
return entryFuncOp.emitError()
<< "unsupported interface function argument " << inputType.index()
<< " type " << inputType.value()
<< "; requires tensors or simple primitive values (i32, etc)";
}
}
return success();
}
struct Interface {
// Materialized interface op with binding symbols.
IREE::HAL::InterfaceOp op;
// 1:1 with the function signature bindings. May be a subset of the interface.
SmallVector<IREE::HAL::InterfaceBindingOp> resourceBindings;
};
// Creates an interface from an executable layout provided from analysis.
static Interface createInterface(Location loc,
const ExecutableLayout &executableLayout,
OpBuilder &executableBuilder) {
Interface interface;
interface.op = executableBuilder.create<IREE::HAL::InterfaceOp>(loc, "io");
interface.op.push_constantsAttr(
executableBuilder.getIndexAttr(executableLayout.pushConstantCount));
auto interfaceBuilder = OpBuilder::atBlockBegin(&interface.op.body().front());
DenseMap<std::pair<unsigned, unsigned>, IREE::HAL::InterfaceBindingOp>
bindingMap;
for (const auto &setLayout : executableLayout.setLayouts) {
for (const auto &binding : setLayout.bindings) {
std::string bindingName = "s" + std::to_string(setLayout.ordinal) + "b" +
std::to_string(binding.ordinal);
if (allEnumBitsSet(binding.access,
IREE::HAL::MemoryAccessBitfield::Read |
IREE::HAL::MemoryAccessBitfield::Write)) {
bindingName += "_rw";
} else if (allEnumBitsSet(binding.access,
IREE::HAL::MemoryAccessBitfield::Read)) {
bindingName += "_ro";
} else if (allEnumBitsSet(binding.access,
IREE::HAL::MemoryAccessBitfield::Discard |
IREE::HAL::MemoryAccessBitfield::Write)) {
bindingName += "_xw";
} else if (allEnumBitsSet(binding.access,
IREE::HAL::MemoryAccessBitfield::Write)) {
bindingName += "_wo";
}
auto bindingOp = interfaceBuilder.create<IREE::HAL::InterfaceBindingOp>(
interface.op.getLoc(), bindingName,
/*set=*/APInt(64, setLayout.ordinal),
/*binding=*/APInt(64, binding.ordinal), binding.type, binding.access);
bindingMap.insert(
{std::make_pair(setLayout.ordinal, binding.ordinal), bindingOp});
}
}
for (auto setBinding : executableLayout.resourceMap) {
interface.resourceBindings.push_back(bindingMap[setBinding]);
}
return interface;
}
// Converts the usage of the given primitive |arg| to interface methods.
static void convertOperandUsage(mlir::FuncOp sourceFuncOp, BlockArgument arg,
unsigned pushConstantIdx, OpBuilder &builder) {
auto alignmentAttr = sourceFuncOp.getArgAttrOfType<IntegerAttr>(
arg.getArgNumber(), "stream.alignment");
auto valuesAttr = sourceFuncOp.getArgAttrOfType<ArrayAttr>(arg.getArgNumber(),
"stream.values");
auto loadOp = builder.create<IREE::HAL::InterfaceLoadConstantOp>(
arg.getLoc(), arg.getType(), builder.getIndexAttr(pushConstantIdx),
alignmentAttr, valuesAttr);
arg.replaceAllUsesWith(loadOp);
}
// Converts the usage of the given !stream.binding |arg| to interface methods.
static void convertBindingUsage(mlir::FuncOp sourceFuncOp, BlockArgument arg,
IREE::HAL::InterfaceOp interfaceOp,
IREE::HAL::InterfaceBindingOp bindingOp) {
if (arg.use_empty()) return; // no-op
for (auto &use : llvm::make_early_inc_range(arg.getUses())) {
auto oldOp = dyn_cast<IREE::Stream::BindingSubspanOp>(use.getOwner());
assert(oldOp && "bindings are only usable by stream.binding.subspan");
OpBuilder builder(oldOp);
auto alignmentAttr = sourceFuncOp.getArgAttrOfType<IntegerAttr>(
arg.getArgNumber(), "stream.alignment");
auto newOp = builder.create<IREE::HAL::InterfaceBindingSubspanOp>(
oldOp.getLoc(), oldOp.getType(),
SymbolRefAttr::get(interfaceOp.sym_nameAttr(),
{SymbolRefAttr::get(bindingOp)}),
oldOp.byte_offset(), /*byte_length=*/Value{}, oldOp.dynamic_dims(),
alignmentAttr);
oldOp.replaceAllUsesWith(newOp.result());
oldOp.erase();
}
}
// Clones |sourceFuncOp| and updates its signature to match the |interfaceOp|
// and use the HAL interface access primitives.
static mlir::FuncOp cloneFuncWithInterface(
mlir::FuncOp sourceFuncOp, const ExecutableLayout &executableLayout,
Interface &interface) {
// Clone so that we can do a bunch of unsafe in-place updates.
auto clonedFuncOp = sourceFuncOp.clone();
// Strip all arguments as functions take all I/O through the interface API.
clonedFuncOp.setType(FunctionType::get(clonedFuncOp.getContext(), {}, {}));
auto *entryBlock = &clonedFuncOp.front();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
// Change the interface from arguments to hal.interface.* methods.
// We do push constant compatible operands first so that they are available
// for use by the binding accessors.
unsigned operandIdx = 0;
for (auto arg : entryBlock->getArguments()) {
if (!arg.getType().isa<IREE::Stream::BindingType>()) {
// TODO(benvanik): symbolic push constant indices.
convertOperandUsage(sourceFuncOp, arg, operandIdx++, entryBuilder);
}
}
unsigned bindingIdx = 0;
for (auto arg : entryBlock->getArguments()) {
if (arg.getType().isa<IREE::Stream::BindingType>()) {
convertBindingUsage(sourceFuncOp, arg, interface.op,
interface.resourceBindings[bindingIdx++]);
}
}
// Remove all arguments now that we've turned them into lookup ops.
entryBlock->eraseArguments([](auto arg) { return true; });
return clonedFuncOp;
}
// Annotates |dispatchOp| with resource binding to interface binding mappings.
// TODO(benvanik): have a HAL op with structured information instead.
static void annotateDispatchSite(IREE::Stream::CmdDispatchOp dispatchOp,
Interface &interface) {
SmallVector<Attribute> bindingSymbols;
for (auto resourceBinding : interface.resourceBindings) {
bindingSymbols.push_back(
SymbolRefAttr::get(dispatchOp.entry_pointAttr().getRootReference(),
{SymbolRefAttr::get(interface.op),
SymbolRefAttr::get(resourceBinding)}));
}
dispatchOp->setAttr("hal.interface.bindings",
ArrayAttr::get(dispatchOp.getContext(), bindingSymbols));
}
// Adds the entry point ops with assigned ordinals for each entry function.
// The entry points will all use the provided |interfaceOp|.
static LogicalResult declareEntryPointOps(
IREE::Stream::ExecutableOp sourceExecutableOp,
IREE::HAL::ExecutableOp targetExecutableOp,
const BindingLayoutAnalysis &layoutAnalysis) {
auto variantOps =
targetExecutableOp.getBlock().getOps<IREE::HAL::ExecutableVariantOp>();
OpBuilder executableBuilder(&targetExecutableOp.getBlock().front());
// For each exported function create a HAL entry point and dispatch thunk.
int nextOrdinal = 0;
for (auto exportOp :
sourceExecutableOp.body().getOps<IREE::Stream::ExecutableExportOp>()) {
int ordinal = nextOrdinal++;
auto sourceFuncOp =
sourceExecutableOp.getInnerModule().lookupSymbol<mlir::FuncOp>(
exportOp.function_ref());
if (failed(verifyEntryPointTypes(sourceFuncOp))) return failure();
const auto &executableLayout = layoutAnalysis.getExecutableLayout(exportOp);
// Create the interface for this entry point based on the analysis of its
// usage within the program.
auto interface = createInterface(sourceFuncOp.getLoc(), executableLayout,
executableBuilder);
// Clone the source function and update it to use the new interface.
auto baseFuncOp =
cloneFuncWithInterface(sourceFuncOp, executableLayout, interface);
// Clone the updated function into each variant.
for (auto variantOp : variantOps) {
// Declare the entry point on the target.
OpBuilder targetBuilder(&variantOp.getBlock().front());
targetBuilder.create<IREE::HAL::ExecutableEntryPointOp>(
exportOp.getLoc(),
targetBuilder.getStringAttr(exportOp.function_ref()),
targetBuilder.getIndexAttr(ordinal), SymbolRefAttr::get(interface.op),
ArrayAttr{}, IntegerAttr{});
// Clone the updated interface-based function into the target.
auto targetFuncOp = baseFuncOp.clone();
variantOp.getInnerModule().push_back(targetFuncOp);
// Copy interface bindings into the target module so symbol references
// work.
auto inlinedInterfaceOp = interface.op.clone();
inlinedInterfaceOp.setPrivate();
variantOp.getInnerModule().push_back(inlinedInterfaceOp);
}
// Update all dispatch sites with the binding information.
for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
annotateDispatchSite(dispatchOp, interface);
}
baseFuncOp.erase();
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.* info op conversion
//===----------------------------------------------------------------------===//
namespace {
template <typename SrcOp, typename DstOp>
struct ConvertDispatchWorkgroupInfoPattern final
: public OpRewritePattern<SrcOp> {
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<DstOp>(op, op.getResult().getType(),
op.dimensionAttr());
return success();
}
};
struct InlineConstantWorkgroupSizePattern
: public OpRewritePattern<IREE::HAL::InterfaceWorkgroupSizeOp> {
using OpRewritePattern<IREE::HAL::InterfaceWorkgroupSizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupSizeOp sizeOp,
PatternRewriter &rewriter) const override {
// Lookup the entry point matching the parent.
auto funcOp = sizeOp->getParentOfType<mlir::FuncOp>();
auto variantOp = funcOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
auto entryPointOp = dyn_cast<IREE::HAL::ExecutableEntryPointOp>(
SymbolTable::lookupSymbolIn(variantOp, funcOp.getName()));
assert(entryPointOp &&
"must have an entry point corresponding to the parent func");
auto workgroupSizeAttr = entryPointOp.workgroup_sizeAttr();
if (!workgroupSizeAttr) return failure();
uint64_t dimIdx = sizeOp.dimension().getZExtValue();
auto dimAttr = workgroupSizeAttr[dimIdx];
rewriter.replaceOpWithNewOp<arith::ConstantOp>(sizeOp, dimAttr,
rewriter.getIndexType());
return success();
}
};
} // namespace
static LogicalResult convertFlowInfoOps(IREE::HAL::ExecutableOp executableOp) {
OwningRewritePatternList patterns(executableOp.getContext());
patterns.insert<
ConvertDispatchWorkgroupInfoPattern<IREE::Flow::DispatchWorkgroupIDOp,
IREE::HAL::InterfaceWorkgroupIDOp>,
ConvertDispatchWorkgroupInfoPattern<IREE::Flow::DispatchWorkgroupCountOp,
IREE::HAL::InterfaceWorkgroupCountOp>,
ConvertDispatchWorkgroupInfoPattern<IREE::Flow::DispatchWorkgroupSizeOp,
IREE::HAL::InterfaceWorkgroupSizeOp>,
InlineConstantWorkgroupSizePattern>(executableOp.getContext());
return applyPatternsAndFoldGreedily(executableOp, std::move(patterns));
}
//===----------------------------------------------------------------------===//
// -iree-hal-materialize-interfaces2
//===----------------------------------------------------------------------===//
class MaterializeInterfacesPass
: public PassWrapper<MaterializeInterfacesPass, OperationPass<ModuleOp>> {
public:
MaterializeInterfacesPass() = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::HAL::HALDialect>();
}
StringRef getArgument() const override {
return "iree-hal-materialize-interfaces2";
}
StringRef getDescription() const override {
return "Materializes hal.executable ops from stream.executable ops";
}
void runOnOperation() override {
SymbolTable symbolTable(getOperation());
const auto &layoutAnalysis = getAnalysis<BindingLayoutAnalysis>();
// Processes all executables within the input module and produce the
// output HAL ops. We should ensure all deduping is performed prior to
// this when it's easier to diff IR and where we still have the flow
// context.
auto sourceOps = llvm::to_vector<32>(
getOperation().getOps<IREE::Stream::ExecutableOp>());
for (auto sourceOp : sourceOps) {
auto exportOps = sourceOp.getOps<IREE::Stream::ExecutableExportOp>();
if (exportOps.empty()) continue;
// Create the op that will contain the translated executable.
OpBuilder builder = OpBuilder::atBlockEnd(getOperation().getBody());
builder.setInsertionPointAfter(sourceOp);
auto executableOp = builder.create<IREE::HAL::ExecutableOp>(
sourceOp.getLoc(), sourceOp.getName());
executableOp.setVisibility(sourceOp.getVisibility());
// Embed the hal.executable.variant ops for each source.
if (failed(declareVariantOps(sourceOp, executableOp))) {
return signalPassFailure();
}
// Define interfaces for each exported function based on analysis.
if (failed(
declareEntryPointOps(sourceOp, executableOp, layoutAnalysis))) {
return signalPassFailure();
}
// Convert interface-related flow.dispatch.* ops to their hal.interface.*
// versions.
if (failed(convertFlowInfoOps(executableOp))) {
return signalPassFailure();
}
sourceOp.erase();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> createMaterializeInterfacesPass() {
return std::make_unique<MaterializeInterfacesPass>();
}
static PassRegistration<MaterializeInterfacesPass> pass([] {
return std::make_unique<MaterializeInterfacesPass>();
});
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir