[StableHLO] Make reduce lowering more robust (#14046)

Check if reduce ops are supported. This is so that these patterns can be
given any reduce, even those that would be normally folded away by canon
patterns.

Issue: https://github.com/openxla/iree/issues/14042
Issue: https://github.com/openxla/iree/issues/12678
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
index 0d0c315..cc9d782 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
@@ -10,6 +10,7 @@
 
 #include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h"
 #include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
+#include "llvm/ADT/STLExtras.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -17,6 +18,23 @@
 
 namespace mlir::iree_compiler::stablehlo {
 namespace {
+/// Returns true when reduction `op` is not supported and should be filtered
+/// out.
+static bool isUnsupported(mlir::stablehlo::ReduceOp op) {
+  // Empty reductions are not supported. We expect canonicalization patterns to
+  // handle them.
+  if (op.getDimensions().empty()) return true;
+
+  // We require all reduce shapes to be the same, up to the element types, so
+  // we can just the first operand and the first result as a representative.
+  if (auto inputTy =
+          dyn_cast<RankedTensorType>(op.getInputs().getType().front())) {
+    return llvm::is_contained(inputTy.getShape(), 0);
+  }
+
+  return false;
+}
+
 /// Returns a permutation AffineMap that puts all reduction dimensions to the
 /// last. The order of parallel loops and reduction loops are all sorted. E.g.,
 /// if `rank` is 4 and `reductionDims` is {1, 3}, then
@@ -85,6 +103,11 @@
   LogicalResult matchAndRewrite(
       mlir::stablehlo::ReduceOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
+    if (isUnsupported(op)) {
+      return rewriter.notifyMatchFailure(op,
+                                         "unsupported reduce (noop or empty)");
+    }
+
     Location loc = op.getLoc();
 
     int numOperands = static_cast<int>(adaptor.getInputs().size());
@@ -154,11 +177,11 @@
     rewriter.inlineRegionBefore(op.getBody(), region, region.end());
     TypeConverter::SignatureConversion signatureConverter(numOperands * 2);
 
-    // The mhlo ReduceOp requires that the seed be used as a LHS operand inside
-    // the region, and the seed is encoded in linalg in the intial out value, so
-    // modify the signature of the block and the value mappings, so the output
-    // args will correlate with the original LHS and the inputs correlate with
-    // the original RHS.
+    // The stablehlo ReduceOp requires that the seed be used as a LHS operand
+    // inside the region, and the seed is encoded in linalg in the initial out
+    // value, so modify the signature of the block and the value mappings, so
+    // the output args will correlate with the original LHS and the inputs
+    // correlate with the original RHS.
     for (auto [idx, val] : llvm::enumerate(op.getInputs())) {
       signatureConverter.addInputs(
           /*origInputNo=*/idx + numOperands,
@@ -188,6 +211,11 @@
   LogicalResult matchAndRewrite(
       mlir::stablehlo::ReduceOp op, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
+    if (isUnsupported(op)) {
+      return rewriter.notifyMatchFailure(op,
+                                         "unsupported reduce (noop or empty)");
+    }
+
     auto reductionDims =
         llvm::to_vector(op.getDimensions().getValues<int64_t>());
     // stablehlo.reduce doesn't specify the order of the reduction dimensions.
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
index 28bd8bb..83d7ca6 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
@@ -358,6 +358,42 @@
 
 // -----
 
+// Make sure we do not crash on unsupported reductions.
+
+// CHECK-LABEL: func.func @reduce_noop
+// CHECK:         stablehlo.reduce
+// CHECK-PRIMITIVE-LABEL: func.func @reduce_noop
+// CHECK-PRIMITIVE:         stablehlo.reduce
+func.func @reduce_noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+  %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
+    reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) {
+    %4 = stablehlo.add %arg1, %arg2 : tensor<f32>
+    stablehlo.return %4 : tensor<f32>
+  }
+  func.return %1 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func.func @reduce_zero_ext
+// CHECK:         stablehlo.reduce
+// CHECK-PRIMITIVE-LABEL: func.func @reduce_zero_ext
+// CHECK-PRIMITIVE:         stablehlo.reduce
+func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor<i32> {
+  %0 = stablehlo.constant dense<false> : tensor<i1>
+  %1 = stablehlo.constant dense<false> : tensor<0xi1>
+  %2 = stablehlo.compare  NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
+  %3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32>
+  %4 = stablehlo.constant dense<0> : tensor<i32>
+  %5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor<i32>) -> tensor<i32>
+    reducer(%arg1: tensor<i32>, %arg2: tensor<i32>)  {
+    %6 = stablehlo.add %arg1, %arg2 : tensor<i32>
+    stablehlo.return %6 : tensor<i32>
+  }
+  return %5 : tensor<i32>
+}
+
+// -----
+
 // CHECK-LABEL: func @reduce_window_min_nhwc
 // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
 // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]