[StableHLO] Port reduce canon patterns (#14045)
Turns out that linalg lowering depends on these.
The code is ported from equivalent folds and rewrite patterns in the
mlir-hlo repo. The only notable difference is that this implementation
fixes bugs with rewrite patterns performing op updates without going
through the pattern rewriter. Also added a new test for the second noop
case.
Fixes: https://github.com/openxla/iree/issues/14042
Issue: https://github.com/openxla/iree/issues/12678
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
index 3cb8e65..cc72647 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
@@ -694,6 +694,80 @@
}
};
+struct NoopReduceOpCanon final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
+ PatternRewriter &rewriter) const override {
+ // No dimensions to reduce.
+ if (op.getDimensions().empty()) {
+ rewriter.replaceOp(op, op.getInputs());
+ return success();
+ }
+
+ // If all returned values in the ReduceOp region exists outside the
+ // region, replace the ReduceOp with those values.
+ if (auto retOp = dyn_cast<mlir::stablehlo::ReturnOp>(
+ op.getBody().front().getTerminator())) {
+ Region *retRegion = retOp->getParentRegion();
+ if (llvm::any_of(retOp.getResults(), [retRegion](Value result) {
+ return result.getParentRegion() == retRegion;
+ })) {
+ return failure();
+ }
+
+ rewriter.replaceOp(op, retOp.getResults());
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+struct EmptyReduceOpCanon final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
+ PatternRewriter &rewriter) const override {
+ // We require all reduce shapes to be the same, up to the element types, so
+ // we can just the first operand and the first result as a representative.
+ auto elemTy = dyn_cast<RankedTensorType>(op.getInputs().getType().front());
+ if (!elemTy) {
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "unranked input unsupported");
+ }
+
+ if (!llvm::is_contained(elemTy.getShape(), 0)) return failure();
+
+ Location loc = op.getLoc();
+ DenseIntElementsAttr empty = rewriter.getI64TensorAttr({});
+ if (elemTy.hasStaticShape()) {
+ SmallVector<Value> broadcasts(op.getNumResults());
+ for (auto [bcast, init, outTy] : llvm::zip_equal(
+ broadcasts, op.getInitValues(), op.getResultTypes())) {
+ bcast = rewriter.create<mlir::stablehlo::BroadcastInDimOp>(loc, outTy,
+ init, empty);
+ }
+ rewriter.replaceOp(op, broadcasts);
+ return success();
+ }
+
+ SmallVector<Value> shapes;
+ if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), shapes))) {
+ return failure();
+ }
+
+ SmallVector<Value> broadcasts(op.getNumResults());
+ for (auto [bcast, init, shape, outTy] : llvm::zip_equal(
+ broadcasts, op.getInitValues(), shapes, op.getResultTypes())) {
+ bcast = rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
+ loc, outTy, init, shape, empty);
+ }
+ rewriter.replaceOp(op, broadcasts);
+ return success();
+ }
+};
+
struct DynamicReshapeOpCanon final
: OpRewritePattern<mlir::stablehlo::DynamicReshapeOp> {
using OpRewritePattern::OpRewritePattern;
@@ -922,6 +996,8 @@
BroadcastInDimOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic,
ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimAllDimsNonExpanding,
+ // Reduce op.
+ NoopReduceOpCanon, EmptyReduceOpCanon,
// Shape manipulation(-ish) ops.
ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon, GatherOpCanon,
ReshapeOpCanon, TransposeOpCanon>(context, benefit);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir
index 911b21a..abc297a 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir
@@ -427,7 +427,7 @@
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape
func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape(%arg0: tensor<i32>) -> tensor<4x32xi32> {
- %0 = mhlo.constant dense<[4, 32]> : tensor<2xi32>
+ %0 = stablehlo.constant dense<[4, 32]> : tensor<2xi32>
// CHECK: %[[RESULT:.+]] = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i32>) -> tensor<4x32xi32>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [] : (tensor<i32>, tensor<2xi32>) -> tensor<?x32xi32>
%2 = stablehlo.dynamic_reshape %1, %0 : (tensor<?x32xi32>, tensor<2xi32>) -> tensor<4x32xi32>
@@ -584,3 +584,48 @@
// CHECK-NEXT: %[[V1:.*]] = stablehlo.reshape %[[V0]] : (tensor<1x2xui32>) -> tensor<2xui32>
// CHECK-NEXT: return %[[V1]] : tensor<2xui32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @reduce_noop_1
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<4x8xf32>)
+func.func @reduce_noop_1(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+ %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
+ reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) {
+ %4 = stablehlo.add %arg1, %arg2 : tensor<f32>
+ stablehlo.return %4 : tensor<f32>
+ }
+ // CHECK: return [[ARG0]] : tensor<4x8xf32>
+ func.return %1 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func.func @reduce_noop_2
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<4x8xi32>, [[ARG1:%.+]]: tensor<i32>)
+func.func @reduce_noop_2(%arg0: tensor<4x8xi32>, %arg1: tensor<i32>) -> tensor<i32> {
+ %0 = stablehlo.constant dense<0> : tensor<i32>
+ %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [0, 1] : (tensor<4x8xi32>, tensor<i32>) -> tensor<i32>
+ reducer(%b1: tensor<i32>, %b2: tensor<i32>) {
+ stablehlo.return %arg1 : tensor<i32>
+ }
+ // CHECK: return [[ARG1]] : tensor<i32>
+ func.return %1 : tensor<i32>
+}
+
+// CHECK-LABEL: func.func @reduce_zero_ext
+func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor<i32> {
+ %0 = stablehlo.constant dense<false> : tensor<i1>
+ %1 = stablehlo.constant dense<false> : tensor<0xi1>
+ %2 = stablehlo.compare NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
+ %3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32>
+ %4 = stablehlo.constant dense<0> : tensor<i32>
+ %5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor<i32>) -> tensor<i32>
+ reducer(%arg1: tensor<i32>, %arg2: tensor<i32>) {
+ %6 = stablehlo.add %arg1, %arg2 : tensor<i32>
+ stablehlo.return %6 : tensor<i32>
+ }
+
+ // CHECK: [[CST:%.+]] = stablehlo.constant dense<0> : tensor<i32>
+ // CHECK: return [[CST]] : tensor<i32>
+ return %5 : tensor<i32>
+}