blob: 9bb41929b12b4b6958b401850b9b5efe9a286852 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <deque>
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#define DEBUG_TYPE "tile-dispatch-using-interface"
namespace mlir::iree_compiler {
/// Helper method to adjust the interchange vector to match the iteration
/// domain.
static SmallVector<int64_t>
fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
size_t iterationDomainSize) {
SmallVector<int64_t> filledVector;
for (auto v : interchangeVector)
filledVector.push_back(v);
if (filledVector.size() < iterationDomainSize) {
auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
filledVector.append(range.begin(), range.end());
}
if (filledVector.size() > iterationDomainSize)
filledVector.resize(iterationDomainSize);
return filledVector;
}
/// Given the `lb` and `step` of a loop, return the lower bound and step to use
/// for a distributed loop. Replace the iteration domain to
/// - lb_partitioned = lb + procId * step
/// - step_partitioned = step * nprocs
static std::tuple<OpFoldResult, OpFoldResult>
getDistributeLBAndStep(OpBuilder &b, Location loc, OpFoldResult lb,
OpFoldResult step, OpFoldResult procId,
OpFoldResult nprocs) {
AffineExpr s0, s1, s2;
bindSymbols(b.getContext(), s0, s1, s2);
auto offsetMap = AffineMap::get(0, 3, {s0 + s1 * s2});
auto stepMap = AffineMap::get(0, 2, {s0 * s1});
OpFoldResult distributeLB = affine::makeComposedFoldedAffineApply(
b, loc, offsetMap, ArrayRef<OpFoldResult>{lb, procId, step});
OpFoldResult distributeStep = affine::makeComposedFoldedAffineApply(
b, loc, stepMap, ArrayRef<OpFoldResult>{step, nprocs});
return {distributeLB, distributeStep};
}
// Helper function to change arith.constant to i64 attribute.
static void changeArithCstToI64Attr(OpBuilder &b,
MutableArrayRef<OpFoldResult> constants) {
for (OpFoldResult &val : constants) {
if (auto dyn_cast = llvm::dyn_cast_if_present<Value>(val)) {
APInt intVal;
if (matchPattern(dyn_cast, m_ConstantInt(&intVal))) {
val = b.getI64IntegerAttr(intVal.getSExtValue());
}
}
}
}
//===----------------------------------------------------------------------===//
// TileDispatchUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
/// This implementation mirrors the implementation in
/// include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h and
/// lib/Dialect/SCF/Transforms/TileUsingInterface.cpp. It is adapted to do
/// distribution and also use `flow.dispatch.tensor.load/store` instead of
/// `tensor.extract_slice/insert_slice`.
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
/// the
/// tile processed within the inner most loop.
static SmallVector<scf::ForOp> generateTileLoopNest(
OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizeVals,
ArrayRef<linalg::DistributionMethod> distributionMethod,
ArrayRef<linalg::ProcInfo> procInfo, SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes, SmallVector<Value> &workgroupCount) {
assert(!loopRanges.empty() && "expected at least one loop range");
assert(loopRanges.size() == tileSizeVals.size() &&
"expected as many tile sizes as loop ranges");
assert(loopRanges.size() == distributionMethod.size() &&
"expected as many entries in distribution method list as number of "
"loop ranges");
OpBuilder::InsertionGuard guard(builder);
SmallVector<scf::ForOp> loops;
offsets.resize(loopRanges.size());
sizes.resize(loopRanges.size());
// The tile size to use (to avoid out of bounds access) is minimum of
// `tileSize` and `ub - iv`, where `iv` is the induction variable
// of the tiled loop.
AffineExpr s0, s1, d0, s2;
bindDims(builder.getContext(), d0);
bindSymbols(builder.getContext(), s0, s1, s2);
AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
auto createBoundedTileSize = [&](OpFoldResult iv, OpFoldResult tileSize,
OpFoldResult size) -> OpFoldResult {
if (isConstantIntValue(tileSize, 1)) {
return tileSize;
}
return affine::makeComposedFoldedAffineMin(
builder, loc, minMap, ArrayRef<OpFoldResult>{iv, tileSize, size});
};
unsigned procDim = 0;
AffineMap numIterationsMap =
AffineMap::get(0, 3, (s1 - s0).ceilDiv(s2), builder.getContext());
for (auto [idx, loopRange] : llvm::enumerate(loopRanges)) {
// Capturing structured bindings in lambdas is a c++20 feature, so we have
// to declare a local variable for it.
int index = idx;
OpFoldResult lb = loopRange.offset;
OpFoldResult ub = loopRange.size;
OpFoldResult step = tileSizeVals[index];
// No loops if tile size is zero. Set offset and size to the loop
// offset and size.
if (isConstantIntValue(tileSizeVals[index], 0)) {
offsets[index] = lb;
sizes[index] = ub;
continue;
}
auto method = distributionMethod[index];
if (method != linalg::DistributionMethod::None) {
Value numWorkgroups = getValueOrCreateConstantIndexOp(
builder, loc,
affine::makeComposedFoldedAffineApply(
builder, loc, numIterationsMap,
{loopRange.offset, loopRange.size, tileSizeVals[index]}));
workgroupCount.push_back(numWorkgroups);
std::tie(lb, step) = getDistributeLBAndStep(builder, loc, lb, step,
procInfo[procDim].procId,
procInfo[procDim].nprocs);
procDim++;
}
if (method == linalg::DistributionMethod::CyclicNumProcsEqNumIters) {
offsets[index] = lb;
sizes[index] = createBoundedTileSize(lb, tileSizeVals[index], ub);
continue;
}
Value lbVal = getValueOrCreateConstantIndexOp(builder, loc, lb);
Value ubVal = getValueOrCreateConstantIndexOp(builder, loc, ub);
Value stepVal = getValueOrCreateConstantIndexOp(builder, loc, step);
auto loop = builder.create<scf::ForOp>(
loc, lbVal, ubVal, stepVal, ValueRange{},
[&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
ValueRange /*iterArgs*/) {
sizes[index] = createBoundedTileSize(iv, tileSizeVals[index], ub);
builder.create<scf::YieldOp>(loc);
});
offsets[index] = loop.getInductionVar();
loops.push_back(loop);
builder.setInsertionPoint(loop.getBody()->getTerminator());
}
// Update the sizes if it contains arith.index with i64 attrs.
// TODO: tensor.extract_slice is unable to determine the
// result type if arith.constant is present. This is a workaround
// to ensure that the result type is determined.
changeArithCstToI64Attr(builder, sizes);
return loops;
}
/// Replace the `flow.dispatch.tensor.store` of the `untiledValue` with a tiled
/// `flow.dispatch.tensor.store` that writes only a tile of the result at
/// offsets given by `tiledOffsets` and sizes given by `tiledSizes`, using
/// `tiledValue` as the source.
static LogicalResult replaceStoresWithTiledVersion(
RewriterBase &rewriter, OpResult untiledValue, Value tiledValue,
ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
Block *innerLoopBody) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(innerLoopBody->getTerminator());
SmallVector<IREE::Flow::DispatchTensorStoreOp> storeOps;
for (OpOperand &use : untiledValue.getUses()) {
auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(use.getOwner());
if (storeOp && storeOp.getValue() == use.get()) {
storeOps.push_back(storeOp);
}
}
if (storeOps.empty())
return success();
if (storeOps.size() != 1) {
return rewriter.notifyMatchFailure(untiledValue.getOwner(),
"expected a single store for the op");
}
auto storeOp = storeOps.front();
SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> combinedOffsets, combinedSizes, combinedStrides;
SliceAndDynamicDims clonedSliceAndVals =
cloneOffsetsSizesAndStrides(rewriter, storeOp);
if (failed(affine::mergeOffsetsSizesAndStrides(
rewriter, storeOp.getLoc(), clonedSliceAndVals.offsets,
clonedSliceAndVals.sizes, clonedSliceAndVals.strides,
storeOp.getDroppedDims(), tileOffsets, tileSizes, tileStrides,
combinedOffsets, combinedSizes, combinedStrides))) {
return rewriter.notifyMatchFailure(
storeOp, "failed to create tiled flow.dispatch.tensor.store op");
}
rewriter.create<IREE::Flow::DispatchTensorStoreOp>(
storeOp.getLoc(), tiledValue, storeOp.getTarget(),
clonedSliceAndVals.dynamicDims, combinedOffsets, combinedSizes,
combinedStrides);
rewriter.eraseOp(storeOp);
return success();
}
/// Replace all `flow.dispatch.tensor.store` operations that use values produced
/// by `untiledOp` as source, with tiled stores, with tiled values produced by
/// `tiledOp`.
static LogicalResult replaceAllStoresWithTiledVersion(
RewriterBase &rewriter, TilingInterface untiledOp,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ValueRange tiledValues, Block *innerLoopBody) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(innerLoopBody->getTerminator());
for (auto [index, result] : llvm::enumerate(untiledOp->getResults())) {
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(untiledOp.getResultTilePosition(rewriter, index, offsets, sizes,
resultOffsets, resultSizes))) {
return rewriter.notifyMatchFailure(
untiledOp, "failed to rewrite destructive update");
}
if (failed(replaceStoresWithTiledVersion(
rewriter, llvm::cast<OpResult>(result), tiledValues[index],
resultOffsets, resultSizes, innerLoopBody))) {
return failure();
}
}
return success();
}
FailureOr<IREETilingResult>
tileDispatchUsingSCFFopOp(RewriterBase &rewriter, TilingInterface op,
linalg::LinalgTilingOptions options) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);
if (!options.tileSizeComputationFunction) {
return rewriter.notifyMatchFailure(
op, "missing tile size computation function");
}
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
Location loc = op.getLoc();
size_t numLoops = iterationDomain.size();
if (numLoops == 0) {
return IREETilingResult();
}
// 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
SmallVector<OpFoldResult> tileSizes =
getAsOpFoldResult(options.tileSizeComputationFunction(rewriter, op));
tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
IREETilingResult tilingResult;
tilingResult.tiledLoops.resize(numLoops, false);
AffineExpr s0, s1, s2, s3; // lb, ub, step, tileSize
bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
AffineExpr numTilesExprs = (s1 - s0).ceilDiv(s2 * s3);
for (auto [index, iteratorType, range, tileSize] :
llvm::enumerate(op.getLoopIteratorTypes(), iterationDomain, tileSizes)) {
// If distribution is specified, only parallel loops are tiled.
if (options.distribution && iteratorType != utils::IteratorType::parallel) {
continue;
}
// If tile size is 0, it isnt tiled.
if (isConstantIntValue(tileSize, 0)) {
continue;
}
// If number of tiles is statically know to be 1, the loop isnt tiled.
OpFoldResult numTiles = affine::makeComposedFoldedAffineApply(
rewriter, loc, numTilesExprs,
{range.offset, range.size, range.stride, tileSize});
if (isConstantIntValue(numTiles, 1)) {
continue;
}
tilingResult.tiledLoops.set(index);
}
if (!tilingResult.tiledLoops.any()) {
return IREETilingResult();
}
{
SmallVector<OpFoldResult> offsets, sizes;
SmallVector<Value> workgroupCount;
// If there is an interchange specified, permute the iteration domain and
// the tile sizes.
SmallVector<int64_t> interchangeVector;
if (!options.interchangeVector.empty()) {
interchangeVector = fillInterchangeVector(options.interchangeVector,
iterationDomain.size());
}
if (!interchangeVector.empty()) {
if (!isPermutationVector(interchangeVector)) {
return rewriter.notifyMatchFailure(
op, "invalid intechange vector, not a permutation of the entire "
"iteration space");
}
applyPermutationToVector(iterationDomain, interchangeVector);
applyPermutationToVector(tileSizes, interchangeVector);
}
// If there is distribution specified, adjust the loop ranges. Note that
// permutation happens *before* distribution.
SmallVector<linalg::DistributionMethod> distributionMethods(
iterationDomain.size(), linalg::DistributionMethod::None);
SmallVector<linalg::ProcInfo> procInfo;
if (options.distribution) {
SmallVector<Range> parallelLoopRanges;
for (auto loopIdx : llvm::seq<unsigned>(0, numLoops)) {
if (tilingResult.tiledLoops.test(loopIdx)) {
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
OpFoldResult parallelLoopStep = affine::makeComposedFoldedAffineApply(
rewriter, loc, s0 * s1,
{iterationDomain[loopIdx].stride, tileSizes[loopIdx]});
Range r = {iterationDomain[loopIdx].offset,
iterationDomain[loopIdx].size, parallelLoopStep};
parallelLoopRanges.emplace_back(std::move(r));
}
}
procInfo =
options.distribution->procInfo(rewriter, loc, parallelLoopRanges);
unsigned partitionedLoopIdx = 0;
for (auto loopIdx : llvm::seq<unsigned>(0, numLoops)) {
if (!tilingResult.tiledLoops.test(loopIdx)) {
continue;
}
distributionMethods[loopIdx] =
procInfo[partitionedLoopIdx++].distributionMethod;
}
}
// 3. Materialize an empty loop nest that iterates over the tiles. These
// loops for now do not return any values even if the original operation has
// results.
tilingResult.loops = generateTileLoopNest(
rewriter, loc, iterationDomain, tileSizes, distributionMethods,
procInfo, offsets, sizes, workgroupCount);
if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
applyPermutationToVector(offsets, inversePermutation);
applyPermutationToVector(sizes, inversePermutation);
}
LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
llvm::errs() << "LoopNest shell :\n";
tilingResult.loops.front().dump();
llvm::errs() << "\n";
}
});
// 4. Generate the tiled implementation within the inner most loop.
if (!tilingResult.loops.empty())
rewriter.setInsertionPoint(
tilingResult.loops.back().getBody()->getTerminator());
FailureOr<TilingResult> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
if (failed(tiledImplementation)) {
return rewriter.notifyMatchFailure(
op, "failed to generate tiled implementation");
}
std::swap(tilingResult.tiledOps, tiledImplementation->tiledOps);
std::swap(tilingResult.tiledValues, tiledImplementation->tiledValues);
LLVM_DEBUG({
if (!tilingResult.loops.empty()) {
llvm::errs() << "After tiled implementation :\n";
tilingResult.loops.front().dump();
llvm::errs() << "\n";
}
});
std::swap(tilingResult.tileOffsets, offsets);
std::swap(tilingResult.tileSizes, sizes);
std::swap(tilingResult.workgroupCount, workgroupCount);
}
if (op->getNumResults() == 0) {
rewriter.eraseOp(op);
return tilingResult;
}
Block *body = tilingResult.loops.empty()
? op->getBlock()
: tilingResult.loops.back().getBody();
// Rewrite all `flow.dispatch.tensor.store` operation with tiled version
// of the store. Its valid to this for all stores of the root untiled op.
if (failed(replaceAllStoresWithTiledVersion(
rewriter, op, tilingResult.tileOffsets, tilingResult.tileSizes,
tilingResult.tiledValues, body))) {
return failure();
}
return tilingResult;
}
//===----------------------------------------------------------------------===//
// tileAndFuseDispatchUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
/// Find all producers to fuse and return them in sorted order;
static SmallVector<Operation *> getAllFusableProducers(TilingInterface op) {
llvm::SetVector<Operation *> producers;
std::deque<Operation *> worklist;
worklist.push_back(op);
while (!worklist.empty()) {
Operation *currOp = worklist.front();
worklist.pop_front();
for (OpOperand &operand : currOp->getOpOperands()) {
Operation *definingOp = operand.get().getDefiningOp();
auto tilingInterfaceProducer =
dyn_cast_or_null<TilingInterface>(definingOp);
if (!tilingInterfaceProducer || isa<tensor::PadOp>(definingOp) ||
producers.count(tilingInterfaceProducer)) {
continue;
}
worklist.push_back(tilingInterfaceProducer);
producers.insert(tilingInterfaceProducer);
}
}
SmallVector<Operation *> sortedOps(producers.begin(), producers.end());
mlir::computeTopologicalSorting(sortedOps);
return sortedOps;
}
/// Return all slices that are used to access a tile of the producer. Assume
/// that `tiledOps` are in "reverse" order of their appearance in the IR.
static SmallVector<tensor::ExtractSliceOp>
getAllFusableProducerUses(Operation *untiledOp,
ArrayRef<Operation *> tiledOps) {
SmallVector<tensor::ExtractSliceOp> sliceOps;
for (auto tiledOp : llvm::reverse(tiledOps)) {
for (OpOperand &operand : llvm::reverse(tiledOp->getOpOperands())) {
auto sliceOp = operand.get().getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp || sliceOp.getSource().getDefiningOp() != untiledOp)
continue;
sliceOps.push_back(sliceOp);
}
}
return sliceOps;
}
FailureOr<IREETileAndFuseResult>
tileAndFuseDispatchUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
linalg::LinalgTilingOptions tilingOptions) {
IREETileAndFuseResult tileAndFuseResult;
auto fusableProducers = getAllFusableProducers(op);
// Apply the tiling pattern.
FailureOr<IREETilingResult> tilingResult =
tileDispatchUsingSCFFopOp(rewriter, op, tilingOptions);
if (failed(tilingResult)) {
return failure();
}
tileAndFuseResult.tiledAndFusedOps = tilingResult->tiledOps;
tileAndFuseResult.loops = tilingResult->loops;
tileAndFuseResult.workgroupCount = tilingResult->workgroupCount;
// If there is no tiling then there is nothing to do for fusion.
if (!tilingResult->tiledLoops.any()) {
return tileAndFuseResult;
}
// Get all ops to tile and fuse.
auto fusableProducersRef = llvm::ArrayRef(fusableProducers);
while (!fusableProducersRef.empty()) {
auto fusableProducer = cast<TilingInterface>(fusableProducersRef.back());
fusableProducersRef = fusableProducersRef.drop_back();
// Find a slice that is used to access the producer. Get all the slice ops.
// It is assumed that the slice ops are returned in-order, so that you
// could use the first slice as the insertion point.
auto sliceOps = getAllFusableProducerUses(
fusableProducer, tileAndFuseResult.tiledAndFusedOps);
for (auto sliceOp : sliceOps) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(sliceOp);
// Generate the tiled implementation of the producer.
OpResult untiledValue = llvm::cast<OpResult>(sliceOp.getSource());
FailureOr<TilingResult> swapSliceResult =
tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp,
untiledValue);
if (failed(swapSliceResult) || swapSliceResult->tiledValues.size() != 1) {
return rewriter.notifyMatchFailure(sliceOp,
"fusion along slice op failed");
}
Block *body = tileAndFuseResult.loops.empty()
? op->getBlock()
: tileAndFuseResult.loops.back().getBody();
// 2c. Finally replace any `flow.dispatch.tensor.store` operation with
// tiled version of the operation. It is only valid to do this under
// the above assumption that the producer and consumer share the loops
// that can be tiled.
if (failed(replaceStoresWithTiledVersion(
rewriter, untiledValue, swapSliceResult->tiledValues[0],
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), body))) {
return failure();
}
rewriter.replaceOp(sliceOp, swapSliceResult->tiledValues[0]);
tileAndFuseResult.tiledAndFusedOps.append(swapSliceResult->tiledOps);
}
}
return tileAndFuseResult;
}
namespace {
//===----------------------------------------------------------------------===//
// SwapExtractSliceWithDispatchTensorLoad
//===----------------------------------------------------------------------===//
/// Pattern to swap `flow.dispatch.tensor.load` -> `tensor.extract_slice` with
/// `flow.dispatch.tensor.load` of the slice.
struct SwapExtractSliceWithDispatchTensorLoad
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto loadOp =
sliceOp.getSource().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
if (!loadOp)
return failure();
SmallVector<OpFoldResult> combinedOffsets, combinedSizes, combinedStrides;
if (failed(affine::mergeOffsetsSizesAndStrides(
rewriter, loadOp.getLoc(), loadOp, sliceOp, loadOp.getDroppedDims(),
combinedOffsets, combinedSizes, combinedStrides))) {
return rewriter.notifyMatchFailure(
sliceOp, "failed to fold offsets, sizes and strides with load op");
}
rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>(
sliceOp, sliceOp.getType(), loadOp.getSource(), loadOp.getSourceDims(),
combinedOffsets, combinedSizes, combinedStrides);
return success();
}
};
//===----------------------------------------------------------------------===//
// SwapExtractSliceWithTensorEmpty
//===----------------------------------------------------------------------===//
/// Pattern to swap `empty` -> `tensor.extract_slice` with
/// `empty` of the slice.
struct SwapExtractSliceWithTensorEmpty
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto emptyTensorOp = sliceOp.getSource().getDefiningOp<tensor::EmptyOp>();
if (!emptyTensorOp)
return failure();
SmallVector<OpFoldResult> mixedSizes = sliceOp.getMixedSizes();
if (mixedSizes.size() != sliceOp.getType().getRank()) {
SmallVector<OpFoldResult> rankReducedMixedSizes;
rankReducedMixedSizes.reserve(sliceOp.getType().getRank());
auto droppedDims = sliceOp.getDroppedDims();
for (auto [index, size] : llvm::enumerate(mixedSizes)) {
if (droppedDims.test(index))
continue;
rankReducedMixedSizes.push_back(size);
}
std::swap(mixedSizes, rankReducedMixedSizes);
}
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
sliceOp, mixedSizes, sliceOp.getType().getElementType());
return success();
}
};
} // namespace
void populateTileAndDistributeToWorkgroupsCleanupPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.insert<SwapExtractSliceWithDispatchTensorLoad,
SwapExtractSliceWithTensorEmpty>(context);
}
} // namespace mlir::iree_compiler