[StableHLO] Port reduce_window to linalg lowering (#13128)

This ports the reduce_window op lowering from mlir-hlo. For the details,
see the initial import: https://github.com/openxla/iree/pull/12957.

The code is cleaned up to match the existing StableHLO -> linalg
conversion patterns. The pattern population logic is updates so that we
do not depend on the order in which patterns are added -- we use pattern
benefit to specify the order instead.

This also moves the previously ported reduction patterns to the same
file as reduce_window.

Issue: https://github.com/openxla/iree/issues/12678
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel
index 4837d58..aa45d90 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel
@@ -52,6 +52,7 @@
         "StableHLOToLinalg.cpp",
         "StableHLOToLinalgDotProd.cpp",
         "StableHLOToLinalgPointwise.cpp",
+        "StableHLOToLinalgReduce.cpp",
         "TypeConversion.cpp",
     ],
     hdrs = [
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt
index a3ea3b2..0829d8d 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt
@@ -49,6 +49,7 @@
     "StableHLOToLinalg.cpp"
     "StableHLOToLinalgDotProd.cpp"
     "StableHLOToLinalgPointwise.cpp"
+    "StableHLOToLinalgReduce.cpp"
     "TypeConversion.cpp"
   DEPS
     ::PassHeaders
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp
index 873122c..8ba8e6f 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp
@@ -159,4 +159,12 @@
          parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
 }
 
+SmallVector<int64_t, 4> extract1DVector(DenseIntElementsAttr elements) {
+  SmallVector<int64_t, 4> ret;
+  for (const APInt& element : elements) {
+    ret.push_back(element.getLimitedValue());
+  }
+  return ret;
+}
+
 }  // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h
index f33b180..33b0a6a 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h
@@ -100,6 +100,14 @@
 /// Returns true if parent op is linalg.
 bool isInBodyOfLinalgOps(Operation* op);
 
+/// Extracts integer values from the attribute |elements|.
+SmallVector<int64_t, 4> extract1DVector(DenseIntElementsAttr elements);
+
+/// Returns true if the given |attr| is a splat of the given |value|.
+inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
+  return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
+}
+
 }  // namespace mlir::iree_compiler::stablehlo
 
 #endif  // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h b/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h
index e13f961..b29aca2 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h
@@ -33,6 +33,12 @@
     MLIRContext* context, TypeConverter& typeConverter,
     RewritePatternSet* patterns);
 
+/// Populates the patterns that convert from reduction StableHLO ops to Linalg
+/// on tensors.
+void populateStableHloReductionToLinalgConversionPatterns(
+    MLIRContext* context, TypeConverter& typeConverter,
+    RewritePatternSet* patterns, bool enablePrimitiveOps);
+
 /// Populates the patterns that convert scalar StableHLO ops to Arith ops.
 void populateScalarHloToArithConversionPatterns(
     MLIRContext* context, TypeConverter& typeConverter,
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
index 13e7529..3b76980 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
@@ -66,38 +66,6 @@
   return getResultValue(op).getType().cast<ShapedType>();
 }
 
-SmallVector<int64_t, 4> extract1DVector(DenseIntElementsAttr elements) {
-  SmallVector<int64_t, 4> ret;
-  for (const APInt& element : elements) {
-    ret.push_back(element.getLimitedValue());
-  }
-  return ret;
-}
-
-/// Returns a permutation AffineMap that puts all reduction dimensions to the
-/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
-/// if `rank` is 4 and `reductionDims` is {1, 3}, then
-/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
-/// the AffineMap is returned.
-AffineMap getTransposeMapForReduction(MLIRContext* context, int rank,
-                                      ArrayRef<int64_t> reductionDims) {
-  llvm::SmallSetVector<int, 4> s;
-  for (auto dim : reductionDims) s.insert(dim);
-
-  SmallVector<unsigned, 4> permutation;
-  for (int i = 0; i < rank; ++i)
-    if (!s.count(i)) permutation.push_back(i);
-  for (auto dim : reductionDims) permutation.push_back(dim);
-
-  auto map = AffineMap::getPermutationMap(permutation, context);
-  return inversePermutation(map);
-}
-
-/// Returns true if the given `attr` is a splat of the given `value`.
-bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
-  return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
-}
-
 /// Extracts an element from a tensor and optionally converts it to an index
 /// type, based on the tensor's pre-type conversion type.
 Value extractIndexFromTensor(OpBuilder& builder, Location loc, Value tensor,
@@ -1787,251 +1755,6 @@
   }
 };
 
