blob: 100c7b98ff67852cbb1c2d3d0fab0e6985696b7a [file] [log] [blame]
//===- LinalgSimpleBufferizePass.cpp - Bufferize Linalg on tensors --------===//
//
// Convert from Linalg ops on tensors to Linalg ops on buffers in a single pass.
// This will aggressively try to perform inplace bufferization and will fail if
// any allocation tries to cross function boundaries or if the pattern
// tensor_load(tensor_memref(x)) is deemed unsafe (very conservative impl for
// now).
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#define DEBUG_TYPE "linalg-comprehensive-bufferize-inplace"
using namespace mlir;
using namespace linalg;
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
namespace {
struct LinalgComprehensiveBufferizePass
: public PassWrapper<LinalgComprehensiveBufferizePass,
OperationPass<ModuleOp>> {
LinalgComprehensiveBufferizePass()
: enablingPassPipeline(OpPassManager("func")) {
enablingPassPipeline.addPass(createCanonicalizerPass());
enablingPassPipeline.addPass(createCSEPass());
enablingPassPipeline.addPass(createLoopInvariantCodeMotionPass());
}
LinalgComprehensiveBufferizePass(const LinalgComprehensiveBufferizePass &pass)
: enablingPassPipeline(pass.enablingPassPipeline) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LinalgDialect, scf::SCFDialect, StandardOpsDialect>();
}
void runOnOperation() override;
void runEnablingTransforms(FuncOp funcOp);
void bufferizeFuncOpInternals(FuncOp funcOp);
Option<bool> disableInPlace{
*this, "disable-inplace",
llvm::cl::desc(
"Disables inplace buferization. This is for testing purposes."),
llvm::cl::init(false)};
/// Dynamic pass pipeline of transformations that enable better inplace
/// bufferization.
OpPassManager enablingPassPipeline;
};
} // namespace
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
//===----------------------------------------------------------------------===//
/// Attribute marker to specify operands that can be bufferized inplace.
constexpr StringLiteral kInPlaceAttrName = "__inplace_attr__";
/// Attribute marker to specify results that fold onto input arguments.
constexpr StringLiteral kResultFoldArgAttrName = "__result_fold_arg_attr__";
// default clause
enum class InPlaceSpec {
False,
True,
None,
};
static StringRef stringify(InPlaceSpec val) {
switch (val) {
case InPlaceSpec::False:
return "false";
case InPlaceSpec::True:
return "true";
case InPlaceSpec::None:
return "none";
}
return "";
}
static Optional<InPlaceSpec> symbolize(StringRef str) {
return StringSwitch<Optional<InPlaceSpec>>(str)
.Case("false", InPlaceSpec::False)
.Case("true", InPlaceSpec::True)
.Case("none", InPlaceSpec::None)
.Default(None);
}
/// Set the attribute entry `kInPlaceAttrName`@`idx` to `inplace`.
/// If the attribute does not exist yet, add a blanket array attribute filled
/// with InPlaceSpec::None before setting `kInPlaceAttrName`@`idx` to `inplace`.
static void setInplace(Operation *op, unsigned idx = 0,
InPlaceSpec inplace = InPlaceSpec::True) {
auto attr = op->getAttr(kInPlaceAttrName);
assert(!attr || attr.isa<ArrayAttr>());
SmallVector<StringRef> pos;
if (!attr) {
auto funcOp = dyn_cast<FuncOp>(op);
pos = funcOp ? SmallVector<StringRef>(funcOp.getNumArguments(),
stringify(InPlaceSpec::None))
: SmallVector<StringRef>(op->getNumOperands(),
stringify(InPlaceSpec::None));
} else {
pos = llvm::to_vector<4>(
attr.cast<ArrayAttr>().getAsValueRange<StringAttr>());
}
LLVM_DEBUG(DBGS() << "Set inplace=" << stringify(inplace) << ": " << *op
<< " @idx=" << idx << "\n");
pos[idx] = stringify(inplace);
op->setAttr(kInPlaceAttrName, OpBuilder(op).getStrArrayAttr(pos));
}
static InPlaceSpec getInplace(Operation *op, unsigned operandIndex = 0) {
auto attr = op->getAttr(kInPlaceAttrName).dyn_cast_or_null<ArrayAttr>();
if (!attr) return InPlaceSpec::None;
assert(attr.size() > operandIndex);
// Must return a proper value.
return *symbolize(
*(attr.getAsValueRange<StringAttr>().begin() + operandIndex));
}
static Optional<int64_t> getResultFoldArgIndex(FuncOp op, unsigned resultIdx) {
auto attr = op->getAttr(kResultFoldArgAttrName).dyn_cast_or_null<ArrayAttr>();
if (!attr) return llvm::None;
APInt val = *(attr.getAsValueRange<IntegerAttr>().begin() + resultIdx);
int64_t res = val.getSExtValue();
if (res < 0) return llvm::None;
return res;
}
//===----------------------------------------------------------------------===//
// Bufferization-specific MemRefType support.
//===----------------------------------------------------------------------===//
/// Return the contiguous MemRefType (i.e. with canonical/empty layout map) to
/// which `type` can be bufferized to, assuming `type` is a RankedTensorType.
static MemRefType getContiguousMemRefType(Type type,
ArrayRef<AffineMap> layout = {},
unsigned addressSpace = 0) {
RankedTensorType tensorType = type.cast<RankedTensorType>();
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
layout, addressSpace);
}
/// Return a MemRefType to which the `tensorType` can be bufferized in a
/// composable fashion. The layout must be the most dynamic possible and
/// canonicalize away once bufferization is finished.
static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
unsigned addressSpace = 0) {
// TODO: address space decisions to connect with the actual alloc.
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
ShapedType::kDynamicStrideOrOffset);
AffineMap stridedLayout = makeStridedLinearLayoutMap(
dynamicStrides, dynamicOffset, tensorType.getContext());
return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
stridedLayout, addressSpace);
}
// Transfer all `dim` ops on `tensor` to `memref`.
static void transferDimOpsToMemref(Value tensor, Value memref) {
for (OpOperand &opOperand : llvm::make_early_inc_range(tensor.getUses())) {
if (isa<DimOp>(opOperand.getOwner())) {
opOperand.set(memref);
}
}
}
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
/// Wrapper for better debugging.
static void map(BlockAndValueMapping &bvm, ValueRange key, ValueRange value) {
if (key.empty()) return;
LLVM_DEBUG(DBGS() << "Map: " << key.front() << " to " << value.front()
<< "\n");
return bvm.map(key, value);
}
/// Wrapper for better debugging.
static void map(BlockAndValueMapping &bvm, Value key, Value value) {
LLVM_DEBUG(DBGS() << "Map: " << key << " to " << value << "\n");
return bvm.map(key, value);
}
/// Wrapper for better debugging.
static Value lookup(BlockAndValueMapping &bvm, Value key) {
if (!bvm.lookupOrNull(key)) {
MemRefType memRefType =
getDynamicMemRefType(key.getType().cast<RankedTensorType>());
Operation *op = key.getDefiningOp() ? key.getDefiningOp()
: key.getParentBlock()->getParentOp();
OpBuilder b(op->getContext());
// No InsertionGuard needed here.
if (auto blockArg = key.dyn_cast<BlockArgument>())
b.setInsertionPointToStart(blockArg.getParentBlock());
else
b.setInsertionPointAfter(op);
map(bvm, key, b.create<TensorToMemrefOp>(op->getLoc(), memRefType, key));
}
return bvm.lookup(key);
}
//===----------------------------------------------------------------------===//
// Bufferization-specific inplace pattern matching support.
//===----------------------------------------------------------------------===//
/// First assign `op` if `slice.back()` isa `T`, then check condition.
/// If anything fails just return failure. Otherwise update `sliceRef` by
/// dropping `sliceRef.back()`, then return success().
template <typename T>
static LogicalResult matchAndDropBack(
ArrayRef<Operation *> &sliceRef, T &op,
llvm::function_ref<LogicalResult(T)> condition = nullptr) {
if (sliceRef.empty()) return failure();
op = dyn_cast<T>(sliceRef.back());
if (!op || (condition && failed(condition(op)))) return failure();
sliceRef = sliceRef.drop_back();
return success();
}
/// First assign `op1`/`op2` if `slice.front()`/`slice.back()` isa `T1`/`T2`,
/// respectively. Then check condition. If anything fails just return failure.
/// Otherwise update `sliceRef` by dropping `sliceRef.front()` and
/// `sliceRef.back()`, then return success().
template <typename T1, typename T2>
static LogicalResult matchAndDropEnclosingPair(
ArrayRef<Operation *> &sliceRef, T1 &op1, T2 &op2,
llvm::function_ref<LogicalResult(T1, T2)> condition = nullptr) {
if (sliceRef.size() < 2) return failure();
op1 = dyn_cast<T1>(sliceRef.front());
op2 = dyn_cast<T2>(sliceRef.back());
if (!op1 || !op2 || (condition && failed(condition(op1, op2))))
return failure();
sliceRef = sliceRef.drop_front().drop_back();
return success();
}
//===----------------------------------------------------------------------===//
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
// TODO: need to hoist this across function boundaries. Maybe by using
// init_tensor + subtensor_insert.
static Value createNewAllocDeallocPairForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
SmallVector<Value, 4> dynOperands = {}) {
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
assert(memRefType || shapedValue.getType().dyn_cast<RankedTensorType>());
// TODO: non-zero address space.
// TODO: layout information if relevant.
if (!memRefType) memRefType = getContiguousMemRefType(shapedValue.getType());
OpBuilder::InsertionGuard g(b);
if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
b.setInsertionPointToStart(bbArg.getOwner());
loc = bbArg.getOwner()->getParentOp()->getLoc();
} else {
b.setInsertionPointAfter(shapedValue.getDefiningOp());
loc = shapedValue.getDefiningOp()->getLoc();
}
// If the dynOperands are not passed explicity, copmpute them.
// This circumvents currently missing dim(init_tensor) canonicalizations.
if (dynOperands.empty()) {
for (auto dim : llvm::enumerate(memRefType.getShape()))
if (dim.value() == ShapedType::kDynamicSize)
dynOperands.push_back(b.create<DimOp>(loc, shapedValue, dim.index()));
}
Value allocated = b.create<AllocOp>(loc, memRefType, dynOperands);
b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
b.create<DeallocOp>(loc, allocated);
return allocated;
}
//===----------------------------------------------------------------------===//
// Bufferization-specific inplace analysis support.
//===----------------------------------------------------------------------===//
/// Walk back the chain of known ops all the way to function arguments:
/// - if an AllocOp, AllocaOp or InitTensorOp is met, return true.
/// - if a LinalgOp is met, return true: either it is already known to trace
/// back to a function arg that is writeable or it is already guaranteed to
/// create an AllocOp into which we can write.
/// - if the function argument is marked inplace, return true.
/// - if the function argument is not marked inplace, return false.
/// - if an unknown op is encountered, abort for now.
static bool livesInWritableMemoryLocation(Value v) {
LLVM_DEBUG(DBGS() << "Start livesInWritableMemoryLocation @" << v << "\n");
bool done = false, res = false;
while (!done) {
// Scalar or vector value comes from a load, just return true.
if (!v.getType()
.isa<MemRefType, RankedTensorType, UnrankedMemRefType,
UnrankedTensorType>())
return true;
if (auto bbArg = v.dyn_cast<BlockArgument>()) {
llvm::TypeSwitch<Operation *, void>(bbArg.getOwner()->getParentOp())
.Case([&](scf::ForOp forOp) {
v = forOp.getIterOperands()[bbArg.getArgNumber() - /*iv=*/1];
})
.Case([&](FuncOp funcOp) {
assert(bbArg.getType().isa<TensorType>() &&
"already bufferized func");
if (getInplace(funcOp, bbArg.getArgNumber()) != InPlaceSpec::True)
res = false;
else
res = true;
done = true;
})
.Default([&](Operation *op) {
llvm::errs() << "In function:\n" << *op->getParentOfType<FuncOp>();
llvm::errs() << "\nUnsupported livesInWritableMemoryLocation "
<< *op << "\nstarting from value: " << v;
abort();
});
continue;
}
auto opResult = v.cast<OpResult>();
llvm::TypeSwitch<Operation *, void>(opResult.getOwner())
.Case([&](LinalgOp linalgOp) {
// TODO: uses implicit knowledge that output tensor matches result
// 1-1.
v = linalgOp.getOutputTensors()[opResult.getResultNumber()];
})
.Case<TensorToMemrefOp, TensorLoadOp, tensor::CastOp>(
[&](Operation *op) { v = op->getOperand(0); })
.Case<linalg::InitTensorOp, AllocOp, AllocaOp>([&](Operation *op) {
res = true;
done = true;
})
.Default([&](Operation *op) {
llvm::errs() << "In function:\n" << *op->getParentOfType<FuncOp>();
llvm::errs() << "\nUnsupported livesInWritableMemoryLocation " << *op
<< "\nstarting from value: " << v;
abort();
});
}
return res;
}
namespace {
// Represent an inplace action that is to be committed as an Operation attribute
// upon successful detection of a hain of ops that can be run inplace.
struct InPlaceAction {
Operation *op;
SmallVector<unsigned> outputIndices;
};
} // namespace
/// Find simple forms of destructive update which writes over a yielded tensor
/// without ever reading from it. For now, we only allow:
/// ```
/// vector.transfer_write -> subtensor_insert -> yield
/// ```
static void iterativeOverwritesAnalysis(Operation *parentOp,
ArrayRef<BlockArgument> candidates) {
if (!isa<scf::ForOp, FuncOp>(parentOp)) return;
for (auto en : llvm::enumerate(candidates)) {
Value candidate = en.value();
if (!candidate.getType().isa<ShapedType>()) continue;
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LLVM_DEBUG(DBGS() << "Iterative overwrite analysis on candidate: "
<< candidate << "\nof:\n"
<< *parentOp << "\n");
if (!livesInWritableMemoryLocation(candidate)) continue;
llvm::SetVector<Operation *> slice;
getForwardSlice(candidate, &slice, [&](Operation *op) {
// Skip any extra nesting between parentOp and op.
return op == parentOp || op->getBlock()->getParentOp() == parentOp;
});
LLVM_DEBUG(DBGS() << "Iterative overwrite TRY:\n");
LLVM_DEBUG(llvm::for_each(
slice, [](Operation *op) { DBGS() << "Slice op: " << *op << "\n"; }));
// bbArg must be used exactly by one subtensor_insert + yield.
if (!candidate.hasOneUse()) {
LLVM_DEBUG(DBGS() << "bbArg does not have exactly 1 use."
"\nIterative overwrite FAIL\n");
continue;
}
if (slice.size() != 2) {
LLVM_DEBUG(DBGS() << "Need exactly 2 ops in slice. "
"\nIterative overwrite FAIL\n");
continue;
}
auto sliceRef = slice.getArrayRef();
// Match yieldOp and update sliceRef.
scf::YieldOp yieldOp;
if (failed(matchAndDropBack(sliceRef, yieldOp))) continue;
// Match subTensorInsertOp and update sliceRef.
SubTensorInsertOp subTensorInsertOp;
if (failed(matchAndDropBack(sliceRef, subTensorInsertOp))) continue;
// Optional vector::TransferWriteOp.
auto vectorTransferWriteOp =
subTensorInsertOp.source().getDefiningOp<vector::TransferWriteOp>();
// subtensor_insert must be used exactly by the yield at index `idx`.
unsigned idx = en.index();
if (!subTensorInsertOp.result().hasOneUse() ||
!isa<scf::YieldOp>(*subTensorInsertOp.result().getUsers().begin()) ||
subTensorInsertOp.result().getUses().begin()->getOperandNumber() !=
idx) {
LLVM_DEBUG(DBGS() << "SubTensorInsertOp does not have a single YieldOp "
"use. \nIterative overwrite chain FAIL\n");
continue;
}
setInplace(parentOp, en.index());
if (vectorTransferWriteOp) setInplace(vectorTransferWriteOp);
setInplace(subTensorInsertOp);
setInplace(yieldOp, en.index());
LLVM_DEBUG(DBGS() << "Iterative overwrite chain SUCCESS\n");
}
}
/// Return true is all offsets, sizes and strides are equal.
static LogicalResult sameOffsetsSizesAndStrides(
OffsetSizeAndStrideOpInterface op1, OffsetSizeAndStrideOpInterface op2) {
if (op1.static_offsets().size() != op2.static_offsets().size())
return failure();
if (op1.static_sizes().size() != op2.static_sizes().size()) return failure();
if (op1.static_strides().size() != op2.static_strides().size())
return failure();
for (auto it : llvm::zip(op1.getMixedOffsets(), op2.getMixedOffsets()))
if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)))
return failure();
for (auto it : llvm::zip(op1.getMixedSizes(), op2.getMixedSizes()))
if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)))
return failure();
for (auto it : llvm::zip(op1.getMixedStrides(), op2.getMixedStrides()))
if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)))
return failure();
return success();
}
static LogicalResult matchingVectorTransfersAtSource(
vector::TransferReadOp read, vector::TransferWriteOp write,
Value subtensor) {
// Either we have a pair of matching transfer read/write or none.
if (read && !write) {
LLVM_DEBUG(DBGS() << "Slice has transferReadOp but no transferWriteOp"
"\nDestructive update chain FAIL\n");
return failure();
}
if (!read && write) {
LLVM_DEBUG(DBGS() << "Slice has transferWriteOp but no transferReadOp"
"\nDestructive update chain FAIL\n");
return failure();
}
if (read && write) {
// If we have a pair of mathing read/write, the tensor and vector shape
// must exactly match (i.e. this is a vectorization).
if (read.source() != subtensor) {
LLVM_DEBUG(DBGS() << "transferReadOp.source() != subTensor.result()"
"\nDestructive update chain FAIL\n");
return failure();
}
if (write.source() != subtensor) {
LLVM_DEBUG(DBGS() << "transferWriteOp.source() != subTensor.result()"
"\nDestructive update chain FAIL\n");
return failure();
}
if (read.getShapedType().getShape() != read.getVectorType().getShape()) {
LLVM_DEBUG(DBGS() << "transferReadOp source and result shapes mismatch"
"\nDestructive update chain FAIL\n");
return failure();
}
if (write.getShapedType().getShape() != write.getVectorType().getShape()) {
LLVM_DEBUG(DBGS() << "transferWriteOp source and result shapes mismatch"
"\nDestructive update chain FAIL\n");
return failure();
}
}
return success();
}
/// In the case of an scf::ForOp, we look for:
/// `candidate -> subtensor -> vector.transfer_read(*) -> ...
/// vector.transfer_write(*) -> subtensor_insert -> return`.
/// sliceRef is automaticaly updated to match `...`.
///
/// (*) represents an optional op in the chain, if a subtensor or
/// vector.transfer is included, the matching op must be included too.
static LogicalResult detectDestructiveUpdatePattern(
FuncOp parentOp, BlockArgument candidate, ArrayRef<Operation *> &sliceRef,
SmallVector<InPlaceAction> &inPlaceActions) {
if (!parentOp) return failure();
ReturnOp terminator;
// Match returnOp and update sliceRef.
if (failed(matchAndDropBack(sliceRef, terminator))) {
LLVM_DEBUG(DBGS() << "destructive update slice must end with a known "
"terminator.\nDestructive update chain FAIL\n");
return failure();
}
return success();
}
/// In the case of an scf::ForOp, we look for:
/// `candidate -> subtensor -> vector.transfer_read(*) -> ...
/// vector.transfer_write(*) -> subtensor_insert -> yield`.
/// sliceRef is automaticaly updated to match `...`.
///
/// (*) represents an optional op in the chain, if a subtensor or
/// vector.transfer is included, the matching op must be included too.
static LogicalResult detectDestructiveUpdatePattern(
scf::ForOp parentOp, BlockArgument candidate,
ArrayRef<Operation *> &sliceRef,
SmallVector<InPlaceAction> &inPlaceActions) {
if (!parentOp) return failure();
scf::YieldOp terminator;
SubTensorOp subTensorOp;
SubTensorInsertOp subTensorInsertOp;
vector::TransferReadOp vectorTransferReadOp;
vector::TransferWriteOp vectorTransferWriteOp;
// bbArg must be used exactly by one subtensor / subtensor_insert pair.
if (candidate.use_empty() || candidate.hasOneUse() ||
std::next(candidate.getUsers().begin(), 2) !=
candidate.getUsers().end()) {
LLVM_DEBUG(DBGS() << "bbArg does not have exactly 2 uses."
"\nDestructive update chain FAIL\n");
return failure();
}
if (sliceRef.size() < 3) {
LLVM_DEBUG(DBGS() << "scf::ForOp destructive updated must have >= 3 ops."
"\nDestructive update chain FAIL\n");
return failure();
}
// Match yieldOp and update sliceRef.
if (failed(matchAndDropBack(sliceRef, terminator))) {
LLVM_DEBUG(DBGS() << "destructive update slice must end with a known "
"terminator.\nDestructive update chain FAIL\n");
return failure();
}
// Match subtensor pair and update sliceRef.
// subtensor / subtensor_insert must match.
auto matchSubTensors = [](SubTensorOp st, SubTensorInsertOp sti) {
auto res = sameOffsetsSizesAndStrides(st, sti);
if (failed(res))
LLVM_DEBUG(DBGS() << "subtensor ops don't match: " << st << " and " << sti
<< "\nDestructive update chain FAIL\n");
return res;
};
if (failed(matchAndDropEnclosingPair<SubTensorOp, SubTensorInsertOp>(
sliceRef, subTensorOp, subTensorInsertOp, matchSubTensors)))
return failure();
// subtensor_insert must be used exactly by the terminator at index `idx`.
unsigned idx = candidate.getArgNumber() - /*#iv=*/1; // adjust for ForOp iv.
if (!subTensorInsertOp.result().hasOneUse() ||
terminator != *subTensorInsertOp.result().getUsers().begin() ||
terminator->getOperand(idx) != subTensorInsertOp.result()) {
LLVM_DEBUG(
DBGS() << "SubTensorInsertOp does not have a single terminator use "
"at the right index.\nDestructive update chain FAIL\n");
return failure();
}
// Maybe match vector transfer pair and update sliceRef.
// If we find one, the other must be present and match too.
auto matchTransfers = [&](vector::TransferReadOp read,
vector::TransferWriteOp write) {
return matchingVectorTransfersAtSource(read, write, subTensorOp.result());
};
if (failed(matchAndDropEnclosingPair<vector::TransferReadOp,
vector::TransferWriteOp>(
sliceRef, vectorTransferReadOp, vectorTransferWriteOp,
matchTransfers)) &&
(vectorTransferReadOp || vectorTransferWriteOp))
return failure();
// Commit what has been detected.
inPlaceActions.push_back(InPlaceAction{subTensorOp});
if (vectorTransferReadOp)
inPlaceActions.push_back(InPlaceAction{vectorTransferReadOp});
if (vectorTransferWriteOp)
inPlaceActions.push_back(InPlaceAction{vectorTransferWriteOp});
inPlaceActions.push_back(InPlaceAction{subTensorInsertOp});
inPlaceActions.push_back(InPlaceAction{terminator, {idx}});
return success();
}
/// Iterate over bbArgs of `parentOp` and determine if they are the root of a
/// destructive update chain such as:
/// ```
/// scf.for bbArg -> subtensor -> DAG of admissible inPlaceActions
/// -> subtensor_insert -> yield.
/// ```
/// Such a representation is related to traditional loop nest + memory analysis
/// but provides a simpler abstraction.
/// In traditional memory-based dependence analysis, one would need to analyze
/// all possible interleavings of possibly aliasing loads and stores in the
/// context of the k-common surrounding loops.
/// With scf.for + subtensor + subtensor_insert + yield, more ordering semantics
/// are available as well as dealiasing thanks to SSA use-def chains.
static void destructiveUpdateAnalysis(Operation *parentOp,
ArrayRef<BlockArgument> candidates) {
for (auto en : llvm::enumerate(candidates)) {
BlockArgument candidate = en.value();
if (!candidate.getType().isa<ShapedType>()) continue;
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LLVM_DEBUG(DBGS() << "Destructive update analysis on candidate: "
<< candidate << "\nof:\n"
<< *parentOp << "\n");
if (!livesInWritableMemoryLocation(candidate)) continue;
llvm::SetVector<Operation *> slice;
getForwardSlice(candidate, &slice, [&](Operation *op) {
// Skip any extra nesting between parentOp and op.
return op == parentOp || op->getBlock()->getParentOp() == parentOp;
});
LLVM_DEBUG(DBGS() << "Slice:\n");
for (auto *op : slice) LLVM_DEBUG(DBGS() << *op << "\n");
SmallVector<InPlaceAction> inPlaceActions;
inPlaceActions.reserve(slice.size());
ArrayRef<Operation *> sliceRef = slice.getArrayRef();
if (failed(detectDestructiveUpdatePattern(dyn_cast<scf::ForOp>(parentOp),
candidate, sliceRef,
inPlaceActions)) &&
failed(detectDestructiveUpdatePattern(
dyn_cast<FuncOp>(parentOp), candidate, sliceRef, inPlaceActions))) {
LLVM_DEBUG(DBGS() << "Failed to detect: Destructive update chain FAIL\n");
continue;
}
// Add the current op and add pattern eagerly to simplify implementation.
inPlaceActions.push_back(
{parentOp, {static_cast<unsigned int>(en.index())}});
for (auto &action : inPlaceActions) {
if (action.outputIndices.empty()) setInplace(action.op);
for (unsigned idx : action.outputIndices) setInplace(action.op, idx);
}
}
parentOp->walk([](Operation *op) {
if (isa<TensorLoadOp, TensorToMemrefOp>(op)) setInplace(op);
if (auto linalgOp = dyn_cast<LinalgOp>(op)) {
// For now, just check that the operand and corresponding result have
// 0 uses. In the future we can build a cost-model to take care of
// diamond dependences.
unsigned resultIdx = 0;
for (auto &opOperand : linalgOp.getOutputTensorsOpOperands()) {
if (opOperand->get().hasOneUse() &&
linalgOp->getResult(resultIdx).hasOneUse())
setInplace(op, opOperand->getOperandNumber());
++resultIdx;
}
}
});
}
static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
if (!sym) return nullptr;
return dyn_cast_or_null<FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
static void inplaceAnalysisFuncOpInternals(FuncOp funcOp) {
funcOp.walk([&](scf::ForOp forOp) {
iterativeOverwritesAnalysis(forOp, forOp.getRegionIterArgs());
});
iterativeOverwritesAnalysis(funcOp, funcOp.getArguments());
funcOp.walk([&](scf::ForOp forOp) {
destructiveUpdateAnalysis(forOp, forOp.getRegionIterArgs());
});
destructiveUpdateAnalysis(funcOp, funcOp.getArguments());
}
/// Analyse a `callOp` to a FuncOp and determine whether any of its tensor
/// operand could be safely written inplace after it is converted to buffer
/// form by a bufferization process. Iterate on the uses of callOp's operands
/// to determine whether all such uses dominate callOp. If any use of an
/// operand does not dominate `callOp`, this means that the operand tensor
/// value may be needed somewhere else and it is illegal to update in-place
/// after bufferization. Add a `kInPlaceAttrName` string attribute to `callOp`
/// to carry the result of this analysis until bufferization is completed. The
/// "meet" of all `kInPlaceAttrName` for all `callOp` to a given FuncOp
/// determines the `kInPlaceAttrName` for that FuncOp.
static void inplaceFunctionArgumentAnalysis(CallOpInterface callOp,
DominanceInfo &domInfo) {
FuncOp funcOp = getCalledFunction(callOp);
if (!funcOp) return;
if (llvm::none_of(callOp->getOperandTypes(),
[](Type t) { return t.isa<TensorType>(); }))
return;
LLVM_DEBUG(DBGS() << "Begin inplaceFunctionArgumentAnalysis within:\n"
<< *callOp->getParentOfType<FuncOp>()
<< "callOp: " << *callOp << "\n";);
for (OpOperand &opOperand : callOp->getOpOperands()) {
Value tensor = opOperand.get();
if (!tensor.getType().isa<TensorType>()) continue;
unsigned idx = opOperand.getOperandNumber();
LLVM_DEBUG(DBGS() << "tensor @idx=" << idx << ": " << tensor << "\n");
// For now, assume any use is a read.
// Write-only is a non-problem: will represent with shapes in the future.
// If any use of the tensor does not properly dominate callOp, we can't
// bufferize the tensor inplace.
InPlaceSpec callInPlace = InPlaceSpec::True;
for (auto &use : tensor.getUses()) {
Operation *user = use.getOwner();
if (domInfo.properlyDominates(user, callOp)) continue;
if (use.getOperandNumber() == idx) continue;
LLVM_DEBUG(DBGS() << "non-properly dominate user: " << *user << "\n");
callInPlace = InPlaceSpec::False;
break;
}
// CallOp instance can immediately determine whether it allows inplace.
setInplace(callOp, idx, callInPlace);
// FuncOp inplace is the meet of all the calls.
InPlaceSpec funcInPlace = getInplace(funcOp, idx);
if (funcInPlace == InPlaceSpec::False) continue;
setInplace(funcOp, idx, callInPlace);
}
LLVM_DEBUG(DBGS() << "End inplaceFunctionArgumentAnalysis within:\n"
<< *callOp->getParentOfType<FuncOp>()
<< "callOp: " << *callOp << "\n";);
}
//===----------------------------------------------------------------------===//
// Bufferization as simple BlockAndValueMapping rewrites / without
// conversions.
//===----------------------------------------------------------------------===//
/// Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
/// This works on mixed tensor + buffer Linalg ops: some results may have been
/// already bufferized by a previous destructive update bufferization.
/// Allocate the output buffers for the remaining tensor output operands of
/// the Linalg op. If the tensor is an "init" tensor (i.e. its value is
/// actually used in the payload region), we additionally copy the original
/// value into the newly allocated buffer.
static LogicalResult allocateBuffersForResults(
OpBuilder &b, Location loc, LinalgOp op,
SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm) {
// Lazily compute loopRanges.
SmallVector<Range, 4> loopRanges;
// Linalg invariant: output tensors and result match 1-1.
assert(op.getNumOutputTensors() == op->getNumResults());
for (auto &opOperand : op.getOutputOpOperands()) {
Value output = opOperand.get();
if (output.getType().isa<MemRefType>()) {
resultBuffers.push_back(output);
continue;
}
// If output tensor is marked inplace, just use the buffer.
if (getInplace(op, opOperand.getOperandNumber()) == InPlaceSpec::True) {
resultBuffers.push_back(lookup(bvm, output));
continue;
}
Value dimTensor = bvm.lookupOrDefault(output);
Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, dimTensor);
resultBuffers.push_back(alloc);
// Additionally, if the output buffer is used, clone its value for now.
if (op.payloadUsesValueFromOpOperand(&opOperand))
b.create<CopyOp>(loc, lookup(bvm, output), alloc);
}
map(bvm, op->getResults(), resultBuffers);
for (auto it : llvm::zip(op->getResults(), resultBuffers)) {
transferDimOpsToMemref(std::get<0>(it), std::get<1>(it));
}
return success();
}
// Non-conversion equivalent of the core MLIR Linalg bufferization patterns.
static void finalizeBufferAllocation(OpBuilder &b, LinalgOp op,
ValueRange inputs, ValueRange outputs,
BlockAndValueMapping &bvm) {
SmallVector<Value, 8> newOperands = inputs;
newOperands.append(outputs.begin(), outputs.end());
auto otherOperands = op.getAssumedNonShapedOperands();
newOperands.append(otherOperands.begin(), otherOperands.end());
Location loc = op.getLoc();
op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands);
// Replace the results of the old op with the new output buffers.
map(bvm, op.getOperation()->getResults(), outputs);
for (auto it : llvm::zip(op.getOperation()->getResults(), outputs)) {
transferDimOpsToMemref(std::get<0>(it), std::get<1>(it));
}
if (!op.hasTensorSemantics()) op->erase();
}
/// Generic conversion pattern that matches any LinalgOp. This avoids
/// template instantiating one pattern for each LinalgOp.
/// This works on mixed tensor + buffer Linalg ops: some results may have been
/// already bufferized by a previousdestructive update bufferization.
static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op,
BlockAndValueMapping &bvm) {
if (op.hasBufferSemantics()) return failure();
LLVM_DEBUG(DBGS() << "convertAnyLinalgOp: " << *op << "\n");
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
Location loc = op.getLoc();
SmallVector<Value, 2> newInputBuffers;
newInputBuffers.reserve(op.getNumInputs());
for (Value v : op.getInputs()) {
newInputBuffers.push_back(lookup(bvm, v));
}
SmallVector<Value, 2> newOutputBuffers;
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm)))
assert(false);
// Delegate to the linalg generic pattern.
if (auto genericOp = dyn_cast<GenericOp>(op.getOperation())) {
finalizeBufferAllocation(b, genericOp, newInputBuffers, newOutputBuffers,
bvm);
return success();
}
SmallVector<Value, 2> newResults;
for (OpOperand &outputOpOperand : op.getOutputOpOperands()) {
Value output = outputOpOperand.get();
if (output.getType().isa<MemRefType>()) continue;
auto tensorType = output.getType().cast<RankedTensorType>();
OpBuilder::InsertionGuard g(b);
b.setInsertionPointAfter(op);
Value tensor = b.create<TensorLoadOp>(
loc, tensorType,
newOutputBuffers[outputOpOperand.getOperandNumber() -
op.getNumInputs()]);
newResults.push_back(tensor);
map(bvm, tensor,
newOutputBuffers[outputOpOperand.getOperandNumber() -
op.getNumInputs()]);
}
// Can't just map.
// map(bvm, op.getOutputs(), newOutputBuffers);
// map(bvm, op->getResults(), newResults);
// Must explicitly push value out because conume ops are not guaranteed to
// pull the value from bvm (e.g. scf.for with core bufferization use
// conversion patterns).
op->replaceAllUsesWith(newResults);
finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm);
return success();
}
static LogicalResult convertTransferOp(OpBuilder &b,
VectorTransferOpInterface op,
BlockAndValueMapping &bvm) {
if (op.getShapedType().isa<MemRefType>()) return failure();
assert(op->getNumResults() == 1);
Value outputTensor = op->getResult(0);
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
Location loc = op.getLoc();
Value newInputBuffer = lookup(bvm, op.source());
if (auto tensorType =
op->getResult(0).getType().dyn_cast<RankedTensorType>()) {
Value tensor = bvm.lookupOrDefault(outputTensor);
Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, tensor);
map(bvm, op->getResult(0), alloc);
transferDimOpsToMemref(op->getResult(0), alloc);
}
// Replace the tensor operand.
if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
readOp.sourceMutable().assign(newInputBuffer);
} else {
auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
b.create<vector::TransferWriteOp>(
loc, writeOp.vector(), newInputBuffer, writeOp.indices(),
writeOp.permutation_map(),
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
Value tensor = b.create<TensorLoadOp>(
loc, writeOp.getResult(0).getType().cast<RankedTensorType>(),
newInputBuffer);
SmallVector<Value, 1> newResult(1, {tensor});
writeOp.replaceAllUsesWith(newResult);
map(bvm, tensor, newInputBuffer);
}
return success();
}
static LogicalResult convertInitTensorOp(OpBuilder &b,
InitTensorOp initTensorOp,
BlockAndValueMapping &bvm) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(initTensorOp);
Value alloc = createNewAllocDeallocPairForShapedValue(
b, initTensorOp->getLoc(), initTensorOp.result(), initTensorOp.sizes());
map(bvm, initTensorOp.result(), alloc);
return success();
}
static LogicalResult convertPadTensorOp(OpBuilder &b, PadTensorOp padTensorOp,
BlockAndValueMapping &bvm) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(padTensorOp);
auto tensorType = padTensorOp.result().getType().cast<RankedTensorType>();
auto sourceMemRef = lookup(bvm, padTensorOp.source());
auto sourceMemRefType = sourceMemRef.getType().cast<MemRefType>();
auto memRefType =
getContiguousMemRefType(tensorType, sourceMemRefType.getAffineMaps(),
sourceMemRefType.getMemorySpaceAsInt());
Value res =
b.create<MemRefCastOp>(padTensorOp.getLoc(), memRefType, sourceMemRef);
map(bvm, padTensorOp.result(), res);
return success();
}
static LogicalResult convertSubTensorInsertOp(
OpBuilder &b, SubTensorInsertOp subTensorInsertOp,
BlockAndValueMapping &bvm) {
Location loc = subTensorInsertOp.getLoc();
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(subTensorInsertOp);
Value dstMemref = lookup(bvm, subTensorInsertOp.dest());
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
Value srcMemref = lookup(bvm, subTensorInsertOp.source());
auto subviewMemRefType =
SubViewOp::inferRankReducedResultType(
subTensorInsertOp.getSourceType().getRank(), dstMemrefType,
subTensorInsertOp.getMixedOffsets(),
subTensorInsertOp.getMixedSizes(),
subTensorInsertOp.getMixedStrides())
.cast<MemRefType>();
// Take a subview of the dst.
Value subView = b.create<SubViewOp>(
loc, subviewMemRefType, dstMemref, subTensorInsertOp.getMixedOffsets(),
subTensorInsertOp.getMixedSizes(), subTensorInsertOp.getMixedStrides());
// Linalg op and vector.transfer_write producers directly write their output
// buffer. If the producer is not one of these ops or if it subtensor_insert
// is not marked inplace, we ened to copy.
bool isInPlaceProducer =
subTensorInsertOp.source().getDefiningOp<LinalgOp>() ||
subTensorInsertOp.source().getDefiningOp<vector::TransferWriteOp>();
if (!isInPlaceProducer || getInplace(subTensorInsertOp) != InPlaceSpec::True)
b.create<CopyOp>(subTensorInsertOp.getLoc(), srcMemref, subView);
Value tensor = b.create<TensorLoadOp>(
loc, subTensorInsertOp->getResult(0).getType(), dstMemref);
SmallVector<Value, 1> newResult(1, {tensor});
subTensorInsertOp->replaceAllUsesWith(newResult);
map(bvm, tensor, dstMemref);
return success();
}
static LogicalResult convertSubTensorOp(OpBuilder &b, SubTensorOp subTensor,
BlockAndValueMapping &bvm) {
Location loc = subTensor.getLoc();
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(subTensor);
Value srcMemref = lookup(bvm, subTensor.source());
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType = subTensor.result().getType().cast<RankedTensorType>();
auto subviewMemRefType =
SubViewOp::inferRankReducedResultType(
dstTensorType.getRank(), srcMemrefType, subTensor.getMixedOffsets(),
subTensor.getMixedSizes(), subTensor.getMixedStrides())
.cast<MemRefType>();
Value subView = b.create<SubViewOp>(
loc, subviewMemRefType, srcMemref, subTensor.getMixedOffsets(),
subTensor.getMixedSizes(), subTensor.getMixedStrides());
map(bvm, subTensor.result(), subView);
return success();
}
static LogicalResult convertTensorCastOp(OpBuilder &b, tensor::CastOp castOp,
BlockAndValueMapping &bvm) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
auto sourceMemRefType =
lookup(bvm, castOp.source()).getType().dyn_cast<MemRefType>();
Type memRefType;
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
if (tensorType.isa<UnrankedTensorType>()) {
memRefType = UnrankedMemRefType::get(
tensorType.getElementType(), sourceMemRefType.getMemorySpaceAsInt());
} else {
memRefType =
getContiguousMemRefType(tensorType, sourceMemRefType.getAffineMaps(),
sourceMemRefType.getMemorySpaceAsInt());
}
Value res = b.create<MemRefCastOp>(castOp.getLoc(), memRefType,
lookup(bvm, castOp.source()));
map(bvm, castOp.getResult(), res);
return success();
}
static void bufferizeFunctionCallBoundaries(FuncOp funcOp) {
// kResultFoldArgAttrName is set once funcOp is bufferized.
if (funcOp->getAttr(kResultFoldArgAttrName)) return;
SmallVector<int64_t> resultArgumentFolding(
funcOp.type().cast<FunctionType>().getNumResults(), -1);
LLVM_DEBUG(DBGS() << "Begin bufferizeFunctionCallBoundaries:\n" << funcOp);
// Take the terminator (assume the last block is the only one that has it).
auto returnOp = cast<ReturnOp>(funcOp.body().back().getTerminator());
for (OpOperand &returnOpOperand : returnOp->getOpOperands()) {
Value returnValue = returnOpOperand.get();
unsigned returnIndex = returnOpOperand.getOperandNumber();
if (!returnValue.getType().isa<RankedTensorType>()) continue;
// If returned value is a bbArg, it only folds if it is a function
// argument.
BlockArgument bbArg = returnValue.dyn_cast<BlockArgument>();
if (bbArg) {
if (returnValue == funcOp.getArgument(bbArg.getArgNumber()))
resultArgumentFolding[returnIndex] = bbArg.getArgNumber();
else
continue;
}
// Otherwise we look for tensor_load(tensor_to_memref(bbarg)).
auto tensorLoadOp = returnValue.getDefiningOp<TensorLoadOp>();
if (!tensorLoadOp) continue;
auto tensorToMemRefOp =
tensorLoadOp.memref().getDefiningOp<TensorToMemrefOp>();
if (!tensorToMemRefOp) continue;
// If returned value is a bbArg, it only folds if it is a function
// argument.
bbArg = tensorToMemRefOp.tensor().dyn_cast<BlockArgument>();
if (bbArg) {
if (bbArg == funcOp.getArgument(bbArg.getArgNumber()))
resultArgumentFolding[returnIndex] = bbArg.getArgNumber();
else
continue;
}
}
funcOp->setAttr(kResultFoldArgAttrName,
OpBuilder(funcOp).getI64ArrayAttr(resultArgumentFolding));
OpBuilder b(returnOp);
SmallVector<Value> returnValues;
for (auto en : enumerate(resultArgumentFolding)) {
LLVM_DEBUG(DBGS() << "return idx: " << en.index() << " folds on "
<< en.value() << "\n");
// Return value folds on some input.
if (en.value() >= 0) continue;
// Return value does not fold, add it to the new return op.
Value unfolded = returnOp->getOperand(en.index());
if (auto tensorLoadOp = unfolded.getDefiningOp<TensorLoadOp>()) {
unfolded = tensorLoadOp.memref();
for (Operation *user : llvm::make_early_inc_range(unfolded.getUsers()))
if (isa<DeallocOp>(user)) user->erase();
}
returnValues.push_back(unfolded);
llvm::errs() << "return val does not fold: " << returnValues.back() << "\n";
}
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
returnOp->erase();
auto argTypes = llvm::to_vector<4>(
llvm::map_range(funcOp.getArguments(), [](BlockArgument bbArg) -> Type {
// TODO: non-zero address space.
// TODO: layout information if relevant.
if (auto tensorType = bbArg.getType().dyn_cast<RankedTensorType>())
return getContiguousMemRefType(tensorType);
return bbArg.getType();
}));
funcOp.setType(FunctionType::get(funcOp->getContext(), argTypes,
ValueRange{returnValues}.getTypes()));
Block &frontBlock = funcOp.body().front();
for (unsigned idx = 0, e = frontBlock.getNumArguments(); idx < e; ++idx) {
auto bbArg = frontBlock.getArgument(0);
auto tensorType = bbArg.getType().dyn_cast<RankedTensorType>();
if (!tensorType) {
frontBlock.addArgument(bbArg.getType());
bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
} else {
// TODO: non-zero address space.
// TODO: layout information if relevant.
Value memref =
frontBlock.addArgument(getContiguousMemRefType(tensorType));
OpBuilder b(funcOp->getContext());
// No InsertionGuard needed here.
b.setInsertionPointToStart(&frontBlock);
Value tensor = b.create<TensorLoadOp>(funcOp->getLoc(), memref);
bbArg.replaceAllUsesWith(tensor);
}
frontBlock.eraseArgument(0);
}
LLVM_DEBUG(DBGS() << "End bufferizeFunctionCallBoundaries:\n" << funcOp);
}
/// Bufferize a single function call.
/// Look for the following pattern for each result to determine whether it can
/// fold onto an argument:
/// ```
/// func @foo(%A: tensor<...>, ..., %Z: tensor<...>) ->
/// (tensor<...>, ..., tensor<...>)
/// #inplace_attr_specification
/// {
/// %p = tensor_to_memref(%some_arg): ...
/// ... // uses of %p (read or writes)
/// %t = tensor_load %p: ...
/// return ..., %t, ...: ..., tensor<...>, ...
/// }
/// ```
static void bufferizeFunctionCall(CallOpInterface callOp,
DominanceInfo &domInfo) {
FuncOp funcOp = getCalledFunction(callOp);
if (!funcOp) return;
if (funcOp.body().empty()) return;
// Only bufferizes the first time `funcOp` is encountered.
bufferizeFunctionCallBoundaries(funcOp);
SmallVector<Value> newOperands;
for (Value v : callOp->getOperands()) {
if (!v.getType().isa<RankedTensorType>()) {
newOperands.push_back(v);
continue;
}
if (auto tensorLoadOp = v.getDefiningOp<TensorLoadOp>()) {
newOperands.push_back(tensorLoadOp.memref());
continue;
}
llvm::errs() << "operand: " << v << "\n";
llvm_unreachable("Operand does not come from a tensor_load");
}
assert(isa<CallOp>(callOp.getOperation()) && "expected a CallOp");
OpBuilder b(callOp);
Operation *newCallOp = b.create<CallOp>(
callOp.getLoc(), funcOp.sym_name(),
funcOp.type().cast<FunctionType>().getResults(), newOperands);
newCallOp->setAttrs(callOp.getAttrs());
int numFoldedArgsSoFar = 0;
for (unsigned callRetIdx = 0, e = callOp->getNumResults(); callRetIdx < e;
++callRetIdx) {
unsigned newCallReturnIdx = callRetIdx - numFoldedArgsSoFar;
auto maybeFoldedArgIndex = getResultFoldArgIndex(funcOp, callRetIdx);
if (maybeFoldedArgIndex) ++numFoldedArgsSoFar;
// If not a ranked tensor, no changes, just replace the new result.
if (!callOp->getResult(callRetIdx).getType().isa<RankedTensorType>()) {
assert(!maybeFoldedArgIndex);
callOp->getResult(callRetIdx)
.replaceAllUsesWith(newCallOp->getResult(newCallReturnIdx));
continue;
}
// If the old callOp result is a ranked tensor that does not fold on some
// input, then there must be an allocated return value.
// That value should be deallocated by the caller.
// That value should be lifted out of the callee at the first enclosing
// parallel scope. This lifting should be done to (the meet of) all
// callers before we can hoist the alloc out of the funcOp.
Value resultMemref = (maybeFoldedArgIndex)
? newOperands[*maybeFoldedArgIndex]
: newCallOp->getResult(newCallReturnIdx);
callOp->getResult(callRetIdx)
.replaceAllUsesWith(
b.create<TensorLoadOp>(callOp.getLoc(), resultMemref));
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(callOp->getBlock()->getTerminator());
// If function returns a memref, it must be freed.
if (!maybeFoldedArgIndex)
b.create<DeallocOp>(callOp.getLoc(), resultMemref);
}
callOp->erase();
}
//===----------------------------------------------------------------------===//
// Bufferization passes.
//===----------------------------------------------------------------------===//
// Transformations that run iteratively with bufferization.
void LinalgComprehensiveBufferizePass::runEnablingTransforms(FuncOp funcOp) {
if (failed(runPipeline(enablingPassPipeline, funcOp)))
return signalPassFailure();
(void)runPipeline(enablingPassPipeline, funcOp);
linalg::hoistRedundantVectorTransfers(funcOp);
linalg::hoistRedundantVectorTransfersOnTensor(funcOp);
}
void LinalgComprehensiveBufferizePass::bufferizeFuncOpInternals(FuncOp funcOp) {
LLVM_DEBUG(DBGS() << "Start BufferizeFuncOpInternals:\n" << funcOp);
OpBuilder b(funcOp->getContext());
BlockAndValueMapping bvm;
bool changed = true;
// It is likely overkill to do this in a loop with canonicalization and
// hoisting but until we stabilize bufferization, c'est la vie.
while (changed) {
changed = false;
runEnablingTransforms(funcOp);
// CSE changes the result of the analysis, need to compute/mark/invalidate
// at each iteration.
inplaceAnalysisFuncOpInternals(funcOp);
auto guard = llvm::make_scope_exit([&] {
funcOp.walk([&](Operation *op) { op->removeAttr(kInPlaceAttrName); });
});
funcOp.walk([&](Operation *operation) {
llvm::TypeSwitch<Operation *, void>(operation)
// TensorLoadOp is not allowed to just fold into the memref!
// If it may alias, it must clone.
.Case([&](TensorLoadOp op) {
// TODO: reduce amount of surprise.
if (auto tensorToMemRef =
op.memref().getDefiningOp<TensorToMemrefOp>()) {
// Folding is allowed thwn tensor_to_memref immediately
// precedes tensor_load -> no interleaved aliasing.
if (tensorToMemRef->getNextNode() == op) {
map(bvm, op.result(), op.memref());
changed = true;
}
// TODO: else clone.
}
})
.Case([&](TensorToMemrefOp op) {
// TODO: reduce amount of surprise.
Value repl = bvm.lookupOrDefault(op.tensor());
if (op.memref() != repl) {
op.memref().replaceAllUsesWith(repl);
op->erase();
}
})
.Case([&](InitTensorOp op) {
changed = succeeded(convertInitTensorOp(b, op, bvm));
})
.Case([&](SubTensorOp op) {
changed = succeeded(convertSubTensorOp(b, op, bvm));
})
.Case([&](SubTensorInsertOp op) {
changed = succeeded(convertSubTensorInsertOp(b, op, bvm));
})
.Case([&](tensor::CastOp op) {
changed = succeeded(convertTensorCastOp(b, op, bvm));
})
.Case([&](PadTensorOp op) {
changed = succeeded(convertPadTensorOp(b, op, bvm));
})
.Case([&](LinalgOp op) {
changed = succeeded(convertAnyLinalgOp(b, op, bvm));
})
.Case([&](VectorTransferOpInterface op) {
changed = succeeded(convertTransferOp(b, op, bvm));
});
});
LLVM_DEBUG(DBGS() << "BufferizeFuncOpInternals step:\n" << funcOp);
}
}
namespace mlir {
std::unique_ptr<Pass> createLinalgComprehensiveBufferizePass() {
return std::make_unique<LinalgComprehensiveBufferizePass>();
}
namespace linalg {
void registerLinalgComprehensiveBufferizePass() {
PassRegistration<LinalgComprehensiveBufferizePass> pass(
"linalg-comprehensive-bufferize-inplace",
"Perform all required bufferization incantations to convert code with "
"Linalg ops on tensors to buffers with inplace optimizations.");
}
} // namespace linalg
} // namespace mlir
void LinalgComprehensiveBufferizePass::runOnOperation() {
ModuleOp module = getOperation();
DominanceInfo domInfo(module);
module.walk([&](CallOpInterface callOp) {
inplaceFunctionArgumentAnalysis(callOp, domInfo);
});
module.walk([&](FuncOp funcOp) { bufferizeFuncOpInternals(funcOp); });
// Recompute domInfo.
domInfo = DominanceInfo(module);
module.walk(
[&](CallOpInterface callOp) { bufferizeFunctionCall(callOp, domInfo); });
PassManager pm(module.getContext());
pm.addPass(createCanonicalizerPass());
(void)pm.run(module);
// Cleanups and sanity checks.
module.walk([&](Operation *op) {
op->removeAttr(kInPlaceAttrName);
op->removeAttr(kResultFoldArgAttrName);
if (auto tensorLoadOp = dyn_cast<TensorLoadOp>(op)) {
if (tensorLoadOp.memref().getDefiningOp<TensorToMemrefOp>()) {
op->getParentOfType<ModuleOp>()->dump();
op->emitWarning(
"Most likely incorrect pattern: tensor_load(tensor_to_memref)");
abort();
}
}
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
for (auto result : callOp->getResults()) {
if (result.getType().isa<MemRefType>()) {
op->getParentOfType<ModuleOp>()->dump();
op->emitWarning(
"Most likely incorrect pattern: function returning memref -> "
"alloc needs to be hoisted out of function boundary");
abort();
}
}
}
});
}