Canonicalizes mhlo.dot_general to a rank-3 mhlo.dot_general or mhlo.dot (#3225)
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index 8123574..e6fe664 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <numeric>
+
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
@@ -56,6 +58,13 @@
[](APInt v) -> bool { return !v.isNullValue(); });
}
+static DenseIntElementsAttr make1DElementsAttr(PatternRewriter &rewriter,
+ ArrayRef<int64_t> integers) {
+ auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
+ rewriter.getIntegerType(64));
+ return DenseIntElementsAttr::get(type, integers);
+}
+
class DecomposeLog1PPattern : public OpRewritePattern<mhlo::Log1pOp> {
public:
using OpRewritePattern<mhlo::Log1pOp>::OpRewritePattern;
@@ -324,6 +333,240 @@
}
};
+// Rewrites rank-3 mhlo.dot_general so lhs contraction dimension is
+// inner most (2) and rhs contraction dimension is dim right after batch
+// dimension. The pattern inserts transposes so the dot_general always has the
+// form: {batch_dim, parallel, contraction}.{batch_dim, contraction, parallel}
+class TransposeRank3GenericDotGeneral
+ : public OpRewritePattern<mhlo::DotGeneralOp> {
+ public:
+ using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
+ PatternRewriter &rewriter) const override {
+ auto lhsShapeType = op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto rhsShapeType = op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
+
+ if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
+ if (resultType.getRank() != 3) return failure();
+
+ if (op.dot_dimension_numbers().lhs_contracting_dimensions().size() != 1 ||
+ op.dot_dimension_numbers().rhs_contracting_dimensions().size() != 1)
+ return failure();
+
+ int64_t lhsBatchDim = (*op.dot_dimension_numbers()
+ .lhs_batching_dimensions()
+ .int_value_begin())
+ .getSExtValue();
+ int64_t rhsBatchDim = (*op.dot_dimension_numbers()
+ .rhs_batching_dimensions()
+ .int_value_begin())
+ .getSExtValue();
+ int64_t lhsContractionDim = (*op.dot_dimension_numbers()
+ .lhs_contracting_dimensions()
+ .int_value_begin())
+ .getSExtValue();
+ int64_t rhsContractionDim = (*op.dot_dimension_numbers()
+ .rhs_contracting_dimensions()
+ .int_value_begin())
+ .getSExtValue();
+ // Only accept rank-3 tensors with dim order when dims are :
+ // lhs : {batch_dim, contraction, parallel}
+ // rhs : {batch_dim, parallel, contraction}
+ if (lhsBatchDim != 0 || rhsBatchDim != 0) return failure();
+ // No transposes are needed.
+ if (lhsContractionDim == 2 && rhsContractionDim == 1) return failure();
+
+ Value lhs = op.lhs(), rhs = op.rhs();
+
+ // transpose {batch_dim, contraction, parallel} case.
+ if (lhsContractionDim == 1) {
+ Type transposedType = RankedTensorType::get(
+ {lhsShapeType.getDimSize(0), lhsShapeType.getDimSize(2),
+ lhsShapeType.getDimSize(1)},
+ resultType.getElementType());
+ lhs = rewriter.create<mhlo::TransposeOp>(
+ op.getLoc(), transposedType, lhs,
+ make1DElementsAttr(rewriter, {0, 2, 1}));
+ }
+
+ // transpose {batch_dim, contraction, parallel} case.
+ if (rhsContractionDim == 2) {
+ Type transposedType = RankedTensorType::get(
+ {rhsShapeType.getDimSize(0), rhsShapeType.getDimSize(2),
+ rhsShapeType.getDimSize(1)},
+ resultType.getElementType());
+ rhs = rewriter.create<mhlo::TransposeOp>(
+ op.getLoc(), transposedType, rhs,
+ make1DElementsAttr(rewriter, {0, 2, 1}));
+ }
+
+ auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
+ /*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+ /*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+ /*lhs_contracting_dimensions=*/make1DElementsAttr(rewriter, {2}),
+ /*rhs_contracting_dimensions=*/
+ make1DElementsAttr(rewriter, {1}), rewriter.getContext());
+
+ Value result = rewriter.create<mhlo::DotGeneralOp>(
+ op.getLoc(), op.getType(), lhs, rhs, dimensionNumbers,
+ op.precision_configAttr());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+// Rewrite mhlo.dot_general to operate on rank-3 tensors when reduction dims are
+// in consecutive order and not spliting the domain. This pattern inserts
+// reshapes to collapse consecutive reduction and parallel dims to always
+// generate a rank-3 dot_general op.
+class RankReducedDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> {
+ public:
+ using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
+ PatternRewriter &rewriter) const override {
+ auto lhsShapeType = op.lhs().getType().dyn_cast<ShapedType>();
+ auto rhsShapeType = op.rhs().getType().dyn_cast<ShapedType>();
+ auto resultType = op.getResult().getType().dyn_cast<ShapedType>();
+
+ if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
+ if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape())
+ return failure();
+ if (resultType.getRank() <= 3) return failure();
+
+ mhlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
+ auto lhsBatchingDims = llvm::to_vector<4>(
+ llvm::map_range(dimNumbers.lhs_batching_dimensions(),
+ [](APInt v) { return v.getSExtValue(); }));
+ auto rhsBatchingDims = llvm::to_vector<4>(
+ llvm::map_range(dimNumbers.rhs_batching_dimensions(),
+ [](APInt v) { return v.getSExtValue(); }));
+ auto lhsContractingDims = llvm::to_vector<4>(
+ llvm::map_range(dimNumbers.lhs_contracting_dimensions(),
+ [](APInt v) { return v.getSExtValue(); }));
+ auto rhsContractingDims = llvm::to_vector<4>(
+ llvm::map_range(dimNumbers.rhs_contracting_dimensions(),
+ [](APInt v) { return v.getSExtValue(); }));
+
+ if (lhsBatchingDims.empty() || rhsBatchingDims.empty()) return failure();
+
+ llvm::sort(lhsBatchingDims);
+ llvm::sort(lhsContractingDims);
+ llvm::sort(rhsBatchingDims);
+ llvm::sort(rhsContractingDims);
+
+ auto isConsecutive = [](ArrayRef<int64_t> array) {
+ for (int i = 1; i < array.size(); ++i) {
+ if (array[i] - array[i - 1] != 1) return false;
+ }
+ return true;
+ };
+
+ auto isDomainSplit = [](ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> batchingDims,
+ ArrayRef<int64_t> contractingDims) {
+ // Batching and contracting are contiguous.
+ if ((contractingDims.front() - batchingDims.back()) == 1) return false;
+ // Contracting dims are inner most.
+ if (contractingDims.back() == (shape.size() - 1)) return false;
+ return true;
+ };
+
+ if (!isConsecutive(lhsBatchingDims) || !isConsecutive(lhsContractingDims) ||
+ !isConsecutive(rhsBatchingDims) || !isConsecutive(rhsContractingDims))
+ return failure();
+
+ if (isDomainSplit(lhsShapeType.getShape(), lhsBatchingDims,
+ lhsContractingDims) ||
+ isDomainSplit(rhsShapeType.getShape(), rhsBatchingDims,
+ rhsContractingDims))
+ return failure();
+
+ // Collapsing shape into a rank-3 tensor, returns newCollabsedShape
+ // contraction and parallel dim indices.
+ auto computeCollapsedShape = [](ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> batchingDims,
+ ArrayRef<int64_t> contractingDims) {
+ auto newRank =
+ shape.size() - batchingDims.size() - contractingDims.size() + 2;
+ auto batchingSize = std::accumulate(
+ batchingDims.begin(), batchingDims.end(), 1,
+ [shape](const int64_t accum, const int64_t index) -> int64_t {
+ return accum * shape[index];
+ });
+ auto contractingSize = std::accumulate(
+ contractingDims.begin(), contractingDims.end(), 1,
+ [shape](const int64_t accum, const int64_t index) -> int64_t {
+ return accum * shape[index];
+ });
+
+ int parallelDimIndex, contractingDimIndex, parallelDimSize = 1;
+ if (contractingDims.front() - batchingDims.back() > 1) {
+ parallelDimIndex = 1;
+ contractingDimIndex = 2;
+ for (int i = batchingDims.back() + 1; i < contractingDims.front();
+ ++i) {
+ parallelDimSize *= shape[i];
+ }
+ } else {
+ contractingDimIndex = 1;
+ parallelDimIndex = 2;
+ for (int i = contractingDims.back() + 1; i < shape.size(); ++i) {
+ parallelDimSize *= shape[i];
+ }
+ }
+ llvm::SmallVector<int64_t, 4> newShape(newRank);
+ newShape[0] = batchingSize;
+ newShape[contractingDimIndex] = contractingSize;
+ newShape[parallelDimIndex] = parallelDimSize;
+ return std::make_tuple(newShape, contractingDimIndex, parallelDimIndex);
+ };
+
+ int lhsContractingDimIndex, rhsContractingDimIndex, lhsParallelDimIndex,
+ rhsParallelDimIndex;
+ SmallVector<int64_t, 4> lhsNewShape, rhsNewShape;
+ std::tie(lhsNewShape, lhsContractingDimIndex, lhsParallelDimIndex) =
+ computeCollapsedShape(lhsShapeType.getShape(), lhsBatchingDims,
+ lhsContractingDims);
+
+ std::tie(rhsNewShape, rhsContractingDimIndex, rhsParallelDimIndex) =
+ computeCollapsedShape(rhsShapeType.getShape(), rhsBatchingDims,
+ rhsContractingDims);
+ SmallVector<int64_t, 4> resultNewShape = {lhsNewShape[0],
+ lhsNewShape[lhsParallelDimIndex],
+ rhsNewShape[rhsParallelDimIndex]};
+ Type dotGeneralResultType =
+ RankedTensorType::get(resultNewShape, resultType.getElementType());
+
+ auto loc = op.getLoc();
+ Value reshapedLhs = rewriter.create<mhlo::ReshapeOp>(
+ loc, RankedTensorType::get(lhsNewShape, lhsShapeType.getElementType()),
+ op.lhs());
+ Value reshapedRhs = rewriter.create<mhlo::ReshapeOp>(
+ loc, RankedTensorType::get(rhsNewShape, rhsShapeType.getElementType()),
+ op.rhs());
+ auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
+ /*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+ /*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+ /*lhs_contracting_dimensions=*/
+ make1DElementsAttr(rewriter, {lhsContractingDimIndex}),
+ /*rhs_contracting_dimensions=*/
+ make1DElementsAttr(rewriter, {rhsContractingDimIndex}),
+ rewriter.getContext());
+ Value dotGeneralResult = rewriter.create<mhlo::DotGeneralOp>(
+ loc, dotGeneralResultType, reshapedLhs, reshapedRhs, dimensionNumbers,
+ op.precision_configAttr());
+
+ Value result =
+ rewriter.create<mhlo::ReshapeOp>(loc, resultType, dotGeneralResult);
+ rewriter.replaceOp(op, result);
+
+ return success();
+ }
+}; // namespace
+
// clang-format off
//
// Reorder BroadcastInDimOp and N-ary elementwise op.
@@ -425,6 +668,11 @@
AdjustDepthwiseFilterShape, DecomposeLog1PPattern,
DecomposeExpM1Pattern>(context);
+ // dot_general canoncalization patterns.
+ mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, context);
+ patterns.insert<RankReducedDotGeneral, TransposeRank3GenericDotGeneral>(
+ context);
+
// Unary elementwise op.
patterns.insert<
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AbsOp>,
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir
new file mode 100644
index 0000000..3447436
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir
@@ -0,0 +1,104 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-hlo-to-hlo-preprocessing %s | IreeFileCheck %s
+
+func @dot_general_to_dot(%arg0: tensor<1x32x128x4xf32>, %arg1: tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<> : tensor<0xi64>,
+ lhs_contracting_dimensions = dense<[2, 3]> : tensor<2xi64>,
+ rhs_batching_dimensions = dense<> : tensor<0xi64>,
+ rhs_contracting_dimensions = dense<[0, 1]> : tensor<2xi64>
+ }, name = "dot_general_to_dot", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x32x128x4xf32>, tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32>
+ return %0 : tensor<1x32x8x64xf32>
+}
+
+// CHECK: dot_general_to_dot(%[[ARG0:.+]]: tensor<1x32x128x4xf32>, %[[ARG1:.+]]: tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x32x128x4xf32>) -> tensor<32x512xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<128x4x8x64xf32>) -> tensor<512x512xf32>
+// CHECK: %[[DOT:.+]] = "mhlo.dot"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT]]) : (tensor<32x512xf32>) -> tensor<1x32x8x64xf32>
+// CHECK: return %[[RESULT]] : tensor<1x32x8x64xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced(%arg0: tensor<1x8x32x64xf32>, %arg1 : tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ lhs_contracting_dimensions = dense<3> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ rhs_contracting_dimensions = dense<2> : tensor<1xi64>
+ }, name = "dot_general_to_dot", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x8x32x64xf32>, tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+ return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_a_transposed(%arg0: tensor<1x8x64x32xf32>, %arg1: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ rhs_contracting_dimensions = dense<2> : tensor<1xi64>
+ }, name = "dot_general_to_dot_trans_a", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x8x64x32xf32>, tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+ return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_a_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG1_RSSHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED_TR]], %[[ARG1_RSSHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_b_transposed(%arg0: tensor<1x8x32x64xf32>, %arg1: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ lhs_contracting_dimensions = dense<3> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ rhs_contracting_dimensions = dense<3> : tensor<1xi64>
+ }, name = "dot_general_to_dot_trans_b", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x8x32x64xf32>, tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+ return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_b_transposed(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG1_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_REHSPAED]], %[[ARG1_REHSPAED_TR]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
+
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_ab_transposed(%arg0: tensor<1x8x64x32xf32>, %arg1: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ rhs_contracting_dimensions = dense<3> : tensor<1xi64>
+ }, name = "dot_general_to_dot_trans_ab", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x8x64x32xf32>, tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+ return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_ab_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG1_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG0_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG0_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG1_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_REHSPAED_TR]], %[[ARG1_REHSPAED_TR]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 55f8428..2bbd6b4 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -52,6 +52,7 @@
"cosine.mlir",
"divide.mlir",
"dot.mlir",
+ "dot_general.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"floor.mlir",
@@ -104,6 +105,7 @@
"cosine.mlir",
"divide.mlir",
"dot.mlir",
+ "dot_general.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"floor.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index ae2a8c1..4bd7219 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -45,6 +45,7 @@
"cosine.mlir"
"divide.mlir"
"dot.mlir"
+ "dot_general.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"floor.mlir"
@@ -97,6 +98,7 @@
"cosine.mlir"
"divide.mlir"
"dot.mlir"
+ "dot_general.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"floor.mlir"