-SmallVector<Value, 8> getReduceOpEmptyTensorDynSizes(
-    OpBuilder& b, Location loc, Value arg, ShapedType resultType,
-    ArrayRef<int64_t> reductionDims) {
-  llvm::SmallSetVector<int, 4> s;
-  for (auto dim : reductionDims) s.insert(dim);
-
-  SmallVector<unsigned, 4> parallelDims;
-  SmallVector<Value, 8> dynShape;
-  int rank = arg.getType().cast<RankedTensorType>().getRank();
-  for (int i = 0, j = 0; i < rank; ++i) {
-    if (s.count(i)) continue;
-    if (!resultType.isDynamicDim(j++)) continue;
-    dynShape.push_back(b.create<tensor::DimOp>(loc, arg, i));
-  }
-
-  return dynShape;
-}
-
-class ReduceRegionReturnOpConversion
-    : public OpConversionPattern<stablehlo::ReturnOp> {
- public:
-  using OpConversionPattern<stablehlo::ReturnOp>::OpConversionPattern;
-  LogicalResult matchAndRewrite(
-      stablehlo::ReturnOp op, OpAdaptor adaptor,
-      ConversionPatternRewriter& rewriter) const final {
-    if (!isInBodyOfLinalgOps(op)) {
-      return failure();
-    }
-    SmallVector<Value, 4> operands(adaptor.getOperands());
-    for (size_t i = 0; i < operands.size(); ++i) {
-      if (operands[i].getType().isa<ShapedType>()) {
-        auto loc = operands[i].getLoc();
-        operands[i] = rewriter.create<tensor::ExtractOp>(loc, operands[i]);
-      }
-    }
-    rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, operands);
-    return success();
-  }
-};
-
-class ReduceOpToGenericConverter
-    : public OpConversionPattern<stablehlo::ReduceOp> {
- public:
-  using OpConversionPattern<stablehlo::ReduceOp>::OpConversionPattern;
-  LogicalResult matchAndRewrite(
-      stablehlo::ReduceOp op, OpAdaptor adaptor,
-      ConversionPatternRewriter& rewriter) const final {
-    Location loc = op.getLoc();
-
-    int numOperands = static_cast<int>(adaptor.getInputs().size());
-
-    if (llvm::any_of(adaptor.getInputs(), [](Value v) {
-          return !v.getType().isa<RankedTensorType>();
-        })) {
-      return rewriter.notifyMatchFailure(op, "expects known-rank args");
-    }
-    auto srcRank =
-        adaptor.getInputs()[0].getType().cast<ShapedType>().getRank();
-
-    SmallVector<int64_t, 4> reductionDims = extract1DVector(op.getDimensions());
-
-    SmallVector<Type> resultTypes;
-    if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
-      return failure();
-
-    SmallVector<Value> outputs;
-    SmallVector<AffineMap, 3> indexingMaps;
-    for (auto [operand, initValue, resultType] :
-         llvm::zip(adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) {
-      // Check if init_value is constant. If so, inline the value into the
-      // region.
-      initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
-
-      SmallVector<Value, 8> dynShape = getReduceOpEmptyTensorDynSizes(
-          rewriter, loc, operand, resultType, reductionDims);
-      auto emptyTensor = getEmptyTensor(rewriter, loc, resultType, dynShape);
-      Value filledTensor =
-          rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor).result();
-      outputs.push_back(filledTensor);
-    }
-
-    // Prepare indexing maps for linalg generic op. The elements are for src
-    // and dst. Transpose `src` to make the reduction loops be the innermost,
-    // because it's easier to fully utilize processors.
-    indexingMaps.append(
-        numOperands, getTransposeMapForReduction(rewriter.getContext(),
-                                                 (int)srcRank, reductionDims));
-
-    // The indexing map of `dst` should drop the reduction loops. Since the
-    // reduction loops now are all in the innermost, drops
-    // `reduction_dims.size()` dimensions. We don't need an inverse
-    // permutation here because they are the same.
-    SmallVector<AffineExpr, 4> exprs;
-    for (int i = 0, e = srcRank - reductionDims.size(); i < e; ++i)
-      exprs.push_back(rewriter.getAffineDimExpr(i));
-    indexingMaps.append(numOperands,
-                        AffineMap::get(srcRank, /*symbolCount=*/0, exprs,
-                                       rewriter.getContext()));
-
-    auto linalgOp = rewriter.create<linalg::GenericOp>(
-        loc, /*resultTensorTypes=*/resultTypes, adaptor.getInputs(),
-        /*outputBuffers=*/ValueRange{outputs}, indexingMaps,
-        getParallelAndReductionIterators(srcRank, reductionDims.size()),
-        /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
-
-    // Convert the signature of the body. The reduce op region apply function
-    // has a signature (lhs, rhs) -> output, all of the same tensor type t.
-    // This is converted to a function with the same signature but with
-    // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
-    // be converted to "(f32, f32, f32)".
-    Region& region = linalgOp.getRegion();
-    rewriter.inlineRegionBefore(op.getBody(), region, region.end());
-    TypeConverter::SignatureConversion signatureConverter(numOperands * 2);
-
-    // The mhlo ReduceOp requires that the seed be used as a LHS operand inside
-    // the region, and the seed is encoded in linalg in the intial out value, so
-    // modify the signature of the block and the value mappings, so the output
-    // args will correlate with the original LHS and the inputs correlate with
-    // the original RHS.
-    for (const auto& [idx, val] : llvm::enumerate(op.getInputs())) {
-      signatureConverter.addInputs(
-          /*origInputNo=*/idx + numOperands,
-          // type for the new operand number 'idx'.
-          typeConverter->convertType(
-              val.getType().cast<ShapedType>().getElementType()));
-    }
-    for (const auto& [idx, val] : llvm::enumerate(op.getInitValues())) {
-      signatureConverter.addInputs(
-          /*origInputNo=*/idx,
-          // type for the new operand number 'idx' + 'numOperands'.
-          typeConverter->convertType(
-              val.getType().cast<ShapedType>().getElementType()));
-    }
-
-    rewriter.applySignatureConversion(&region, signatureConverter,
-                                      getTypeConverter());
-    rewriter.replaceOp(op, linalgOp.getResults());
-    return success();
-  }
-};
-
-struct ReduceOpToReduceConverter
-    : public OpConversionPattern<stablehlo::ReduceOp> {
-  using OpConversionPattern<stablehlo::ReduceOp>::OpConversionPattern;
-
-  LogicalResult matchAndRewrite(
-      stablehlo::ReduceOp op, OpAdaptor adaptor,
-      ConversionPatternRewriter& rewriter) const final {
-    auto reductionDims =
-        llvm::to_vector(op.getDimensions().getValues<int64_t>());
-    // mhlo.reduce doesn't specify the order of the reduction dimensions.
-    llvm::sort(reductionDims);
-
-    auto toRankedTensor = [](Value v) -> RankedTensorType {
-      return v.getType().dyn_cast<RankedTensorType>();
-    };
-
-    SmallVector<Value> outputs;
-    SmallVector<RankedTensorType> operandTypes, initTypes;
-    SmallVector<Type> resultTypes;
-    if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
-      return failure();
-
-    Location loc = op.getLoc();
-    for (auto [operand, initValue, resultType] :
-         llvm::zip(adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) {
-      auto initType = toRankedTensor(initValue);
-      if (!initType)
-        return rewriter.notifyMatchFailure(op,
-                                           "expects known-rank init values");
-      initTypes.push_back(initType);
-      auto operandType = toRankedTensor(operand);
-      if (!operandType)
-        return rewriter.notifyMatchFailure(op, "expects known-rank operands");
-      operandTypes.push_back(operandType);
-      initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
-      auto tensorResultType = resultType.cast<RankedTensorType>();
-      // For linalg.reduce, the result type's dimensions must match the input's
-      // dimensions, whereas MHLO allows replacing static dimensions with
-      // dynamic ones.
-      SmallVector<int64_t> resultShape;
-      SmallVector<Value, 8> dynShape;
-      for (auto [index, dim] :
-           llvm::enumerate(operand.getType().cast<ShapedType>().getShape())) {
-        if (!llvm::is_contained(reductionDims, index)) {
-          resultShape.push_back(dim);
-          if (ShapedType::isDynamic(dim)) {
-            dynShape.push_back(
-                rewriter.create<tensor::DimOp>(loc, operand, index));
-          }
-        }
-      }
-
-      Value emptyTensor = rewriter.create<tensor::EmptyOp>(
-          loc, resultShape, tensorResultType.getElementType(), dynShape);
-      Value filledTensor =
-          rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor).result();
-      outputs.push_back(filledTensor);
-    }
-
-    auto linalgOp = rewriter.create<linalg::ReduceOp>(
-        loc, adaptor.getInputs(), outputs, reductionDims,
-        /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
-
-    Region& region = linalgOp.getRegion();
-    rewriter.inlineRegionBefore(op.getBody(), region, region.end());
-
-    // Convert the signature of the body. The reduce op 'computation' region
-    // apply function has a signature with tensor types, this is converted to a
-    // function with element types. E.g. the signature "(tensor<f32>,
-    // tensor<f32>) -> tensor<f32>" will be converted to "(f32, f32) -> f32".
-    // Also, we need to swap the operands of the function. The mhlo.reduce op
-    // expects the init values to be the first parameters of the apply function,
-    // while the linalg.reduction op expects the init values as the last
-    // parameters of the 'combiner' region apply function.
-    TypeConverter::SignatureConversion signatureConverter(
-        linalgOp.getNumDpsInputs() * 2);
-    assert(linalgOp.getNumDpsInputs() == linalgOp.getNumDpsInits());
-    for (const auto& [idx, val] : llvm::enumerate(operandTypes)) {
-      signatureConverter.addInputs(
-          /*origInputNo=*/idx + linalgOp.getNumDpsInputs(),
-          // type for new operand number 'idx'.
-          typeConverter->convertType(val.getElementType()));
-    }
-    for (const auto& [idx, val] : llvm::enumerate(initTypes)) {
-      signatureConverter.addInputs(
-          /*origInputNo=*/idx,
-          // type for new operand number 'idx' + linalgOp.getNumInputs()
-          typeConverter->convertType(val.getElementType()));
-    }
-    rewriter.applySignatureConversion(&region, signatureConverter,
-                                      getTypeConverter());
-
-    // Cast the result to the correct type.
-    SmallVector<Value> results;
-    for (auto [result, resultType] :
-         llvm::zip(linalgOp.getResults(), resultTypes)) {
-      results.push_back(
-          rewriter.createOrFold<tensor::CastOp>(loc, resultType, result));
-    }
-    rewriter.replaceOp(op, results);
-    return success();
-  }
-};
-
 /// Converts xla-hlo.select_and_scatter op to a sequence of linalg.generics ops.
 /// The current version computes the scattered index and populates the correct
 /// value for each tile. It does not currently handle overlapping tiles.
@@ -2727,8 +2450,7 @@
       PadOpConversion,
       PadOpNegativePaddingConversion,
       TorchIndexSelectOpConversion,
-      SelectAndScatterNoOverlapConverter,
-      ReduceRegionReturnOpConversion
+      SelectAndScatterNoOverlapConverter
       >(typeConverter, context);
 
   detail::populatePointwiseStableHloToLinalgConversionPatterns(
@@ -2742,7 +2464,6 @@
       IotaToMapConverter<stablehlo::IotaOp>,
       IotaToMapConverter<stablehlo::DynamicIotaOp>,
       MapOpToMapConverter,
-      ReduceOpToReduceConverter,
       TransposeOpToTransposeConverter
     >(typeConverter, context);
   } else {
@@ -2753,17 +2474,18 @@
       HloBroadcastInDimConverter,
       HloDynamicBroadcastInDimConverter,
       MapOpToGenericConverter,
-      ReduceOpToGenericConverter,
       TransposeConverter<stablehlo::TransposeOp>
     >(typeConverter, context);
   }
 
   // clang-format on
 
-  // TODO(#12678): Handle the convolution and reduce_window ops.
+  // TODO(#12678): Handle the convolution.
 
   detail::populateStableHloDotProdToLinalgConversionPatterns(
       context, typeConverter, patterns);
+  detail::populateStableHloReductionToLinalgConversionPatterns(
+      context, typeConverter, patterns, enablePrimitiveOps);
   detail::populateScalarHloToArithConversionPatterns(
       context, typeConverter, patterns, isInBodyOfLinalgOps);
   linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
new file mode 100644
index 0000000..5f5b54a
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
@@ -0,0 +1,698 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Implements logic for lowering StableHLO reduction ops to Linalg dialect.
+// These patterns are separated out to their own file to save on the compilation
+// times.
+
+#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h"
+#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "stablehlo/dialect/StablehloOps.h"
+
+namespace mlir::iree_compiler::stablehlo {
+namespace {
+namespace stablehlo = mlir::stablehlo;
+
+/// Returns a permutation AffineMap that puts all reduction dimensions to the
+/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
+/// if `rank` is 4 and `reductionDims` is {1, 3}, then
+/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
+/// the AffineMap is returned.
+AffineMap getTransposeMapForReduction(MLIRContext *context, int rank,
+                                      ArrayRef<int64_t> reductionDims) {
+  llvm::SmallSetVector<int, 4> s(reductionDims.begin(), reductionDims.end());
+
+  SmallVector<unsigned, 4> permutation;
+  for (int i = 0; i < rank; ++i) {
+    if (!s.contains(i)) {
+      permutation.push_back(i);
+    }
+  }
+
+  llvm::append_range(permutation, reductionDims);
+  auto map = AffineMap::getPermutationMap(permutation, context);
+  return inversePermutation(map);
+}
+
+SmallVector<Value, 8> getReduceOpEmptyTensorDynSizes(
+    OpBuilder &b, Location loc, Value arg, ShapedType resultType,
+    ArrayRef<int64_t> reductionDims) {
+  llvm::SmallSetVector<int, 4> s(reductionDims.begin(), reductionDims.end());
+
+  SmallVector<unsigned, 4> parallelDims;
+  SmallVector<Value, 8> dynShape;
+  int rank = cast<RankedTensorType>(arg.getType()).getRank();
+  for (int i = 0, j = 0; i < rank; ++i) {
+    if (s.contains(i)) continue;
+    if (!resultType.isDynamicDim(j++)) continue;
+    dynShape.push_back(b.create<tensor::DimOp>(loc, arg, i));
+  }
+
+  return dynShape;
+}
+
+struct ReduceRegionReturnOpConversion final
+    : OpConversionPattern<stablehlo::ReturnOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      stablehlo::ReturnOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    if (!isInBodyOfLinalgOps(op)) {
+      return failure();
+    }
+
+    SmallVector<Value, 4> operands(adaptor.getOperands());
+    for (Value &operand : operands) {
+      if (isa<ShapedType>(operand.getType())) {
+        Location loc = operand.getLoc();
+        operand = rewriter.create<tensor::ExtractOp>(loc, operand);
+      }
+    }
+    rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, operands);
+    return success();
+  }
+};
+
+struct ReduceOpToGenericConverter final
+    : OpConversionPattern<stablehlo::ReduceOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      stablehlo::ReduceOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    int numOperands = static_cast<int>(adaptor.getInputs().size());
+
+    if (llvm::any_of(adaptor.getInputs(), [](Value v) {
+          return !isa<RankedTensorType>(v.getType());
+        })) {
+      return rewriter.notifyMatchFailure(op, "expects known-rank args");
+    }
+    auto srcRank = cast<ShapedType>(adaptor.getInputs()[0].getType()).getRank();
+
+    SmallVector<int64_t, 4> reductionDims = extract1DVector(op.getDimensions());
+
+    SmallVector<Type> resultTypes;
+    if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
+      return failure();
+
+    SmallVector<Value> outputs;
+    SmallVector<AffineMap, 3> indexingMaps;
+    for (auto [operand, initValue, resultType] : llvm::zip_equal(
+             adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) {
+      // Check if init_value is constant. If so, inline the value into the
+      // region.
+      initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
+
+      SmallVector<Value, 8> dynShape = getReduceOpEmptyTensorDynSizes(
+          rewriter, loc, operand, resultType, reductionDims);
+      auto emptyTensor = getEmptyTensor(rewriter, loc, resultType, dynShape);
+      Value filledTensor =
+          rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor).result();
+      outputs.push_back(filledTensor);
+    }
+
+    // Prepare indexing maps for linalg generic op. The elements are for src
+    // and dst. Transpose `src` to make the reduction loops be the innermost,
+    // because it's easier to fully utilize processors.
+    indexingMaps.append(
+        numOperands,
+        getTransposeMapForReduction(rewriter.getContext(),
+                                    static_cast<int>(srcRank), reductionDims));
+
+    // The indexing map of `dst` should drop the reduction loops. Since the
+    // reduction loops now are all in the innermost, drops
+    // `reduction_dims.size()` dimensions. We don't need an inverse
+    // permutation here because they are the same.
+    SmallVector<AffineExpr, 4> exprs;
+    for (int i = 0, e = srcRank - reductionDims.size(); i < e; ++i) {
+      exprs.push_back(rewriter.getAffineDimExpr(i));
+    }
+    indexingMaps.append(numOperands,
+                        AffineMap::get(srcRank, /*symbolCount=*/0, exprs,
+                                       rewriter.getContext()));
+
+    auto linalgOp = rewriter.create<linalg::GenericOp>(
+        loc, /*resultTensorTypes=*/resultTypes, adaptor.getInputs(),
+        /*outputBuffers=*/ValueRange{outputs}, indexingMaps,
+        getParallelAndReductionIterators(srcRank, reductionDims.size()),
+        /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
+
+    // Convert the signature of the body. The reduce op region apply function
+    // has a signature (lhs, rhs) -> output, all of the same tensor type t.
+    // This is converted to a function with the same signature but with
+    // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
+    // be converted to "(f32, f32, f32)".
+    Region &region = linalgOp.getRegion();
+    rewriter.inlineRegionBefore(op.getBody(), region, region.end());
+    TypeConverter::SignatureConversion signatureConverter(numOperands * 2);
+
+    // The mhlo ReduceOp requires that the seed be used as a LHS operand inside
+    // the region, and the seed is encoded in linalg in the intial out value, so
+    // modify the signature of the block and the value mappings, so the output
+    // args will correlate with the original LHS and the inputs correlate with
+    // the original RHS.
+    for (auto [idx, val] : llvm::enumerate(op.getInputs())) {
+      signatureConverter.addInputs(
+          /*origInputNo=*/idx + numOperands,
+          // type for the new operand number 'idx'.
+          typeConverter->convertType(
+              cast<ShapedType>(val.getType()).getElementType()));
+    }
+    for (auto [idx, val] : llvm::enumerate(op.getInitValues())) {
+      signatureConverter.addInputs(
+          /*origInputNo=*/idx,
+          // type for the new operand number 'idx' + 'numOperands'.
+          typeConverter->convertType(
+              cast<ShapedType>(val.getType()).getElementType()));
+    }
+
+    rewriter.applySignatureConversion(&region, signatureConverter,
+                                      getTypeConverter());
+    rewriter.replaceOp(op, linalgOp.getResults());
+    return success();
+  }
+};
+
+struct ReduceOpToReduceConverter final
+    : OpConversionPattern<stablehlo::ReduceOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      stablehlo::ReduceOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto reductionDims =
+        llvm::to_vector(op.getDimensions().getValues<int64_t>());
+    // stablehlo.reduce doesn't specify the order of the reduction dimensions.
+    llvm::sort(reductionDims);
+
+    auto toRankedTensor = [](Value v) -> RankedTensorType {
+      return dyn_cast<RankedTensorType>(v.getType());
+    };
+
+    SmallVector<Value> outputs;
+    SmallVector<RankedTensorType> operandTypes, initTypes;
+    SmallVector<Type> resultTypes;
+    if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
+      return failure();
+
+    Location loc = op.getLoc();
+    for (auto [operand, initValue, resultType] : llvm::zip_equal(
+             adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) {
+      auto initType = toRankedTensor(initValue);
+      if (!initType)
+        return rewriter.notifyMatchFailure(op,
+                                           "expects known-rank init values");
+      initTypes.push_back(initType);
+      auto operandType = toRankedTensor(operand);
+      if (!operandType)
+        return rewriter.notifyMatchFailure(op, "expects known-rank operands");
+      operandTypes.push_back(operandType);
+      initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
+      auto tensorResultType = cast<RankedTensorType>(resultType);
+      // For linalg.reduce, the result type's dimensions must match the input's
+      // dimensions, whereas StableHLO allows replacing static dimensions with
+      // dynamic ones.
+      SmallVector<int64_t> resultShape;
+      SmallVector<Value, 8> dynShape;
+      for (auto [index, dim] :
+           llvm::enumerate(cast<ShapedType>(operand.getType()).getShape())) {
+        if (!llvm::is_contained(reductionDims, index)) {
+          resultShape.push_back(dim);
+          if (ShapedType::isDynamic(dim)) {
+            dynShape.push_back(
+                rewriter.create<tensor::DimOp>(loc, operand, index));
+          }
+        }
+      }
+
+      Value emptyTensor = rewriter.create<tensor::EmptyOp>(
+          loc, resultShape, tensorResultType.getElementType(), dynShape);
+      Value filledTensor =
+          rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor).result();
+      outputs.push_back(filledTensor);
+    }
+
+    auto linalgOp = rewriter.create<linalg::ReduceOp>(
+        loc, adaptor.getInputs(), outputs, reductionDims,
+        /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
+
+    Region &region = linalgOp.getRegion();
+    rewriter.inlineRegionBefore(op.getBody(), region, region.end());
+
+    // Convert the signature of the body. The reduce op 'computation' region
+    // apply function has a signature with tensor types, this is converted to a
+    // function with element types. E.g. the signature "(tensor<f32>,
+    // tensor<f32>) -> tensor<f32>" will be converted to "(f32, f32) -> f32".
+    // Also, we need to swap the operands of the function. The mhlo.reduce op
+    // expects the init values to be the first parameters of the apply function,
+    // while the linalg.reduction op expects the init values as the last
+    // parameters of the 'combiner' region apply function.
+    TypeConverter::SignatureConversion signatureConverter(
+        linalgOp.getNumDpsInputs() * 2);
+    assert(linalgOp.getNumDpsInputs() == linalgOp.getNumDpsInits());
+    for (const auto &[idx, val] : llvm::enumerate(operandTypes)) {
+      signatureConverter.addInputs(
+          /*origInputNo=*/idx + linalgOp.getNumDpsInputs(),
+          // type for new operand number 'idx'.
+          typeConverter->convertType(val.getElementType()));
+    }
+    for (const auto &[idx, val] : llvm::enumerate(initTypes)) {
+      signatureConverter.addInputs(
+          /*origInputNo=*/idx,
+          // type for new operand number 'idx' + linalgOp.getNumInputs()
+          typeConverter->convertType(val.getElementType()));
+    }
+    rewriter.applySignatureConversion(&region, signatureConverter,
+                                      getTypeConverter());
+
+    // Cast the result to the correct type.
+    SmallVector<Value> results;
+    for (auto [result, resultType] :
+         llvm::zip(linalgOp.getResults(), resultTypes)) {
+      results.push_back(
+          rewriter.createOrFold<tensor::CastOp>(loc, resultType, result));
+    }
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
+struct ReduceWindowOpOnTensorsGenericConversion final
+    : OpConversionPattern<stablehlo::ReduceWindowOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      stablehlo::ReduceWindowOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    MLIRContext *ctx = op->getContext();
+    Location loc = op.getLoc();
+    llvm::SmallVector<Value> initValues = adaptor.getInitValues();
+    llvm::SmallVector<Type> resultTypes;
+    if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
+      return failure();
+    auto numOperands = initValues.size();
+
+    llvm::SmallVector<int64_t> windowDimensions =
+        extract1DVector(op.getWindowDimensions());
+
+    llvm::SmallVector<int64_t> padding;
+    if (op.getPadding()) {
+      padding = extract1DVector(*op.getPadding());
+    }
+
+    llvm::SmallVector<int64_t> baseDilations;
+    if (op.getBaseDilations()) {
+      baseDilations = extract1DVector(*op.getBaseDilations());
+    }
+
+    llvm::SmallVector<int64_t> windowStrides(windowDimensions.size(), 1);
+    if (op.getWindowStrides()) {
+      windowStrides = extract1DVector(*op.getWindowStrides());
+    }
+
+    llvm::SmallVector<int64_t> windowDilations(windowDimensions.size(), 1);
+    if (op.getWindowDilations()) {
+      windowDilations = extract1DVector(*op.getWindowDilations());
+    }
+
+    auto rank = static_cast<int64_t>(windowDimensions.size());
+    SmallVector<AffineExpr, 2> srcExprs;
+    SmallVector<AffineExpr, 2> windowExprs;
+    SmallVector<AffineExpr, 2> dstExprs;
+    SmallVector<int64_t> filteredWindowDims;
+
+    int windowDim = 0;
+    for (int64_t i = 0; i < rank; i++) {
+      AffineExpr srcExpr = mlir::getAffineDimExpr(i, ctx);
+
+      if (windowStrides[i] != 1) srcExpr = srcExpr * windowStrides[i];
+
+      if (windowDimensions[i] != 1) {
+        filteredWindowDims.push_back(windowDimensions[i]);
+        AffineExpr windowExpr = mlir::getAffineDimExpr(rank + windowDim, ctx);
+        windowExprs.push_back(windowExpr);
+
+        if (windowDilations[i] != 1)
+          windowExpr = windowExpr * windowDilations[i];
+
+        srcExpr = srcExpr + windowExpr;
+        windowDim++;
+      }
+
+      srcExprs.push_back(srcExpr);
+      dstExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+    }
+
+    SmallVector<AffineMap, 4> inferredMaps(3, AffineMap::get(ctx));
+    if (rank > 0) {
+      inferredMaps =
+          AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs});
+    }
+
+    SmallVector<AffineMap, 4> indexingMaps;
+
+    indexingMaps.append(numOperands, inferredMaps[0]);
+    indexingMaps.append(1, inferredMaps[1]);
+    indexingMaps.append(numOperands, inferredMaps[2]);
+
+    // Setup the initial values.
+    llvm::SmallVector<Value> broadcastValues;
+    for (uint64_t i = 0, s = initValues.size(); i < s; i++) {
+      Value initValue = initValues[i];
+      auto resultTy = resultTypes[i].cast<ShapedType>();
+      if (!resultTy.hasStaticShape()) return failure();
+
+      auto broadcastSizes = rewriter.getI64TensorAttr(resultTy.getShape());
+      broadcastValues.push_back(rewriter.create<stablehlo::BroadcastOp>(
+          loc, resultTy, initValue, broadcastSizes));
+    }
+
+    llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.getInputs());
+
+    // Pad as necessary.
+    if (llvm::any_of(padding, [](int64_t v) { return v != 0; }) ||
+        llvm::any_of(baseDilations, [](int64_t v) { return v != 1; })) {
+      llvm::SmallVector<int64_t> staticLows(rank, 0);
+      llvm::SmallVector<int64_t> staticHighs(rank, 0);
+      for (int64_t i = 0; i < static_cast<int64_t>(padding.size()); i += 2) {
+        staticLows[i / 2] = padding[i];
+        staticHighs[i / 2] = padding[i + 1];
+      }
+      // Translate base dilation into interior padding.
+      llvm::SmallVector<int64_t> staticInteriors(rank, 0);
+      for (auto [idx, dilation] : llvm::enumerate(baseDilations)) {
+        staticInteriors[idx] = dilation - 1;
+      }
+
+      auto padAttrType =
+          RankedTensorType::get({rank}, rewriter.getIntegerType(64));
+      auto padLows = DenseIntElementsAttr::get(padAttrType, staticLows);
+      auto padHighs = DenseIntElementsAttr::get(padAttrType, staticHighs);
+      auto padInteriors =
+          DenseIntElementsAttr::get(padAttrType, staticInteriors);
+
+      for (auto [input, initValue] : llvm::zip(inputs, initValues)) {
+        input = rewriter.create<stablehlo::PadOp>(
+            loc, input, initValue, padLows, padHighs, padInteriors);
+      }
+    }
+
+    // Add the extra input for the reduction dimension.
+    inputs.push_back(rewriter.create<tensor::EmptyOp>(loc, filteredWindowDims,
+                                                      rewriter.getF32Type()));
+
+    auto linalgOp = rewriter.create<linalg::GenericOp>(
+        loc, /*resultTensors=*/resultTypes,
+        /*inputs=*/inputs,
+        /*outputs=*/broadcastValues, indexingMaps,
+        getParallelAndReductionIterators(rank + filteredWindowDims.size(),
+                                         filteredWindowDims.size()),
+        /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
+
+    // Convert the signature of the body. This includes converting scalar
+    // tensors to their scalar values and inserting an additional block arg for
+    // the window arg.
+    Region &region = linalgOp.getRegion();
+    rewriter.cloneRegionBefore(op.getBody(), region, region.end());
+
+    TypeConverter::SignatureConversion signatureConverter(
+        inputs.size() + op->getNumResults() - 1);
+
+    // ReduceWindow requires that the seed be used as a LHS operand inside the
+    // region, and the seed is encoded in linalg in the initial out value, so
+    // modify the signature of the block and the value mappings, so the output
+    // args will correlate with the LHS and the inputs correlate with the RHS.
+    for (auto [i, type] : llvm::enumerate(resultTypes)) {
+      auto idx = inputs.size() + i - 1;
+      signatureConverter.addInputs(idx,
+                                   cast<ShapedType>(type).getElementType());
+    }
+
+    signatureConverter.addInputs(
+        cast<ShapedType>(inputs.back().getType()).getElementType());
+
+    for (auto [i, input] :
+         llvm::enumerate(ArrayRef<Value>(inputs).drop_back())) {
+      signatureConverter.addInputs(
+          i, cast<ShapedType>(input.getType()).getElementType());
+    }
+
+    rewriter.applySignatureConversion(&region, signatureConverter,
+                                      getTypeConverter());
+    rewriter.replaceOp(op, linalgOp.getResults());
+    return success();
+  }
+};
+
+struct ReduceWindowOpConversion final
+    : OpConversionPattern<stablehlo::ReduceWindowOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  /// Get the operation used for reduction applied to `result_index`th result.
+  /// Its expected to be a binary operation that consumes `result_index`th and
+  /// `result_index + getInputs().size`th arguments of the body.
+  static Operation *getReductionOp(stablehlo::ReduceWindowOp op,
+                                   int resultIndex) {
+    auto returnOp =
+        cast<stablehlo::ReturnOp>(op.getBody().front().getTerminator());
+    Operation *computeOp = returnOp.getResults()[resultIndex].getDefiningOp();
+    if (computeOp->getNumOperands() != 2) return nullptr;
+    auto arg0 = computeOp->getOperand(0).dyn_cast<BlockArgument>();
+    auto arg1 = computeOp->getOperand(1).dyn_cast<BlockArgument>();
+    if (!arg0 || !arg1) return nullptr;
+    int64_t arg0Num = arg0.getArgNumber();
+    int64_t arg1Num = arg1.getArgNumber();
+    int64_t otherArgIndex = resultIndex + op.getInputs().size();
+    if (arg0Num == resultIndex && arg1Num == otherArgIndex) return computeOp;
+    if (arg0Num == otherArgIndex && arg1Num == resultIndex &&
+        computeOp->hasTrait<mlir::OpTrait::IsCommutative>())
+      return computeOp;
+    return nullptr;
+  }
+
+  /// stablehlo.reduce_window is mapped to a linalg.pooling operation. The type
+  /// of the pooling is determined based on the body of the reduce window
+  /// operation. This class enumerates the different variants.
+  enum class PoolingType {
+    kInvalid,
+    k2DMin,
+    k3DMin,
+    k2DMax,
+    k3DMax,
+    k2DAdd,
+    k3DAdd,
+  };
+
+  static PoolingType getPoolingType(stablehlo::ReduceWindowOp reduceOp,
+                                    int resultIndex) {
+    auto rank =
+        reduceOp.getResultTypes()[resultIndex].cast<ShapedType>().getRank();
+    if (Operation *op = getReductionOp(reduceOp, resultIndex)) {
+      if (isa<stablehlo::MinOp>(*op) && rank == 4) return PoolingType::k2DMin;
+      if (isa<stablehlo::MinOp>(*op) && rank == 5) return PoolingType::k3DMin;
+      if (isa<stablehlo::MaxOp>(*op) && rank == 4) return PoolingType::k2DMax;
+      if (isa<stablehlo::MaxOp>(*op) && rank == 5) return PoolingType::k3DMax;
+      if (isa<stablehlo::AddOp>(*op) && rank == 4) return PoolingType::k2DAdd;
+      if (isa<stablehlo::AddOp>(*op) && rank == 5) return PoolingType::k3DAdd;
+    }
+    return PoolingType::kInvalid;
+  }
+
+  LogicalResult matchAndRewrite(
+      stablehlo::ReduceWindowOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
+    if (rank != 4 && rank != 5) {
+      return rewriter.notifyMatchFailure(
+          op, "expected NHWC/NDHWC pooling-based op");
+    }
+
+    if (op.getPadding() && !isSplatValue(*op.getPadding(), 0)) {
+      return rewriter.notifyMatchFailure(op, "require paddings are all zero");
+    }
+
+    if (op.getBaseDilations() && !isSplatValue(*op.getBaseDilations(), 1)) {
+      return rewriter.notifyMatchFailure(op, "expected undilated base");
+    }
+
+    int lastDim = rank - 1;
+    SmallVector<int64_t, 2> fakeWindowShapes;
+    for (int i = 1; i < lastDim; ++i) {
+      fakeWindowShapes.push_back(
+          op.getWindowDimensions().getValues<int64_t>()[i]);
+    }
+
+    if (op.getWindowStrides() &&
+        (op.getWindowStrides().value().getValues<int64_t>()[0] != 1 ||
+         op.getWindowStrides().value().getValues<int64_t>()[lastDim] != 1)) {
+      return rewriter.notifyMatchFailure(
+          op, "expected window_strides to be [1,x,y,(z),1]");
+    }
+    if (op.getWindowDimensions() &&
+        (op.getWindowDimensions().getValues<int64_t>()[0] != 1 ||
+         op.getWindowDimensions().getValues<int64_t>()[lastDim] != 1)) {
+      return rewriter.notifyMatchFailure(
+          op, "expected window_dimensions to be [1,x,y,(z),1]");
+    }
+
+    Attribute strides;
+    SmallVector<int64_t> vec;
+    if (op.getWindowStridesAttr()) {
+      for (int i = 1; i < lastDim; ++i) {
+        vec.push_back(op.getWindowStrides().value().getValues<int64_t>()[i]);
+      }
+    } else {
+      vec.assign(rank - 2, 1);
+    }
+    strides = rewriter.getI64VectorAttr(vec);
+
+    Attribute dilations;
+    vec.clear();
+    if (op.getWindowDilations()) {
+      for (int i = 1; i < lastDim; ++i) {
+        vec.push_back(op.getWindowDilations().value().getValues<int64_t>()[i]);
+      }
+    } else {
+      vec.assign(rank - 2, 1);
+    }
+    dilations = rewriter.getI64VectorAttr(vec);
+
+    SmallVector<Value> poolingOps;
+
+    ValueRange operands = adaptor.getInputs();
+    ValueRange initValues = adaptor.getInitValues();
+    for (auto it : llvm::zip(op.getResults(), operands, initValues)) {
+      OpResult result = std::get<0>(it);
+      Value input = std::get<1>(it);
+      Value initValue = std::get<2>(it);
+      auto resultType = cast<ShapedType>(result.getType());
+      if (!cast<ShapedType>(input.getType()).getElementType().isF32()) {
+        return rewriter.notifyMatchFailure(op,
+                                           "expected element type to be f32");
+      }
+
+      // Create a fake window dimension.
+      auto fakeWindowDims = rewriter.create<tensor::EmptyOp>(
+          loc, fakeWindowShapes, resultType.getElementType());
+
+      SmallVector<Value> resultDynamicDims;
+      for (const auto &en : llvm::enumerate(resultType.getShape())) {
+        if (en.value() != ShapedType::kDynamic) continue;
+        Value dimSize = rewriter.create<tensor::DimOp>(loc, input, en.index());
+        if (en.index() == 0 || static_cast<int64_t>(en.index()) == rank - 1) {
+          // batch dims and channel dims can be derived from input dims
+          // directly.
+          resultDynamicDims.push_back(dimSize);
+        } else {
+          auto i = en.index() - 1;
+          auto stride =
+              strides.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
+          auto dilation =
+              dilations.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
+          // let j = i * stride
+          // output[i] = reduce( input[j, j + window_size * dilation) )
+          Value offset = rewriter.create<arith::ConstantIndexOp>(
+              loc, fakeWindowShapes[i] * dilation);
+          dimSize = rewriter.create<arith::SubIOp>(loc, dimSize, offset);
+          dimSize = rewriter.create<arith::DivUIOp>(
+              loc, dimSize,
+              rewriter.create<arith::ConstantIndexOp>(loc, stride));
+          dimSize = rewriter.create<arith::AddIOp>(
+              loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, 1));
+          resultDynamicDims.push_back(dimSize);
+        }
+      }
+      Value emptyTensor = rewriter.create<tensor::EmptyOp>(
+          loc, resultType.getShape(), resultType.getElementType(),
+          resultDynamicDims);
+
+      initValue = rewriter.create<tensor::ExtractOp>(loc, initValue);
+      Value filledInitTensor =
+          rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor)
+              .getResult(0);
+      auto createOp = [&](auto *typePtr) -> linalg::LinalgOp {
+        return cast<linalg::LinalgOp>(
+            rewriter
+                .create<std::remove_pointer_t<decltype(typePtr)>>(
+                    loc, ArrayRef<Type>{resultType},
+                    ValueRange{input, fakeWindowDims.getResult()},
+                    filledInitTensor, strides, dilations,
+                    linalg::getPrunedAttributeList(op))
+                .getOperation());
+      };
+      linalg::LinalgOp poolingOp;
+      PoolingType poolingType = getPoolingType(op, result.getResultNumber());
+      switch (poolingType) {
+        case PoolingType::k2DMin: {
+          poolingOp =
+              createOp(static_cast<linalg::PoolingNhwcMinOp *>(nullptr));
+          break;
+        }
+        case PoolingType::k3DMin: {
+          poolingOp =
+              createOp(static_cast<linalg::PoolingNdhwcMinOp *>(nullptr));
+          break;
+        }
+        case PoolingType::k2DMax: {
+          poolingOp =
+              createOp(static_cast<linalg::PoolingNhwcMaxOp *>(nullptr));
+          break;
+        }
+        case PoolingType::k3DMax: {
+          poolingOp =
+              createOp(static_cast<linalg::PoolingNdhwcMaxOp *>(nullptr));
+          break;
+        }
+        case PoolingType::k2DAdd: {
+          poolingOp =
+              createOp(static_cast<linalg::PoolingNhwcSumOp *>(nullptr));
+          break;
+        }
+        case PoolingType::k3DAdd: {
+          poolingOp =
+              createOp(static_cast<linalg::PoolingNdhwcSumOp *>(nullptr));
+          break;
+        }
+        case PoolingType::kInvalid:
+          return rewriter.notifyMatchFailure(op, "unknown reduction operation");
+      }
+      poolingOps.push_back(poolingOp->getResult(0));
+    }
+    rewriter.replaceOp(op, poolingOps);
+    return success();
+  }
+};
+
+}  // namespace
+
+namespace detail {
+void populateStableHloReductionToLinalgConversionPatterns(
+    MLIRContext *context, TypeConverter &typeConverter,
+    RewritePatternSet *patterns, bool enablePrimitiveOps) {
+  if (enablePrimitiveOps) {
+    patterns->add<ReduceOpToReduceConverter>(typeConverter, context);
+  } else {
+    patterns->add<ReduceOpToGenericConverter>(typeConverter, context);
+  }
+  patterns->add<ReduceRegionReturnOpConversion,
+                ReduceWindowOpOnTensorsGenericConversion>(typeConverter,
+                                                          context);
+
+  // Ensure specialized patterns are higher priority than their generic
+  // versions.
+  patterns->add<ReduceWindowOpConversion>(typeConverter, context,
+                                          PatternBenefit(2));
+}
+}  // namespace detail
+}  // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
index 156011c..28bd8bb 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
@@ -355,3 +355,419 @@
 // CHECK-PRIMITIVE-NEXT:      %[[RES0:.*]] = arith.select %[[B0]], %[[RHS0]], %[[LHS0]] : f32
 // CHECK-PRIMITIVE-NEXT:      %[[RES1:.*]] = arith.select %[[B0]], %[[RHS1]], %[[LHS1]] : i32
 // CHECK-PRIMITIVE-NEXT:      linalg.yield %[[RES0]], %[[RES1]] : f32, i32
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_min_nhwc
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_min_nhwc(%arg0: tensor<1x17x17x64xf32>,
+                             %arg1: tensor<f32>) -> tensor<1x8x8x64xf32>{
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+    %1 = stablehlo.minimum %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+      window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
+      someattr} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+  func.return %0 : tensor<1x8x8x64xf32>
+}
+// CHECK:         %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_nhwc_min
+// CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME:       someattr,
+// CHECK-SAME:       strides = dense<2> : vector<2xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_max_nhwc
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_max_nhwc(%arg0: tensor<1x17x17x64xf32>,
+                             %arg1: tensor<f32>) -> tensor<1x8x8x64xf32>{
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+    %1 = stablehlo.maximum %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+      window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+  func.return %0 : tensor<1x8x8x64xf32>
+}
+// CHECK:         %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_nhwc_max
+// CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME:       strides = dense<2> : vector<2xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_sum_nhwc
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_sum_nhwc(%arg0: tensor<1x17x17x64xf32>,
+                             %arg1: tensor<f32>) -> tensor<1x8x8x64xf32>{
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+      window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+  func.return %0 : tensor<1x8x8x64xf32>
+}
+// CHECK:         %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_nhwc_sum
+// CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME:       strides = dense<2> : vector<2xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_max_nhwc_with_cst
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+func.func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x17x17x64xf32>) -> tensor<1x8x8x64xf32> {
+  %0 = arith.constant dense<0xFF800000> : tensor<f32>
+  %1 = "stablehlo.reduce_window"(%arg0, %0) ({
+  ^bb0(%arg1: tensor<f32>, %arg2 : tensor<f32>):
+    %2 = stablehlo.maximum %arg1, %arg2 : tensor<f32>
+    "stablehlo.return"(%2) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+      window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+  func.return %1 : tensor<1x8x8x64xf32>
+}
+
+// CHECK-DAG:     %[[CST:.+]] = arith.constant 0xFF800000
+// CHECK:         %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_nhwc_max
+// CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME:       strides = dense<2> : vector<2xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_sum_max_nhwc
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_sum_max_nhwc(%arg0: tensor<1x17x17x64xf32>,
+                             %arg1: tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) {
+  %0:2 = "stablehlo.reduce_window"(%arg0, %arg0, %arg1, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>, %arg4: tensor<f32>, %arg5 : tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg4 : tensor<f32>
+    %2 = stablehlo.maximum %arg3, %arg5 : tensor<f32>
+    "stablehlo.return"(%1, %2) : (tensor<f32>, tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+      window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<1x17x17x64xf32>, tensor<f32>, tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>)
+  func.return %0#0, %0#1 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>
+}
+
+// CHECK:         %[[WINDOW0:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK:         %[[INIT0:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL0:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES0:.+]] = linalg.pooling_nhwc_sum
+// CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME:       strides = dense<2> : vector<2xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW0]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK:         %[[WINDOW1:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK:         %[[INIT1:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK:         %[[INIT_VAL1:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL1:.+]] = linalg.fill ins(%[[INIT_VAL1]] : f32) outs(%[[INIT1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES1:.+]] = linalg.pooling_nhwc_max
+// CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME:       strides = dense<2> : vector<2xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW1]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK:         return %[[RES0]], %[[RES1]]
+
+// -----
+
+// Just check that this lowers successfully.
+// CHECK-LABEL: func @reduce_window_unsigned
+func.func @reduce_window_unsigned(%arg0: tensor<1x1xui32>) -> tensor<1x1xui32> {
+  %0 = stablehlo.constant dense<0> : tensor<ui32>
+  %1 = "stablehlo.reduce_window"(%arg0, %0) ({
+  ^bb0(%arg1: tensor<ui32>, %arg2: tensor<ui32>):
+    stablehlo.return %arg1 : tensor<ui32>
+  }) {
+    window_dimensions = dense<[1, 1]> : tensor<2xi64>,
+    window_strides = dense<[1, 1]> : tensor<2xi64>
+  } : (tensor<1x1xui32>, tensor<ui32>) -> tensor<1x1xui32>
+  return %1 : tensor<1x1xui32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dynamic_reduce_window_sum_nhwc
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @dynamic_reduce_window_sum_nhwc(%arg0: tensor<?x?x?x?xf32>,
+                                      %arg1: tensor<f32>) -> tensor<?x?x?x?xf32>{
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+      window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
+  func.return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C3:.+]] = arith.constant 3 : index
+// CHECK:         %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK:         %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK:         %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK:         %[[T2:.+]] = arith.subi %[[T1]], %[[C3]]
+// CHECK:         %[[T3:.+]] = arith.divui %[[T2]], %[[C2]]
+// CHECK:         %[[D1:.+]] = arith.addi %[[T3]], %[[C1]]
+// CHECK:         %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK:         %[[T2:.+]] = arith.subi %[[T1]], %[[C3]]
+// CHECK:         %[[T3:.+]] = arith.divui %[[T2]], %[[C2]]
+// CHECK:         %[[D2:.+]] = arith.addi %[[T3]], %[[C1]]
+// CHECK:         %[[D3:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?xf32>
+// CHECK:         %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]], %[[D3]]) : tensor<?x?x?x?xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_nhwc_sum
+// CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME:       strides = dense<2> : vector<2xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<?x?x?x?xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_min_ndhwc
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_min_ndhwc(%arg0: tensor<1x17x17x17x64xf32>,
+                              %arg1: tensor<f32>) -> tensor<1x8x8x8x64xf32>{
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+    %1 = stablehlo.minimum %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
+      window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
+  func.return %0 : tensor<1x8x8x8x64xf32>
+}
+// CHECK:         %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32>
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_ndhwc_min
+// CHECK-SAME:      {dilations = dense<1> : vector<3xi64>
+// CHECK-SAME:       strides = dense<2> : vector<3xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_max_ndhwc
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_max_ndhwc(%arg0: tensor<1x17x17x17x64xf32>,
+                              %arg1: tensor<f32>) -> tensor<1x8x8x8x64xf32>{
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+    %1 = stablehlo.maximum %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
+      window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
+  func.return %0 : tensor<1x8x8x8x64xf32>
+}
+// CHECK:         %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32>
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_ndhwc_max
+// CHECK-SAME:      {dilations = dense<1> : vector<3xi64>
+// CHECK-SAME:       strides = dense<2> : vector<3xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_sum_ndhwc
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_sum_ndhwc(%arg0: tensor<1x17x17x17x64xf32>,
+                              %arg1: tensor<f32>) -> tensor<1x8x8x8x64xf32>{
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
+      window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
+  func.return %0 : tensor<1x8x8x8x64xf32>
+}
+// CHECK:         %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32>
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_ndhwc_sum
+// CHECK-SAME:      {dilations = dense<1> : vector<3xi64>
+// CHECK-SAME:       strides = dense<2> : vector<3xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_sum_ndhwc_dilated_base
+// CHECK: linalg.generic
+func.func @reduce_window_sum_ndhwc_dilated_base(
+    %arg0: tensor<1x17x17x17x64xf32>,
+    %arg1: tensor<f32>) -> tensor<1x8x8x16x64xf32>{
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {base_dilations = dense<[1, 1, 1, 2, 1]> : tensor<5xi64>,
+      window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
+      window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x16x64xf32>
+  func.return %0 : tensor<1x8x8x16x64xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 * 2, d1 + d2 * 2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK:      func @reduce_window_generic
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_generic(%arg0: tensor<4x6xf32>, %arg1: tensor<f32>) -> tensor<4x7xf32> {
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+  func.return %0 : tensor<4x7xf32>
+}
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<4x7xf32>
+// CHECK: %[[FILL:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<f32>) outs(%[[INIT]] : tensor<4x7xf32>)
+// CHECK: ^{{[a-z0-9_]*}}
+// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32
+// CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32
+// CHECK:   linalg.yield %[[IN]] : f32
+
+// CHECK: %[[PADVAL:.+]] = tensor.extract %arg1[] : tensor<f32>
+// CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1] high[3, 2]
+// CHECK: ^{{[a-z0-9_]*}}
+// CHECK-SAME: %{{[a-zA-Z0-9_]*}}: index
+// CHECK-SAME: %{{[a-zA-Z0-9_]*}}: index
+// CHECK:   tensor.yield %[[PADVAL]] : f32
+
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<2xf32>
+// CHECK: %[[REDUCE:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[PAD]], %[[WINDOW]] : tensor<7x9xf32>, tensor<2xf32>) outs(%[[FILL]] : tensor<4x7xf32>) {
+// CHECK: ^{{[a-z0-9_]*}}
+// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32
+// CHECK-SAME: %[[IN2:[a-zA-Z0-9_]*]]: f32
+// CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32
+// CHECK:   %[[ADD:.+]] = arith.addf %[[OUT]], %[[IN]] : f32
+// CHECK:   linalg.yield %[[ADD]]
+
+// CHECK: return %[[REDUCE]]
+// -----
+
+// CHECK-LABEL: func @reduce_window_generic_captured_constant
+func.func @reduce_window_generic_captured_constant(%arg0: tensor<4x6xf32>, %arg1: tensor<f32>) -> tensor<4x7xf32> {
+  %c2 = stablehlo.constant dense<2.0> : tensor<f32>
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    %2 = stablehlo.multiply %1, %c2 : tensor<f32>
+    "stablehlo.return"(%2) : (tensor<f32>) -> ()
+  }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+  func.return %0 : tensor<4x7xf32>
+}
+
+// CHECK: %[[C2:.*]] = arith.constant 2.0
+// CHECK: linalg.generic
+// CHECK: %[[SUM:.*]] = arith.addf
+// CHECK: %[[PROD:.*]] = arith.mulf %[[SUM]], %[[C2]]
+// CHECK: linalg.yield %[[PROD]]
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_generic_padding
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_generic_padding(%arg0: tensor<3x6xf32>, %arg1: tensor<f32>) -> tensor<3x7xf32> {
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<3x7xf32>
+  func.return %0 : tensor<3x7xf32>
+}
+// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[0, 1] high[3, 2]
+// CHECK: tensor.yield %[[PADVAL]] : f32
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_generic_base_dilation
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_generic_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor<f32>) -> tensor<3x4xf32> {
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {base_dilations = dense<[2, 1]> : tensor<2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<3x4xf32>
+  func.return %0 : tensor<3x4xf32>
+}
+// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 0] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<5x6xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_generic_padding_base_dilation
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_generic_padding_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor<f32>) -> tensor<4x7xf32> {
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {base_dilations = dense<[2, 1]> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+  func.return %0 : tensor<4x7xf32>
+}
+// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<8x9xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<8x9xf32>) -> tensor<8x9xf32>
+// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 1] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<8x9xf32>
+
+// -----
+
+// CHECK: #[[MAP:.+]] = affine_map<() -> ()>
+// CHECK: func @reduce_window_generic_scalar
+func.func @reduce_window_generic_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
+  %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+    %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+    "stablehlo.return"(%1) : (tensor<f32>) -> ()
+  }) {base_dilations = dense<> : tensor<0xi64>, padding = dense<> : tensor<0x2xi64>, window_dilations = dense<> : tensor<0xi64>, window_dimensions = dense<> : tensor<0xi64>, window_strides = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  func.return %0 : tensor<f32>
+}
+// CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]