[StableHLO] Port reduce canon patterns (#14045)

Turns out that linalg lowering depends on these.

The code is ported from equivalent folds and rewrite patterns in the
mlir-hlo repo. The only notable difference is that this implementation
fixes bugs with rewrite patterns performing op updates without going
through the pattern rewriter. Also added a new test for the second noop
case.

Fixes: https://github.com/openxla/iree/issues/14042
Issue: https://github.com/openxla/iree/issues/12678
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
index 3cb8e65..cc72647 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
@@ -694,6 +694,80 @@
   }
 };
 
+struct NoopReduceOpCanon final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    // No dimensions to reduce.
+    if (op.getDimensions().empty()) {
+      rewriter.replaceOp(op, op.getInputs());
+      return success();
+    }
+
+    // If all returned values in the ReduceOp region exists outside the
+    // region, replace the ReduceOp with those values.
+    if (auto retOp = dyn_cast<mlir::stablehlo::ReturnOp>(
+            op.getBody().front().getTerminator())) {
+      Region *retRegion = retOp->getParentRegion();
+      if (llvm::any_of(retOp.getResults(), [retRegion](Value result) {
+            return result.getParentRegion() == retRegion;
+          })) {
+        return failure();
+      }
+
+      rewriter.replaceOp(op, retOp.getResults());
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+struct EmptyReduceOpCanon final : OpRewritePattern<mlir::stablehlo::ReduceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    // We require all reduce shapes to be the same, up to the element types, so
+    // we can just the first operand and the first result as a representative.
+    auto elemTy = dyn_cast<RankedTensorType>(op.getInputs().getType().front());
+    if (!elemTy) {
+      return rewriter.notifyMatchFailure(op.getLoc(),
+                                         "unranked input unsupported");
+    }
+
+    if (!llvm::is_contained(elemTy.getShape(), 0)) return failure();
+
+    Location loc = op.getLoc();
+    DenseIntElementsAttr empty = rewriter.getI64TensorAttr({});
+    if (elemTy.hasStaticShape()) {
+      SmallVector<Value> broadcasts(op.getNumResults());
+      for (auto [bcast, init, outTy] : llvm::zip_equal(
+               broadcasts, op.getInitValues(), op.getResultTypes())) {
+        bcast = rewriter.create<mlir::stablehlo::BroadcastInDimOp>(loc, outTy,
+                                                                   init, empty);
+      }
+      rewriter.replaceOp(op, broadcasts);
+      return success();
+    }
+
+    SmallVector<Value> shapes;
+    if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), shapes))) {
+      return failure();
+    }
+
+    SmallVector<Value> broadcasts(op.getNumResults());
+    for (auto [bcast, init, shape, outTy] : llvm::zip_equal(
+             broadcasts, op.getInitValues(), shapes, op.getResultTypes())) {
+      bcast = rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
+          loc, outTy, init, shape, empty);
+    }
+    rewriter.replaceOp(op, broadcasts);
+    return success();
+  }
+};
+
 struct DynamicReshapeOpCanon final
     : OpRewritePattern<mlir::stablehlo::DynamicReshapeOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -922,6 +996,8 @@
       BroadcastInDimOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic,
       ChainedDynamicBroadcastInDimCanonicalization,
       DynamicBroadcastInDimAllDimsNonExpanding,
+      // Reduce op.
+      NoopReduceOpCanon, EmptyReduceOpCanon,
       // Shape manipulation(-ish) ops.
       ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon, GatherOpCanon,
       ReshapeOpCanon, TransposeOpCanon>(context, benefit);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir
index 911b21a..abc297a 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir
@@ -427,7 +427,7 @@
 
 // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape
 func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape(%arg0: tensor<i32>) -> tensor<4x32xi32> {
-  %0 = mhlo.constant dense<[4, 32]> : tensor<2xi32>
+  %0 = stablehlo.constant dense<[4, 32]> : tensor<2xi32>
   // CHECK: %[[RESULT:.+]] = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<i32>) -> tensor<4x32xi32>
   %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [] : (tensor<i32>, tensor<2xi32>) -> tensor<?x32xi32>
   %2 = stablehlo.dynamic_reshape %1, %0 : (tensor<?x32xi32>, tensor<2xi32>) -> tensor<4x32xi32>
@@ -584,3 +584,48 @@
   // CHECK-NEXT: %[[V1:.*]] = stablehlo.reshape %[[V0]] : (tensor<1x2xui32>) -> tensor<2xui32>
   // CHECK-NEXT: return %[[V1]] : tensor<2xui32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @reduce_noop_1
+// CHECK-SAME:   ([[ARG0:%.+]]: tensor<4x8xf32>)
+func.func @reduce_noop_1(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+  %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
+    reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) {
+    %4 = stablehlo.add %arg1, %arg2 : tensor<f32>
+    stablehlo.return %4 : tensor<f32>
+  }
+  // CHECK: return [[ARG0]] : tensor<4x8xf32>
+  func.return %1 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func.func @reduce_noop_2
+// CHECK-SAME:   ([[ARG0:%.+]]: tensor<4x8xi32>, [[ARG1:%.+]]: tensor<i32>)
+func.func @reduce_noop_2(%arg0: tensor<4x8xi32>, %arg1: tensor<i32>) -> tensor<i32> {
+  %0 = stablehlo.constant dense<0> : tensor<i32>
+  %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [0, 1] : (tensor<4x8xi32>, tensor<i32>) -> tensor<i32>
+    reducer(%b1: tensor<i32>, %b2: tensor<i32>) {
+    stablehlo.return %arg1 : tensor<i32>
+  }
+  // CHECK: return [[ARG1]] : tensor<i32>
+  func.return %1 : tensor<i32>
+}
+
+// CHECK-LABEL: func.func @reduce_zero_ext
+func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor<i32> {
+  %0 = stablehlo.constant dense<false> : tensor<i1>
+  %1 = stablehlo.constant dense<false> : tensor<0xi1>
+  %2 = stablehlo.compare  NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
+  %3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32>
+  %4 = stablehlo.constant dense<0> : tensor<i32>
+  %5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor<i32>) -> tensor<i32>
+    reducer(%arg1: tensor<i32>, %arg2: tensor<i32>)  {
+    %6 = stablehlo.add %arg1, %arg2 : tensor<i32>
+    stablehlo.return %6 : tensor<i32>
+  }
+
+  // CHECK: [[CST:%.+]] = stablehlo.constant dense<0> : tensor<i32>
+  // CHECK: return [[CST]] : tensor<i32>
+  return %5 : tensor<i32>
+}