Explicitly ordering transpose and reshape patterns when legalizing mhlo.dot_general
1. RankReducedDotGeneral Assume batchingDims at the starting position and the parallelDims is consecutive
2. after parallelDims reduce Also need to reduce rank
[Solution]:
move TransposeGenericDotGeneral to the front of RankReducedDotGeneral
reduce rank after reduce parallelDims
Fixes https://github.com/google/iree/issues/7272
diff --git a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
index ca32815..abf39f2 100644
--- a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
+++ b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
@@ -493,11 +493,16 @@
// inserts transposes so the dot_general always has the form:
// {batch_dims, parallel_dims, contraction_dims}.
// {batch_dims, contraction_dims, parallel_dims}
-class TransposeGenericDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> {
+// After that, batch_dims, contraction_dims, parallel_dims are
+// in consecutive order and not spliting the domain. This pattern inserts
+// reshapes to collapse consecutive reduction and parallel dims to always
+// generate a rank-3 dot_general op.
+class TransposeReshapeGenericDotGeneral
+ : public OpRewritePattern<mhlo::DotGeneralOp> {
public:
using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
- Value TransposeIfNonConsecutive(OpBuilder b, Location loc, Value src,
+ Value TransposeIfNonConsecutive(OpBuilder &b, Location loc, Value src,
ArrayRef<int64_t> targetOrder) const {
if (isConsecutive(targetOrder)) return src;
auto type = src.getType().cast<RankedTensorType>();
@@ -510,6 +515,23 @@
b.getI64TensorAttr(targetOrder));
}
+ Value ReshapeIfMorethan3D(OpBuilder &b, Location loc, Value src,
+ size_t dimsBorder0, size_t dimsBorder1) const {
+ auto type = src.getType().cast<RankedTensorType>();
+ if (type.getRank() <= 3) return src;
+ auto shape = type.getShape();
+ SmallVector<int64_t, 4> result_shape = {
+ std::accumulate(shape.begin(), shape.begin() + dimsBorder0, 1,
+ std::multiplies<int64_t>()),
+ std::accumulate(shape.begin() + dimsBorder0,
+ shape.begin() + dimsBorder1, 1,
+ std::multiplies<int64_t>()),
+ std::accumulate(shape.begin() + dimsBorder1, shape.end(), 1,
+ std::multiplies<int64_t>())};
+ return b.create<mhlo::ReshapeOp>(
+ loc, RankedTensorType::get(result_shape, type.getElementType()), src);
+ }
+
LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
PatternRewriter &rewriter) const override {
auto lhsShapeType = op.lhs().getType().dyn_cast<RankedTensorType>();
@@ -559,167 +581,53 @@
lhsTargetOrder);
Value rhs = TransposeIfNonConsecutive(rewriter, op.getLoc(), op.rhs(),
rhsTargetOrder);
- if (lhs == op.lhs() && rhs == op.rhs()) return failure();
+ // The dimensions of this will always be transposed into {batch_dims,
+ // parallel_dims, contraction_dims}, and the
+ // following logic is based on this assumption.
+ // TODO(#7443): If we consider transpose performance, the above assumptions
+ // may not be true.
int64_t numLhsContractionDims = lhsContractingDims.size();
int64_t lhsContractionBase = lhsShapeType.getRank() - numLhsContractionDims;
int64_t rhsContractionBase = rhsBatchingDims.size();
int64_t numRhsContractionDims =
rhsContractionBase + rhsContractingDims.size();
- auto lhsBatchingDimsAttr =
- llvm::to_vector<4>(llvm::seq<int64_t>(0, lhsBatchingDims.size()));
- auto rhsBatchingDimsAttr =
- llvm::to_vector<4>(llvm::seq<int64_t>(0, rhsBatchingDims.size()));
- auto lhsContractingDimsAttr = llvm::to_vector<4>(
- llvm::seq<int64_t>(lhsContractionBase, lhsShapeType.getRank()));
- auto rhsContractingDimsAttr = llvm::to_vector<4>(
- llvm::seq<int64_t>(rhsContractionBase, numRhsContractionDims));
+
+ lhs = ReshapeIfMorethan3D(rewriter, op.getLoc(), lhs,
+ rhsBatchingDims.size(), lhsContractionBase);
+ rhs = ReshapeIfMorethan3D(rewriter, op.getLoc(), rhs,
+ rhsBatchingDims.size(), numRhsContractionDims);
+
+ if (lhs == op.lhs() && rhs == op.rhs()) return failure();
+
auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get(
- rewriter.getContext(), lhsBatchingDimsAttr, rhsBatchingDimsAttr,
- lhsContractingDimsAttr, rhsContractingDimsAttr);
+ rewriter.getContext(), /*lhsBatchingDimensions=*/0,
+ /*rhsBatchingDimensions=*/0,
+ /*lhsContractingDimensions=*/2, /*rhsContractingDimensions=*/1);
+ auto lhsNewType = lhs.getType().cast<RankedTensorType>();
+ auto rhsNewType = rhs.getType().cast<RankedTensorType>();
+
+ // if lhs's shape or rhs's shape has collapsed, we need reshape the result
+ bool needReshapeResult = lhsNewType.getRank() < lhsShapeType.getRank() ||
+ rhsNewType.getRank() < rhsShapeType.getRank();
+ // batching、lhs parallel、rhs parallel this order is a convension
+ SmallVector<int64_t, 4> newShape = {lhsNewType.getShape()[0],
+ lhsNewType.getShape()[1],
+ rhsNewType.getShape()[2]};
+ auto newResultType =
+ needReshapeResult
+ ? RankedTensorType::get(newShape, resultType.getElementType())
+ : op.getType();
Value result = rewriter.create<mhlo::DotGeneralOp>(
- op.getLoc(), op.getType(), lhs, rhs, dimensionNumbers,
- op.precision_configAttr());
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-// Rewrite mhlo.dot_general to operate on rank-3 tensors when reduction dims are
-// in consecutive order and not spliting the domain. This pattern inserts
-// reshapes to collapse consecutive reduction and parallel dims to always
-// generate a rank-3 dot_general op.
-class RankReducedDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> {
- public:
- using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
- PatternRewriter &rewriter) const override {
- auto lhsShapeType = op.lhs().getType().dyn_cast<ShapedType>();
- auto rhsShapeType = op.rhs().getType().dyn_cast<ShapedType>();
- auto resultType = op.getResult().getType().dyn_cast<ShapedType>();
-
- if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
- if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape())
- return failure();
- if (resultType.getRank() <= 3) return failure();
-
- mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers();
- auto lhsBatchingDims =
- llvm::to_vector<4>(dimNumbers.getLhsBatchingDimensions());
- auto rhsBatchingDims =
- llvm::to_vector<4>(dimNumbers.getRhsBatchingDimensions());
- auto lhsContractingDims =
- llvm::to_vector<4>(dimNumbers.getLhsContractingDimensions());
- auto rhsContractingDims =
- llvm::to_vector<4>(dimNumbers.getRhsContractingDimensions());
-
- if (lhsBatchingDims.empty() || rhsBatchingDims.empty()) return failure();
-
- llvm::sort(lhsBatchingDims);
- llvm::sort(lhsContractingDims);
- llvm::sort(rhsBatchingDims);
- llvm::sort(rhsContractingDims);
-
- auto isDomainSplit = [](ArrayRef<int64_t> shape,
- ArrayRef<int64_t> batchingDims,
- ArrayRef<int64_t> contractingDims) {
- // Batching and contracting are contiguous.
- if ((contractingDims.front() - batchingDims.back()) == 1) return false;
- // Contracting dims are inner most.
- if (contractingDims.back() == (shape.size() - 1)) return false;
- return true;
- };
-
- if (!isConsecutive(lhsBatchingDims) || !isConsecutive(lhsContractingDims) ||
- !isConsecutive(rhsBatchingDims) || !isConsecutive(rhsContractingDims))
- return failure();
-
- if (isDomainSplit(lhsShapeType.getShape(), lhsBatchingDims,
- lhsContractingDims) ||
- isDomainSplit(rhsShapeType.getShape(), rhsBatchingDims,
- rhsContractingDims))
- return failure();
-
- // Collapsing shape into a rank-3 tensor, returns newCollabsedShape
- // contraction and parallel dim indices.
- auto computeCollapsedShape = [](ArrayRef<int64_t> shape,
- ArrayRef<int64_t> batchingDims,
- ArrayRef<int64_t> contractingDims) {
- auto newRank =
- shape.size() - batchingDims.size() - contractingDims.size() + 2;
- auto batchingSize = std::accumulate(
- batchingDims.begin(), batchingDims.end(), 1,
- [shape](const int64_t accum, const int64_t index) -> int64_t {
- return accum * shape[index];
- });
- auto contractingSize = std::accumulate(
- contractingDims.begin(), contractingDims.end(), 1,
- [shape](const int64_t accum, const int64_t index) -> int64_t {
- return accum * shape[index];
- });
-
- int parallelDimIndex, contractingDimIndex, parallelDimSize = 1;
- if (contractingDims.front() - batchingDims.back() > 1) {
- parallelDimIndex = 1;
- contractingDimIndex = 2;
- for (int i = batchingDims.back() + 1; i < contractingDims.front();
- ++i) {
- parallelDimSize *= shape[i];
- }
- } else {
- contractingDimIndex = 1;
- parallelDimIndex = 2;
- for (int i = contractingDims.back() + 1; i < shape.size(); ++i) {
- parallelDimSize *= shape[i];
- }
- }
- llvm::SmallVector<int64_t, 4> newShape(newRank);
- newShape[0] = batchingSize;
- newShape[contractingDimIndex] = contractingSize;
- newShape[parallelDimIndex] = parallelDimSize;
- return std::make_tuple(newShape, contractingDimIndex, parallelDimIndex);
- };
-
- int lhsContractingDimIndex, rhsContractingDimIndex, lhsParallelDimIndex,
- rhsParallelDimIndex;
- SmallVector<int64_t, 4> lhsNewShape, rhsNewShape;
- std::tie(lhsNewShape, lhsContractingDimIndex, lhsParallelDimIndex) =
- computeCollapsedShape(lhsShapeType.getShape(), lhsBatchingDims,
- lhsContractingDims);
-
- std::tie(rhsNewShape, rhsContractingDimIndex, rhsParallelDimIndex) =
- computeCollapsedShape(rhsShapeType.getShape(), rhsBatchingDims,
- rhsContractingDims);
- SmallVector<int64_t, 4> resultNewShape = {lhsNewShape[0],
- lhsNewShape[lhsParallelDimIndex],
- rhsNewShape[rhsParallelDimIndex]};
- Type dotGeneralResultType =
- RankedTensorType::get(resultNewShape, resultType.getElementType());
-
- auto loc = op.getLoc();
- Value reshapedLhs = rewriter.create<mhlo::ReshapeOp>(
- loc, RankedTensorType::get(lhsNewShape, lhsShapeType.getElementType()),
- op.lhs());
- Value reshapedRhs = rewriter.create<mhlo::ReshapeOp>(
- loc, RankedTensorType::get(rhsNewShape, rhsShapeType.getElementType()),
- op.rhs());
- auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get(
- rewriter.getContext(),
- /*lhs_batching_dimensions=*/{0},
- /*rhs_batching_dimensions=*/{0},
- /*lhs_contracting_dimensions=*/{lhsContractingDimIndex},
- /*rhs_contracting_dimensions=*/
- {rhsContractingDimIndex});
- Value dotGeneralResult = rewriter.create<mhlo::DotGeneralOp>(
- loc, dotGeneralResultType, reshapedLhs, reshapedRhs, dimensionNumbers,
+ op.getLoc(), newResultType, lhs, rhs, dimensionNumbers,
op.precision_configAttr());
- Value result =
- rewriter.create<mhlo::ReshapeOp>(loc, resultType, dotGeneralResult);
+ if (needReshapeResult) {
+ result =
+ rewriter.create<mhlo::ReshapeOp>(op.getLoc(), resultType, result);
+ }
rewriter.replaceOp(op, result);
-
return success();
}
};
@@ -915,7 +823,7 @@
// dot_general canoncalization patterns.
mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, context);
- patterns.insert<RankReducedDotGeneral, TransposeGenericDotGeneral>(context);
+ patterns.insert<TransposeReshapeGenericDotGeneral>(context);
// Unary elementwise op.
patterns.insert<
diff --git a/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir b/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir
index fbef414..ee8ddf6 100644
--- a/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canoncalize_dot_general.mlir
@@ -56,10 +56,10 @@
return %0 : tensor<1x8x32x32xf32>
}
// CHECK: dot_general_to_dot_general_rank_reduced_a_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x64x32xf32>) -> tensor<1x8x32x64xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_RESHAPED_TR]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
// CHECK: %[[ARG1_RSSHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
-// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED_TR]], %[[ARG1_RSSHAPED]])
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RSSHAPED]])
// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
// -----
@@ -69,7 +69,7 @@
dot_dimension_numbers = #mhlo.dot<
lhs_batching_dimensions = [0, 1],
lhs_contracting_dimensions = [3],
- rhs_batching_dimensions = [0,1 ],
+ rhs_batching_dimensions = [0, 1],
rhs_contracting_dimensions = [3],
>,
precision_config = ["DEFAULT", "DEFAULT"]
@@ -77,10 +77,10 @@
return %0 : tensor<1x8x32x32xf32>
}
// CHECK: dot_general_to_dot_general_rank_reduced_b_transposed(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x32x64xf32>) -> tensor<1x8x64x32xf32>
// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
-// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED_TR]])
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1_RESHAPED_TR]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
@@ -100,11 +100,11 @@
return %0 : tensor<1x8x32x32xf32>
}
// CHECK: dot_general_to_dot_general_rank_reduced_ab_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
-// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
-// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED_TR]], %[[ARG1_RESHAPED_TR]])
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x64x32xf32>) -> tensor<1x8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x32x64xf32>) -> tensor<1x8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_RESHAPED_TR]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1_RESHAPED_TR]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
@@ -135,3 +135,31 @@
// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x1x512xf32>) -> tensor<1x8x1x512xf32>
// CHECK: return %[[RESULT]] : tensor<1x8x1x512xf32>
+
+// -----
+
+func @dot_general_1d_batching_1d_contracting(%arg0: tensor<64x155x4x36xf32>, %arg1: tensor<309x4x36xf32>) -> tensor<4x64x155x309xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = #mhlo.dot<
+ lhs_batching_dimensions = [2],
+ rhs_batching_dimensions = [1],
+ lhs_contracting_dimensions = [3],
+ rhs_contracting_dimensions = [2]
+ >} : (tensor<64x155x4x36xf32>, tensor<309x4x36xf32>) -> tensor<4x64x155x309xf32>
+ return %0 : tensor<4x64x155x309xf32>
+}
+
+// CHECK-LABEL: func @dot_general_1d_batching_1d_contracting
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0]])
+// CHECK-SAME: {permutation = dense<[2, 0, 1, 3]> : tensor<4xi64>}
+// CHECK-SAME: (tensor<64x155x4x36xf32>) -> tensor<4x64x155x36xf32>
+// CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1]])
+// CHECK-SAME: {permutation = dense<[1, 2, 0]> : tensor<3xi64>}
+// CHECK-SAME: (tensor<309x4x36xf32>) -> tensor<4x36x309xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_RESHAPED_TR]])
+// CHECK-SAME: (tensor<4x64x155x36xf32>) -> tensor<4x9920x36xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED_TR]])
+// CHECK-SAME: (tensor<4x9920x36xf32>, tensor<4x36x309xf32>) -> tensor<4x9920x309xf32>
+// CHECK: "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<4x9920x309xf32>) -> tensor<4x64x155x309xf32>
diff --git a/iree/test/e2e/xla_ops/dot_general.mlir b/iree/test/e2e/xla_ops/dot_general.mlir
index 48e106f..d56b98a 100644
--- a/iree/test/e2e/xla_ops/dot_general.mlir
+++ b/iree/test/e2e/xla_ops/dot_general.mlir
@@ -141,3 +141,33 @@
check.expect_almost_eq_const(%res, dense<409.596> : tensor<4x32x64xf32>) : tensor<4x32x64xf32>
return
}
+
+func @dot_general_nontrivial_batching_mutliple_parallel_dimension() {
+ %lhs = util.unfoldable_constant dense<[
+ [[[0.0], [1.0]], [[2.0], [3.0]], [[ 4.0], [ 5.0]]],
+ [[[6.0], [7.0]], [[8.0], [9.0]], [[10.0], [11.0]]]
+ ]> : tensor<2x3x2x1xf32>
+ %rhs = util.unfoldable_constant dense<[
+ [[0.0], [1.0]], [[2.0], [3.0]]
+ ]> : tensor<2x2x1xf32>
+ %res = "mhlo.dot_general"(%lhs, %rhs) {
+ dot_dimension_numbers = #mhlo.dot<
+ lhs_batching_dimensions = [2],
+ rhs_batching_dimensions = [1],
+ lhs_contracting_dimensions = [3],
+ rhs_contracting_dimensions = [2]
+ >,
+ precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<2x3x2x1xf32>, tensor<2x2x1xf32>) -> tensor<2x2x3x2xf32>
+ check.expect_almost_eq_const(%res, dense<[
+ [
+ [[0.0, 0.0], [0.0, 4.0], [0.0, 8.0]],
+ [[0.0, 12.0], [0.0, 16.0], [0.0, 20.0]]
+ ],
+ [
+ [[1.0, 3.0], [3.0, 9.0], [ 5.0, 15.0]],
+ [[7.0, 21.0], [9.0, 27.0], [11.0, 33.0]]
+ ]
+ ]> : tensor<2x2x3x2xf32>) : tensor<2x2x3x2xf32>
+ return
+}