blob: 0214463a908b285eeba844ecde1e5d65db742a29 [file] [log] [blame]
//===- LinalgComprehensiveBufferizePass.cpp - Bufferize Linalg on tensors -===//
//
// Convert from Linalg ops on tensors to Linalg ops on buffers in a single pass.
// Aggressively try to perform inPlace bufferization and fail if any allocation
// tries to cross function boundaries or if the pattern
// `tensor_load(tensor_memref(x))` is deemed unsafe (very conservative impl for
// now).
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/UseDefLists.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#define DEBUG_TYPE "linalg-comprehensive-bufferize-inplace"
using namespace mlir;
using namespace linalg;
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
/// Comprehensive Linalg bufferize pass that aims at avoiding phase-ordering and
/// safety + optimization issues that are present in upstream bufferization.
/// At the same time, this pass only cares about enabling aggressive inPlace
/// bufferization for linalg ops and scf.for, **including across function
/// boundaries**.
/// In particular no branching behavior is supported atm besides function calls
/// and scf.for.
/// This ModulePass consists in the following steps:
/// 1. perform a `funcArgumentsInPlaceAnalysis` which traverses all CallOps and
/// determine whether any tensor operand could potentially bufferize to a
/// buffer that can be updated inPlace (i.e. an in-out buffer).
/// Such operands are ones whose value is not read in any other op at the
/// caller site.
/// As a result of this analysis, CallOp operands are marked with
/// `kInPlaceResultsAttrName`. The "meet" of all `kInPlaceResultsAttrName`
/// for all `callOp` to a given FuncOp determines the
/// `kInPlaceResultsAttrName` for that FuncOp.
/// 2. traverse each FuncOp and perform bufferization within the function
/// boundaries. Bufferization occurs by:
/// a. performing an inPlace analysis `inPlaceAnalysisFuncOpInternals`
/// which marks each operation within the function with the
/// `kInPlaceResultsAttrName` attribute.
/// b. traversing each operation in the function and rewriting it in
/// buffer form and keeping a BlockAndValueMapping mapping of the
/// rewrites.
/// New allocations are introduced during this step.
/// TODO: Allocation + depending op hoisting to outermost enclosing
/// sequential scope.
/// 3. once bufferization within function boundaries is done, the next step
/// runs `bufferizeFunctionsAndCalls`, which involves:
/// a. detecting `function_arg -> tensor_to_memref -> tensor_load -> return`
/// patterns for each FuncOp, which determines the `tiedResultMap` between
/// function args and results.
/// b. rewrite function arguments and returns in buffer forms, skipping the
/// tensors that appear in the `tiedResultMap`.
/// c. bufferize the CallOps using the callee's `tiedResultMap`.
///
/// TensorToMemRefOps are only ever inserted as a transient abstraction for
/// function arguments that have not yet been bufferized.
/// All other places either allocate or forward existing buffers.
///
/// TensorLoadOps are only even inserted as a transient abstraction for
/// terminators (return, scf.yield).
/// The `function_arg -> tensor_to_memref -> tensor_load -> return` is used to
/// analyze which function result ties to a function operand.
namespace {
struct LinalgComprehensiveBufferizePass
: public PassWrapper<LinalgComprehensiveBufferizePass,
OperationPass<ModuleOp>> {
LinalgComprehensiveBufferizePass()
: enablingPassPipeline(OpPassManager("func")) {
enablingPassPipeline.addPass(createCanonicalizerPass());
enablingPassPipeline.addPass(createCSEPass());
enablingPassPipeline.addPass(createLoopInvariantCodeMotionPass());
}
LinalgComprehensiveBufferizePass(const LinalgComprehensiveBufferizePass &pass)
: enablingPassPipeline(pass.enablingPassPipeline) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LinalgDialect, scf::SCFDialect, StandardOpsDialect>();
}
void runOnOperation() override;
void runEnablingTransforms(FuncOp funcOp);
void bufferizeFuncOpInternals(FuncOp funcOp, BlockAndValueMapping &bvm);
void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
const DominanceInfo &domInfo);
/// Dynamic pass pipeline of transformations that enable better inPlace
/// bufferization.
OpPassManager enablingPassPipeline;
};
} // namespace
//===----------------------------------------------------------------------===//
// Forward declarations.
//===----------------------------------------------------------------------===//
/// Return a MemRefType to which the `tensorType` can be bufferized in a
/// composable fashion. The layout must be the most dynamic possible and
/// canonicalize away once bufferization is finished.
static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
unsigned addressSpace = 0);
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
//===----------------------------------------------------------------------===//
/// Attribute marker to specify results that can be bufferized inPlace.
constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_results_attr__";
/// Attribute marker to specify func/call arguments that can be written inPlace
/// from the perspective of the caller.
constexpr StringLiteral kInPlaceArgsAttrName = "__inplace_args_attr__";
// default clause
enum class InPlaceSpec {
False,
True,
None,
};
static StringRef stringify(InPlaceSpec val) {
switch (val) {
case InPlaceSpec::False:
return "false";
case InPlaceSpec::True:
return "true";
case InPlaceSpec::None:
return "none";
}
return "";
}
static Optional<InPlaceSpec> symbolize(StringRef str) {
return StringSwitch<Optional<InPlaceSpec>>(str)
.Case("false", InPlaceSpec::False)
.Case("true", InPlaceSpec::True)
.Case("none", InPlaceSpec::None)
.Default(None);
}
static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
if (!sym) return nullptr;
return dyn_cast_or_null<FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
/// Factor out the logic that matches tied OpResult to BlockArgument.
/// For FuncOp the analysis is dependent on the result of bufferization so we
/// always return null.
static OpResult getTiedOpResult(BlockArgument &bbArg) {
Operation *op = bbArg.getOwner()->getParentOp();
if (auto forOp = dyn_cast<scf::ForOp>(op))
return forOp->getResult(bbArg.getArgNumber() - /*#iv=*/1);
if (auto funcOp = dyn_cast<FuncOp>(op)) return OpResult();
op->dump();
llvm_unreachable("Unsupported op");
}
/// Factor out the logic that matches tied OpResult to OpOperand.
/// For CallOp, the analysis is dependent on the result of bufferization of the
/// callee, so we always return null.
/// For terminators there is no possible operand/result tie, so we always return
/// null.
/// Other ops are enumerated on a case-by-case basis for now.
/// TODO: we should really have a TiedOpInterface for this.
static OpResult getTiedOpResult(OpOperand &opOperand) {
Operation *op = opOperand.getOwner();
if (auto forOp = dyn_cast<scf::ForOp>(op))
return forOp->getResult(opOperand.getOperandNumber() -
forOp.getNumControlOperands());
if (auto linalgOp = dyn_cast<LinalgOp>(op)) {
if (opOperand.getOperandNumber() < linalgOp.getNumInputs())
return OpResult();
return linalgOp->getResult(opOperand.getOperandNumber() -
linalgOp.getNumInputs());
}
if (isa<SubTensorOp, SubTensorInsertOp, tensor::CastOp,
vector::TransferReadOp, vector::TransferWriteOp>(op))
return op->getResult(0);
if (op->hasTrait<mlir::OpTrait::IsTerminator>()) return OpResult();
if (isa<CallOpInterface, vector::PrintOp, vector::ContractionOp>(op))
return OpResult();
op->dump();
llvm_unreachable("Unsupported op");
}
namespace detail {
static void setInPlaceFuncOrCallArgument(
Operation *op, unsigned idx, InPlaceSpec inPlace = InPlaceSpec::True) {
auto funcOp = dyn_cast<FuncOp>(op);
auto callOp = dyn_cast<CallOpInterface>(op);
assert((funcOp || callOp) && "must be func or call");
unsigned numArgs =
funcOp ? funcOp.getNumArguments() : callOp->getNumOperands();
auto attr = op->getAttr(kInPlaceArgsAttrName).dyn_cast_or_null<ArrayAttr>();
SmallVector<StringRef> inPlaceVector =
attr ? SmallVector<StringRef>(
llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()))
: SmallVector<StringRef>(numArgs, stringify(InPlaceSpec::None));
LLVM_DEBUG(DBGS() << "Set inPlace=" << stringify(inPlace) << ": " << *op
<< " @idx=" << idx << "\n");
inPlaceVector[idx] = stringify(inPlace);
op->setAttr(kInPlaceArgsAttrName,
OpBuilder(op).getStrArrayAttr(inPlaceVector));
}
} // namespace detail
static void setInPlaceFuncArgument(BlockArgument arg,
InPlaceSpec inPlace = InPlaceSpec::True) {
::detail::setInPlaceFuncOrCallArgument(arg.getOwner()->getParentOp(),
arg.getArgNumber(), inPlace);
}
static void setInPlaceCallArgument(OpOperand &operand,
InPlaceSpec inPlace = InPlaceSpec::True) {
::detail::setInPlaceFuncOrCallArgument(operand.getOwner(),
operand.getOperandNumber(), inPlace);
}
static void setInPlaceOpResult(OpResult opResult,
InPlaceSpec inPlace = InPlaceSpec::True) {
if (!opResult) return;
Operation *op = opResult.getOwner();
auto attr =
op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
SmallVector<StringRef> inPlaceVector =
attr ? SmallVector<StringRef>(
llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()))
: SmallVector<StringRef>(op->getNumResults(),
stringify(InPlaceSpec::None));
LLVM_DEBUG(DBGS() << "Set inPlace=" << stringify(inPlace) << ": " << *op
<< " @idx=" << opResult.getResultNumber() << "\n");
inPlaceVector[opResult.getResultNumber()] = stringify(inPlace);
op->setAttr(kInPlaceResultsAttrName,
OpBuilder(op).getStrArrayAttr(inPlaceVector));
}
/// Get the attribute entry `kInPlaceResultsAttrName`@`idx` corresponding to a
/// tied operand/result pair. If `idx` is llvm::None, this means the `op` has
/// only a single relevant tensor operand/result and that its position is not
/// important. In such cases, we just get the single entry string array
/// attribute @0. If the attribute does not exist yet, return InPlaceSpec::None.
static InPlaceSpec getInPlace(OpResult opResult) {
if (!opResult) return InPlaceSpec::None;
Operation *op = opResult.getOwner();
auto attr =
op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
if (!attr) return InPlaceSpec::None;
// Must return a proper value.
return *symbolize(*(attr.getAsValueRange<StringAttr>().begin() +
opResult.getResultNumber()));
}
namespace detail {
static InPlaceSpec getInPlaceFuncOrCallArgName(Operation *op, unsigned idx) {
auto funcOp = dyn_cast<FuncOp>(op);
auto callOp = dyn_cast<CallOpInterface>(op);
assert((funcOp || callOp) && "must be func or call");
auto attr = op->getAttr(kInPlaceArgsAttrName).dyn_cast_or_null<ArrayAttr>();
if (!attr) return InPlaceSpec::None;
// Must return a proper value.
return *symbolize(*(attr.getAsValueRange<StringAttr>().begin() + idx));
}
} // namespace detail
/// Get inPlace information depending on the owner of `bbArg`:
/// 1. if not a FuncOp, get the information from `kInPlaceResultsAttrName`
/// for the tied op result.
/// 2. otherwise, get the information from `kInPlaceArgsAttrName`
static InPlaceSpec getInPlace(BlockArgument bbArg) {
if (!isa<FuncOp>(bbArg.getOwner()->getParentOp()))
return getInPlace(getTiedOpResult(bbArg));
return ::detail::getInPlaceFuncOrCallArgName(bbArg.getOwner()->getParentOp(),
bbArg.getArgNumber());
}
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
/// Wrapper for better debugging.
static void map(BlockAndValueMapping &bvm, ValueRange key, ValueRange value) {
if (key.empty()) return;
LLVM_DEBUG(DBGS() << "Map: " << key.front() << " to " << value.front()
<< "\n");
return bvm.map(key, value);
}
/// Wrapper for better debugging.
static void map(BlockAndValueMapping &bvm, Value key, Value value) {
LLVM_DEBUG(DBGS() << "Map: " << key << " to " << value << "\n");
return bvm.map(key, value);
}
/// Wrapper for better debugging.
static Value lookup(BlockAndValueMapping &bvm, Value key) {
// TODO: if key comes from bbArg, forward.
assert(key.getType().isa<RankedTensorType>());
if (!bvm.lookupOrNull(key)) {
if (auto bbArg = key.dyn_cast<BlockArgument>()) {
if (isa<FuncOp>(key.getParentBlock()->getParentOp()))
key.getParentBlock()->getParentOp()->dump();
else
key.getParentBlock()->getParentOp()->getParentOfType<FuncOp>()->dump();
bbArg.getOwner()->getParentOp()->dump();
} else {
key.getDefiningOp()->getParentOfType<FuncOp>()->dump();
}
llvm::errs() << "NO VALUE FOR KEY: " << key << "\n";
abort();
}
return bvm.lookup(key);
}
//===----------------------------------------------------------------------===//
// Bufferization-specific support.
//===----------------------------------------------------------------------===//
/// For now, assume any use is a read.
/// Write-only is a non-problem: will represent with shapes in the future.
/// If any use of the tensor does not properly dominate `opOperand.getOwner()`,
/// then the tensor cannot be bufferized inPlace.
bool hasInterferingTensorRead(OpOperand &opOperand,
const DominanceInfo &domInfo) {
if (!opOperand.get().getType().isa<RankedTensorType>()) return false;
for (auto &use : opOperand.get().getUses()) {
Operation *user = use.getOwner();
if (domInfo.properlyDominates(user, opOperand.getOwner())) continue;
if (user == opOperand.getOwner() &&
use.getOperandNumber() == opOperand.getOperandNumber())
continue;
LLVM_DEBUG(DBGS() << "found interfering read operand #"
<< opOperand.getOperandNumber()
<< " in op: " << *opOperand.getOwner() << "\n");
return true;
}
LLVM_DEBUG(DBGS() << "no interfering read\n");
return false;
}
//===----------------------------------------------------------------------===//
// Bufferization-specific MemRefType support.
//===----------------------------------------------------------------------===//
/// Return the contiguous MemRefType (i.e. with canonical/empty layout map) to
/// which `type` can be bufferized to, assuming `type` is a RankedTensorType.
static MemRefType getContiguousMemRefType(Type type,
ArrayRef<AffineMap> layout = {},
unsigned addressSpace = 0) {
RankedTensorType tensorType = type.cast<RankedTensorType>();
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
layout, addressSpace);
}
/// Return a MemRefType to which the `tensorType` can be bufferized in a
/// composable fashion. The layout must be the most dynamic possible and
/// canonicalize away once bufferization is finished.
static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
unsigned addressSpace) {
// TODO: address space decisions to connect with the actual alloc.
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
ShapedType::kDynamicStrideOrOffset);
AffineMap stridedLayout = makeStridedLinearLayoutMap(
dynamicStrides, dynamicOffset, tensorType.getContext());
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
stridedLayout, addressSpace);
}
// Transfer all `dim` ops on `tensor` to `memref`.
static void transferDimOpsToMemref(Value tensor, Value memref) {
for (OpOperand &opOperand : llvm::make_early_inc_range(tensor.getUses())) {
if (isa<DimOp>(opOperand.getOwner())) {
opOperand.set(memref);
}
}
}
//===----------------------------------------------------------------------===//
// Bufferization-specific inPlace pattern matching support.
//===----------------------------------------------------------------------===//
/// First assign `op` if `slice.back()` isa `T`, then check condition.
/// If anything fails just return failure. Otherwise update `sliceRef` by
/// dropping `sliceRef.back()`, then return success().
template <typename T>
static LogicalResult matchAndDropBack(
ArrayRef<Operation *> &sliceRef, T &op,
llvm::function_ref<LogicalResult(T)> condition = nullptr) {
if (sliceRef.empty()) return failure();
op = dyn_cast<T>(sliceRef.back());
if (!op || (condition && failed(condition(op)))) return failure();
sliceRef = sliceRef.drop_back();
return success();
}
/// First assign `op1`/`op2` if `slice.front()`/`slice.back()` isa `T1`/`T2`,
/// respectively. Then check condition. If anything fails just return failure.
/// Otherwise update `sliceRef` by dropping `sliceRef.front()` and
/// `sliceRef.back()`, then return success().
template <typename T1, typename T2>
static LogicalResult matchAndDropEnclosingPair(
ArrayRef<Operation *> &sliceRef, T1 &op1, T2 &op2,
llvm::function_ref<LogicalResult(T1, T2)> condition = nullptr) {
if (sliceRef.size() < 2) return failure();
op1 = dyn_cast<T1>(sliceRef.front());
op2 = dyn_cast<T2>(sliceRef.back());
if (!op1 || !op2 || (condition && failed(condition(op1, op2))))
return failure();
sliceRef = sliceRef.drop_front().drop_back();
return success();
}
//===----------------------------------------------------------------------===//
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a bbArg)
/// and the DeallocOp is at the end of the block.
/// Since this may insert **after** the op definining `shapedValue`, there is
/// a risk of abstraction gap with what the caller may legitimately expect.
/// As a consequence, this function should not be called with `b` rooted around
/// `shapedValue.getDefiningOp()`, as the insertion point may shift.
// TODO: need a better API to make things less surprising while avoiding
// implicit state passed across function boundaries: this still significantly
// beats mutating the insertion point for `b`.
// TODO: need to hoist this across function boundaries. Maybe by using
// init_tensor + subtensor_insert before bufferization.
static Value createNewAllocDeallocPairForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
SmallVector<Value, 4> dynOperands = {}) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
assert(memRefType || shapedValue.getType().dyn_cast<RankedTensorType>());
// TODO: non-zero address space.
// TODO: layout information if relevant.
if (!memRefType) memRefType = getContiguousMemRefType(shapedValue.getType());
if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
b.setInsertionPointToStart(bbArg.getOwner());
loc = bbArg.getOwner()->getParentOp()->getLoc();
} else {
b.setInsertionPointAfter(shapedValue.getDefiningOp());
loc = shapedValue.getDefiningOp()->getLoc();
}
// If the dynOperands are not passed explicity, copmpute them.
// This circumvents currently missing dim(init_tensor) canonicalizations.
if (dynOperands.empty()) {
for (auto dim : llvm::enumerate(memRefType.getShape()))
if (dim.value() == ShapedType::kDynamicSize)
dynOperands.push_back(b.create<DimOp>(loc, shapedValue, dim.index()));
}
Value allocated = b.create<AllocOp>(loc, memRefType, dynOperands);
b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
b.create<DeallocOp>(loc, allocated);
return allocated;
}
//===----------------------------------------------------------------------===//
// Bufferization-specific inPlace analysis support.
//===----------------------------------------------------------------------===//
/// Return true is all offsets, sizes and strides are equal.
static LogicalResult sameOffsetsSizesAndStrides(
OffsetSizeAndStrideOpInterface op1, OffsetSizeAndStrideOpInterface op2) {
if (op1.static_offsets().size() != op2.static_offsets().size())
return failure();
if (op1.static_sizes().size() != op2.static_sizes().size()) return failure();
if (op1.static_strides().size() != op2.static_strides().size())
return failure();
for (auto it : llvm::zip(op1.getMixedOffsets(), op2.getMixedOffsets()))
if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)))
return failure();
for (auto it : llvm::zip(op1.getMixedSizes(), op2.getMixedSizes()))
if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)))
return failure();
for (auto it : llvm::zip(op1.getMixedStrides(), op2.getMixedStrides()))
if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)))
return failure();
return success();
}
static LogicalResult matchingVectorTransfersAtSource(
vector::TransferReadOp read, vector::TransferWriteOp write,
Value subtensor) {
// Either we have a pair of matching transfer read/write or none.
if (read && !write) {
LLVM_DEBUG(DBGS() << "Slice has transferReadOp but no transferWriteOp"
"\nDestructive update chain FAIL\n");
return failure();
}
if (!read && write) {
LLVM_DEBUG(DBGS() << "Slice has transferWriteOp but no transferReadOp"
"\nDestructive update chain FAIL\n");
return failure();
}
if (read && write) {
// If we have a pair of mathing read/write, the tensor and vector shape
// must exactly match (i.e. this is a vectorization).
if (read.source() != subtensor) {
LLVM_DEBUG(DBGS() << "transferReadOp.source() != subTensor.result()"
"\nDestructive update chain FAIL\n");
return failure();
}
if (write.source() != subtensor) {
LLVM_DEBUG(DBGS() << "transferWriteOp.source() != subTensor.result()"
"\nDestructive update chain FAIL\n");
return failure();
}
if (read.getShapedType().getShape() != read.getVectorType().getShape()) {
LLVM_DEBUG(DBGS() << "transferReadOp source and result shapes mismatch"
"\nDestructive update chain FAIL\n");
return failure();
}
if (write.getShapedType().getShape() != write.getVectorType().getShape()) {
LLVM_DEBUG(DBGS() << "transferWriteOp source and result shapes mismatch"
"\nDestructive update chain FAIL\n");
return failure();
}
}
return success();
}
/// Detect whether `v` has a single user that is exactly `terminatorOp`.
/// If `bbArg` comes from an scf::ForOp, additionally check the operand index
/// is exactly `bbArg.getArgumentNumber`.
template <typename TerminatorOp>
static LogicalResult isInPlaceSingleUseTerminatorValue(
Value v, TerminatorOp terminatorOp, BlockArgument bbArg) {
if (!v.hasOneUse() || *v.getUsers().begin() != terminatorOp) return failure();
if (isa<scf::ForOp>(bbArg.getOwner()->getParentOp()))
return (getTiedOpResult(bbArg).getResultNumber() ==
v.getUses().begin()->getOperandNumber())
? success()
: failure();
if (isa<FuncOp>(bbArg.getOwner()->getParentOp())) return success();
llvm_unreachable("isInPlaceSingleUseOperand: unsupported op");
}
/// Detect the simple overwrite pattern:
/// ```
/// candidate -> vector.transfer_write(**) -> subtensor_insert(**) -> term
/// ```
///
/// (**) represents an optional op in the chain, at least one must be present
template <typename ContainerOp, typename TerminatorOp>
static LogicalResult detectOverWritePattern(
Operation *parentOp, BlockArgument candidate,
ArrayRef<Operation *> &sliceRef, SmallVector<OpResult> &inPlaceOpResults) {
if (!parentOp || !isa<ContainerOp>(parentOp)) return failure();
ArrayRef<Operation *> tmpSliceRef = sliceRef;
if (!candidate.hasOneUse()) {
LLVM_DEBUG(
DBGS()
<< "FAILURE: partial overwrite pattern -> bbArg needs exactly 1 use\n");
return failure();
}
TerminatorOp terminatorOp;
// Match terminator and update tmpSliceRef.
if (failed(matchAndDropBack(tmpSliceRef, terminatorOp))) {
LLVM_DEBUG(DBGS() << "FAILURE: partial overwrite pattern -> must end with "
"known terminator\n");
return failure();
}
SubTensorInsertOp subTensorInsertOp;
vector::TransferWriteOp vectorTransferWriteOp;
// Maybe match subTensorInsertOp and update tmpSliceRef.
(void)matchAndDropBack(tmpSliceRef, subTensorInsertOp);
// Maybe match vectorTransferWriteOp and update tmpSliceRef.
(void)matchAndDropBack(tmpSliceRef, vectorTransferWriteOp);
// subtensor_insert must be used exactly by the terminator at index matching
// the candidate BlockArgument.
if (subTensorInsertOp) {
if (failed(isInPlaceSingleUseTerminatorValue(subTensorInsertOp.result(),
terminatorOp, candidate))) {
LLVM_DEBUG(
DBGS() << "FAILURE: partial overwrite pattern -> subtensor_insert "
"single use must match terminator\n");
return failure();
}
} else if (vectorTransferWriteOp) {
// transfer_write must be used exactly by the terminator at index matching
// the candidate BlockArgument.
if (failed(isInPlaceSingleUseTerminatorValue(vectorTransferWriteOp.result(),
terminatorOp, candidate))) {
LLVM_DEBUG(
DBGS() << "FAILURE: partial overwrite pattern -> "
"vector.transfer_write single use must match terminator\n");
return failure();
}
} else {
LLVM_DEBUG(DBGS() << "FAILURE: partial overwrite pattern -> need at least "
"a subtensor_insert or a vector.transfer_write\n");
return failure();
}
// Commit what has been detected.
if (vectorTransferWriteOp)
inPlaceOpResults.push_back(vectorTransferWriteOp->getResult(0));
if (subTensorInsertOp)
inPlaceOpResults.push_back(subTensorInsertOp->getResult(0));
// No action for the terminator.
tmpSliceRef = sliceRef;
LLVM_DEBUG(DBGS() << "SUCCESS: partial overwrite pattern\n");
return success();
}
template <typename ContainerOp, typename TerminatorOp>
static LogicalResult detectLinalgReturn(
Operation *parentOp, BlockArgument candidate,
ArrayRef<Operation *> &sliceRef, SmallVector<OpResult> &inPlaceOpResults) {
if (!parentOp || !isa<ContainerOp>(parentOp)) return failure();
ArrayRef<Operation *> tmpSliceRef = sliceRef;
TerminatorOp terminatorOp;
// Match returnOp and update tmpSliceRef.
if (failed(matchAndDropBack(tmpSliceRef, terminatorOp))) {
LLVM_DEBUG(DBGS() << "FAILURE: linalg return pattern -> slice must end "
"with a known terminator\n");
return failure();
}
// bbArg must have a single use.
if (!candidate.hasOneUse()) {
LLVM_DEBUG(
DBGS() << "FAILURE: linalg return pattern -> bbArg with != 1 use\n");
return failure();
}
LinalgOp linalgOp;
// Match linalgOp with a single output tensor for now and update tmpSliceRef.
if (succeeded(matchAndDropBack(tmpSliceRef, linalgOp))) {
if (linalgOp.getNumOutputTensors() != 1 ||
// For now, just check that the operand and corresponding result have
// no additional uses. In the future we can build a cost-model to take
// care of diamond dependences.
!linalgOp.getOutputTensors().front().hasOneUse() ||
!linalgOp->getResult(0).hasOneUse()) {
LLVM_DEBUG(DBGS() << "FAILURE: linalg return pattern -> slice must end "
"with linalg op\n");
// BREAK DUMP DEBUG HERE
return failure();
}
}
scf::ForOp forOp;
// Match forOp with a single output tensor for now and update tmpSliceRef.
// TODO: support more than single result.
if (succeeded(matchAndDropBack(tmpSliceRef, forOp))) {
if (forOp->getNumResults() != 1 ||
// For now, just check that the operand and corresponding result have
// no additional uses. In the future we can build a cost-model to take
// care of diamond dependences.
!forOp.getIterOperands().front().hasOneUse() ||
!forOp->getResult(0).hasOneUse()) {
LLVM_DEBUG(DBGS() << "FAILURE: linalg return pattern -> slice must end "
"with forOp op\n");
return failure();
}
}
if (!linalgOp && !forOp) {
LLVM_DEBUG(DBGS() << "FAILURE: linalg return pattern -> ASFDASFA\n");
return failure();
}
// Commit what has been detected.
// TODO: support more than single result.
if (linalgOp) inPlaceOpResults.push_back(linalgOp->getResult(0));
if (forOp) inPlaceOpResults.push_back(forOp->getResult(0));
tmpSliceRef = sliceRef;
LLVM_DEBUG(DBGS() << "SUCCESS: linalg return pattern\n");
return success();
}
/// In the case of an scf::ForOp, we look for:
/// `candidate -> subtensor -> vector.transfer_read(*) -> ...
/// vector.transfer_write(*) -> subtensor_insert -> yield`.
/// sliceRef is automaticaly updated to match `...`.
///
/// (*) represents an optional op in the chain, if a subtensor or
/// vector.transfer is included, the matching op must be included too.
template <typename ContainerOp, typename TerminatorOp>
static LogicalResult detectDestructiveUpdatePattern(
Operation *parentOp, BlockArgument candidate,
ArrayRef<Operation *> &sliceRef, SmallVector<OpResult> &inPlaceOpResults) {
if (!parentOp || !isa<ContainerOp>(parentOp)) return failure();
ArrayRef<Operation *> tmpSliceRef = sliceRef;
// bbArg must be used exactly by one subtensor / subtensor_insert pair.
if (candidate.use_empty() || candidate.hasOneUse() ||
std::next(candidate.getUsers().begin(), 2) !=
candidate.getUsers().end()) {
LLVM_DEBUG(
DBGS() << "FAILURE: destructive updates -> bbArg with != 2 uses\n");
return failure();
}
if (tmpSliceRef.size() < 3) {
LLVM_DEBUG(
DBGS() << "FAILURE: destructive updates -> slice must have >= 3 ops\n");
return failure();
}
// Match yieldOp and update tmpSliceRef.
TerminatorOp terminatorOp;
if (failed(matchAndDropBack(tmpSliceRef, terminatorOp))) {
LLVM_DEBUG(
DBGS() << "FAILURE: destructive updates -> slice unknown terminator\n");
return failure();
}
// Match subtensor pair and update tmpSliceRef.
// subtensor / subtensor_insert must match.
SubTensorOp subTensorOp;
SubTensorInsertOp subTensorInsertOp;
auto matchSubTensors = [](SubTensorOp st, SubTensorInsertOp sti) {
auto res = sameOffsetsSizesAndStrides(st, sti);
if (failed(res))
LLVM_DEBUG(
DBGS()
<< "FAILURE: destructive updates -> subtensor ops don't match: " << st
<< " and " << sti);
return res;
};
if (failed(matchAndDropEnclosingPair<SubTensorOp, SubTensorInsertOp>(
tmpSliceRef, subTensorOp, subTensorInsertOp, matchSubTensors)))
return failure();
// subtensor_insert must be used exactly by the terminator at index matching
// the candidate BlockArgument.
if (failed(isInPlaceSingleUseTerminatorValue(subTensorInsertOp.result(),
terminatorOp, candidate))) {
LLVM_DEBUG(DBGS() << "FAILURE: destructive updates -> SubTensorInsertOp "
"does not have a single terminator use "
"at the right index\n");
return failure();
}
// Maybe match vector transfer pair and update tmpSliceRef.
// If we find one, the other must be present and match too.
vector::TransferReadOp vectorTransferReadOp;
vector::TransferWriteOp vectorTransferWriteOp;
auto matchTransfers = [&](vector::TransferReadOp read,
vector::TransferWriteOp write) {
return matchingVectorTransfersAtSource(read, write, subTensorOp.result());
};
if (failed(matchAndDropEnclosingPair<vector::TransferReadOp,
vector::TransferWriteOp>(
tmpSliceRef, vectorTransferReadOp, vectorTransferWriteOp,
matchTransfers)) &&
(vectorTransferReadOp || vectorTransferWriteOp))
return failure();
// Commit what has been detected.
inPlaceOpResults.push_back(subTensorOp->getResult(0));
if (vectorTransferReadOp)
inPlaceOpResults.push_back(vectorTransferReadOp->getResult(0));
if (vectorTransferWriteOp)
inPlaceOpResults.push_back(vectorTransferWriteOp->getResult(0));
inPlaceOpResults.push_back(subTensorInsertOp->getResult(0));
// No action for the terminator.
tmpSliceRef = sliceRef;
LLVM_DEBUG(DBGS() << "SUCCESS: destructive updates pattern\n");
return success();
}
namespace detail {
// TODO: generalize and refactor.
// TODO: do we need more safeguards for setting ops inPlace ?
// The following uses internal knowledge of the position of tied operand /
// results. A proper TieOperandInterface would be much better.
static void propagateInPlace(const SmallVector<OpOperand *> &initalWorklist,
const DominanceInfo &domInfo) {
LLVM_DEBUG(DBGS() << "Start propagateInPlace from initial WL\n");
for (OpOperand *operand : initalWorklist)
LLVM_DEBUG(DBGS() << "WL item: " << operand->get() << " used by "
<< *operand->getOwner() << "\n");
SmallVector<OpOperand *> worklist(initalWorklist);
for (unsigned idx = 0; idx < worklist.size(); ++idx) {
OpOperand &operand = *worklist[idx];
LLVM_DEBUG(DBGS() << "WL item: " << *operand.getOwner() << "\n");
// If the owner turns out to be a CallOp without `kInPlaceArgsAttrName`
// this will be a noop.
if (operand.get().getType().isa<RankedTensorType>() &&
!hasInterferingTensorRead(operand, domInfo)) {
LLVM_DEBUG(DBGS() << "no interfering read\n");
setInPlaceOpResult(getTiedOpResult(operand));
}
LLVM_DEBUG(DBGS() << "propagatedInPlace: " << *operand.getOwner() << "\n");
// use can have interfering reads that prevent it from being written inPlace
// but the values it produces are still themselves candidates for inPlace at
// their point of use.
for (Value v : operand.getOwner()->getResults()) {
LLVM_DEBUG(DBGS() << "propagate result: " << v << "\n");
for (auto &use : v.getUses()) {
LLVM_DEBUG(DBGS() << "add use to WL: " << use.get() << "\n");
worklist.push_back(&use);
}
}
}
}
} // namespace detail
static void propagateInPlace(OpOperand &opOperand,
const DominanceInfo &domInfo) {
SmallVector<OpOperand *> worklist{&opOperand};
::detail::propagateInPlace(worklist, domInfo);
}
static void propagateInPlace(BlockArgument &bbArg,
const DominanceInfo &domInfo) {
SmallVector<OpOperand *> worklist;
for (auto &use : bbArg.getUses()) worklist.push_back(&use);
::detail::propagateInPlace(worklist, domInfo);
}
/// Iterate over bbArgs of `parentOp` and determine if they are the root of a
/// known destructive update chain. Such a destructive update is related to
/// traditional loop nest + memory analysis but provides a simpler
/// abstraction. In traditional memory-based dependence analysis, one would
/// need to analyze all possible interleavings of possibly aliasing loads and
/// stores in the context of the k-common surrounding loops. With scf.for +
/// subtensor + subtensor_insert + yield, more ordering semantics are
/// available as well as dealiasing thanks to SSA use-def chains.
static void destructiveUpdateAnalysis(Block *block,
const DominanceInfo &domInfo) {
Operation *parentOp = block->getParentOp();
// In this loop, we do not check whether `candidate` can itself be bufferized
// inPlace: this is not a consideration for the inside of `block`.
for (BlockArgument candidate : block->getArguments()) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LLVM_DEBUG(DBGS() << "Destructive update analysis on candidate: "
<< candidate << "\nof:\n"
<< *parentOp << "\n");
if (!candidate.getType().isa<ShapedType>()) {
LLVM_DEBUG(DBGS() << "Not a tensor\n");
continue;
}
llvm::SetVector<Operation *> slice;
getForwardSlice(candidate, &slice, [&](Operation *op) {
// Skip any extra nesting between parentOp and op.
return op == parentOp || op->getBlock()->getParentOp() == parentOp;
});
LLVM_DEBUG(DBGS() << "Slice:\n");
for (auto *op : slice) LLVM_DEBUG(DBGS() << *op << "\n");
SmallVector<OpResult> inPlaceOpResults;
inPlaceOpResults.reserve(slice.size());
ArrayRef<Operation *> sliceRef = slice.getArrayRef();
if (failed(detectDestructiveUpdatePattern<scf::ForOp, scf::YieldOp>(
parentOp, candidate, sliceRef, inPlaceOpResults)) &&
failed(detectOverWritePattern<scf::ForOp, scf::YieldOp>(
parentOp, candidate, sliceRef, inPlaceOpResults)) &&
failed(detectLinalgReturn<scf::ForOp, scf::YieldOp>(
parentOp, candidate, sliceRef, inPlaceOpResults)) &&
failed(detectDestructiveUpdatePattern<FuncOp, ReturnOp>(
parentOp, candidate, sliceRef, inPlaceOpResults)) &&
failed(detectOverWritePattern<FuncOp, ReturnOp>(
parentOp, candidate, sliceRef, inPlaceOpResults)) &&
failed(detectLinalgReturn<FuncOp, ReturnOp>(
parentOp, candidate, sliceRef, inPlaceOpResults))) {
LLVM_DEBUG(DBGS() << "Failed to detect a destructive update pattern\n");
continue;
}
// Mark ops inPlace eagerly.
for (auto &res : inPlaceOpResults) setInPlaceOpResult(res);
propagateInPlace(candidate, domInfo);
}
}
void LinalgComprehensiveBufferizePass::inPlaceAnalysisFuncOpInternals(
FuncOp funcOp, const DominanceInfo &domInfo) {
if (!funcOp || funcOp->getNumRegions() == 0 || funcOp.body().empty()) return;
// Start propagating from InitTensorOps.
funcOp.walk<WalkOrder::PreOrder>([&](InitTensorOp initTensorOp) {
for (auto &use : initTensorOp->getUses()) propagateInPlace(use, domInfo);
});
// Start propagating from FuncOp bbArgs.
destructiveUpdateAnalysis(&funcOp.body().front(), domInfo);
// Start propagating from scf::ForOps.
funcOp.walk<WalkOrder::PreOrder>([&](scf::ForOp forOp) {
destructiveUpdateAnalysis(&forOp.region().front(), domInfo);
});
}
/// Analyze a `callOp` to a FuncOp and determine whether any of its tensor
/// operand could be safely written inPlace after it is converted to buffer
/// form by a bufferization process. Iterate on the uses of callOp's operands
/// to determine whether all such uses dominate callOp. If any use of an
/// operand does not dominate `callOp`, this means that the operand tensor
/// value may be needed somewhere else and it is illegal to update in-place
/// after bufferization. Add a `kInPlaceResultsAttrName` string attribute to
/// `callOp` to carry the result of this analysis until bufferization is
/// completed. The "meet" of all `kInPlaceResultsAttrName` for all `callOp` to a
/// given FuncOp determines the `kInPlaceResultsAttrName` for that FuncOp.
static void funcArgumentsInPlaceAnalysis(CallOpInterface callOp,
const DominanceInfo &domInfo) {
FuncOp funcOp = getCalledFunction(callOp);
if (!funcOp || funcOp.body().empty()) return;
if (llvm::none_of(callOp->getOperandTypes(),
[](Type t) { return t.isa<TensorType>(); }))
return;
LLVM_DEBUG(DBGS() << "Begin funcArgumentsInPlaceAnalysis within:\n"
<< *callOp->getParentOfType<FuncOp>()
<< "callOp: " << *callOp << "\n";);
for (OpOperand &opOperand : callOp->getOpOperands()) {
Value tensor = opOperand.get();
if (!tensor.getType().isa<TensorType>()) continue;
unsigned idx = opOperand.getOperandNumber();
LLVM_DEBUG(DBGS() << "tensor @idx=" << idx << ": " << tensor << "\n");
// FuncOp inPlace is the meet of all the calls. If we already know it
// cannot be bufferized inPlace, just skip. Can't easily connect arguments
// to results in FuncOp: use explicit idx.
InPlaceSpec funcInPlace = getInPlace(funcOp.getArgument(idx));
if (funcInPlace == InPlaceSpec::False) continue;
InPlaceSpec callInPlace = hasInterferingTensorRead(opOperand, domInfo)
? InPlaceSpec::False
: InPlaceSpec::True;
setInPlaceCallArgument(opOperand, callInPlace);
setInPlaceFuncArgument(funcOp.getArgument(idx), callInPlace);
}
LLVM_DEBUG(DBGS() << "End funcArgumentsInPlaceAnalysis within:\n"
<< *callOp->getParentOfType<FuncOp>()
<< "callOp: " << *callOp << "\n";);
}
//===----------------------------------------------------------------------===//
// Bufferization as simple BlockAndValueMapping rewrites / without
// conversions.
//===----------------------------------------------------------------------===//
/// Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
/// This works on mixed tensor + buffer Linalg ops: some results may have been
/// already bufferized by a previous destructive update bufferization.
/// Allocate the output buffers for the remaining tensor output operands of
/// the Linalg op. If the tensor is an "init" tensor (i.e. its value is
/// actually used in the payload region), we additionally copy the original
/// value into the newly allocated buffer.
static LogicalResult allocateBuffersForResults(
OpBuilder &b, Location loc, LinalgOp op,
SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
// Lazily compute loopRanges.
SmallVector<Range, 4> loopRanges;
// Linalg invariant: output tensors and result match 1-1.
assert(op.getNumOutputTensors() == op->getNumResults());
for (auto &opOperand : op.getOutputOpOperands()) {
Value output = opOperand.get();
if (output.getType().isa<MemRefType>()) {
resultBuffers.push_back(output);
continue;
}
// If output tensor is marked inPlace, just use the buffer.
// The following uses internal knowledge of the position of tied operand /
// results. A proper TieOperandInterface would be much better.
if (getInPlace(getTiedOpResult(opOperand)) == InPlaceSpec::True) {
resultBuffers.push_back(lookup(bvm, output));
continue;
}
Value dimTensor = bvm.lookupOrDefault(output);
Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, dimTensor);
b.setInsertionPointAfter(alloc.getDefiningOp());
resultBuffers.push_back(alloc);
// Additionally, if the output buffer is used, clone its value for now.
if (op.payloadUsesValueFromOpOperand(&opOperand))
b.create<CopyOp>(loc, lookup(bvm, output), alloc);
}
map(bvm, op->getResults(), resultBuffers);
for (auto it : llvm::zip(op->getResults(), resultBuffers)) {
transferDimOpsToMemref(std::get<0>(it), std::get<1>(it));
}
return success();
}
// Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
static void finalizeBufferAllocation(OpBuilder &b, LinalgOp op,
ValueRange inputs, ValueRange outputs,
BlockAndValueMapping &bvm) {
SmallVector<Value, 8> newOperands = inputs;
newOperands.append(outputs.begin(), outputs.end());
auto otherOperands = op.getAssumedNonShapedOperands();
newOperands.append(otherOperands.begin(), otherOperands.end());
Location loc = op.getLoc();
op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands);
// Replace the results of the old op with the new output buffers.
map(bvm, op.getOperation()->getResults(), outputs);
for (auto it : llvm::zip(op.getOperation()->getResults(), outputs)) {
transferDimOpsToMemref(std::get<0>(it), std::get<1>(it));
}
if (!op.hasTensorSemantics()) op->erase();
}
/// Generic conversion pattern that matches any LinalgOp. This avoids
/// template instantiating one pattern for each LinalgOp.
/// This works on mixed tensor + buffer Linalg ops: some results may have been
/// already bufferized by a previous destructive update bufferization.
static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op,
BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
if (op.hasBufferSemantics()) return failure();
LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
b.setInsertionPoint(op);
Location loc = op.getLoc();
SmallVector<Value, 2> newInputBuffers;
newInputBuffers.reserve(op.getNumInputs());
for (Value v : op.getInputs()) {
newInputBuffers.push_back(lookup(bvm, v));
}
SmallVector<Value, 2> newOutputBuffers;
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm)))
assert(false);
// Delegate to the linalg generic pattern.
if (auto genericOp = dyn_cast<GenericOp>(op.getOperation())) {
finalizeBufferAllocation(b, genericOp, newInputBuffers, newOutputBuffers,
bvm);
return success();
}
finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm);
return success();
}
static LogicalResult convertTransferOp(OpBuilder &b,
VectorTransferOpInterface op,
BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
Location loc = op.getLoc();
if (op.getShapedType().isa<MemRefType>()) return failure();
LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
/// transfer_read from buffer
if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
readOp.sourceMutable().assign(lookup(bvm, op.source()));
return success();
}
auto inPlace = getInPlace(op->getResult(0));
auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
// If transfer_write is not inPlace, allocate a new buffer.
Value newInputBuffer;
if (inPlace != InPlaceSpec::True) {
newInputBuffer =
createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result());
b.setInsertionPointAfter(newInputBuffer.getDefiningOp());
map(bvm, writeOp.result(), newInputBuffer);
transferDimOpsToMemref(writeOp.result(), newInputBuffer);
} else {
// InPlace write will result in tensor_load(x) which must canonicalize
// away with one of it uses.
newInputBuffer = lookup(bvm, writeOp.source());
}
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
b.create<vector::TransferWriteOp>(
loc, writeOp.vector(), newInputBuffer, writeOp.indices(),
writeOp.permutation_map(),
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
map(bvm, op->getResult(0), newInputBuffer);
return success();
}
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp,
BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&funcOp.body().front());
for (auto bbArg : funcOp.getArguments()) {
auto rankedTensorType = bbArg.getType().dyn_cast<RankedTensorType>();
if (!rankedTensorType) continue;
MemRefType memRefType = getDynamicMemRefType(rankedTensorType);
Value tensorToMemref =
b.create<TensorToMemrefOp>(funcOp.getLoc(), memRefType, bbArg);
map(bvm, bbArg, tensorToMemref);
}
return success();
}
static LogicalResult convertScfForOp(OpBuilder &b, scf::ForOp forOp,
BlockAndValueMapping &bvm) {
LLVM_DEBUG(DBGS() << "convert: " << *forOp << "\n");
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(forOp.getBody());
// If inPlace, just forward the buffer.
// Otherwise alloc and copy.
b.setInsertionPointAfter(forOp);
for (auto it : llvm::zip(forOp.getRegionIterArgs(), forOp->getResults())) {
BlockArgument bbArg = std::get<0>(it);
if (!bbArg.getType().isa<RankedTensorType>()) continue;
OpResult opResult = std::get<1>(it);
Value operand = forOp.getIterOperands()[opResult.getResultNumber()];
Value operandBuffer = lookup(bvm, operand);
if (getInPlace(bbArg) != InPlaceSpec::True) {
Value alloc =
createNewAllocDeallocPairForShapedValue(b, forOp.getLoc(), operand);
// If the tensor comes from `linalg::InitTensorOp`, the value is
// unitialized and we do not need to copy.
if (!operand.getDefiningOp<linalg::InitTensorOp>())
b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, alloc);
operandBuffer = alloc;
}
map(bvm, bbArg, operandBuffer);
map(bvm, opResult, operandBuffer);
}
return success();
}
static LogicalResult convertScfYieldOp(OpBuilder &b, scf::YieldOp yieldOp,
BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(yieldOp);
scf::ForOp forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
assert(forOp && "only support scf::ForOp parent for scf::YieldOp");
for (OpOperand &operand : yieldOp->getOpOperands()) {
auto rankedTensorType =
operand.get().getType().dyn_cast<RankedTensorType>();
if (!rankedTensorType) continue;
auto bbArg = forOp.getRegionIterArgs()[operand.getOperandNumber()];
if (getInPlace(bbArg) == InPlaceSpec::True)
operand.set(bbArg);
else
operand.set(b.create<TensorLoadOp>(yieldOp.getLoc(), lookup(bvm, bbArg)));
}
return success();
}
static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp,
BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(returnOp);
FuncOp funcOp = cast<FuncOp>(returnOp->getParentOp());
assert(funcOp && "only support scf::ForOp parent for scf::YieldOp");
for (OpOperand &operand : returnOp->getOpOperands()) {
auto rankedTensorType =
operand.get().getType().dyn_cast<RankedTensorType>();
if (!rankedTensorType) continue;
operand.set(
b.create<TensorLoadOp>(returnOp.getLoc(), lookup(bvm, operand.get())));
}
return success();
}
/// InitTensor always allocates.
/// TODO: hoist across function boundaries prior to bufferization.
static LogicalResult convertInitTensorOp(OpBuilder &b,
InitTensorOp initTensorOp,
BlockAndValueMapping &bvm) {
LLVM_DEBUG(DBGS() << "convert: " << *initTensorOp << "\n");
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(initTensorOp);
Value alloc = createNewAllocDeallocPairForShapedValue(
b, initTensorOp->getLoc(), initTensorOp.result(), initTensorOp.sizes());
map(bvm, initTensorOp.result(), alloc);
return success();
}
// This implementation is a shortcut that assumes the tile size divides the
// problem size and is generally incorrect.
// TODO: revisit this.
static LogicalResult convertPadTensorOp(OpBuilder &b, PadTensorOp padTensorOp,
BlockAndValueMapping &bvm) {
LLVM_DEBUG(DBGS() << "convert: " << *padTensorOp << "\n");
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(padTensorOp);
auto tensorType = padTensorOp.result().getType().cast<RankedTensorType>();
auto sourceMemRef = lookup(bvm, padTensorOp.source());
auto sourceMemRefType = sourceMemRef.getType().cast<MemRefType>();
auto memRefType =
getContiguousMemRefType(tensorType, sourceMemRefType.getAffineMaps(),
sourceMemRefType.getMemorySpaceAsInt());
Value res =
b.create<MemRefCastOp>(padTensorOp.getLoc(), memRefType, sourceMemRef);
map(bvm, padTensorOp.result(), res);
return success();
}
/// SubTensorInsertOp never allocates but may copy if it is not marked
/// inPlace.
static LogicalResult convertSubTensorInsertOp(
OpBuilder &b, SubTensorInsertOp subTensorInsertOp,
BlockAndValueMapping &bvm) {
LLVM_DEBUG(DBGS() << "convert: " << *subTensorInsertOp << "\n");
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(subTensorInsertOp);
Location loc = subTensorInsertOp.getLoc();
Value dstMemref;
auto inPlace = getInPlace(subTensorInsertOp->getResult(0));
// subtensor_insert must be inPlace, otherwise this is considered a bug.
if (inPlace != InPlaceSpec::True) {
llvm_unreachable("SubTensorInsertOp must be inPlace");
} else {
// InPlace write will result in tensor_load(x) which must canonicalize
// away with one of it uses.
dstMemref = lookup(bvm, subTensorInsertOp.dest());
}
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
Value srcMemref = lookup(bvm, subTensorInsertOp.source());
auto subviewMemRefType =
SubViewOp::inferRankReducedResultType(
subTensorInsertOp.getSourceType().getRank(), dstMemrefType,
subTensorInsertOp.getMixedOffsets(),
subTensorInsertOp.getMixedSizes(),
subTensorInsertOp.getMixedStrides())
.cast<MemRefType>();
// Take a subview of the dst.
Value subView = b.create<SubViewOp>(
loc, subviewMemRefType, dstMemref, subTensorInsertOp.getMixedOffsets(),
subTensorInsertOp.getMixedSizes(), subTensorInsertOp.getMixedStrides());
// Linalg op and vector.transfer_write producers directly write their output
// buffer. If the producer is not one of these ops, we need to copy.
Value source = subTensorInsertOp.source();
InPlaceSpec inPlaceProducer = InPlaceSpec::None;
if (isa<LinalgOp, vector::TransferWriteOp>(source.getDefiningOp()))
inPlaceProducer = getInPlace(source.cast<OpResult>());
if (inPlaceProducer != InPlaceSpec::True)
b.create<CopyOp>(subTensorInsertOp.getLoc(), srcMemref, subView);
map(bvm, subTensorInsertOp.result(), subView);
return success();
}
/// SubTensorOpnever allocates or copies.
static LogicalResult convertSubTensorOp(OpBuilder &b, SubTensorOp subTensor,
BlockAndValueMapping &bvm) {
LLVM_DEBUG(DBGS() << "convert: " << *subTensor << "\n");
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(subTensor);
Location loc = subTensor.getLoc();
Value srcMemref = lookup(bvm, subTensor.source());
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType = subTensor.result().getType().cast<RankedTensorType>();
auto subviewMemRefType =
SubViewOp::inferRankReducedResultType(
dstTensorType.getRank(), srcMemrefType, subTensor.getMixedOffsets(),
subTensor.getMixedSizes(), subTensor.getMixedStrides())
.cast<MemRefType>();
Value subView = b.create<SubViewOp>(
loc, subviewMemRefType, srcMemref, subTensor.getMixedOffsets(),
subTensor.getMixedSizes(), subTensor.getMixedStrides());
map(bvm, subTensor.result(), subView);
return success();
}
/// TensorCastOp just lowers to MemRefCastOp.
static LogicalResult convertTensorCastOp(OpBuilder &b, tensor::CastOp castOp,
BlockAndValueMapping &bvm) {
LLVM_DEBUG(DBGS() << "convert: " << *castOp << "\n");
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
auto sourceMemRefType =
lookup(bvm, castOp.source()).getType().dyn_cast<MemRefType>();
Type memRefType;
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
if (tensorType.isa<UnrankedTensorType>()) {
memRefType = UnrankedMemRefType::get(
tensorType.getElementType(), sourceMemRefType.getMemorySpaceAsInt());
} else {
memRefType =
getContiguousMemRefType(tensorType, sourceMemRefType.getAffineMaps(),
sourceMemRefType.getMemorySpaceAsInt());
}
Value res = b.create<MemRefCastOp>(castOp.getLoc(), memRefType,
lookup(bvm, castOp.source()));
map(bvm, castOp.getResult(), res);
return success();
}
/// Return a FuncOp block argument if the `returnOperand` is produced by an
/// inPlace update pattern. Return the function argument that `returnOperand`
/// traces back to, if the following pattern is detected:
/// ```
/// func @foo(%A: tensor<...>, ..., %Z: tensor<...>) ->
/// (tensor<...>, ..., tensor<...>)
/// #inPlace_attr_specification
/// {
/// %p = tensor_to_memref(%some_arg): ...
/// ... // uses of %p (read or writes)
/// %t = tensor_load %p: ...
/// return ..., %t, ...: ..., tensor<...>, ...
/// }
/// ```
/// Otherwise return nullptr.
static BlockArgument analyzeTiedFuncOpResults(OpOperand &returnOperand) {
assert(isa<ReturnOp>(returnOperand.getOwner()));
FuncOp funcOp =
cast<FuncOp>(returnOperand.get().getParentBlock()->getParentOp());
Value returnValue = returnOperand.get();
// Only consider ranked tensors for folding.
if (!returnValue.getType().isa<RankedTensorType>()) return BlockArgument();
// If returned value is a bbArg, it folds iff it is a function argument.
if (auto bbArg = returnValue.dyn_cast<BlockArgument>())
return (bbArg == funcOp.getArgument(bbArg.getArgNumber()))
? bbArg
: BlockArgument();
// Otherwise we look for tensor_load(tensor_to_memref(bbArg)).
auto tensorLoadOp = returnValue.getDefiningOp<TensorLoadOp>();
if (!tensorLoadOp) return BlockArgument();
auto tensorToMemRefOp =
tensorLoadOp.memref().getDefiningOp<TensorToMemrefOp>();
if (!tensorToMemRefOp) return BlockArgument();
// If returned value is a bbArg, it only folds if it is a function
// argument.
if (auto bbArg = tensorToMemRefOp.tensor().dyn_cast<BlockArgument>())
return (bbArg == funcOp.getArgument(bbArg.getArgNumber()))
? bbArg
: BlockArgument();
return BlockArgument();
}
static bool hasOnlyTensorToMemRefUses(Value v) {
for (auto &use : v.getUses())
if (!isa<TensorToMemrefOp>(use.getOwner())) return false;
return true;
}
/// Search `funcOp` for the following pattern for each result to determine
/// whether it can fold onto an argument:
/// ```
/// func @foo(%A: tensor<...>, ..., %Z: tensor<...>) ->
/// (tensor<...>, ..., tensor<...>)
/// #inPlace_attr_specification
/// {
/// %p = tensor_to_memref(%some_arg): ...
/// ... // uses of %p (read or writes)
/// %t = tensor_load %p: ...
/// return ..., %t, ...: ..., tensor<...>, ...
/// }
/// ```
/// Information for such inPlace-bufferizable operands and the corresponding
/// result is added to `tiedResultsMap`.
/// Rewrite the `funcOp` arguments analysis return values and terminator into
/// buffer form (using the canonical memref layout for now), according to the
/// inPlace-bufferizable information added to `tiedResultsMap`.
static void bufferizeFuncOpBoundary(
FuncOp funcOp, DenseMap<FuncOp, SmallVector<int64_t>> &tiedResultsMap) {
// Bail on pure declarations.
if (funcOp.getBody().empty()) return;
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp);
// 1. Analyze inplace return patterns and set an entry in `tiedResultsMap`.
// Assume the last block terminator is the funcOp return.
// TODO: Double-check this.
auto returnOp = cast<ReturnOp>(funcOp.body().back().getTerminator());
SmallVector<int64_t> resultArgumentFolding(
funcOp.type().cast<FunctionType>().getNumResults(), -1);
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
BlockArgument bbArg = analyzeTiedFuncOpResults(returnOperand);
if (!bbArg) continue;
// If the bbArg is not null, we still need to check the func arg is inPlace
// writeable.
if (getInPlace(bbArg) != InPlaceSpec::True) continue;
// Mark bbArg as inPlace bufferizable.
unsigned returnIndex = returnOperand.getOperandNumber();
resultArgumentFolding[returnIndex] = bbArg.getArgNumber();
}
tiedResultsMap.insert(std::make_pair(funcOp, resultArgumentFolding));
LLVM_DEBUG(
DBGS() << "Computed tiedResultsMap:"
<< OpBuilder(funcOp).getIndexArrayAttr(resultArgumentFolding));
// 2. Traverse terminator, skip return values that are inPlace bufferizable.
OpBuilder b(returnOp);
SmallVector<Value> returnValues;
for (auto en : enumerate(resultArgumentFolding)) {
LLVM_DEBUG(DBGS() << "return idx: " << en.index()
<< " inPlace bufferizable on input " << en.value()
<< "\n");
// Return value folds on some input.
if (en.value() >= 0) continue;
// Return value does not fold, add it to the new return op.
Value unfolded = returnOp->getOperand(en.index());
if (auto tensorLoadOp = unfolded.getDefiningOp<TensorLoadOp>()) {
unfolded = tensorLoadOp.memref();
for (Operation *user : llvm::make_early_inc_range(unfolded.getUsers()))
if (isa<DeallocOp>(user)) user->erase();
}
returnValues.push_back(unfolded);
if (unfolded.getType().isa<MemRefType>()) {
funcOp->dump();
llvm::errs() << "return val is not inPlace bufferizable: "
<< returnValues.back() << "\n";
abort();
}
}
// 3. Rewrite the terminator without the inPlace bufferizable values.
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
returnOp->erase();
// 4. Rewrite the FuncOp type to buffer form.
// TODO: Generalize the use of contiguous MemRef at the function boundary.
auto argTypes = llvm::to_vector<4>(
llvm::map_range(funcOp.getArguments(), [](BlockArgument bbArg) -> Type {
// TODO: non-zero address space.
// TODO: layout information if relevant.
if (auto tensorType = bbArg.getType().dyn_cast<RankedTensorType>())
return getContiguousMemRefType(tensorType);
return bbArg.getType();
}));
funcOp.setType(FunctionType::get(funcOp->getContext(), argTypes,
ValueRange{returnValues}.getTypes()));
// 5. Rewrite the bbArgs.
Block &frontBlock = funcOp.body().front();
unsigned numArgs = frontBlock.getNumArguments();
// Iterate on the original `numArgs` and replace them in order.
// This guarantees the argument order still matches after the rewrite.
for (unsigned idx = 0; idx < numArgs; ++idx) {
auto bbArg = frontBlock.getArgument(0);
auto tensorType = bbArg.getType().dyn_cast<RankedTensorType>();
if (!tensorType) {
frontBlock.addArgument(bbArg.getType());
bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
} else {
// TODO: non-zero address space.
// TODO: layout information if relevant.
Value memref =
frontBlock.addArgument(getContiguousMemRefType(tensorType));
OpBuilder b(funcOp->getContext());
// No InsertionGuard needed here.
b.setInsertionPointToStart(&frontBlock);
// If the bbArg is only used by TensorToMemRef, we can directly replace
// them by a simple MemRefCastOp.
if (hasOnlyTensorToMemRefUses(bbArg)) {
for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
Value tensorToMemRef = use.getOwner()->getResult(0);
tensorToMemRef.replaceAllUsesWith(b.create<MemRefCastOp>(
funcOp.getLoc(), tensorToMemRef.getType(), memref));
use.getOwner()->erase();
}
} else {
// Otherwise, there are uses that are not TensorToMemRefOp, we need to
// insert a TensorLoadOp. Subsequent canonicalizations that perform:
// `tensor_to_memref(tensor_load(x)) -> x` will later occur.
Value tensor = b.create<TensorLoadOp>(funcOp->getLoc(), memref);
bbArg.replaceAllUsesWith(tensor);
}
}
frontBlock.eraseArgument(0);
}
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp);
}
/// Bufferize a single function call. Fold results that have a nonnegative entry
/// in `tiedResults` onto the proper operand.
static void bufferizeOneFunctionCall(CallOpInterface callOp,
BlockAndValueMapping &bvm,
const DominanceInfo &domInfo,
const SmallVector<int64_t> &tiedResults) {
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && !funcOp.body().empty());
LLVM_DEBUG(DBGS() << "Begin bufferizeOneFunctionCall: " << callOp << "\n");
// 1. Rewrite tensor operands as memrefs. For now, only allow either using:
// a. a memref from the `bvm`, or
// b. the memref fed to a tensor_load, if it does not itself come from a
// tensor_to_memref.
SmallVector<Value> newOperands(callOp->getOperands());
for (Value &v : newOperands) {
if (!v.getType().isa<RankedTensorType>()) continue;
if ((v = bvm.lookupOrNull(v))) continue;
// TODO: how dangerous is this at this point in spacetime ?
if (auto tensorLoadOp = v.getDefiningOp<TensorLoadOp>()) {
if (!isa<TensorToMemrefOp>(tensorLoadOp.memref().getDefiningOp())) {
v = tensorLoadOp.memref();
continue;
}
}
llvm::errs() << "operand: " << v << "\n";
llvm_unreachable("Operand does not come from a tensor_load");
}
// 2. Clone the CallOp with its attributes.
assert(isa<CallOp>(callOp.getOperation()) && "expected a CallOp");
OpBuilder b(callOp);
Operation *newCallOp = b.create<CallOp>(
callOp.getLoc(), funcOp.sym_name(),
funcOp.type().cast<FunctionType>().getResults(), newOperands);
newCallOp->setAttrs(callOp.getAttrs());
// 3. Prepare replacements for the old CallOp results.
unsigned newCallOpResultIndex = 0;
SmallVector<Value> replacements;
replacements.reserve(callOp->getNumResults());
for (OpResult oldRes : callOp->getResults()) {
// If not a ranked tensor, no changes, just replace the new result.
if (!oldRes.getType().isa<RankedTensorType>()) {
replacements.push_back(newCallOp->getResult(newCallOpResultIndex++));
continue;
}
// Disallow memref returns for now as they are generally ambiguous. This
// means we must have a non-negative `operandIndex`.
// TODO: when such cases occur, add an Alloc hoisting pass and create new
// inPlace function arguments.
int64_t operandIndex = tiedResults[oldRes.getResultNumber()];
if (operandIndex < 0) {
callOp->getParentOfType<FuncOp>().dump();
llvm_unreachable("Unsupported result memref");
}
// If the old callOp result is a ranked tensor that does not fold on some
// input, then there must be an allocated return value.
// That value should be deallocated by the caller.
// TODO: That value should be lifted out of the callee at the first
// enclosing parallel scope. This lifting should be done to (the meet of)
// all callers before we can hoist the alloc out of the funcOp.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointAfter(callOp);
replacements.push_back(
b.create<TensorLoadOp>(callOp.getLoc(), newOperands[operandIndex]));
}
callOp->replaceAllUsesWith(replacements);
callOp->erase();
LLVM_DEBUG(DBGS() << "Bufferized neighborhood:\n"
<< *newCallOp->getParentOp() << "\n");
LLVM_DEBUG(DBGS() << "End bufferizeOneFunctionCall.\n");
}
/// Perform bufferization at each FuncOp boundary and all CallOps within
/// `moduleOp`.
static void bufferizeFunctionsAndCalls(ModuleOp moduleOp,
BlockAndValueMapping &bvm) {
// For each function, analyze boundary tensor_load(tensor_to_memref(bbarg))
// patterns that result from bufferizing the internals of a FuncOp to rewrite
// function arguments / return values.
// `tiedResultsMap` is filled with a vector of tied result to operand indices.
DominanceInfo domInfo = DominanceInfo(moduleOp);
DenseMap<FuncOp, SmallVector<int64_t>> tiedResultsMap;
moduleOp.walk(
[&](FuncOp funcOp) { bufferizeFuncOpBoundary(funcOp, tiedResultsMap); });
// Bufferize calls, a `tiedResultsMap` entry must be present for the callee.
moduleOp.walk([&](CallOpInterface callOp) {
FuncOp funcOp = getCalledFunction(callOp);
if (!funcOp || funcOp.body().empty()) return;
bufferizeOneFunctionCall(callOp, bvm, domInfo,
tiedResultsMap.lookup(funcOp));
});
}
//===----------------------------------------------------------------------===//
// Bufferization passes.
//===----------------------------------------------------------------------===//
// Transformations that run iteratively with bufferization.
void LinalgComprehensiveBufferizePass::runEnablingTransforms(FuncOp funcOp) {
if (failed(runPipeline(enablingPassPipeline, funcOp)))
return signalPassFailure();
(void)runPipeline(enablingPassPipeline, funcOp);
linalg::hoistRedundantVectorTransfers(funcOp);
linalg::hoistRedundantVectorTransfersOnTensor(funcOp);
}
void LinalgComprehensiveBufferizePass::bufferizeFuncOpInternals(
FuncOp funcOp, BlockAndValueMapping &bvm) {
if (!funcOp || funcOp->getNumRegions() == 0 || funcOp.body().empty()) return;
LLVM_DEBUG(DBGS() << "Begin BufferizeFuncOpInternals:\n" << funcOp << "\n");
OpBuilder b(funcOp->getContext());
auto guard = llvm::make_scope_exit([&] {
funcOp.walk(
[&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); });
});
/// Start by converting `funcOp` arguments.
(void)succeeded(convertFuncOp(b, funcOp, bvm));
funcOp.walk<WalkOrder::PreOrder>([&](Operation *operation) {
llvm::TypeSwitch<Operation *, void>(operation)
.Case([&](scf::ForOp op) {
(void)succeeded(convertScfForOp(b, op, bvm));
})
.Case([&](InitTensorOp op) {
(void)succeeded(convertInitTensorOp(b, op, bvm));
})
.Case([&](SubTensorOp op) {
(void)succeeded(convertSubTensorOp(b, op, bvm));
})
.Case([&](SubTensorInsertOp op) {
(void)succeeded(convertSubTensorInsertOp(b, op, bvm));
})
.Case([&](tensor::CastOp op) {
(void)succeeded(convertTensorCastOp(b, op, bvm));
})
.Case([&](PadTensorOp op) {
(void)succeeded(convertPadTensorOp(b, op, bvm));
})
.Case([&](LinalgOp op) {
(void)succeeded(convertAnyLinalgOp(b, op, bvm));
})
.Case([&](VectorTransferOpInterface op) {
(void)succeeded(convertTransferOp(b, op, bvm));
})
.Case([&](scf::YieldOp op) {
(void)succeeded(convertScfYieldOp(b, op, bvm));
})
.Case(
[&](ReturnOp op) { (void)succeeded(convertReturnOp(b, op, bvm)); });
});
LLVM_DEBUG(DBGS() << "End BufferizeFuncOpInternals:\n" << funcOp << "\n");
}
namespace mlir {
std::unique_ptr<Pass> createLinalgComprehensiveBufferizePass() {
return std::make_unique<LinalgComprehensiveBufferizePass>();
}
namespace linalg {
void registerLinalgComprehensiveBufferizePass() {
PassRegistration<LinalgComprehensiveBufferizePass> pass(
"linalg-comprehensive-bufferize-inplace",
"Perform all required bufferization incantations to convert code with "
"Linalg ops on tensors to buffers with inPlace optimizations.");
}
} // namespace linalg
} // namespace mlir
static void postTransformSanityChecks(ModuleOp moduleOp,
BlockAndValueMapping &bvm) {
moduleOp.walk([&](Operation *op) {
op->removeAttr(kInPlaceResultsAttrName);
op->removeAttr(kInPlaceArgsAttrName);
assert(!isa<TensorToMemrefOp>(op));
if (auto tensorLoadOp = dyn_cast<TensorLoadOp>(op)) {
if (tensorLoadOp.memref().getDefiningOp<TensorToMemrefOp>()) {
op->getParentOfType<ModuleOp>()->dump();
op->emitWarning(
"Most likely incorrect pattern: tensor_load(tensor_to_memref)");
abort();
}
return;
}
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
for (auto result : callOp->getResults()) {
if (result.getType().isa<MemRefType>()) {
op->getParentOfType<ModuleOp>()->dump();
op->emitWarning(
"Most likely incorrect pattern: function returning memref -> "
"alloc needs to be hoisted out of function boundary");
abort();
}
}
return;
}
});
}
void LinalgComprehensiveBufferizePass::runOnOperation() {
ModuleOp moduleOp = getOperation();
// 0. Perform a bunch of enabling transformations related to canonicalizations
// CSE and hoisting.
moduleOp.walk([&](FuncOp funcOp) { runEnablingTransforms(funcOp); });
// 1. Perform inPlace analysis to mark the arguments/operands of all calls and
// functions that can be performed inPlace. The information set on the FuncOp
// is the meet of the information set on the all CallOp calling that FuncOp.
DominanceInfo domInfo(moduleOp);
moduleOp.walk([&](CallOpInterface callOp) {
funcArgumentsInPlaceAnalysis(callOp, domInfo);
});
// 2. Bufferize destructive update patterns within function boundaries.
BlockAndValueMapping bvm;
moduleOp.walk([&](FuncOp funcOp) {
// Perform bufferization within the funcOp boundary. This produces IR
// in a form on which `bufferizeFuncOpBoundary` can decide whether return
// values can fold onto operands.
inPlaceAnalysisFuncOpInternals(funcOp, domInfo);
bufferizeFuncOpInternals(funcOp, bvm);
});
// 3. Perform bufferization at each FuncOp boundary and all CallOps.
bufferizeFunctionsAndCalls(moduleOp, bvm);
// 4. Run cleanup pipeline.
moduleOp.walk([&](FuncOp funcOp) { runEnablingTransforms(funcOp); });
// 5. Sanity checks.
postTransformSanityChecks(moduleOp, bvm);
}