blob: 2b8f8ec8565b06ea941a10144da096db262329b0 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/LinalgExt/LinalgExtBufferization.h"
#include <mlir/IR/BuiltinOps.h>
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
using bufferization::AnalysisState;
using bufferization::BufferizableOpInterface;
using bufferization::BufferizationState;
using bufferization::BufferRelation;
using bufferization::getMemRefType;
using bufferization::replaceOpWithBufferizedValues;
using bufferization::replaceOpWithNewBufferizedOp;
using tensor::ExtractSliceOp;
/// Return the destinations that an InParallelOp is inserting into. One per
/// ParallelInsertSliceOp.
static SmallVector<OpOperand *> getInsertionDest(InParallelOp inParallelOp) {
Operation *terminator = inParallelOp.region().front().getTerminator();
auto performConcOp = dyn_cast<PerformConcurrentlyOp>(terminator);
assert(performConcOp && "expected PerformConcurrentlyOp as terminator");
SmallVector<OpOperand *> result;
performConcOp.walk([&](ParallelInsertSliceOp insertOp) {
result.push_back(&insertOp->getOpOperand(1) /*dest*/);
});
return result;
}
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace LinalgExt {
/// Bufferization of InParallelOp. This also bufferizes the terminator of the
/// region. There are op interfaces for the terminators (PerformConcurrentlyOp
/// and ParallelInsertSliceOp), but these are only used during analysis. Not
/// for bufferization.
struct InParallelOpInterface
: public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
InParallelOp> {
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
const AnalysisState &state) const {
// Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
auto inParallelOp = cast<InParallelOp>(op);
return {getInsertionDest(inParallelOp)[opResult.getResultNumber()]};
}
bool isMemoryWrite(Operation *op, OpResult opResult,
const AnalysisState &state) const {
// This op is a memory write. Stop lookup here to avoid finding false
// conflicts involving this op and one of the ops in the region. This is
// similar to how scf.if ops are analyzed.
return true;
}
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &b,
BufferizationState &state) const {
OpBuilder::InsertionGuard g(b);
auto inParallelOp = cast<InParallelOp>(op);
Block *body = &inParallelOp.region().front();
Operation *oldTerminator = body->getTerminator();
assert(isa<PerformConcurrentlyOp>(oldTerminator) &&
"unexpected terminator");
// Gather new results of the InParallelOp.
SmallVector<Value> newResults;
for (OpResult opResult : inParallelOp->getOpResults()) {
SmallVector<OpOperand *> insertDestOperands =
state.getAnalysisState().getAliasingOpOperand(opResult);
assert(insertDestOperands.size() == 1 &&
"expected exactly one aliasing OpOperand");
// Insert copies right before the PerformConcurrentlyOp terminator. They
// should not be inside terminator (which would be the default insertion
// point).
Value buffer = *state.getBuffer(
b, *insertDestOperands.front(), /*forceInPlace=*/false,
/*customCopyInsertionPoint=*/oldTerminator);
newResults.push_back(buffer);
Value destTensor = insertDestOperands.front()->get();
// Replace all uses of the insert dest tensor inside the InParallelOp
// with the result buffer.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(body);
Value toTensorOp =
b.create<bufferization::ToTensorOp>(inParallelOp.getLoc(), buffer);
for (OpOperand &use : destTensor.getUses())
if (body->findAncestorOpInBlock(*use.getOwner()))
// This is a use inside the InParallelOp.
use.set(toTensorOp);
}
// Create new InParallelOp without any results.
TypeRange newResultTypes;
auto newInParallelOp = b.create<InParallelOp>(
inParallelOp.getLoc(), newResultTypes, inParallelOp.num_threads());
// Delete terminator.
newInParallelOp.getBody()->getTerminator()->erase();
// Move over block contents of the old op.
b.mergeBlocks(inParallelOp.getBody(), newInParallelOp.getBody(),
{newInParallelOp.getBody()->getArgument(0)});
// Bufferize terminator.
auto performConcurrentlyOp =
cast<PerformConcurrentlyOp>(newInParallelOp.getBody()->getTerminator());
b.setInsertionPoint(performConcurrentlyOp);
WalkResult walkResult =
performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
Location loc = insertOp.getLoc();
Type srcType = getMemRefType(
insertOp.source().getType().cast<RankedTensorType>(),
state.getOptions());
Type destType =
getMemRefType(insertOp.dest().getType().cast<RankedTensorType>(),
state.getOptions());
// ParallelInsertSliceOp bufferizes to a copy.
auto srcMemref = b.create<bufferization::ToMemrefOp>(
loc, srcType, insertOp.source());
auto destMemref = b.create<bufferization::ToMemrefOp>(
loc, destType, insertOp.dest());
Value subview = b.create<memref::SubViewOp>(
loc, destMemref, insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place.
if (failed(createMemCpy(b, insertOp.getLoc(), srcMemref, subview,
state.getOptions())))
return WalkResult::interrupt();
b.eraseOp(insertOp);
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return failure();
// Replace the op.
replaceOpWithBufferizedValues(b, op, newResults);
return success();
}
};
/// Nothing to do for PerformConcurrentlyOp.
struct PerformConcurrentlyOpInterface
: public BufferizableOpInterface::ExternalModel<
PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
LogicalResult bufferize(Operation *op, RewriterBase &b,
BufferizationState &state) const {
assert(false && "op does not have any tensor OpOperands / OpResults");
return failure();
}
};
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
static bool areEquivalentExtractSliceOps(const AnalysisState &state,
ExtractSliceOp st,
ParallelInsertSliceOp sti) {
if (!st || !sti)
return false;
if (st != sti &&
!state.areEquivalentBufferizedValues(st.source(), sti.dest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
return true;
}
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
ParallelInsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
return true;
return false;
};
return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
condition);
}
/// Analysis of ParallelInsertSliceOp.
struct ParallelInsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
if (&opOperand != &op->getOpOperand(1) /*dest*/)
return {};
// ParallelInsertSliceOp itself has no results. Tensors are returned via
// the parent op.
auto inParallelOp = op->getParentOfType<InParallelOp>();
assert(inParallelOp &&
"could not find valid owner of parallel_insert_slice");
// The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
// of the parent InParallelOp.
Block *block = op->getBlock();
unsigned int opIdx = 0;
for (ParallelInsertSliceOp insertOp :
block->getOps<ParallelInsertSliceOp>()) {
if (insertOp.getOperation() == op)
break;
++opIdx;
}
assert(opIdx < inParallelOp->getNumResults() &&
"could not find op inside terminator op");
return {inParallelOp->getResult(opIdx)};
}
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &b,
BufferizationState &state) const {
// Will be bufferized as part of InParallelOp.
return failure();
}
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
// the code.
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const AnalysisState &state) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
// being read by uRead, this is not a conflict.
//
// In the above example:
// uRead = OpOperand 1 (%t) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
//
// The read of %t does not conflict with the write of the FillOp
// (same aliases!) because the area that the FillOp operates on is
// exactly the one that is *not* read via %t.
return true;
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
// InsertSliceOp is writing.
//
// In the above example:
// uRead = OpOperand 0 (%1) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
return true;
}
// If uConflictingWrite is an InsertSliceOp...
if (auto insertSliceOp =
dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// %3 = vector.transfer_read %1, %cst
//
// In the above example:
// uRead = OpOperand 0 (%1) of vector.transfer_read
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
// lastWrite = %1
//
// This is not a conflict because the InsertSliceOp overwrites the
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
state.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.source()) &&
hasMatchingExtractSliceOp(state, insertSliceOp.source(),
insertSliceOp))
return true;
return false;
}
};
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
void mlir::iree_compiler::IREE::LinalgExt::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<InParallelOp, InParallelOpInterface>();
registry
.addOpInterface<PerformConcurrentlyOp, PerformConcurrentlyOpInterface>();
registry
.addOpInterface<ParallelInsertSliceOp, ParallelInsertSliceOpInterface>();
}