blob: 8ef1c4e532990ff2042f9f8936d28d9bea959687 [file]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===--------------- FusionPreprocessing.cpp ------------------------------===//
//
// Miscellaneous patterns run before fusion.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler::DispatchCreation {
#define GEN_PASS_DEF_FUSIONPREPROCESSINGPASS
#include "iree/compiler/DispatchCreation/Passes.h.inc"
namespace {
//===----------------------------------------------------------------------===//
// ElementwiseOpInterchangePattern
//===----------------------------------------------------------------------===//
// If possible, interchange indexing maps to make input maps all identity.
struct ElementwiseOpInterchangePattern final
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 ||
genericOp.getNumDpsInputs() == 0)
return failure();
// All input maps must be equal and non-identity. All maps, including
// output, must be be permutations. Permutation maps are checked by
// isElementwise but may be removed.
AffineMap inputMap = genericOp.getIndexingMapsArray().front();
auto *initOperand = genericOp.getDpsInitOperand(0);
if (inputMap.isIdentity() || !inputMap.isPermutation() ||
!genericOp.getMatchingIndexingMap(initOperand).isPermutation()) {
return failure();
}
for (auto *operand : genericOp.getDpsInputOperands()) {
if (genericOp.getMatchingIndexingMap(operand) != inputMap) {
return failure();
}
}
// Make all inputs identity.
ArrayRef<AffineExpr> exprs = inputMap.getResults();
auto perm = llvm::map_to_vector(exprs, [](AffineExpr e) -> unsigned {
return cast<AffineDimExpr>(e).getPosition();
});
return linalg::interchangeGenericOp(rewriter, genericOp, perm);
}
};
//===----------------------------------------------------------------------===//
// FoldSuccessiveTensorInsertSliceOps
//===----------------------------------------------------------------------===//
/// Pattern to fold
///
/// ```
/// %0 = linalg.fill ins(%cst : )
/// %1 = tensor.insert_slice %a into %0
/// %2 = linalg.fill ins(%cst : )
/// %3 = tensor.insert_slice %1 into %2
/// ```
///
/// to
///
/// ```
/// %2 = linalg.fill ins(%cst : )
/// %3 = tensor.insert_slice %a into %2
/// ```
struct FoldSuccessiveTensorInsertSliceOps final
: public OpRewritePattern<tensor::InsertSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto sourceInsertSlice =
sliceOp.getSource().getDefiningOp<tensor::InsertSliceOp>();
if (!sourceInsertSlice) {
return failure();
}
auto sourceSliceFillOp =
sourceInsertSlice.getDest().getDefiningOp<linalg::FillOp>();
auto destSliceFillOp = sliceOp.getDest().getDefiningOp<linalg::FillOp>();
if (!sourceSliceFillOp || !destSliceFillOp) {
return rewriter.notifyMatchFailure(
sliceOp, "dest of both insert_slices expected to be fill operations");
}
if (sourceSliceFillOp.getDpsInputOperand(0)->get() !=
destSliceFillOp.getDpsInputOperand(0)->get()) {
return rewriter.notifyMatchFailure(
sliceOp, "dest of both insert_slices expected "
"to be fill operation with same value");
}
auto isAllConstantOne = [](OpFoldResult ofr) {
return isConstantIntValue(ofr, 1);
};
if (!llvm::all_of(sliceOp.getMixedStrides(), isAllConstantOne) ||
!llvm::all_of(sliceOp.getMixedStrides(), isAllConstantOne)) {
return rewriter.notifyMatchFailure(
sliceOp, "unhandled non-unit strides of slices");
}
SmallVector<OpFoldResult> sourceSliceOffsets =
sourceInsertSlice.getMixedOffsets();
SmallVector<OpFoldResult> destSliceOffsets = sliceOp.getMixedOffsets();
AffineExpr d0, d1;
bindDims(rewriter.getContext(), d0, d1);
AffineExpr addExpr = d0 + d1;
SmallVector<OpFoldResult> offsets = llvm::map_to_vector(
llvm::zip_equal(sourceSliceOffsets, destSliceOffsets), [&](auto it) {
return affine::makeComposedFoldedAffineApply(
rewriter, sliceOp.getLoc(), addExpr,
{std::get<0>(it), std::get<1>(it)});
});
SmallVector<OpFoldResult> sizes = sourceInsertSlice.getMixedSizes();
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
sliceOp, sourceInsertSlice.getSource(), sliceOp.getDest(), offsets,
sizes, strides);
return success();
}
};
struct FusionPreprocessingPass final
: public impl::FusionPreprocessingPassBase<FusionPreprocessingPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<ElementwiseOpInterchangePattern,
FoldSuccessiveTensorInsertSliceOps>(&getContext());
// Fold away `tensor.dim` operations that can be resolved in terms of its
// operand shapes.
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
} // namespace mlir::iree_compiler::DispatchCreation