Canonicalizes mhlo.dot_general to a rank-3 mhlo.dot_general or mhlo.dot (#3225)

diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index 8123574..e6fe664 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -12,6 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <numeric>
+
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
@@ -56,6 +58,13 @@
                       [](APInt v) -> bool { return !v.isNullValue(); });
 }
 
+static DenseIntElementsAttr make1DElementsAttr(PatternRewriter &rewriter,
+                                               ArrayRef<int64_t> integers) {
+  auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
+                                    rewriter.getIntegerType(64));
+  return DenseIntElementsAttr::get(type, integers);
+}
+
 class DecomposeLog1PPattern : public OpRewritePattern<mhlo::Log1pOp> {
  public:
   using OpRewritePattern<mhlo::Log1pOp>::OpRewritePattern;
@@ -324,6 +333,240 @@
   }
 };
 
+// Rewrites rank-3 mhlo.dot_general so lhs contraction dimension is
+// inner most (2) and rhs contraction dimension is dim right after batch
+// dimension. The pattern inserts transposes so the dot_general always has the
+// form: {batch_dim, parallel, contraction}.{batch_dim, contraction, parallel}
+class TransposeRank3GenericDotGeneral
+    : public OpRewritePattern<mhlo::DotGeneralOp> {
+ public:
+  using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
+                                PatternRewriter &rewriter) const override {
+    auto lhsShapeType = op.lhs().getType().dyn_cast<RankedTensorType>();
+    auto rhsShapeType = op.rhs().getType().dyn_cast<RankedTensorType>();
+    auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
+
+    if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
+    if (resultType.getRank() != 3) return failure();
+
+    if (op.dot_dimension_numbers().lhs_contracting_dimensions().size() != 1 ||
+        op.dot_dimension_numbers().rhs_contracting_dimensions().size() != 1)
+      return failure();
+
+    int64_t lhsBatchDim = (*op.dot_dimension_numbers()
+                                .lhs_batching_dimensions()
+                                .int_value_begin())
+                              .getSExtValue();
+    int64_t rhsBatchDim = (*op.dot_dimension_numbers()
+                                .rhs_batching_dimensions()
+                                .int_value_begin())
+                              .getSExtValue();
+    int64_t lhsContractionDim = (*op.dot_dimension_numbers()
+                                      .lhs_contracting_dimensions()
+                                      .int_value_begin())
+                                    .getSExtValue();
+    int64_t rhsContractionDim = (*op.dot_dimension_numbers()
+                                      .rhs_contracting_dimensions()
+                                      .int_value_begin())
+                                    .getSExtValue();
+    // Only accept rank-3 tensors with dim order when dims are :
+    // lhs : {batch_dim, contraction, parallel}
+    // rhs : {batch_dim, parallel, contraction}
+    if (lhsBatchDim != 0 || rhsBatchDim != 0) return failure();
+    // No transposes are needed.
+    if (lhsContractionDim == 2 && rhsContractionDim == 1) return failure();
+
+    Value lhs = op.lhs(), rhs = op.rhs();
+
+    // transpose {batch_dim, contraction, parallel} case.
+    if (lhsContractionDim == 1) {
+      Type transposedType = RankedTensorType::get(
+          {lhsShapeType.getDimSize(0), lhsShapeType.getDimSize(2),
+           lhsShapeType.getDimSize(1)},
+          resultType.getElementType());
+      lhs = rewriter.create<mhlo::TransposeOp>(
+          op.getLoc(), transposedType, lhs,
+          make1DElementsAttr(rewriter, {0, 2, 1}));
+    }
+
+    // transpose {batch_dim, contraction, parallel} case.
+    if (rhsContractionDim == 2) {
+      Type transposedType = RankedTensorType::get(
+          {rhsShapeType.getDimSize(0), rhsShapeType.getDimSize(2),
+           rhsShapeType.getDimSize(1)},
+          resultType.getElementType());
+      rhs = rewriter.create<mhlo::TransposeOp>(
+          op.getLoc(), transposedType, rhs,
+          make1DElementsAttr(rewriter, {0, 2, 1}));
+    }
+
+    auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
+        /*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+        /*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+        /*lhs_contracting_dimensions=*/make1DElementsAttr(rewriter, {2}),
+        /*rhs_contracting_dimensions=*/
+        make1DElementsAttr(rewriter, {1}), rewriter.getContext());
+
+    Value result = rewriter.create<mhlo::DotGeneralOp>(
+        op.getLoc(), op.getType(), lhs, rhs, dimensionNumbers,
+        op.precision_configAttr());
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+// Rewrite mhlo.dot_general to operate on rank-3 tensors when reduction dims are
+// in consecutive order and not spliting the domain. This pattern inserts
+// reshapes to collapse consecutive reduction and parallel dims to always
+// generate a rank-3 dot_general op.
+class RankReducedDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> {
+ public:
+  using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
+                                PatternRewriter &rewriter) const override {
+    auto lhsShapeType = op.lhs().getType().dyn_cast<ShapedType>();
+    auto rhsShapeType = op.rhs().getType().dyn_cast<ShapedType>();
+    auto resultType = op.getResult().getType().dyn_cast<ShapedType>();
+
+    if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
+    if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape())
+      return failure();
+    if (resultType.getRank() <= 3) return failure();
+
+    mhlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
+    auto lhsBatchingDims = llvm::to_vector<4>(
+        llvm::map_range(dimNumbers.lhs_batching_dimensions(),
+                        [](APInt v) { return v.getSExtValue(); }));
+    auto rhsBatchingDims = llvm::to_vector<4>(
+        llvm::map_range(dimNumbers.rhs_batching_dimensions(),
+                        [](APInt v) { return v.getSExtValue(); }));
+    auto lhsContractingDims = llvm::to_vector<4>(
+        llvm::map_range(dimNumbers.lhs_contracting_dimensions(),
+                        [](APInt v) { return v.getSExtValue(); }));
+    auto rhsContractingDims = llvm::to_vector<4>(
+        llvm::map_range(dimNumbers.rhs_contracting_dimensions(),
+                        [](APInt v) { return v.getSExtValue(); }));
+
+    if (lhsBatchingDims.empty() || rhsBatchingDims.empty()) return failure();
+
+    llvm::sort(lhsBatchingDims);
+    llvm::sort(lhsContractingDims);
+    llvm::sort(rhsBatchingDims);
+    llvm::sort(rhsContractingDims);
+
+    auto isConsecutive = [](ArrayRef<int64_t> array) {
+      for (int i = 1; i < array.size(); ++i) {
+        if (array[i] - array[i - 1] != 1) return false;
+      }
+      return true;
+    };
+
+    auto isDomainSplit = [](ArrayRef<int64_t> shape,
+                            ArrayRef<int64_t> batchingDims,
+                            ArrayRef<int64_t> contractingDims) {
+      // Batching and contracting are contiguous.
+      if ((contractingDims.front() - batchingDims.back()) == 1) return false;
+      // Contracting dims are inner most.
+      if (contractingDims.back() == (shape.size() - 1)) return false;
+      return true;
+    };
+
+    if (!isConsecutive(lhsBatchingDims) || !isConsecutive(lhsContractingDims) ||
+        !isConsecutive(rhsBatchingDims) || !isConsecutive(rhsContractingDims))
+      return failure();
+
+    if (isDomainSplit(lhsShapeType.getShape(), lhsBatchingDims,
+                      lhsContractingDims) ||
+        isDomainSplit(rhsShapeType.getShape(), rhsBatchingDims,
+                      rhsContractingDims))
+      return failure();
+
+    // Collapsing shape into a rank-3 tensor, returns newCollabsedShape
+    // contraction and parallel dim indices.
+    auto computeCollapsedShape = [](ArrayRef<int64_t> shape,
+                                    ArrayRef<int64_t> batchingDims,
+                                    ArrayRef<int64_t> contractingDims) {
+      auto newRank =
+          shape.size() - batchingDims.size() - contractingDims.size() + 2;
+      auto batchingSize = std::accumulate(
+          batchingDims.begin(), batchingDims.end(), 1,
+          [shape](const int64_t accum, const int64_t index) -> int64_t {
+            return accum * shape[index];
+          });
+      auto contractingSize = std::accumulate(
+          contractingDims.begin(), contractingDims.end(), 1,
+          [shape](const int64_t accum, const int64_t index) -> int64_t {
+            return accum * shape[index];
+          });
+
+      int parallelDimIndex, contractingDimIndex, parallelDimSize = 1;
+      if (contractingDims.front() - batchingDims.back() > 1) {
+        parallelDimIndex = 1;
+        contractingDimIndex = 2;
+        for (int i = batchingDims.back() + 1; i < contractingDims.front();
+             ++i) {
+          parallelDimSize *= shape[i];
+        }
+      } else {
+        contractingDimIndex = 1;
+        parallelDimIndex = 2;
+        for (int i = contractingDims.back() + 1; i < shape.size(); ++i) {
+          parallelDimSize *= shape[i];
+        }
+      }
+      llvm::SmallVector<int64_t, 4> newShape(newRank);
+      newShape[0] = batchingSize;
+      newShape[contractingDimIndex] = contractingSize;
+      newShape[parallelDimIndex] = parallelDimSize;
+      return std::make_tuple(newShape, contractingDimIndex, parallelDimIndex);
+    };
+
+    int lhsContractingDimIndex, rhsContractingDimIndex, lhsParallelDimIndex,
+        rhsParallelDimIndex;
+    SmallVector<int64_t, 4> lhsNewShape, rhsNewShape;
+    std::tie(lhsNewShape, lhsContractingDimIndex, lhsParallelDimIndex) =
+        computeCollapsedShape(lhsShapeType.getShape(), lhsBatchingDims,
+                              lhsContractingDims);
+
+    std::tie(rhsNewShape, rhsContractingDimIndex, rhsParallelDimIndex) =
+        computeCollapsedShape(rhsShapeType.getShape(), rhsBatchingDims,
+                              rhsContractingDims);
+    SmallVector<int64_t, 4> resultNewShape = {lhsNewShape[0],
+                                              lhsNewShape[lhsParallelDimIndex],
+                                              rhsNewShape[rhsParallelDimIndex]};
+    Type dotGeneralResultType =
+        RankedTensorType::get(resultNewShape, resultType.getElementType());
+
+    auto loc = op.getLoc();
+    Value reshapedLhs = rewriter.create<mhlo::ReshapeOp>(
+        loc, RankedTensorType::get(lhsNewShape, lhsShapeType.getElementType()),
+        op.lhs());
+    Value reshapedRhs = rewriter.create<mhlo::ReshapeOp>(
+        loc, RankedTensorType::get(rhsNewShape, rhsShapeType.getElementType()),
+        op.rhs());
+    auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
+        /*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+        /*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+        /*lhs_contracting_dimensions=*/
+        make1DElementsAttr(rewriter, {lhsContractingDimIndex}),
+        /*rhs_contracting_dimensions=*/
+        make1DElementsAttr(rewriter, {rhsContractingDimIndex}),
+        rewriter.getContext());
+    Value dotGeneralResult = rewriter.create<mhlo::DotGeneralOp>(
+        loc, dotGeneralResultType, reshapedLhs, reshapedRhs, dimensionNumbers,
+        op.precision_configAttr());
+
+    Value result =
+        rewriter.create<mhlo::ReshapeOp>(loc, resultType, dotGeneralResult);
+    rewriter.replaceOp(op, result);
+
+    return success();
+  }
+};  // namespace
+
 // clang-format off
 //
 // Reorder BroadcastInDimOp and N-ary elementwise op.
