[StableHLO] Make reduce lowering more robust (#14046)
Check if reduce ops are supported. This is so that these patterns can be
given any reduce, even those that would be normally folded away by canon
patterns.
Issue: https://github.com/openxla/iree/issues/14042
Issue: https://github.com/openxla/iree/issues/12678
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
index 0d0c315..cc9d782 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h"
#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
+#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -17,6 +18,23 @@
namespace mlir::iree_compiler::stablehlo {
namespace {
+/// Returns true when reduction `op` is not supported and should be filtered
+/// out.
+static bool isUnsupported(mlir::stablehlo::ReduceOp op) {
+ // Empty reductions are not supported. We expect canonicalization patterns to
+ // handle them.
+ if (op.getDimensions().empty()) return true;
+
+ // We require all reduce shapes to be the same, up to the element types, so
+ // we can just the first operand and the first result as a representative.
+ if (auto inputTy =
+ dyn_cast<RankedTensorType>(op.getInputs().getType().front())) {
+ return llvm::is_contained(inputTy.getShape(), 0);
+ }
+
+ return false;
+}
+
/// Returns a permutation AffineMap that puts all reduction dimensions to the
/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
/// if `rank` is 4 and `reductionDims` is {1, 3}, then
@@ -85,6 +103,11 @@
LogicalResult matchAndRewrite(
mlir::stablehlo::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (isUnsupported(op)) {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported reduce (noop or empty)");
+ }
+
Location loc = op.getLoc();
int numOperands = static_cast<int>(adaptor.getInputs().size());
@@ -154,11 +177,11 @@
rewriter.inlineRegionBefore(op.getBody(), region, region.end());
TypeConverter::SignatureConversion signatureConverter(numOperands * 2);
- // The mhlo ReduceOp requires that the seed be used as a LHS operand inside
- // the region, and the seed is encoded in linalg in the intial out value, so
- // modify the signature of the block and the value mappings, so the output
- // args will correlate with the original LHS and the inputs correlate with
- // the original RHS.
+ // The stablehlo ReduceOp requires that the seed be used as a LHS operand
+ // inside the region, and the seed is encoded in linalg in the initial out
+ // value, so modify the signature of the block and the value mappings, so
+ // the output args will correlate with the original LHS and the inputs
+ // correlate with the original RHS.
for (auto [idx, val] : llvm::enumerate(op.getInputs())) {
signatureConverter.addInputs(
/*origInputNo=*/idx + numOperands,
@@ -188,6 +211,11 @@
LogicalResult matchAndRewrite(
mlir::stablehlo::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (isUnsupported(op)) {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported reduce (noop or empty)");
+ }
+
auto reductionDims =
llvm::to_vector(op.getDimensions().getValues<int64_t>());
// stablehlo.reduce doesn't specify the order of the reduction dimensions.
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
index 28bd8bb..83d7ca6 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
@@ -358,6 +358,42 @@
// -----
+// Make sure we do not crash on unsupported reductions.
+
+// CHECK-LABEL: func.func @reduce_noop
+// CHECK: stablehlo.reduce
+// CHECK-PRIMITIVE-LABEL: func.func @reduce_noop
+// CHECK-PRIMITIVE: stablehlo.reduce
+func.func @reduce_noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+ %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [] : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
+ reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) {
+ %4 = stablehlo.add %arg1, %arg2 : tensor<f32>
+ stablehlo.return %4 : tensor<f32>
+ }
+ func.return %1 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func.func @reduce_zero_ext
+// CHECK: stablehlo.reduce
+// CHECK-PRIMITIVE-LABEL: func.func @reduce_zero_ext
+// CHECK-PRIMITIVE: stablehlo.reduce
+func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor<i32> {
+ %0 = stablehlo.constant dense<false> : tensor<i1>
+ %1 = stablehlo.constant dense<false> : tensor<0xi1>
+ %2 = stablehlo.compare NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
+ %3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32>
+ %4 = stablehlo.constant dense<0> : tensor<i32>
+ %5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor<i32>) -> tensor<i32>
+ reducer(%arg1: tensor<i32>, %arg2: tensor<i32>) {
+ %6 = stablehlo.add %arg1, %arg2 : tensor<i32>
+ stablehlo.return %6 : tensor<i32>
+ }
+ return %5 : tensor<i32>
+}
+
+// -----
+
// CHECK-LABEL: func @reduce_window_min_nhwc
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]