Explicitly ordering transpose and reshape patterns when legalizing mhlo.dot_general

1. RankReducedDotGeneral Assume batchingDims at the starting position and the parallelDims is consecutive
2. after parallelDims reduce Also need to reduce rank

[Solution]:
move TransposeGenericDotGeneral to the front of RankReducedDotGeneral
reduce rank after reduce parallelDims

Fixes https://github.com/google/iree/issues/7272
diff --git a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
index ca32815..abf39f2 100644
--- a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
+++ b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
@@ -493,11 +493,16 @@
 // inserts transposes so the dot_general always has the form:
 // {batch_dims, parallel_dims, contraction_dims}.
 //   {batch_dims, contraction_dims, parallel_dims}
-class TransposeGenericDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> {
+// After that, batch_dims, contraction_dims, parallel_dims are
+// in consecutive order and not spliting the domain. This pattern inserts
+// reshapes to collapse consecutive reduction and parallel dims to always
+// generate a rank-3 dot_general op.
+class TransposeReshapeGenericDotGeneral
+    : public OpRewritePattern<mhlo::DotGeneralOp> {
  public:
   using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
 
-  Value TransposeIfNonConsecutive(OpBuilder b, Location loc, Value src,
+  Value TransposeIfNonConsecutive(OpBuilder &b, Location loc, Value src,
                                   ArrayRef<int64_t> targetOrder) const {
     if (isConsecutive(targetOrder)) return src;
     auto type = src.getType().cast<RankedTensorType>();
@@ -510,6 +515,23 @@
         b.getI64TensorAttr(targetOrder));
   }
 
+  Value ReshapeIfMorethan3D(OpBuilder &b, Location loc, Value src,
+                            size_t dimsBorder0, size_t dimsBorder1) const {
+    auto type = src.getType().cast<RankedTensorType>();
+    if (type.getRank() <= 3) return src;
+    auto shape = type.getShape();
+    SmallVector<int64_t, 4> result_shape = {
+        std::accumulate(shape.begin(), shape.begin() + dimsBorder0, 1,
+                        std::multiplies<int64_t>()),
+        std::accumulate(shape.begin() + dimsBorder0,
+                        shape.begin() + dimsBorder1, 1,
+                        std::multiplies<int64_t>()),
+        std::accumulate(shape.begin() + dimsBorder1, shape.end(), 1,
+                        std::multiplies<int64_t>())};
+    return b.create<mhlo::ReshapeOp>(
+        loc, RankedTensorType::get(result_shape, type.getElementType()), src);
+  }
+
   LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
                                 PatternRewriter &rewriter) const override {
     auto lhsShapeType = op.lhs().getType().dyn_cast<RankedTensorType>();
@@ -559,167 +581,53 @@
                                           lhsTargetOrder);
     Value rhs = TransposeIfNonConsecutive(rewriter, op.getLoc(), op.rhs(),
                                           rhsTargetOrder);
-    if (lhs == op.lhs() && rhs == op.rhs()) return failure();
 
+    // The dimensions of this will always be transposed into {batch_dims,
+    // parallel_dims, contraction_dims}, and the
+    // following logic is based on this assumption.
+    // TODO(#7443): If we consider transpose performance, the above assumptions
+    // may not be true.
     int64_t numLhsContractionDims = lhsContractingDims.size();
     int64_t lhsContractionBase = lhsShapeType.getRank() - numLhsContractionDims;
     int64_t rhsContractionBase = rhsBatchingDims.size();
     int64_t numRhsContractionDims =
         rhsContractionBase + rhsContractingDims.size();
-    auto lhsBatchingDimsAttr =
-        llvm::to_vector<4>(llvm::seq<int64_t>(0, lhsBatchingDims.size()));
-    auto rhsBatchingDimsAttr =
-        llvm::to_vector<4>(llvm::seq<int64_t>(0, rhsBatchingDims.size()));
-    auto lhsContractingDimsAttr = llvm::to_vector<4>(
-        llvm::seq<int64_t>(lhsContractionBase, lhsShapeType.getRank()));
-    auto rhsContractingDimsAttr = llvm::to_vector<4>(
-        llvm::seq<int64_t>(rhsContractionBase, numRhsContractionDims));
+
+    lhs = ReshapeIfMorethan3D(rewriter, op.getLoc(), lhs,
+                              rhsBatchingDims.size(), lhsContractionBase);
+    rhs = ReshapeIfMorethan3D(rewriter, op.getLoc(), rhs,
+                              rhsBatchingDims.size(), numRhsContractionDims);
+
+    if (lhs == op.lhs() && rhs == op.rhs()) return failure();
+
     auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get(
-        rewriter.getContext(), lhsBatchingDimsAttr, rhsBatchingDimsAttr,
-        lhsContractingDimsAttr, rhsContractingDimsAttr);
+        rewriter.getContext(), /*lhsBatchingDimensions=*/0,
+        /*rhsBatchingDimensions=*/0,
+        /*lhsContractingDimensions=*/2, /*rhsContractingDimensions=*/1);
+    auto lhsNewType = lhs.getType().cast<RankedTensorType>();
+    auto rhsNewType = rhs.getType().cast<RankedTensorType>();
+
+    // if lhs's shape or rhs's shape has collapsed, we need reshape the result
+    bool needReshapeResult = lhsNewType.getRank() < lhsShapeType.getRank() ||
+                             rhsNewType.getRank() < rhsShapeType.getRank();
+    // batching、lhs parallel、rhs parallel this order is a convension
+    SmallVector<int64_t, 4> newShape = {lhsNewType.getShape()[0],
+                                        lhsNewType.getShape()[1],
+                                        rhsNewType.getShape()[2]};
+    auto newResultType =
+        needReshapeResult
+            ? RankedTensorType::get(newShape, resultType.getElementType())
+            : op.getType();
 
     Value result = rewriter.create<mhlo::DotGeneralOp>(
-        op.getLoc(), op.getType(), lhs, rhs, dimensionNumbers,
-        op.precision_configAttr());
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-// Rewrite mhlo.dot_general to operate on rank-3 tensors when reduction dims are
-// in consecutive order and not spliting the domain. This pattern inserts
-// reshapes to collapse consecutive reduction and parallel dims to always
-// generate a rank-3 dot_general op.
-class RankReducedDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> {
- public:
-  using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
-                                PatternRewriter &rewriter) const override {
-    auto lhsShapeType = op.lhs().getType().dyn_cast<ShapedType>();
-    auto rhsShapeType = op.rhs().getType().dyn_cast<ShapedType>();
-    auto resultType = op.getResult().getType().dyn_cast<ShapedType>();
-
-    if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
-    if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape())
-      return failure();
-    if (resultType.getRank() <= 3) return failure();
-
-    mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers();
-    auto lhsBatchingDims =
-        llvm::to_vector<4>(dimNumbers.getLhsBatchingDimensions());
-    auto rhsBatchingDims =
-        llvm::to_vector<4>(dimNumbers.getRhsBatchingDimensions());
-    auto lhsContractingDims =
-        llvm::to_vector<4>(dimNumbers.getLhsContractingDimensions());
-    auto rhsContractingDims =
-        llvm::to_vector<4>(dimNumbers.getRhsContractingDimensions());
-
-    if (lhsBatchingDims.empty() || rhsBatchingDims.empty()) return failure();
-
-    llvm::sort(lhsBatchingDims);
-    llvm::sort(lhsContractingDims);
-    llvm::sort(rhsBatchingDims);
-    llvm::sort(rhsContractingDims);
-
-    auto isDomainSplit = [](ArrayRef<int64_t> shape,
-                            ArrayRef<int64_t> batchingDims,
-                            ArrayRef<int64_t> contractingDims) {
-      // Batching and contracting are contiguous.
-      if ((contractingDims.front() - batchingDims.back()) == 1) return false;
-      // Contracting dims are inner most.
-      if (contractingDims.back() == (shape.size() - 1)) return false;
-      return true;
-    };
-
-    if (!isConsecutive(lhsBatchingDims) || !isConsecutive(lhsContractingDims) ||
-        !isConsecutive(rhsBatchingDims) || !isConsecutive(rhsContractingDims))
-      return failure();
-
-    if (isDomainSplit(lhsShapeType.getShape(), lhsBatchingDims,
-                      lhsContractingDims) ||
-        isDomainSplit(rhsShapeType.getShape(), rhsBatchingDims,
-                      rhsContractingDims))
-      return failure();
-
-    // Collapsing shape into a rank-3 tensor, returns newCollabsedShape
-    // contraction and parallel dim indices.
-    auto computeCollapsedShape = [](ArrayRef<int64_t> shape,
-                                    ArrayRef<int64_t> batchingDims,
-                                    ArrayRef<int64_t> contractingDims) {
-      auto newRank =
-          shape.size() - batchingDims.size() - contractingDims.size() + 2;
-      auto batchingSize = std::accumulate(
-          batchingDims.begin(), batchingDims.end(), 1,
-          [shape](const int64_t accum, const int64_t index) -> int64_t {
-            return accum * shape[index];
-          });
-      auto contractingSize = std::accumulate(
-          contractingDims.begin(), contractingDims.end(), 1,
-          [shape](const int64_t accum, const int64_t index) -> int64_t {
-            return accum * shape[index];
-          });
-
-      int parallelDimIndex, contractingDimIndex, parallelDimSize = 1;
-      if (contractingDims.front() - batchingDims.back() > 1) {
-        parallelDimIndex = 1;
-        contractingDimIndex = 2;
-        for (int i = batchingDims.back() + 1; i < contractingDims.front();
-             ++i) {
-          parallelDimSize *= shape[i];
-        }
-      } else {
-        contractingDimIndex = 1;
-        parallelDimIndex = 2;
-        for (int i = contractingDims.back() + 1; i < shape.size(); ++i) {
-          parallelDimSize *= shape[i];
-        }
-      }
-      llvm::SmallVector<int64_t, 4> newShape(newRank);
-      newShape[0] = batchingSize;
-      newShape[contractingDimIndex] = contractingSize;
-      newShape[parallelDimIndex] = parallelDimSize;
-      return std::make_tuple(newShape, contractingDimIndex, parallelDimIndex);
-    };
-
-    int lhsContractingDimIndex, rhsContractingDimIndex, lhsParallelDimIndex,
-        rhsParallelDimIndex;
-    SmallVector<int64_t, 4> lhsNewShape, rhsNewShape;
-    std::tie(lhsNewShape, lhsContractingDimIndex, lhsParallelDimIndex) =
-        computeCollapsedShape(lhsShapeType.getShape(), lhsBatchingDims,
-                              lhsContractingDims);
-
-    std::tie(rhsNewShape, rhsContractingDimIndex, rhsParallelDimIndex) =
-        computeCollapsedShape(rhsShapeType.getShape(), rhsBatchingDims,
-                              rhsContractingDims);
-    SmallVector<int64_t, 4> resultNewShape = {lhsNewShape[0],
-                                              lhsNewShape[lhsParallelDimIndex],
-                                              rhsNewShape[rhsParallelDimIndex]};
-    Type dotGeneralResultType =
-        RankedTensorType::get(resultNewShape, resultType.getElementType());
-
-    auto loc = op.getLoc();
-    Value reshapedLhs = rewriter.create<mhlo::ReshapeOp>(
-        loc, RankedTensorType::get(lhsNewShape, lhsShapeType.getElementType()),
-        op.lhs());
-    Value reshapedRhs = rewriter.create<mhlo::ReshapeOp>(
-        loc, RankedTensorType::get(rhsNewShape, rhsShapeType.getElementType()),
-        op.rhs());
-    auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get(
-        rewriter.getContext(),
-        /*lhs_batching_dimensions=*/{0},
-        /*rhs_batching_dimensions=*/{0},
-        /*lhs_contracting_dimensions=*/{lhsContractingDimIndex},
-        /*rhs_contracting_dimensions=*/
-        {rhsContractingDimIndex});
-    Value dotGeneralResult = rewriter.create<mhlo::DotGeneralOp>(
-        loc, dotGeneralResultType, reshapedLhs, reshapedRhs, dimensionNumbers,
+        op.getLoc(), newResultType, lhs, rhs, dimensionNumbers,
         op.precision_configAttr());
 
-    Value result =
-        rewriter.create<mhlo::ReshapeOp>(loc, resultType, dotGeneralResult);
+    if (needReshapeResult) {
+      result =
+          rewriter.create<mhlo::ReshapeOp>(op.getLoc(), resultType, result);
+    }
     rewriter.replaceOp(op, result);
-
     return success();
   }
 };
@@ -915,7 +823,7 @@
 
     // dot_general canoncalization patterns.
     mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, context);
-    patterns.insert<RankReducedDotGeneral, TransposeGenericDotGeneral>(context);
+    patterns.insert<TransposeReshapeGenericDotGeneral>(context);
 
     // Unary elementwise op.
     patterns.insert<
diff --git a/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir b/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir
index fbef414..ee8ddf6 100644
--- a/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir
@@ -56,10 +56,10 @@
   return %0 : tensor<1x8x32x32xf32>
 }
 // CHECK: dot_general_to_dot_general_rank_reduced_a_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x64x32xf32>) -> tensor<1x8x32x64xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_RESHAPED_TR]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
 // CHECK: %[[ARG1_RSSHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
-// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED_TR]], %[[ARG1_RSSHAPED]])
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RSSHAPED]])
 // CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
 
 // -----
@@ -69,7 +69,7 @@
     dot_dimension_numbers = #mhlo.dot<
       lhs_batching_dimensions = [0, 1],
       lhs_contracting_dimensions = [3],
