Handle convolution with swapped input/output in kernel (#3496)
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index e6fe664..ada669a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -38,6 +38,11 @@
llvm::cl::desc("Extract padding attributes from conv op"),
llvm::cl::init(true));
+static llvm::cl::opt<bool> orderConvFeatures(
+ "iree-flow-order-conv-features",
+ llvm::cl::desc("Guarantees input/output features ordered for conv kernel"),
+ llvm::cl::init(true));
+
static llvm::cl::opt<bool> conv1x1toDot(
"iree-flow-1x1-conv-to-dot",
llvm::cl::desc("Rewrites mhlo.conv with 1x1 filter into mhlo.dot"),
@@ -153,6 +158,63 @@
}
};
+struct OrderConvFeatureDimensions : public OpRewritePattern<mhlo::ConvOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(mhlo::ConvOp op,
+ PatternRewriter &rewriter) const override {
+ auto dimensionNumbers = op.dimension_numbers();
+ auto inputFeatureDimension =
+ dimensionNumbers.kernel_input_feature_dimension().getInt();
+ auto outputFeatureDimension =
+ dimensionNumbers.kernel_output_feature_dimension().getInt();
+
+ // Input feature dimension is first.
+ if (inputFeatureDimension <= outputFeatureDimension) {
+ return failure();
+ }
+
+ auto rhs = op.rhs();
+ auto rhsType = rhs.getType().cast<ShapedType>();
+ if (!rhsType.hasRank()) return failure();
+
+ // Convert the permutation for the transpose.
+ llvm::SmallVector<int64_t, 4> permutation;
+ permutation.resize(rhsType.getRank());
+ std::iota(permutation.begin(), permutation.end(), 0);
+ permutation[inputFeatureDimension] = outputFeatureDimension;
+ permutation[outputFeatureDimension] = inputFeatureDimension;
+
+ llvm::SmallVector<int64_t, 4> new_shape(rhsType.getShape().begin(),
+ rhsType.getShape().end());
+ std::swap(new_shape[inputFeatureDimension],
+ new_shape[outputFeatureDimension]);
+
+ auto newInputTy =
+ RankedTensorType::get(new_shape, rhsType.getElementType());
+ auto transposeRhs = rewriter.create<mhlo::TransposeOp>(
+ op.getLoc(), newInputTy, rhs, rewriter.getI64TensorAttr(permutation));
+
+ auto newDimensionNumbers = mhlo::ConvDimensionNumbers::get(
+ dimensionNumbers.input_batch_dimension(),
+ dimensionNumbers.input_feature_dimension(),
+ dimensionNumbers.input_spatial_dimensions(),
+ dimensionNumbers.kernel_output_feature_dimension(),
+ dimensionNumbers.kernel_input_feature_dimension(),
+ dimensionNumbers.kernel_spatial_dimensions(),
+ dimensionNumbers.output_batch_dimension(),
+ dimensionNumbers.output_feature_dimension(),
+ dimensionNumbers.output_spatial_dimensions(), op.getContext());
+
+ SmallVector<Value, 2> operands = {op.lhs(), transposeRhs};
+ mhlo::ConvOp newConv = rewriter.create<mhlo::ConvOp>(
+ op.getLoc(), op.getType(), operands, op.getAttrs());
+ newConv.dimension_numbersAttr(newDimensionNumbers);
+
+ rewriter.replaceOp(op, {newConv.getResult()});
+ return success();
+ }
+};
+
class ExtractReduceWindowOpPaddingAttributes
: public OpRewritePattern<mhlo::ReduceWindowOp> {
public:
@@ -719,6 +781,9 @@
if (extractPadFromConv) {
patterns.insert<ExtractConvOpPaddingAttributes>(context);
}
+ if (orderConvFeatures) {
+ patterns.insert<OrderConvFeatureDimensions>(context);
+ }
if (conv1x1toDot) {
patterns.insert<Lower1x1ConvolutionToDotOp>(context);
}
diff --git a/iree/test/e2e/xla_ops/convolution.mlir b/iree/test/e2e/xla_ops/convolution.mlir
index 6ac1719..72b9458 100644
--- a/iree/test/e2e/xla_ops/convolution.mlir
+++ b/iree/test/e2e/xla_ops/convolution.mlir
@@ -30,6 +30,38 @@
return
}
+func @conv2d_nopadding_reorder_features() attributes { iree.module.export } {
+ %inputs = iree.unfoldable_constant dense<[[
+ [[ 1.0, 2.0], [ 3.0, 4.0], [ 5.0, 6.0], [ 7.0, 8.0], [ 9.0, 10.0]],
+ [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0], [19.0, 20.0]],
+ [[21.0, 22.0], [23.0, 24.0], [25.0, 26.0], [27.0, 28.0], [29.0, 30.0]],
+ [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0], [39.0, 40.0]]]]> : tensor<1x4x5x2xf32>
+ %weights = iree.unfoldable_constant dense<[
+ [[[ 1.0, 2.0]], [[ 3.0, 4.0]]],
+ [[[ 5.0, 6.0]], [[ 7.0, 8.0]]],
+ [[[ 9.0, 10.0]], [[11.0, 12.0]]]]> : tensor<3x2x1x2xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 3 : i64,
+ kernel_output_feature_dimension = 2 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+ feature_group_count = 1 : i64,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x5x2xf32>, tensor<3x2x1x2xf32>) -> tensor<1x2x3x1xf32>
+ check.expect_almost_eq_const(%res, dense<[[
+ [[1310.0],[1466.0],[1622.0]],
+ [[2090.0],[2246.0],[2402.0]]
+ ]]> : tensor<1x2x3x1xf32>) : tensor<1x2x3x1xf32>
+ return
+}
+
func @conv2d_1452x3221_same() attributes { iree.module.export } {
%inputs = iree.unfoldable_constant dense<[[
[[ 1.0, 2.0], [ 3.0, 4.0], [ 5.0, 6.0], [ 7.0, 8.0], [ 9.0, 10.0]],