@@ -425,6 +668,11 @@
                     AdjustDepthwiseFilterShape, DecomposeLog1PPattern,
                     DecomposeExpM1Pattern>(context);
 
+    // dot_general canoncalization patterns.
+    mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, context);
+    patterns.insert<RankReducedDotGeneral, TransposeRank3GenericDotGeneral>(
+        context);
+
     // Unary elementwise op.
     patterns.insert<
         ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AbsOp>,
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir
new file mode 100644
index 0000000..3447436
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir
@@ -0,0 +1,104 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-hlo-to-hlo-preprocessing %s | IreeFileCheck %s
+
+func @dot_general_to_dot(%arg0: tensor<1x32x128x4xf32>, %arg1: tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+      dot_dimension_numbers = {
+        lhs_batching_dimensions = dense<> : tensor<0xi64>,
+        lhs_contracting_dimensions = dense<[2, 3]> : tensor<2xi64>,
+        rhs_batching_dimensions = dense<> : tensor<0xi64>,
+        rhs_contracting_dimensions = dense<[0, 1]> : tensor<2xi64>
+      }, name = "dot_general_to_dot", precision_config = ["DEFAULT", "DEFAULT"]
+    } : (tensor<1x32x128x4xf32>, tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32>
+  return %0 : tensor<1x32x8x64xf32>
+}
+
+// CHECK: dot_general_to_dot(%[[ARG0:.+]]: tensor<1x32x128x4xf32>, %[[ARG1:.+]]: tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x32x128x4xf32>) -> tensor<32x512xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<128x4x8x64xf32>) -> tensor<512x512xf32>
+// CHECK: %[[DOT:.+]] = "mhlo.dot"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]]) 
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT]]) : (tensor<32x512xf32>) -> tensor<1x32x8x64xf32>
+// CHECK: return %[[RESULT]] : tensor<1x32x8x64xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced(%arg0: tensor<1x8x32x64xf32>, %arg1 : tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+    dot_dimension_numbers = {
+      lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      lhs_contracting_dimensions = dense<3> : tensor<1xi64>,
+      rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      rhs_contracting_dimensions = dense<2> : tensor<1xi64>
+    }, name = "dot_general_to_dot", precision_config = ["DEFAULT", "DEFAULT"]
+  } : (tensor<1x8x32x64xf32>, tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+  return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_a_transposed(%arg0: tensor<1x8x64x32xf32>, %arg1: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+    dot_dimension_numbers = {
+      lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+      rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      rhs_contracting_dimensions = dense<2> : tensor<1xi64>
+    }, name = "dot_general_to_dot_trans_a", precision_config = ["DEFAULT", "DEFAULT"]
+  } : (tensor<1x8x64x32xf32>, tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+  return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_a_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG1_RSSHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED_TR]], %[[ARG1_RSSHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_b_transposed(%arg0: tensor<1x8x32x64xf32>, %arg1: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+    dot_dimension_numbers = {
+      lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      lhs_contracting_dimensions = dense<3> : tensor<1xi64>,
+      rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      rhs_contracting_dimensions = dense<3> : tensor<1xi64>
+    }, name = "dot_general_to_dot_trans_b", precision_config = ["DEFAULT", "DEFAULT"]
+  } : (tensor<1x8x32x64xf32>, tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+  return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_b_transposed(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG1_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_REHSPAED]], %[[ARG1_REHSPAED_TR]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
+
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_ab_transposed(%arg0: tensor<1x8x64x32xf32>, %arg1: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32> {
+  %0 = "mhlo.dot_general"(%arg0, %arg1) {
+    dot_dimension_numbers = {
+      lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+      rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+      rhs_contracting_dimensions = dense<3> : tensor<1xi64>
+    }, name = "dot_general_to_dot_trans_ab", precision_config = ["DEFAULT", "DEFAULT"]
+  } : (tensor<1x8x64x32xf32>, tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+  return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_ab_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG1_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG0_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG0_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG1_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_REHSPAED_TR]], %[[ARG1_REHSPAED_TR]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 55f8428..2bbd6b4 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -52,6 +52,7 @@
         "cosine.mlir",
         "divide.mlir",
         "dot.mlir",
+        "dot_general.mlir",
         "exponential.mlir",
         "exponential_minus_one.mlir",
         "floor.mlir",
@@ -104,6 +105,7 @@
         "cosine.mlir",
         "divide.mlir",
         "dot.mlir",
+        "dot_general.mlir",
         "exponential.mlir",
         "exponential_minus_one.mlir",
         "floor.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index ae2a8c1..4bd7219 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -45,6 +45,7 @@
     "cosine.mlir"
     "divide.mlir"
     "dot.mlir"
+    "dot_general.mlir"
     "exponential.mlir"
     "exponential_minus_one.mlir"
     "floor.mlir"
@@ -97,6 +98,7 @@
     "cosine.mlir"
     "divide.mlir"
     "dot.mlir"
+    "dot_general.mlir"
     "exponential.mlir"
     "exponential_minus_one.mlir"
     "floor.mlir"