-      rhs_batching_dimensions = [0,1 ],
+      rhs_batching_dimensions = [0, 1],
       rhs_contracting_dimensions = [3],
     >,
     precision_config = ["DEFAULT", "DEFAULT"]
@@ -77,10 +77,10 @@
   return %0 : tensor<1x8x32x32xf32>
 }
 // CHECK: dot_general_to_dot_general_rank_reduced_b_transposed(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x32x64xf32>) -> tensor<1x8x64x32xf32>
 // CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
-// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED_TR]])
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1_RESHAPED_TR]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
 // CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
 // CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
 
@@ -100,11 +100,11 @@
   return %0 : tensor<1x8x32x32xf32>
 }
 // CHECK: dot_general_to_dot_general_rank_reduced_ab_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
-// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
-// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED_TR]], %[[ARG1_RESHAPED_TR]])
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x64x32xf32>) -> tensor<1x8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x32x64xf32>) -> tensor<1x8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_RESHAPED_TR]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1_RESHAPED_TR]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
 // CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
 // CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
 
@@ -135,3 +135,31 @@
 // CHECK:         %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
 // CHECK:         %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x1x512xf32>) -> tensor<1x8x1x512xf32>
 // CHECK:         return %[[RESULT]] : tensor<1x8x1x512xf32>
+
+// -----
+
+func @dot_general_1d_batching_1d_contracting(%arg0: tensor<64x155x4x36xf32>, %arg1: tensor<309x4x36xf32>) -> tensor<4x64x155x309xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+    dot_dimension_numbers = #mhlo.dot<
+      lhs_batching_dimensions = [2],
+      rhs_batching_dimensions = [1],
+      lhs_contracting_dimensions = [3],
+      rhs_contracting_dimensions = [2]
+    >} : (tensor<64x155x4x36xf32>, tensor<309x4x36xf32>) -> tensor<4x64x155x309xf32>
+  return %0 : tensor<4x64x155x309xf32>
+}
+
+// CHECK-LABEL: func @dot_general_1d_batching_1d_contracting
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0]]) 
+// CHECK-SAME: {permutation = dense<[2, 0, 1, 3]> : tensor<4xi64>} 
+// CHECK-SAME: (tensor<64x155x4x36xf32>) -> tensor<4x64x155x36xf32>
+// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1]])
+// CHECK-SAME: {permutation = dense<[1, 2, 0]> : tensor<3xi64>}
+// CHECK-SAME: (tensor<309x4x36xf32>) -> tensor<4x36x309xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_RESHAPED_TR]])
+// CHECK-SAME: (tensor<4x64x155x36xf32>) -> tensor<4x9920x36xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED_TR]])
+// CHECK-SAME: (tensor<4x9920x36xf32>, tensor<4x36x309xf32>) -> tensor<4x9920x309xf32>
+// CHECK: "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<4x9920x309xf32>) -> tensor<4x64x155x309xf32>
diff --git a/iree/test/e2e/xla_ops/dot_general.mlir b/iree/test/e2e/xla_ops/dot_general.mlir
index 48e106f..d56b98a 100644
--- a/iree/test/e2e/xla_ops/dot_general.mlir
+++ b/iree/test/e2e/xla_ops/dot_general.mlir
@@ -141,3 +141,33 @@
   check.expect_almost_eq_const(%res, dense<409.596> : tensor<4x32x64xf32>) : tensor<4x32x64xf32>
   return
 }
+
+func @dot_general_nontrivial_batching_mutliple_parallel_dimension() {
+  %lhs = util.unfoldable_constant dense<[
+    [[[0.0], [1.0]], [[2.0], [3.0]], [[ 4.0], [ 5.0]]], 
+    [[[6.0], [7.0]], [[8.0], [9.0]], [[10.0], [11.0]]]
+  ]> : tensor<2x3x2x1xf32>
+  %rhs = util.unfoldable_constant dense<[
+    [[0.0], [1.0]], [[2.0], [3.0]]
+  ]> : tensor<2x2x1xf32>
+  %res = "mhlo.dot_general"(%lhs, %rhs) {
+    dot_dimension_numbers = #mhlo.dot<
+      lhs_batching_dimensions = [2],
+      rhs_batching_dimensions = [1],
+      lhs_contracting_dimensions = [3],
+      rhs_contracting_dimensions = [2]
+    >,
+    precision_config = ["DEFAULT", "DEFAULT"]
+  } : (tensor<2x3x2x1xf32>, tensor<2x2x1xf32>) -> tensor<2x2x3x2xf32>
+  check.expect_almost_eq_const(%res, dense<[
+    [
+      [[0.0,  0.0], [0.0,  4.0], [0.0,  8.0]],
+      [[0.0, 12.0], [0.0, 16.0], [0.0, 20.0]]
+    ],
+    [
+      [[1.0,  3.0], [3.0,  9.0], [ 5.0, 15.0]],
+      [[7.0, 21.0], [9.0, 27.0], [11.0, 33.0]]
+    ]
+  ]> : tensor<2x2x3x2xf32>) : tensor<2x2x3x2xf32>
+  return
+}