blob: d7824584b86453272a385fdcc37a159fa13f1845 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <memory>
#include <utility>
#include "iree/compiler/Dialect/HAL/Analysis/BindingLayout.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-hal-materialize-interfaces"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
namespace {
//===----------------------------------------------------------------------===//
// hal.executable.source materialization
//===----------------------------------------------------------------------===//
static LogicalResult materializeExecutableFromSourceOp(
IREE::HAL::ExecutableSourceOp sourceOp,
ArrayRef<IREE::HAL::ExecutableTargetAttr> targetAttrs) {
OpBuilder moduleBuilder(sourceOp);
// Create the op that will contain the translated executable.
auto executableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
sourceOp.getLoc(), sourceOp.getName());
executableOp.setVisibility(sourceOp.getVisibility());
// With this hand-authored path all variants have the same layout and entry
// points and we can just clone them.
auto sourceEntryPointOps =
sourceOp.getOps<IREE::HAL::ExecutableEntryPointOp>();
auto sourceModuleOp = sourceOp.getInnerModule();
// Materialize all of the hal.executable.variant ops for all backends we are
// targeting.
SymbolTable targetSymbolTable(executableOp);
OpBuilder targetBuilder(&executableOp.getBlock().back());
for (auto targetAttr : targetAttrs) {
auto targetVariantOp = targetBuilder.create<IREE::HAL::ExecutableVariantOp>(
sourceOp->getLoc(), targetAttr.getSymbolNameFragment(), targetAttr);
targetSymbolTable.insert(targetVariantOp);
OpBuilder variantBuilder(&targetVariantOp.getBlock().back());
for (auto sourceEntryPointOp : sourceEntryPointOps) {
variantBuilder.clone(*sourceEntryPointOp);
}
variantBuilder.clone(*sourceModuleOp);
}
// Remove the original.
sourceOp.erase();
return success();
}
static LogicalResult materializeExecutablesFromSourceOps(
mlir::ModuleOp moduleOp) {
auto sourceOps =
llvm::to_vector<32>(moduleOp.getOps<IREE::HAL::ExecutableSourceOp>());
for (auto sourceOp : sourceOps) {
// Gather a list of all #hal.executable.targets that we should produce
// variants for.
auto targetAttrs =
IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(sourceOp);
if (targetAttrs.empty()) {
return sourceOp.emitError()
<< "no executable targets specified for translation";
}
if (failed(materializeExecutableFromSourceOp(sourceOp, targetAttrs))) {
return failure();
}
}
return success();
}
//===----------------------------------------------------------------------===//
// Interface definition
//===----------------------------------------------------------------------===//
// Verifies that all types used with the given entry point are supportable.
static LogicalResult verifyEntryPointTypes(mlir::func::FuncOp entryFuncOp) {
for (auto inputType :
llvm::enumerate(entryFuncOp.getFunctionType().getInputs())) {
if (inputType.value().isa<IREE::Stream::BindingType>() ||
inputType.value().isInteger(32)) {
// OK - directly translates to a HAL interface binding.
} else {
return entryFuncOp.emitError()
<< "unsupported interface function argument " << inputType.index()
<< " type " << inputType.value()
<< "; requires !stream.binding or i32 operands only";
}
}
return success();
}
// Creates an executable layout attr from the analysis results.
static IREE::HAL::ExecutableLayoutAttr makeExecutableLayoutAttr(
const ExecutableLayout &executableLayout, OpBuilder &builder) {
SmallVector<IREE::HAL::DescriptorSetLayoutAttr> setLayoutAttrs;
for (const auto &setLayout : executableLayout.setLayouts) {
SmallVector<IREE::HAL::DescriptorSetBindingAttr> bindingAttrs;
for (const auto &binding : setLayout.bindings) {
bindingAttrs.push_back(IREE::HAL::DescriptorSetBindingAttr::get(
builder.getContext(), binding.ordinal, binding.type));
}
setLayoutAttrs.push_back(IREE::HAL::DescriptorSetLayoutAttr::get(
builder.getContext(), setLayout.ordinal, bindingAttrs));
}
return IREE::HAL::ExecutableLayoutAttr::get(
builder.getContext(), executableLayout.pushConstantCount, setLayoutAttrs);
}
// Converts the usage of the given primitive |arg| to interface methods.
static void convertOperandUsage(mlir::func::FuncOp sourceFuncOp,
BlockArgument arg, unsigned pushConstantIdx,
OpBuilder &builder) {
auto alignmentAttr = sourceFuncOp.getArgAttrOfType<IntegerAttr>(
arg.getArgNumber(), "stream.alignment");
auto valuesAttr = sourceFuncOp.getArgAttrOfType<ArrayAttr>(arg.getArgNumber(),
"stream.values");
auto loadOp = builder.create<IREE::HAL::InterfaceConstantLoadOp>(
arg.getLoc(), arg.getType(), builder.getIndexAttr(pushConstantIdx),
alignmentAttr, valuesAttr);
arg.replaceAllUsesWith(loadOp);
}
// Converts the usage of the given !stream.binding |arg| to interface methods.
static void convertBindingUsage(
mlir::func::FuncOp sourceFuncOp, BlockArgument arg,
IREE::HAL::DescriptorSetLayoutAttr setLayoutAttr,
IREE::HAL::DescriptorSetBindingAttr bindingAttr) {
if (arg.use_empty()) return; // no-op
for (auto &use : llvm::make_early_inc_range(arg.getUses())) {
auto oldOp = dyn_cast<IREE::Stream::BindingSubspanOp>(use.getOwner());
assert(oldOp && "bindings are only usable by stream.binding.subspan");
OpBuilder builder(oldOp);
auto alignmentAttr = sourceFuncOp.getArgAttrOfType<IntegerAttr>(
arg.getArgNumber(), "stream.alignment");
auto newOp = builder.create<IREE::HAL::InterfaceBindingSubspanOp>(
oldOp.getLoc(), oldOp.getType(), APInt(64, setLayoutAttr.getOrdinal()),
APInt(64, bindingAttr.getOrdinal()), bindingAttr.getType(),
oldOp.byte_offset(), oldOp.dynamic_dims(), alignmentAttr);
oldOp.replaceAllUsesWith(newOp.result());
oldOp.erase();
}
}
// Clones |sourceFuncOp| and updates its signature to match the |interfaceOp|
// and use the HAL interface access primitives.
static mlir::func::FuncOp cloneFuncWithInterface(
mlir::func::FuncOp sourceFuncOp, const ExecutableLayout &executableLayout,
IREE::HAL::ExecutableLayoutAttr layoutAttr) {
// Clone so that we can do a bunch of unsafe in-place updates.
auto clonedFuncOp = sourceFuncOp.clone();
// Strip all arguments as functions take all I/O through the interface API.
clonedFuncOp.setType(FunctionType::get(clonedFuncOp.getContext(), {}, {}));
auto *entryBlock = &clonedFuncOp.front();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
// Change the interface from arguments to hal.interface.* methods.
// We do push constant compatible operands first so that they are available
// for use by the binding accessors.
unsigned operandIdx = 0;
for (auto arg : entryBlock->getArguments()) {
if (!arg.getType().isa<IREE::Stream::BindingType>()) {
convertOperandUsage(sourceFuncOp, arg, operandIdx++, entryBuilder);
}
}
unsigned resourceIdx = 0;
for (auto arg : entryBlock->getArguments()) {
if (!arg.getType().isa<IREE::Stream::BindingType>()) continue;
auto setBinding = executableLayout.resourceMap[resourceIdx++];
auto setLayoutAttr = layoutAttr.getSetLayouts()[setBinding.first];
auto bindingAttr = setLayoutAttr.getBindings()[setBinding.second];
convertBindingUsage(sourceFuncOp, arg, setLayoutAttr, bindingAttr);
}
// Remove all arguments now that we've turned them into lookup ops.
entryBlock->eraseArguments([](auto arg) { return true; });
return clonedFuncOp;
}
// Annotates |dispatchOp| with resource binding to interface binding mappings.
// TODO(benvanik): have a HAL op with structured information instead.
static void annotateDispatchSite(IREE::Stream::CmdDispatchOp dispatchOp,
const ExecutableLayout &executableLayout) {
SmallVector<Attribute> bindingAttrs;
for (auto setBinding : executableLayout.resourceMap) {
bindingAttrs.push_back(IREE::HAL::InterfaceBindingAttr::get(
dispatchOp.getContext(), setBinding.first, setBinding.second));
}
dispatchOp->setAttr("hal.interface.bindings",
ArrayAttr::get(dispatchOp.getContext(), bindingAttrs));
}
// Adds the entry point ops with assigned ordinals for each entry function.
// The entry points will all use the provided |interfaceOp|.
static LogicalResult declareEntryPointOps(
IREE::Stream::ExecutableOp sourceExecutableOp,
IREE::HAL::ExecutableOp targetExecutableOp,
const BindingLayoutAnalysis &layoutAnalysis) {
auto variantOps =
targetExecutableOp.getBlock().getOps<IREE::HAL::ExecutableVariantOp>();
OpBuilder executableBuilder(&targetExecutableOp.getBlock().front());
// For each exported function create a HAL entry point and dispatch thunk.
int nextOrdinal = 0;
for (auto exportOp :
sourceExecutableOp.body().getOps<IREE::Stream::ExecutableExportOp>()) {
int ordinal = nextOrdinal++;
auto sourceFuncOp =
sourceExecutableOp.getInnerModule().lookupSymbol<mlir::func::FuncOp>(
exportOp.function_ref());
if (failed(verifyEntryPointTypes(sourceFuncOp))) return failure();
const auto &executableLayout = layoutAnalysis.getExecutableLayout(exportOp);
// Create the interface for this entry point based on the analysis of its
// usage within the program.
auto layoutAttr =
makeExecutableLayoutAttr(executableLayout, executableBuilder);
// Clone the source function and update it to use the new interface.
auto baseFuncOp =
cloneFuncWithInterface(sourceFuncOp, executableLayout, layoutAttr);
// Clone the updated function into each variant.
for (auto variantOp : variantOps) {
// Declare the entry point on the target.
OpBuilder targetBuilder(&variantOp.getBlock().front());
targetBuilder.create<IREE::HAL::ExecutableEntryPointOp>(
exportOp.getLoc(),
targetBuilder.getStringAttr(exportOp.function_ref()),
targetBuilder.getIndexAttr(ordinal), layoutAttr, ArrayAttr{},
IntegerAttr{});
// Clone the updated interface-based function into the target.
auto targetFuncOp = baseFuncOp.clone();
variantOp.getInnerModule().push_back(targetFuncOp);
}
// Update all dispatch sites with the binding information.
for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
annotateDispatchSite(dispatchOp, executableLayout);
}
baseFuncOp.erase();
}
return success();
}
//===----------------------------------------------------------------------===//
// flow.dispatch.* info op conversion
//===----------------------------------------------------------------------===//
namespace {
template <typename SrcOp, typename DstOp>
struct ConvertDispatchWorkgroupInfoPattern final
: public OpRewritePattern<SrcOp> {
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<DstOp>(op, op.getResult().getType(),
op.dimensionAttr());
return success();
}
};
struct InlineConstantWorkgroupSizePattern
: public OpRewritePattern<IREE::HAL::InterfaceWorkgroupSizeOp> {
using OpRewritePattern<IREE::HAL::InterfaceWorkgroupSizeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IREE::HAL::InterfaceWorkgroupSizeOp sizeOp,
PatternRewriter &rewriter) const override {
// Lookup the entry point matching the parent.
auto funcOp = sizeOp->getParentOfType<mlir::func::FuncOp>();
auto variantOp = funcOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
auto entryPointOp = dyn_cast<IREE::HAL::ExecutableEntryPointOp>(
SymbolTable::lookupSymbolIn(variantOp, funcOp.getName()));
assert(entryPointOp &&
"must have an entry point corresponding to the parent func");
auto workgroupSizeAttr = entryPointOp.workgroup_sizeAttr();
if (!workgroupSizeAttr) return failure();
uint64_t dimIdx = sizeOp.dimension().getZExtValue();
auto dimAttr = workgroupSizeAttr[dimIdx];
rewriter.replaceOpWithNewOp<arith::ConstantOp>(sizeOp, dimAttr,
rewriter.getIndexType());
return success();
}
};
} // namespace
static LogicalResult convertFlowInfoOps(IREE::HAL::ExecutableOp executableOp) {
RewritePatternSet patterns(executableOp.getContext());
patterns.insert<
ConvertDispatchWorkgroupInfoPattern<IREE::Flow::DispatchWorkgroupIDOp,
IREE::HAL::InterfaceWorkgroupIDOp>,
ConvertDispatchWorkgroupInfoPattern<IREE::Flow::DispatchWorkgroupCountOp,
IREE::HAL::InterfaceWorkgroupCountOp>,
ConvertDispatchWorkgroupInfoPattern<IREE::Flow::DispatchWorkgroupSizeOp,
IREE::HAL::InterfaceWorkgroupSizeOp>,
InlineConstantWorkgroupSizePattern>(executableOp.getContext());
return applyPatternsAndFoldGreedily(executableOp, std::move(patterns));
}
//===----------------------------------------------------------------------===//
// -iree-hal-materialize-interfaces
//===----------------------------------------------------------------------===//
class MaterializeInterfacesPass
: public PassWrapper<MaterializeInterfacesPass, OperationPass<ModuleOp>> {
public:
MaterializeInterfacesPass() = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::HAL::HALDialect>();
}
StringRef getArgument() const override {
return "iree-hal-materialize-interfaces";
}
StringRef getDescription() const override {
return "Materializes hal.executable ops from stream.executable ops";
}
void runOnOperation() override {
SymbolTable symbolTable(getOperation());
// Handle any hand-authored executables; these only need variant expansion
// and no layout analysis as the user specified the layout themselves.
if (failed(materializeExecutablesFromSourceOps(getOperation()))) {
return signalPassFailure();
}
const auto &layoutAnalysis = getAnalysis<BindingLayoutAnalysis>();
// Processes all executables within the input module and produce the
// output HAL ops. We should ensure all deduping is performed prior to
// this when it's easier to diff IR and where we still have the flow
// context.
auto sourceOps = llvm::to_vector<32>(
getOperation().getOps<IREE::Stream::ExecutableOp>());
for (auto sourceOp : sourceOps) {
auto exportOps = sourceOp.getOps<IREE::Stream::ExecutableExportOp>();
if (exportOps.empty()) continue;
// Gather a list of all #hal.executable.targets that we should produce
// variants for.
auto targetAttrs =
IREE::HAL::DeviceTargetAttr::lookupExecutableTargets(sourceOp);
if (targetAttrs.empty()) {
sourceOp.emitError()
<< "no executable targets specified for translation";
return signalPassFailure();
}
// Create the op that will contain the translated executable.
OpBuilder builder = OpBuilder::atBlockEnd(getOperation().getBody());
builder.setInsertionPointAfter(sourceOp);
auto executableOp = builder.create<IREE::HAL::ExecutableOp>(
sourceOp.getLoc(), sourceOp.getName());
executableOp.setVisibility(sourceOp.getVisibility());
// Materialize all of the hal.executable.variant ops for all backends we
// are targeting.
SymbolTable targetSymbolTable(executableOp);
OpBuilder targetBuilder(&executableOp.getBlock().back());
for (auto targetAttr : targetAttrs) {
auto targetContainerOp =
targetBuilder.create<IREE::HAL::ExecutableVariantOp>(
sourceOp->getLoc(), targetAttr.getSymbolNameFragment(),
targetAttr);
targetSymbolTable.insert(targetContainerOp);
OpBuilder containerBuilder(&targetContainerOp.getBlock().back());
containerBuilder.create<mlir::ModuleOp>(sourceOp->getLoc());
}
// Define interfaces for each exported function based on analysis.
if (failed(
declareEntryPointOps(sourceOp, executableOp, layoutAnalysis))) {
return signalPassFailure();
}
// Convert interface-related flow.dispatch.* ops to their hal.interface.*
// versions.
if (failed(convertFlowInfoOps(executableOp))) {
return signalPassFailure();
}
sourceOp.erase();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> createMaterializeInterfacesPass() {
return std::make_unique<MaterializeInterfacesPass>();
}
static PassRegistration<MaterializeInterfacesPass> pass([] {
return std::make_unique<MaterializeInterfacesPass>();
});
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir