Remove redundant reshape checks in dot general preprocessing (#15319)
The final reshape is not required to output the correct size as
canonicalizers should clean up any unneeded transforms.
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp
index 6e7ecde..fca7077 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp
@@ -148,6 +148,68 @@
return transposeReshape(arg, loc, contractDims, outerDims, shape, rewriter);
}
+struct GeneralDotRemoveBatch final
+ : OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::stablehlo::DotGeneralOp op,
+ PatternRewriter &rewriter) const override {
+ auto lhsTy = cast<ShapedType>(op.getLhs().getType());
+ auto rhsTy = cast<ShapedType>(op.getRhs().getType());
+ auto ty = cast<ShapedType>(op.getType());
+
+ if (!ty.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(op, "does not have static shape");
+ }
+
+ auto dimNumbers = op.getDotDimensionNumbers();
+ if (dimNumbers.getLhsBatchingDimensions().size() != 1 ||
+ dimNumbers.getLhsBatchingDimensions().size() != 1) {
+ return rewriter.notifyMatchFailure(op, "non-unary batch dimension");
+ }
+
+ if (dimNumbers.getLhsBatchingDimensions().front() != 0 ||
+ dimNumbers.getRhsBatchingDimensions().front() != 0) {
+ return rewriter.notifyMatchFailure(op, "not first dim on lhs/rhs");
+ }
+
+ if (lhsTy.getDimSize(0) != 1 || rhsTy.getDimSize(0) != 1) {
+ return rewriter.notifyMatchFailure(op, "not unary batch size");
+ }
+
+ // We no longer include the batch dimension of 1.
+ llvm::SmallVector<int64_t> newLhsContractingDims;
+ for (auto dim : dimNumbers.getLhsContractingDimensions())
+ newLhsContractingDims.push_back(dim - 1);
+
+ llvm::SmallVector<int64_t> newRhsContractingDims;
+ for (auto dim : dimNumbers.getRhsContractingDimensions())
+ newRhsContractingDims.push_back(dim - 1);
+
+ auto lhs = rewriter.create<mlir::stablehlo::ReshapeOp>(
+ op.getLoc(), lhsTy.clone(lhsTy.getShape().drop_front()), op.getLhs());
+
+ auto rhs = rewriter.create<mlir::stablehlo::ReshapeOp>(
+ op.getLoc(), rhsTy.clone(rhsTy.getShape().drop_front()), op.getRhs());
+
+ auto newDimNumbers = mlir::stablehlo::DotDimensionNumbersAttr::get(
+ rewriter.getContext(),
+ /*lhsBatchingDimensions=*/{},
+ /*rhsBatchingDimensions=*/{},
+ /*lhsContractingDimensions=*/
+ newLhsContractingDims,
+ /*rhsContractingDimensions=*/
+ newRhsContractingDims);
+
+ auto dot = rewriter.create<mlir::stablehlo::DotGeneralOp>(
+ op.getLoc(), ty.clone(ty.getShape().drop_front()), lhs, rhs,
+ newDimNumbers, op.getPrecisionConfigAttr());
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, ty,
+ dot.getResult());
+ return success();
+ }
+};
+
struct GeneralDotConvert final
: OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
using OpRewritePattern::OpRewritePattern;
@@ -374,7 +436,9 @@
void populatePreprocessingDotGeneralToDotPatterns(mlir::MLIRContext *context,
RewritePatternSet *patterns,
PatternBenefit benefit) {
- patterns->add<GeneralDotConvert, DotVectorOptimization>(context, benefit);
+ patterns
+ ->add<GeneralDotConvert, GeneralDotRemoveBatch, DotVectorOptimization>(
+ context, benefit);
}
} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
index 30d7416..601c4e3 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
@@ -418,19 +418,17 @@
auto lhsNewType = cast<RankedTensorType>(lhs.getType());
auto rhsNewType = cast<RankedTensorType>(rhs.getType());
- // if lhs's shape or rhs's shape has collapsed, we need reshape the result
- bool needReshapeResult = lhsNewType.getRank() < lhsShapeType.getRank() ||
- rhsNewType.getRank() < rhsShapeType.getRank();
// batching、lhs parallel、rhs parallel this order is a conversion
- SmallVector<int64_t> newShape = {lhsNewType.getShape()[0],
- lhsNewType.getShape()[1]};
+ SmallVector<int64_t, 3> newShape = {lhsNewType.getShape()[0]};
+
+ if (lhsNewType.getRank() > 2)
+ newShape.push_back(lhsNewType.getDimSize(1));
+
if (rhsNewType.getRank() > 2)
newShape.push_back(rhsNewType.getDimSize(2));
TensorType newResultType =
- needReshapeResult
- ? RankedTensorType::get(newShape, resultType.getElementType())
- : op.getType();
+ RankedTensorType::get(newShape, resultType.getElementType());
auto newOp = rewriter.create<mlir::stablehlo::DotGeneralOp>(
op.getLoc(), newResultType, lhs, rhs, dimensionNumbers,
@@ -446,10 +444,11 @@
}
Value result = newOp.getResult();
- if (needReshapeResult) {
- result = rewriter.create<mlir::stablehlo::ReshapeOp>(op.getLoc(),
- resultType, result);
+ if (op.getType() != newResultType) {
+ result = rewriter.create<mlir::stablehlo::ReshapeOp>(
+ op.getLoc(), op.getType(), newOp.getResult());
}
+
rewriter.replaceOp(op, result);
return success();
}
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalize_dot_general.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalize_dot_general.mlir
index bcb2ac1..f08a02c 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalize_dot_general.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalize_dot_general.mlir
@@ -17,6 +17,8 @@
return %0 : tensor<3xf32>
}
+// -----
+
// CHECK-LABEL: @dot_general_4d
func.func public @dot_general_4d(%arg0: tensor<1x2x3xf32> {stablehlo.sharding = ""}, %arg1: tensor<1x4x2x3xf32> {stablehlo.sharding = ""}) -> tensor<1x2x4xf32> {
%0 = stablehlo.dot_general %arg0, %arg1,
@@ -35,3 +37,18 @@
// CHECK-NEXT: return %[[RES]]
return %0 : tensor<1x2x4xf32>
}
+
+
+// -----
+
+// CHECK-LABEL: @unary_out_channel_dot
+func.func public @unary_out_channel_dot(%arg0: tensor<1x3x4xui16>, %arg1: tensor<1x4x3xui16>) -> tensor<1xui16> {
+
+ // CHECK: %[[TRANS:.+]] = stablehlo.transpose %arg0, dims = [0, 2, 1]
+ // CHECK: %[[LHS:.+]] = stablehlo.reshape %[[TRANS]] : (tensor<1x4x3xui16>) -> tensor<12xui16>
+ // CHECK: %[[RHS:.+]] = stablehlo.reshape %arg1 : (tensor<1x4x3xui16>) -> tensor<12xui16>
+ // CHECK: %[[DOT:.+]] = stablehlo.dot %[[LHS]], %[[RHS]]
+ // CHECK: %[[OUT:.+]] = stablehlo.reshape %[[DOT]] : (tensor<ui16>) -> tensor<1xui16>
+ %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2, 1] x [1, 2], precision = [HIGH, HIGH] : (tensor<1x3x4xui16>, tensor<1x4x3xui16>) -> tensor<1xui16>
+ return %0 : tensor<1xui16>
+}