blob: bbaedc2dcf711a26585a662f564052e03ae57007 [file] [log] [blame] [edit]
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-preprocessing-convert-conv-filter-to-channels-last"
namespace mlir::iree_compiler::Preprocessing {
#define GEN_PASS_DEF_CONVERTCONVFILTERTOCHANNELSLASTPASS
#include "iree/compiler/Preprocessing/Common/Passes.h.inc"
static AffineMap applyPermutationToResults(AffineMap map,
ArrayRef<int64_t> perm) {
unsigned numDims = map.getNumDims();
ArrayRef<AffineExpr> mapResults = map.getResults();
SmallVector<AffineExpr> exprs;
for (int i = 0, e = perm.size(); i < e; ++i) {
exprs.push_back(mapResults[perm[i]]);
}
return AffineMap::get(numDims, map.getNumSymbols(), exprs, map.getContext());
}
static Value createTransposeOp(RewriterBase &rewriter, Location loc,
Value tensor, ArrayRef<int64_t> perm) {
SmallVector<OpFoldResult> dimSizes =
tensor::getMixedSizes(rewriter, loc, tensor);
applyPermutationToVector(dimSizes, perm);
auto tensorType = cast<RankedTensorType>(tensor.getType());
auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, dimSizes,
tensorType.getElementType());
return linalg::TransposeOp::create(rewriter, loc, tensor, emptyTensor, perm)
.getResult()[0];
}
static LogicalResult
convertConvFilterToTargetLayout(linalg::Conv2DNhwcHwcfOp convOp,
RewriterBase &rewriter,
SmallVector<int64_t> &perm) {
Location loc = convOp.getLoc();
Value input = convOp.getInputs()[0];
Value filter = convOp.getInputs()[1];
Value output = convOp.getOutputs()[0];
AffineMap inputMap = convOp.getIndexingMapsArray()[0];
AffineMap filterMap = convOp.getIndexingMapsArray()[1];
AffineMap outputMap = convOp.getIndexingMapsArray()[2];
AffineMap transposedFilterMap = applyPermutationToResults(filterMap, perm);
Value transposedFilter = createTransposeOp(rewriter, loc, filter, perm);
SmallVector<utils::IteratorType> iterators = convOp.getIteratorTypesArray();
auto genericOp = linalg::GenericOp::create(
rewriter, loc, output.getType(), ValueRange{input, transposedFilter},
output, ArrayRef<AffineMap>{inputMap, transposedFilterMap, outputMap},
iterators);
// Reuse the same payload as the original convolution op.
rewriter.inlineRegionBefore(convOp->getRegion(0), genericOp.getRegion(),
genericOp.getRegion().begin());
rewriter.replaceOp(convOp, genericOp->getResults());
return success();
}
namespace {
struct ConvertHwcfToHwfc : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
using Base::Base;
LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> perm = {0, 1, 3, 2};
return convertConvFilterToTargetLayout(convOp, rewriter, perm);
}
};
struct ConvertHwcfToFhwc : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
using Base::Base;
LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> perm = {3, 0, 1, 2};
return convertConvFilterToTargetLayout(convOp, rewriter, perm);
}
};
struct ConvertGenericChwfToFhwc : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
auto linalgOp = cast<linalg::LinalgOp>(op.getOperation());
if (!linalgOp || !linalg::isaConvolutionOpInterface(linalgOp)) {
return failure();
}
FailureOr<mlir::linalg::ConvolutionDimensions> convolutionDims =
mlir::linalg::inferConvolutionDims(linalgOp);
if (failed(convolutionDims)) {
return failure();
}
OpOperand *input = linalgOp.getDpsInputOperand(0);
OpOperand *filter = linalgOp.getDpsInputOperand(1);
OpOperand *output = linalgOp.getDpsInitOperand(0);
AffineMap inputMap = linalgOp.getMatchingIndexingMap(input);
AffineMap filterMap = linalgOp.getMatchingIndexingMap(filter);
AffineMap outputMap = linalgOp.getMatchingIndexingMap(output);
Value inputVal = input->get();
Value filterVal = filter->get();
Value outputVal = output->get();
ArrayRef<int64_t> inputShape =
cast<ShapedType>(inputVal.getType()).getShape();
ArrayRef<int64_t> filterShape =
cast<ShapedType>(filterVal.getType()).getShape();
ArrayRef<int64_t> outputShape =
cast<ShapedType>(outputVal.getType()).getShape();
// TODO(vivian): Once the matmul shape check below is dropped, the
// dynamic-shape check can also be removed.
if (ShapedType::isDynamicShape(inputShape) ||
ShapedType::isDynamicShape(filterShape) ||
ShapedType::isDynamicShape(outputShape)) {
return failure();
}
auto getDimPositions = [&](ArrayRef<unsigned> dims, const AffineMap &map) {
SmallVector<int64_t> positions;
for (auto dim : dims) {
for (auto [idx, e] : llvm::enumerate(map.getResults())) {
if (e.isFunctionOfDim(dim)) {
positions.push_back(idx);
}
}
}
return positions;
};
// Only transpose when the input channel is the last dimension of conv
// input.
SmallVector<int64_t> cInputPos =
getDimPositions(convolutionDims->inputChannel, inputMap);
if (cInputPos.back() != inputShape.size() - 1) {
return failure();
}
// Only transpose when the filter is `CHWF` layout.
SmallVector<int64_t> fFilterPos =
getDimPositions(convolutionDims->outputChannel, filterMap);
SmallVector<int64_t> cFilterPos =
getDimPositions(convolutionDims->inputChannel, filterMap);
SmallVector<int64_t> kFilterPos =
getDimPositions(convolutionDims->filterLoop, filterMap);
int64_t fPos = fFilterPos.back();
int64_t cPos = cFilterPos.back();
int64_t kPos = kFilterPos.back();
if (cPos > kPos || fPos != filterShape.size() - 1) {
return failure();
}
// Don't transpose if it is a matmul and the input shape is small.
// TODO(vivian): Solve the fusion of transpose op and remove this check.
SmallVector<int64_t> imagePos =
getDimPositions(convolutionDims->outputImage, outputMap);
SmallVector<int64_t> batchPos =
getDimPositions(convolutionDims->batch, outputMap);
SmallVector<int64_t> mPos = imagePos;
mPos.append(batchPos.begin(), batchPos.end());
auto getProduct = [](ArrayRef<int64_t> shape, ArrayRef<int64_t> pos) {
return llvm::accumulate(pos, int64_t{1}, [&](int64_t a, int64_t idx) {
return a * shape[idx];
});
};
int64_t mSize = getProduct(outputShape, mPos);
int64_t nSize = getProduct(filterShape, fFilterPos);
int64_t kSize = getProduct(filterShape, cFilterPos);
int64_t filterProd = getProduct(filterShape, kFilterPos);
bool smallShape = mSize < 384 || nSize < 384 || kSize < 384;
if (filterProd == 1 && smallShape) {
return failure();
}
// Swap the input and output channel dimension.
SmallVector<int64_t> perm =
llvm::to_vector(llvm::seq<int64_t>(0, filterShape.size()));
std::swap(perm[cPos], perm[fPos]);
Location loc = linalgOp.getLoc();
AffineMap transposedFilterMap = applyPermutationToResults(filterMap, perm);
Value transposedFilter = createTransposeOp(rewriter, loc, filterVal, perm);
// Insert compute_barrier.start to avoid propagation of reshape ops and
// undesirable fusion.
auto barrierStartOp = IREE::TensorExt::ComputeBarrierStartOp::create(
rewriter, loc, transposedFilter);
SmallVector<utils::IteratorType> iterators =
linalgOp.getIteratorTypesArray();
auto genericOp = linalg::GenericOp::create(
rewriter, loc, outputVal.getType(),
ValueRange{inputVal, barrierStartOp.getResult()}, outputVal,
ArrayRef<AffineMap>{inputMap, transposedFilterMap, outputMap},
iterators);
// Reuse the same payload as the original convolution op.
rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
genericOp.getRegion().begin());
// Reorder the indexing dimensions so that the input channel loops appears
// after the filter loops.
unsigned numParallelLoop = genericOp.getNumParallelLoops();
SmallVector<unsigned> interchange =
llvm::to_vector(llvm::seq<unsigned>(0, numParallelLoop));
interchange.append(convolutionDims->filterLoop.begin(),
convolutionDims->filterLoop.end());
interchange.append(convolutionDims->inputChannel.begin(),
convolutionDims->inputChannel.end());
FailureOr<linalg::GenericOp> reorderOp =
linalg::interchangeGenericOp(rewriter, genericOp, interchange);
if (failed(reorderOp))
return failure();
rewriter.replaceOp(linalgOp, reorderOp->getResults());
return success();
}
};
class ConvertConvFilterToChannelsLastPass
: public iree_compiler::Preprocessing::impl::
ConvertConvFilterToChannelsLastPassBase<
ConvertConvFilterToChannelsLastPass> {
public:
using iree_compiler::Preprocessing::impl::
ConvertConvFilterToChannelsLastPassBase<
ConvertConvFilterToChannelsLastPass>::
ConvertConvFilterToChannelsLastPassBase;
void runOnOperation() override {
auto op = getOperation();
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
if (filterLayout == "hwfc") {
LDBG() << "Converting filter layout to hwfc.";
patterns.add<ConvertHwcfToHwfc>(context);
} else if (filterLayout == "fhwc") {
LDBG() << "Converting filter layout to fhwc.";
patterns.add<ConvertHwcfToFhwc, ConvertGenericChwfToFhwc>(context);
} else {
LDBG() << "convert-filter-to-channels-last pass didn't apply since an "
"unsupported layout is given. Please use hwfc or fhwc as pass "
"filter-layout option.";
return signalPassFailure();
}
if (failed(applyPatternsGreedily(op, std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
} // namespace mlir::iree_compiler::Preprocessing