Handle convolution with swapped input/output in kernel (#3496)

diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index e6fe664..ada669a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -38,6 +38,11 @@
     llvm::cl::desc("Extract padding attributes from conv op"),
     llvm::cl::init(true));
 
+static llvm::cl::opt<bool> orderConvFeatures(
+    "iree-flow-order-conv-features",
+    llvm::cl::desc("Guarantees input/output features ordered for conv kernel"),
+    llvm::cl::init(true));
+
 static llvm::cl::opt<bool> conv1x1toDot(
     "iree-flow-1x1-conv-to-dot",
     llvm::cl::desc("Rewrites mhlo.conv with 1x1 filter into mhlo.dot"),
@@ -153,6 +158,63 @@
   }
 };
 
+struct OrderConvFeatureDimensions : public OpRewritePattern<mhlo::ConvOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(mhlo::ConvOp op,
+                                PatternRewriter &rewriter) const override {
+    auto dimensionNumbers = op.dimension_numbers();
+    auto inputFeatureDimension =
+        dimensionNumbers.kernel_input_feature_dimension().getInt();
+    auto outputFeatureDimension =
+        dimensionNumbers.kernel_output_feature_dimension().getInt();
+
+    // Input feature dimension is first.
+    if (inputFeatureDimension <= outputFeatureDimension) {
+      return failure();
+    }
+
+    auto rhs = op.rhs();
+    auto rhsType = rhs.getType().cast<ShapedType>();
+    if (!rhsType.hasRank()) return failure();
+
+    // Convert the permutation for the transpose.
+    llvm::SmallVector<int64_t, 4> permutation;
+    permutation.resize(rhsType.getRank());
+    std::iota(permutation.begin(), permutation.end(), 0);
+    permutation[inputFeatureDimension] = outputFeatureDimension;
+    permutation[outputFeatureDimension] = inputFeatureDimension;
+
+    llvm::SmallVector<int64_t, 4> new_shape(rhsType.getShape().begin(),
+                                            rhsType.getShape().end());
+    std::swap(new_shape[inputFeatureDimension],
+              new_shape[outputFeatureDimension]);
+
+    auto newInputTy =
+        RankedTensorType::get(new_shape, rhsType.getElementType());
+    auto transposeRhs = rewriter.create<mhlo::TransposeOp>(
+        op.getLoc(), newInputTy, rhs, rewriter.getI64TensorAttr(permutation));
+
+    auto newDimensionNumbers = mhlo::ConvDimensionNumbers::get(
+        dimensionNumbers.input_batch_dimension(),
+        dimensionNumbers.input_feature_dimension(),
+        dimensionNumbers.input_spatial_dimensions(),
+        dimensionNumbers.kernel_output_feature_dimension(),
+        dimensionNumbers.kernel_input_feature_dimension(),
+        dimensionNumbers.kernel_spatial_dimensions(),
+        dimensionNumbers.output_batch_dimension(),
+        dimensionNumbers.output_feature_dimension(),
+        dimensionNumbers.output_spatial_dimensions(), op.getContext());
+
+    SmallVector<Value, 2> operands = {op.lhs(), transposeRhs};
+    mhlo::ConvOp newConv = rewriter.create<mhlo::ConvOp>(
+        op.getLoc(), op.getType(), operands, op.getAttrs());
+    newConv.dimension_numbersAttr(newDimensionNumbers);
+
+    rewriter.replaceOp(op, {newConv.getResult()});
+    return success();
+  }
+};
+
 class ExtractReduceWindowOpPaddingAttributes
     : public OpRewritePattern<mhlo::ReduceWindowOp> {
  public:
@@ -719,6 +781,9 @@
     if (extractPadFromConv) {
       patterns.insert<ExtractConvOpPaddingAttributes>(context);
     }
+    if (orderConvFeatures) {
+      patterns.insert<OrderConvFeatureDimensions>(context);
+    }
     if (conv1x1toDot) {
       patterns.insert<Lower1x1ConvolutionToDotOp>(context);
     }
diff --git a/iree/test/e2e/xla_ops/convolution.mlir b/iree/test/e2e/xla_ops/convolution.mlir
index 6ac1719..72b9458 100644
--- a/iree/test/e2e/xla_ops/convolution.mlir
+++ b/iree/test/e2e/xla_ops/convolution.mlir
@@ -30,6 +30,38 @@
   return
 }
 
+func @conv2d_nopadding_reorder_features() attributes { iree.module.export } {
+  %inputs = iree.unfoldable_constant dense<[[
+      [[ 1.0,  2.0], [ 3.0,  4.0], [ 5.0,  6.0], [ 7.0,  8.0], [ 9.0, 10.0]],
+      [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0], [19.0, 20.0]],
+      [[21.0, 22.0], [23.0, 24.0], [25.0, 26.0], [27.0, 28.0], [29.0, 30.0]],
+      [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0], [39.0, 40.0]]]]> : tensor<1x4x5x2xf32>
+  %weights = iree.unfoldable_constant dense<[
+      [[[ 1.0,  2.0]], [[ 3.0,  4.0]]],
+      [[[ 5.0,  6.0]], [[ 7.0,  8.0]]],
+      [[[ 9.0, 10.0]], [[11.0, 12.0]]]]> : tensor<3x2x1x2xf32>
+  %res = "mhlo.convolution"(%inputs, %weights) {
+        batch_group_count = 1 : i64,
+        dimension_numbers = {
+          input_batch_dimension = 0 : i64,
+          input_feature_dimension = 3 : i64,
+          input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+          kernel_input_feature_dimension = 3 : i64,
+          kernel_output_feature_dimension = 2 : i64,
+          kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+          output_batch_dimension = 0 : i64,
+          output_feature_dimension = 3 : i64,
+          output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+        feature_group_count = 1 : i64,
+        rhs_dilation = dense<1> : tensor<2xi64>,
+        window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x5x2xf32>, tensor<3x2x1x2xf32>) -> tensor<1x2x3x1xf32>
+  check.expect_almost_eq_const(%res, dense<[[
+      [[1310.0],[1466.0],[1622.0]],
+      [[2090.0],[2246.0],[2402.0]]
+  ]]> : tensor<1x2x3x1xf32>) : tensor<1x2x3x1xf32>
+  return
+}
+
 func @conv2d_1452x3221_same() attributes { iree.module.export } {
   %inputs = iree.unfoldable_constant dense<[[
       [[ 1.0,  2.0], [ 3.0,  4.0], [ 5.0,  6.0], [ 7.0,  8.0], [ 9.0, 10.0]],