blob: 0305b4a4e4587285b21e46ec496d78dc2e94c238 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
namespace {
// Converts linalg.conv_2d_input_nhwc_filter_nhwc op to linalg.matmul
class Convert1x1ConvolutionMatmulOp
: public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
public:
using OpRewritePattern<linalg::Conv2DNhwcHwcfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
ShapedType inputShapeType =
convOp.getInputOperand(0)->get().getType().cast<ShapedType>();
ShapedType filterShapeType =
convOp.getInputOperand(1)->get().getType().cast<ShapedType>();
ShapedType outputShapeType =
convOp.getOutputOperand(0)->get().getType().cast<ShapedType>();
auto inputShape = inputShapeType.getShape();
auto filterShape = filterShapeType.getShape();
auto outputShape = outputShapeType.getShape();
if (filterShape[0] != 1 || filterShape[1] != 1) return failure();
// TODO(ataei): Support conversion to linalg.batch_matmul.
if (inputShape[0] != 1) return failure();
if (!llvm::all_of(convOp.strides(), [](APInt element) {
return element.getSExtValue() == 1;
}))
return failure();
if (!llvm::all_of(convOp.dilations(), [](APInt element) {
return element.getSExtValue() == 1;
}))
return failure();
SmallVector<ReassociationIndices, 4> reassociationIndices = {{0, 1, 2},
{3}};
auto reshapedInputType =
RankedTensorType::get({inputShape[1] * inputShape[2], inputShape[3]},
inputShapeType.getElementType());
auto reshapedFilterType = RankedTensorType::get(
{filterShape[2], filterShape[3]}, filterShapeType.getElementType());
auto reshapedOutputType =
RankedTensorType::get({outputShape[1] * outputShape[2], outputShape[3]},
outputShapeType.getElementType());
Value input = convOp.getInputOperand(0)->get();
Value filter = convOp.getInputOperand(1)->get();
Value output = convOp.getOutputOperand(0)->get();
auto loc = convOp.getLoc();
Value reshapedInput = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, reshapedInputType, input, reassociationIndices);
Value reshapedFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, reshapedFilterType, filter, reassociationIndices);
Value reshapedOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, reshapedOutputType, output, reassociationIndices);
auto matmulResult = rewriter.create<linalg::MatmulOp>(
loc, reshapedOutputType, ArrayRef<Value>{reshapedInput, reshapedFilter},
ArrayRef<Value>{reshapedOutput});
auto reshapedResult = rewriter.create<linalg::TensorExpandShapeOp>(
loc, outputShapeType, matmulResult.getResults()[0],
reassociationIndices);
rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
return success();
}
};
struct ConvertConv2D1x1ConvToMatmulPass
: public ConvertConv2D1x1ConvToMatmulBase<
ConvertConv2D1x1ConvToMatmulPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
OwningRewritePatternList patterns(&getContext());
patterns.insert<Convert1x1ConvolutionMatmulOp>(context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
std::unique_ptr<OperationPass<mlir::FuncOp>>
createConvertConv2D1x1ToMatmulPass() {
return std::make_unique<ConvertConv2D1x1ConvToMatmulPass>();
}
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir