[StableHLO] Port reduce_window to linalg lowering (#13128)
This ports the reduce_window op lowering from mlir-hlo. For the details,
see the initial import: https://github.com/openxla/iree/pull/12957.
The code is cleaned up to match the existing StableHLO -> linalg
conversion patterns. The pattern population logic is updates so that we
do not depend on the order in which patterns are added -- we use pattern
benefit to specify the order instead.
This also moves the previously ported reduction patterns to the same
file as reduce_window.
Issue: https://github.com/openxla/iree/issues/12678
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel
index 4837d58..aa45d90 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel
@@ -52,6 +52,7 @@
"StableHLOToLinalg.cpp",
"StableHLOToLinalgDotProd.cpp",
"StableHLOToLinalgPointwise.cpp",
+ "StableHLOToLinalgReduce.cpp",
"TypeConversion.cpp",
],
hdrs = [
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt
index a3ea3b2..0829d8d 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt
@@ -49,6 +49,7 @@
"StableHLOToLinalg.cpp"
"StableHLOToLinalgDotProd.cpp"
"StableHLOToLinalgPointwise.cpp"
+ "StableHLOToLinalgReduce.cpp"
"TypeConversion.cpp"
DEPS
::PassHeaders
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp
index 873122c..8ba8e6f 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp
@@ -159,4 +159,12 @@
parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
}
+SmallVector<int64_t, 4> extract1DVector(DenseIntElementsAttr elements) {
+ SmallVector<int64_t, 4> ret;
+ for (const APInt& element : elements) {
+ ret.push_back(element.getLimitedValue());
+ }
+ return ret;
+}
+
} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h
index f33b180..33b0a6a 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h
@@ -100,6 +100,14 @@
/// Returns true if parent op is linalg.
bool isInBodyOfLinalgOps(Operation* op);
+/// Extracts integer values from the attribute |elements|.
+SmallVector<int64_t, 4> extract1DVector(DenseIntElementsAttr elements);
+
+/// Returns true if the given |attr| is a splat of the given |value|.
+inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
+ return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
+}
+
} // namespace mlir::iree_compiler::stablehlo
#endif // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h b/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h
index e13f961..b29aca2 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h
@@ -33,6 +33,12 @@
MLIRContext* context, TypeConverter& typeConverter,
RewritePatternSet* patterns);
+/// Populates the patterns that convert from reduction StableHLO ops to Linalg
+/// on tensors.
+void populateStableHloReductionToLinalgConversionPatterns(
+ MLIRContext* context, TypeConverter& typeConverter,
+ RewritePatternSet* patterns, bool enablePrimitiveOps);
+
/// Populates the patterns that convert scalar StableHLO ops to Arith ops.
void populateScalarHloToArithConversionPatterns(
MLIRContext* context, TypeConverter& typeConverter,
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
index 13e7529..3b76980 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
@@ -66,38 +66,6 @@
return getResultValue(op).getType().cast<ShapedType>();
}
-SmallVector<int64_t, 4> extract1DVector(DenseIntElementsAttr elements) {
- SmallVector<int64_t, 4> ret;
- for (const APInt& element : elements) {
- ret.push_back(element.getLimitedValue());
- }
- return ret;
-}
-
-/// Returns a permutation AffineMap that puts all reduction dimensions to the
-/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
-/// if `rank` is 4 and `reductionDims` is {1, 3}, then
-/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
-/// the AffineMap is returned.
-AffineMap getTransposeMapForReduction(MLIRContext* context, int rank,
- ArrayRef<int64_t> reductionDims) {
- llvm::SmallSetVector<int, 4> s;
- for (auto dim : reductionDims) s.insert(dim);
-
- SmallVector<unsigned, 4> permutation;
- for (int i = 0; i < rank; ++i)
- if (!s.count(i)) permutation.push_back(i);
- for (auto dim : reductionDims) permutation.push_back(dim);
-
- auto map = AffineMap::getPermutationMap(permutation, context);
- return inversePermutation(map);
-}
-
-/// Returns true if the given `attr` is a splat of the given `value`.
-bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
- return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
-}
-
/// Extracts an element from a tensor and optionally converts it to an index
/// type, based on the tensor's pre-type conversion type.
Value extractIndexFromTensor(OpBuilder& builder, Location loc, Value tensor,
@@ -1787,251 +1755,6 @@
}
};
-SmallVector<Value, 8> getReduceOpEmptyTensorDynSizes(
- OpBuilder& b, Location loc, Value arg, ShapedType resultType,
- ArrayRef<int64_t> reductionDims) {
- llvm::SmallSetVector<int, 4> s;
- for (auto dim : reductionDims) s.insert(dim);
-
- SmallVector<unsigned, 4> parallelDims;
- SmallVector<Value, 8> dynShape;
- int rank = arg.getType().cast<RankedTensorType>().getRank();
- for (int i = 0, j = 0; i < rank; ++i) {
- if (s.count(i)) continue;
- if (!resultType.isDynamicDim(j++)) continue;
- dynShape.push_back(b.create<tensor::DimOp>(loc, arg, i));
- }
-
- return dynShape;
-}
-
-class ReduceRegionReturnOpConversion
- : public OpConversionPattern<stablehlo::ReturnOp> {
- public:
- using OpConversionPattern<stablehlo::ReturnOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(
- stablehlo::ReturnOp op, OpAdaptor adaptor,
- ConversionPatternRewriter& rewriter) const final {
- if (!isInBodyOfLinalgOps(op)) {
- return failure();
- }
- SmallVector<Value, 4> operands(adaptor.getOperands());
- for (size_t i = 0; i < operands.size(); ++i) {
- if (operands[i].getType().isa<ShapedType>()) {
- auto loc = operands[i].getLoc();
- operands[i] = rewriter.create<tensor::ExtractOp>(loc, operands[i]);
- }
- }
- rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, operands);
- return success();
- }
-};
-
-class ReduceOpToGenericConverter
- : public OpConversionPattern<stablehlo::ReduceOp> {
- public:
- using OpConversionPattern<stablehlo::ReduceOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(
- stablehlo::ReduceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter& rewriter) const final {
- Location loc = op.getLoc();
-
- int numOperands = static_cast<int>(adaptor.getInputs().size());
-
- if (llvm::any_of(adaptor.getInputs(), [](Value v) {
- return !v.getType().isa<RankedTensorType>();
- })) {
- return rewriter.notifyMatchFailure(op, "expects known-rank args");
- }
- auto srcRank =
- adaptor.getInputs()[0].getType().cast<ShapedType>().getRank();
-
- SmallVector<int64_t, 4> reductionDims = extract1DVector(op.getDimensions());
-
- SmallVector<Type> resultTypes;
- if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
- return failure();
-
- SmallVector<Value> outputs;
- SmallVector<AffineMap, 3> indexingMaps;
- for (auto [operand, initValue, resultType] :
- llvm::zip(adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) {
- // Check if init_value is constant. If so, inline the value into the
- // region.
- initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
-
- SmallVector<Value, 8> dynShape = getReduceOpEmptyTensorDynSizes(
- rewriter, loc, operand, resultType, reductionDims);
- auto emptyTensor = getEmptyTensor(rewriter, loc, resultType, dynShape);
- Value filledTensor =
- rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor).result();
- outputs.push_back(filledTensor);
- }
-
- // Prepare indexing maps for linalg generic op. The elements are for src
- // and dst. Transpose `src` to make the reduction loops be the innermost,
- // because it's easier to fully utilize processors.
- indexingMaps.append(
- numOperands, getTransposeMapForReduction(rewriter.getContext(),
- (int)srcRank, reductionDims));
-
- // The indexing map of `dst` should drop the reduction loops. Since the
- // reduction loops now are all in the innermost, drops
- // `reduction_dims.size()` dimensions. We don't need an inverse
- // permutation here because they are the same.
- SmallVector<AffineExpr, 4> exprs;
- for (int i = 0, e = srcRank - reductionDims.size(); i < e; ++i)
- exprs.push_back(rewriter.getAffineDimExpr(i));
- indexingMaps.append(numOperands,
- AffineMap::get(srcRank, /*symbolCount=*/0, exprs,
- rewriter.getContext()));
-
- auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, /*resultTensorTypes=*/resultTypes, adaptor.getInputs(),
- /*outputBuffers=*/ValueRange{outputs}, indexingMaps,
- getParallelAndReductionIterators(srcRank, reductionDims.size()),
- /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
-
- // Convert the signature of the body. The reduce op region apply function
- // has a signature (lhs, rhs) -> output, all of the same tensor type t.
- // This is converted to a function with the same signature but with
- // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
- // be converted to "(f32, f32, f32)".
- Region& region = linalgOp.getRegion();
- rewriter.inlineRegionBefore(op.getBody(), region, region.end());
- TypeConverter::SignatureConversion signatureConverter(numOperands * 2);
-
- // The mhlo ReduceOp requires that the seed be used as a LHS operand inside
- // the region, and the seed is encoded in linalg in the intial out value, so
- // modify the signature of the block and the value mappings, so the output
- // args will correlate with the original LHS and the inputs correlate with
- // the original RHS.
- for (const auto& [idx, val] : llvm::enumerate(op.getInputs())) {
- signatureConverter.addInputs(
- /*origInputNo=*/idx + numOperands,
- // type for the new operand number 'idx'.
- typeConverter->convertType(
- val.getType().cast<ShapedType>().getElementType()));
- }
- for (const auto& [idx, val] : llvm::enumerate(op.getInitValues())) {
- signatureConverter.addInputs(
- /*origInputNo=*/idx,
- // type for the new operand number 'idx' + 'numOperands'.
- typeConverter->convertType(
- val.getType().cast<ShapedType>().getElementType()));
- }
-
- rewriter.applySignatureConversion(®ion, signatureConverter,
- getTypeConverter());
- rewriter.replaceOp(op, linalgOp.getResults());
- return success();
- }
-};
-
-struct ReduceOpToReduceConverter
- : public OpConversionPattern<stablehlo::ReduceOp> {
- using OpConversionPattern<stablehlo::ReduceOp>::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- stablehlo::ReduceOp op, OpAdaptor adaptor,
- ConversionPatternRewriter& rewriter) const final {
- auto reductionDims =
- llvm::to_vector(op.getDimensions().getValues<int64_t>());
- // mhlo.reduce doesn't specify the order of the reduction dimensions.
- llvm::sort(reductionDims);
-
- auto toRankedTensor = [](Value v) -> RankedTensorType {
- return v.getType().dyn_cast<RankedTensorType>();
- };
-
- SmallVector<Value> outputs;
- SmallVector<RankedTensorType> operandTypes, initTypes;
- SmallVector<Type> resultTypes;
- if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
- return failure();
-
- Location loc = op.getLoc();
- for (auto [operand, initValue, resultType] :
- llvm::zip(adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) {
- auto initType = toRankedTensor(initValue);
- if (!initType)
- return rewriter.notifyMatchFailure(op,
- "expects known-rank init values");
- initTypes.push_back(initType);
- auto operandType = toRankedTensor(operand);
- if (!operandType)
- return rewriter.notifyMatchFailure(op, "expects known-rank operands");
- operandTypes.push_back(operandType);
- initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
- auto tensorResultType = resultType.cast<RankedTensorType>();
- // For linalg.reduce, the result type's dimensions must match the input's
- // dimensions, whereas MHLO allows replacing static dimensions with
- // dynamic ones.
- SmallVector<int64_t> resultShape;
- SmallVector<Value, 8> dynShape;
- for (auto [index, dim] :
- llvm::enumerate(operand.getType().cast<ShapedType>().getShape())) {
- if (!llvm::is_contained(reductionDims, index)) {
- resultShape.push_back(dim);
- if (ShapedType::isDynamic(dim)) {
- dynShape.push_back(
- rewriter.create<tensor::DimOp>(loc, operand, index));
- }
- }
- }
-
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultShape, tensorResultType.getElementType(), dynShape);
- Value filledTensor =
- rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor).result();
- outputs.push_back(filledTensor);
- }
-
- auto linalgOp = rewriter.create<linalg::ReduceOp>(
- loc, adaptor.getInputs(), outputs, reductionDims,
- /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
-
- Region& region = linalgOp.getRegion();
- rewriter.inlineRegionBefore(op.getBody(), region, region.end());
-
- // Convert the signature of the body. The reduce op 'computation' region
- // apply function has a signature with tensor types, this is converted to a
- // function with element types. E.g. the signature "(tensor<f32>,
- // tensor<f32>) -> tensor<f32>" will be converted to "(f32, f32) -> f32".
- // Also, we need to swap the operands of the function. The mhlo.reduce op
- // expects the init values to be the first parameters of the apply function,
- // while the linalg.reduction op expects the init values as the last
- // parameters of the 'combiner' region apply function.
- TypeConverter::SignatureConversion signatureConverter(
- linalgOp.getNumDpsInputs() * 2);
- assert(linalgOp.getNumDpsInputs() == linalgOp.getNumDpsInits());
- for (const auto& [idx, val] : llvm::enumerate(operandTypes)) {
- signatureConverter.addInputs(
- /*origInputNo=*/idx + linalgOp.getNumDpsInputs(),
- // type for new operand number 'idx'.
- typeConverter->convertType(val.getElementType()));
- }
- for (const auto& [idx, val] : llvm::enumerate(initTypes)) {
- signatureConverter.addInputs(
- /*origInputNo=*/idx,
- // type for new operand number 'idx' + linalgOp.getNumInputs()
- typeConverter->convertType(val.getElementType()));
- }
- rewriter.applySignatureConversion(®ion, signatureConverter,
- getTypeConverter());
-
- // Cast the result to the correct type.
- SmallVector<Value> results;
- for (auto [result, resultType] :
- llvm::zip(linalgOp.getResults(), resultTypes)) {
- results.push_back(
- rewriter.createOrFold<tensor::CastOp>(loc, resultType, result));
- }
- rewriter.replaceOp(op, results);
- return success();
- }
-};
-
/// Converts xla-hlo.select_and_scatter op to a sequence of linalg.generics ops.
/// The current version computes the scattered index and populates the correct
/// value for each tile. It does not currently handle overlapping tiles.
@@ -2727,8 +2450,7 @@
PadOpConversion,
PadOpNegativePaddingConversion,
TorchIndexSelectOpConversion,
- SelectAndScatterNoOverlapConverter,
- ReduceRegionReturnOpConversion
+ SelectAndScatterNoOverlapConverter
>(typeConverter, context);
detail::populatePointwiseStableHloToLinalgConversionPatterns(
@@ -2742,7 +2464,6 @@
IotaToMapConverter<stablehlo::IotaOp>,
IotaToMapConverter<stablehlo::DynamicIotaOp>,
MapOpToMapConverter,
- ReduceOpToReduceConverter,
TransposeOpToTransposeConverter
>(typeConverter, context);
} else {
@@ -2753,17 +2474,18 @@
HloBroadcastInDimConverter,
HloDynamicBroadcastInDimConverter,
MapOpToGenericConverter,
- ReduceOpToGenericConverter,
TransposeConverter<stablehlo::TransposeOp>
>(typeConverter, context);
}
// clang-format on
- // TODO(#12678): Handle the convolution and reduce_window ops.
+ // TODO(#12678): Handle the convolution.
detail::populateStableHloDotProdToLinalgConversionPatterns(
context, typeConverter, patterns);
+ detail::populateStableHloReductionToLinalgConversionPatterns(
+ context, typeConverter, patterns, enablePrimitiveOps);
detail::populateScalarHloToArithConversionPatterns(
context, typeConverter, patterns, isInBodyOfLinalgOps);
linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
new file mode 100644
index 0000000..5f5b54a
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
@@ -0,0 +1,698 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Implements logic for lowering StableHLO reduction ops to Linalg dialect.
+// These patterns are separated out to their own file to save on the compilation
+// times.
+
+#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h"
+#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "stablehlo/dialect/StablehloOps.h"
+
+namespace mlir::iree_compiler::stablehlo {
+namespace {
+namespace stablehlo = mlir::stablehlo;
+
+/// Returns a permutation AffineMap that puts all reduction dimensions to the
+/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
+/// if `rank` is 4 and `reductionDims` is {1, 3}, then
+/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
+/// the AffineMap is returned.
+AffineMap getTransposeMapForReduction(MLIRContext *context, int rank,
+ ArrayRef<int64_t> reductionDims) {
+ llvm::SmallSetVector<int, 4> s(reductionDims.begin(), reductionDims.end());
+
+ SmallVector<unsigned, 4> permutation;
+ for (int i = 0; i < rank; ++i) {
+ if (!s.contains(i)) {
+ permutation.push_back(i);
+ }
+ }
+
+ llvm::append_range(permutation, reductionDims);
+ auto map = AffineMap::getPermutationMap(permutation, context);
+ return inversePermutation(map);
+}
+
+SmallVector<Value, 8> getReduceOpEmptyTensorDynSizes(
+ OpBuilder &b, Location loc, Value arg, ShapedType resultType,
+ ArrayRef<int64_t> reductionDims) {
+ llvm::SmallSetVector<int, 4> s(reductionDims.begin(), reductionDims.end());
+
+ SmallVector<unsigned, 4> parallelDims;
+ SmallVector<Value, 8> dynShape;
+ int rank = cast<RankedTensorType>(arg.getType()).getRank();
+ for (int i = 0, j = 0; i < rank; ++i) {
+ if (s.contains(i)) continue;
+ if (!resultType.isDynamicDim(j++)) continue;
+ dynShape.push_back(b.create<tensor::DimOp>(loc, arg, i));
+ }
+
+ return dynShape;
+}
+
+struct ReduceRegionReturnOpConversion final
+ : OpConversionPattern<stablehlo::ReturnOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ stablehlo::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isInBodyOfLinalgOps(op)) {
+ return failure();
+ }
+
+ SmallVector<Value, 4> operands(adaptor.getOperands());
+ for (Value &operand : operands) {
+ if (isa<ShapedType>(operand.getType())) {
+ Location loc = operand.getLoc();
+ operand = rewriter.create<tensor::ExtractOp>(loc, operand);
+ }
+ }
+ rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, operands);
+ return success();
+ }
+};
+
+struct ReduceOpToGenericConverter final
+ : OpConversionPattern<stablehlo::ReduceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ stablehlo::ReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
+ int numOperands = static_cast<int>(adaptor.getInputs().size());
+
+ if (llvm::any_of(adaptor.getInputs(), [](Value v) {
+ return !isa<RankedTensorType>(v.getType());
+ })) {
+ return rewriter.notifyMatchFailure(op, "expects known-rank args");
+ }
+ auto srcRank = cast<ShapedType>(adaptor.getInputs()[0].getType()).getRank();
+
+ SmallVector<int64_t, 4> reductionDims = extract1DVector(op.getDimensions());
+
+ SmallVector<Type> resultTypes;
+ if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
+ return failure();
+
+ SmallVector<Value> outputs;
+ SmallVector<AffineMap, 3> indexingMaps;
+ for (auto [operand, initValue, resultType] : llvm::zip_equal(
+ adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) {
+ // Check if init_value is constant. If so, inline the value into the
+ // region.
+ initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
+
+ SmallVector<Value, 8> dynShape = getReduceOpEmptyTensorDynSizes(
+ rewriter, loc, operand, resultType, reductionDims);
+ auto emptyTensor = getEmptyTensor(rewriter, loc, resultType, dynShape);
+ Value filledTensor =
+ rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor).result();
+ outputs.push_back(filledTensor);
+ }
+
+ // Prepare indexing maps for linalg generic op. The elements are for src
+ // and dst. Transpose `src` to make the reduction loops be the innermost,
+ // because it's easier to fully utilize processors.
+ indexingMaps.append(
+ numOperands,
+ getTransposeMapForReduction(rewriter.getContext(),
+ static_cast<int>(srcRank), reductionDims));
+
+ // The indexing map of `dst` should drop the reduction loops. Since the
+ // reduction loops now are all in the innermost, drops
+ // `reduction_dims.size()` dimensions. We don't need an inverse
+ // permutation here because they are the same.
+ SmallVector<AffineExpr, 4> exprs;
+ for (int i = 0, e = srcRank - reductionDims.size(); i < e; ++i) {
+ exprs.push_back(rewriter.getAffineDimExpr(i));
+ }
+ indexingMaps.append(numOperands,
+ AffineMap::get(srcRank, /*symbolCount=*/0, exprs,
+ rewriter.getContext()));
+
+ auto linalgOp = rewriter.create<linalg::GenericOp>(
+ loc, /*resultTensorTypes=*/resultTypes, adaptor.getInputs(),
+ /*outputBuffers=*/ValueRange{outputs}, indexingMaps,
+ getParallelAndReductionIterators(srcRank, reductionDims.size()),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
+
+ // Convert the signature of the body. The reduce op region apply function
+ // has a signature (lhs, rhs) -> output, all of the same tensor type t.
+ // This is converted to a function with the same signature but with
+ // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
+ // be converted to "(f32, f32, f32)".
+ Region ®ion = linalgOp.getRegion();
+ rewriter.inlineRegionBefore(op.getBody(), region, region.end());
+ TypeConverter::SignatureConversion signatureConverter(numOperands * 2);
+
+ // The mhlo ReduceOp requires that the seed be used as a LHS operand inside
+ // the region, and the seed is encoded in linalg in the intial out value, so
+ // modify the signature of the block and the value mappings, so the output
+ // args will correlate with the original LHS and the inputs correlate with
+ // the original RHS.
+ for (auto [idx, val] : llvm::enumerate(op.getInputs())) {
+ signatureConverter.addInputs(
+ /*origInputNo=*/idx + numOperands,
+ // type for the new operand number 'idx'.
+ typeConverter->convertType(
+ cast<ShapedType>(val.getType()).getElementType()));
+ }
+ for (auto [idx, val] : llvm::enumerate(op.getInitValues())) {
+ signatureConverter.addInputs(
+ /*origInputNo=*/idx,
+ // type for the new operand number 'idx' + 'numOperands'.
+ typeConverter->convertType(
+ cast<ShapedType>(val.getType()).getElementType()));
+ }
+
+ rewriter.applySignatureConversion(®ion, signatureConverter,
+ getTypeConverter());
+ rewriter.replaceOp(op, linalgOp.getResults());
+ return success();
+ }
+};
+
+struct ReduceOpToReduceConverter final
+ : OpConversionPattern<stablehlo::ReduceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ stablehlo::ReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto reductionDims =
+ llvm::to_vector(op.getDimensions().getValues<int64_t>());
+ // stablehlo.reduce doesn't specify the order of the reduction dimensions.
+ llvm::sort(reductionDims);
+
+ auto toRankedTensor = [](Value v) -> RankedTensorType {
+ return dyn_cast<RankedTensorType>(v.getType());
+ };
+
+ SmallVector<Value> outputs;
+ SmallVector<RankedTensorType> operandTypes, initTypes;
+ SmallVector<Type> resultTypes;
+ if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
+ return failure();
+
+ Location loc = op.getLoc();
+ for (auto [operand, initValue, resultType] : llvm::zip_equal(
+ adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) {
+ auto initType = toRankedTensor(initValue);
+ if (!initType)
+ return rewriter.notifyMatchFailure(op,
+ "expects known-rank init values");
+ initTypes.push_back(initType);
+ auto operandType = toRankedTensor(operand);
+ if (!operandType)
+ return rewriter.notifyMatchFailure(op, "expects known-rank operands");
+ operandTypes.push_back(operandType);
+ initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
+ auto tensorResultType = cast<RankedTensorType>(resultType);
+ // For linalg.reduce, the result type's dimensions must match the input's
+ // dimensions, whereas StableHLO allows replacing static dimensions with
+ // dynamic ones.
+ SmallVector<int64_t> resultShape;
+ SmallVector<Value, 8> dynShape;
+ for (auto [index, dim] :
+ llvm::enumerate(cast<ShapedType>(operand.getType()).getShape())) {
+ if (!llvm::is_contained(reductionDims, index)) {
+ resultShape.push_back(dim);
+ if (ShapedType::isDynamic(dim)) {
+ dynShape.push_back(
+ rewriter.create<tensor::DimOp>(loc, operand, index));
+ }
+ }
+ }
+
+ Value emptyTensor = rewriter.create<tensor::EmptyOp>(
+ loc, resultShape, tensorResultType.getElementType(), dynShape);
+ Value filledTensor =
+ rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor).result();
+ outputs.push_back(filledTensor);
+ }
+
+ auto linalgOp = rewriter.create<linalg::ReduceOp>(
+ loc, adaptor.getInputs(), outputs, reductionDims,
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
+
+ Region ®ion = linalgOp.getRegion();
+ rewriter.inlineRegionBefore(op.getBody(), region, region.end());
+
+ // Convert the signature of the body. The reduce op 'computation' region
+ // apply function has a signature with tensor types, this is converted to a
+ // function with element types. E.g. the signature "(tensor<f32>,
+ // tensor<f32>) -> tensor<f32>" will be converted to "(f32, f32) -> f32".
+ // Also, we need to swap the operands of the function. The mhlo.reduce op
+ // expects the init values to be the first parameters of the apply function,
+ // while the linalg.reduction op expects the init values as the last
+ // parameters of the 'combiner' region apply function.
+ TypeConverter::SignatureConversion signatureConverter(
+ linalgOp.getNumDpsInputs() * 2);
+ assert(linalgOp.getNumDpsInputs() == linalgOp.getNumDpsInits());
+ for (const auto &[idx, val] : llvm::enumerate(operandTypes)) {
+ signatureConverter.addInputs(
+ /*origInputNo=*/idx + linalgOp.getNumDpsInputs(),
+ // type for new operand number 'idx'.
+ typeConverter->convertType(val.getElementType()));
+ }
+ for (const auto &[idx, val] : llvm::enumerate(initTypes)) {
+ signatureConverter.addInputs(
+ /*origInputNo=*/idx,
+ // type for new operand number 'idx' + linalgOp.getNumInputs()
+ typeConverter->convertType(val.getElementType()));
+ }
+ rewriter.applySignatureConversion(®ion, signatureConverter,
+ getTypeConverter());
+
+ // Cast the result to the correct type.
+ SmallVector<Value> results;
+ for (auto [result, resultType] :
+ llvm::zip(linalgOp.getResults(), resultTypes)) {
+ results.push_back(
+ rewriter.createOrFold<tensor::CastOp>(loc, resultType, result));
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+struct ReduceWindowOpOnTensorsGenericConversion final
+ : OpConversionPattern<stablehlo::ReduceWindowOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ stablehlo::ReduceWindowOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MLIRContext *ctx = op->getContext();
+ Location loc = op.getLoc();
+ llvm::SmallVector<Value> initValues = adaptor.getInitValues();
+ llvm::SmallVector<Type> resultTypes;
+ if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
+ return failure();
+ auto numOperands = initValues.size();
+
+ llvm::SmallVector<int64_t> windowDimensions =
+ extract1DVector(op.getWindowDimensions());
+
+ llvm::SmallVector<int64_t> padding;
+ if (op.getPadding()) {
+ padding = extract1DVector(*op.getPadding());
+ }
+
+ llvm::SmallVector<int64_t> baseDilations;
+ if (op.getBaseDilations()) {
+ baseDilations = extract1DVector(*op.getBaseDilations());
+ }
+
+ llvm::SmallVector<int64_t> windowStrides(windowDimensions.size(), 1);
+ if (op.getWindowStrides()) {
+ windowStrides = extract1DVector(*op.getWindowStrides());
+ }
+
+ llvm::SmallVector<int64_t> windowDilations(windowDimensions.size(), 1);
+ if (op.getWindowDilations()) {
+ windowDilations = extract1DVector(*op.getWindowDilations());
+ }
+
+ auto rank = static_cast<int64_t>(windowDimensions.size());
+ SmallVector<AffineExpr, 2> srcExprs;
+ SmallVector<AffineExpr, 2> windowExprs;
+ SmallVector<AffineExpr, 2> dstExprs;
+ SmallVector<int64_t> filteredWindowDims;
+
+ int windowDim = 0;
+ for (int64_t i = 0; i < rank; i++) {
+ AffineExpr srcExpr = mlir::getAffineDimExpr(i, ctx);
+
+ if (windowStrides[i] != 1) srcExpr = srcExpr * windowStrides[i];
+
+ if (windowDimensions[i] != 1) {
+ filteredWindowDims.push_back(windowDimensions[i]);
+ AffineExpr windowExpr = mlir::getAffineDimExpr(rank + windowDim, ctx);
+ windowExprs.push_back(windowExpr);
+
+ if (windowDilations[i] != 1)
+ windowExpr = windowExpr * windowDilations[i];
+
+ srcExpr = srcExpr + windowExpr;
+ windowDim++;
+ }
+
+ srcExprs.push_back(srcExpr);
+ dstExprs.push_back(mlir::getAffineDimExpr(i, ctx));
+ }
+
+ SmallVector<AffineMap, 4> inferredMaps(3, AffineMap::get(ctx));
+ if (rank > 0) {
+ inferredMaps =
+ AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs});
+ }
+
+ SmallVector<AffineMap, 4> indexingMaps;
+
+ indexingMaps.append(numOperands, inferredMaps[0]);
+ indexingMaps.append(1, inferredMaps[1]);
+ indexingMaps.append(numOperands, inferredMaps[2]);
+
+ // Setup the initial values.
+ llvm::SmallVector<Value> broadcastValues;
+ for (uint64_t i = 0, s = initValues.size(); i < s; i++) {
+ Value initValue = initValues[i];
+ auto resultTy = resultTypes[i].cast<ShapedType>();
+ if (!resultTy.hasStaticShape()) return failure();
+
+ auto broadcastSizes = rewriter.getI64TensorAttr(resultTy.getShape());
+ broadcastValues.push_back(rewriter.create<stablehlo::BroadcastOp>(
+ loc, resultTy, initValue, broadcastSizes));
+ }
+
+ llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.getInputs());
+
+ // Pad as necessary.
+ if (llvm::any_of(padding, [](int64_t v) { return v != 0; }) ||
+ llvm::any_of(baseDilations, [](int64_t v) { return v != 1; })) {
+ llvm::SmallVector<int64_t> staticLows(rank, 0);
+ llvm::SmallVector<int64_t> staticHighs(rank, 0);
+ for (int64_t i = 0; i < static_cast<int64_t>(padding.size()); i += 2) {
+ staticLows[i / 2] = padding[i];
+ staticHighs[i / 2] = padding[i + 1];
+ }
+ // Translate base dilation into interior padding.
+ llvm::SmallVector<int64_t> staticInteriors(rank, 0);
+ for (auto [idx, dilation] : llvm::enumerate(baseDilations)) {
+ staticInteriors[idx] = dilation - 1;
+ }
+
+ auto padAttrType =
+ RankedTensorType::get({rank}, rewriter.getIntegerType(64));
+ auto padLows = DenseIntElementsAttr::get(padAttrType, staticLows);
+ auto padHighs = DenseIntElementsAttr::get(padAttrType, staticHighs);
+ auto padInteriors =
+ DenseIntElementsAttr::get(padAttrType, staticInteriors);
+
+ for (auto [input, initValue] : llvm::zip(inputs, initValues)) {
+ input = rewriter.create<stablehlo::PadOp>(
+ loc, input, initValue, padLows, padHighs, padInteriors);
+ }
+ }
+
+ // Add the extra input for the reduction dimension.
+ inputs.push_back(rewriter.create<tensor::EmptyOp>(loc, filteredWindowDims,
+ rewriter.getF32Type()));
+
+ auto linalgOp = rewriter.create<linalg::GenericOp>(
+ loc, /*resultTensors=*/resultTypes,
+ /*inputs=*/inputs,
+ /*outputs=*/broadcastValues, indexingMaps,
+ getParallelAndReductionIterators(rank + filteredWindowDims.size(),
+ filteredWindowDims.size()),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op));
+
+ // Convert the signature of the body. This includes converting scalar
+ // tensors to their scalar values and inserting an additional block arg for
+ // the window arg.
+ Region ®ion = linalgOp.getRegion();
+ rewriter.cloneRegionBefore(op.getBody(), region, region.end());
+
+ TypeConverter::SignatureConversion signatureConverter(
+ inputs.size() + op->getNumResults() - 1);
+
+ // ReduceWindow requires that the seed be used as a LHS operand inside the
+ // region, and the seed is encoded in linalg in the initial out value, so
+ // modify the signature of the block and the value mappings, so the output
+ // args will correlate with the LHS and the inputs correlate with the RHS.
+ for (auto [i, type] : llvm::enumerate(resultTypes)) {
+ auto idx = inputs.size() + i - 1;
+ signatureConverter.addInputs(idx,
+ cast<ShapedType>(type).getElementType());
+ }
+
+ signatureConverter.addInputs(
+ cast<ShapedType>(inputs.back().getType()).getElementType());
+
+ for (auto [i, input] :
+ llvm::enumerate(ArrayRef<Value>(inputs).drop_back())) {
+ signatureConverter.addInputs(
+ i, cast<ShapedType>(input.getType()).getElementType());
+ }
+
+ rewriter.applySignatureConversion(®ion, signatureConverter,
+ getTypeConverter());
+ rewriter.replaceOp(op, linalgOp.getResults());
+ return success();
+ }
+};
+
+struct ReduceWindowOpConversion final
+ : OpConversionPattern<stablehlo::ReduceWindowOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ /// Get the operation used for reduction applied to `result_index`th result.
+ /// Its expected to be a binary operation that consumes `result_index`th and
+ /// `result_index + getInputs().size`th arguments of the body.
+ static Operation *getReductionOp(stablehlo::ReduceWindowOp op,
+ int resultIndex) {
+ auto returnOp =
+ cast<stablehlo::ReturnOp>(op.getBody().front().getTerminator());
+ Operation *computeOp = returnOp.getResults()[resultIndex].getDefiningOp();
+ if (computeOp->getNumOperands() != 2) return nullptr;
+ auto arg0 = computeOp->getOperand(0).dyn_cast<BlockArgument>();
+ auto arg1 = computeOp->getOperand(1).dyn_cast<BlockArgument>();
+ if (!arg0 || !arg1) return nullptr;
+ int64_t arg0Num = arg0.getArgNumber();
+ int64_t arg1Num = arg1.getArgNumber();
+ int64_t otherArgIndex = resultIndex + op.getInputs().size();
+ if (arg0Num == resultIndex && arg1Num == otherArgIndex) return computeOp;
+ if (arg0Num == otherArgIndex && arg1Num == resultIndex &&
+ computeOp->hasTrait<mlir::OpTrait::IsCommutative>())
+ return computeOp;
+ return nullptr;
+ }
+
+ /// stablehlo.reduce_window is mapped to a linalg.pooling operation. The type
+ /// of the pooling is determined based on the body of the reduce window
+ /// operation. This class enumerates the different variants.
+ enum class PoolingType {
+ kInvalid,
+ k2DMin,
+ k3DMin,
+ k2DMax,
+ k3DMax,
+ k2DAdd,
+ k3DAdd,
+ };
+
+ static PoolingType getPoolingType(stablehlo::ReduceWindowOp reduceOp,
+ int resultIndex) {
+ auto rank =
+ reduceOp.getResultTypes()[resultIndex].cast<ShapedType>().getRank();
+ if (Operation *op = getReductionOp(reduceOp, resultIndex)) {
+ if (isa<stablehlo::MinOp>(*op) && rank == 4) return PoolingType::k2DMin;
+ if (isa<stablehlo::MinOp>(*op) && rank == 5) return PoolingType::k3DMin;
+ if (isa<stablehlo::MaxOp>(*op) && rank == 4) return PoolingType::k2DMax;
+ if (isa<stablehlo::MaxOp>(*op) && rank == 5) return PoolingType::k3DMax;
+ if (isa<stablehlo::AddOp>(*op) && rank == 4) return PoolingType::k2DAdd;
+ if (isa<stablehlo::AddOp>(*op) && rank == 5) return PoolingType::k3DAdd;
+ }
+ return PoolingType::kInvalid;
+ }
+
+ LogicalResult matchAndRewrite(
+ stablehlo::ReduceWindowOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
+ if (rank != 4 && rank != 5) {
+ return rewriter.notifyMatchFailure(
+ op, "expected NHWC/NDHWC pooling-based op");
+ }
+
+ if (op.getPadding() && !isSplatValue(*op.getPadding(), 0)) {
+ return rewriter.notifyMatchFailure(op, "require paddings are all zero");
+ }
+
+ if (op.getBaseDilations() && !isSplatValue(*op.getBaseDilations(), 1)) {
+ return rewriter.notifyMatchFailure(op, "expected undilated base");
+ }
+
+ int lastDim = rank - 1;
+ SmallVector<int64_t, 2> fakeWindowShapes;
+ for (int i = 1; i < lastDim; ++i) {
+ fakeWindowShapes.push_back(
+ op.getWindowDimensions().getValues<int64_t>()[i]);
+ }
+
+ if (op.getWindowStrides() &&
+ (op.getWindowStrides().value().getValues<int64_t>()[0] != 1 ||
+ op.getWindowStrides().value().getValues<int64_t>()[lastDim] != 1)) {
+ return rewriter.notifyMatchFailure(
+ op, "expected window_strides to be [1,x,y,(z),1]");
+ }
+ if (op.getWindowDimensions() &&
+ (op.getWindowDimensions().getValues<int64_t>()[0] != 1 ||
+ op.getWindowDimensions().getValues<int64_t>()[lastDim] != 1)) {
+ return rewriter.notifyMatchFailure(
+ op, "expected window_dimensions to be [1,x,y,(z),1]");
+ }
+
+ Attribute strides;
+ SmallVector<int64_t> vec;
+ if (op.getWindowStridesAttr()) {
+ for (int i = 1; i < lastDim; ++i) {
+ vec.push_back(op.getWindowStrides().value().getValues<int64_t>()[i]);
+ }
+ } else {
+ vec.assign(rank - 2, 1);
+ }
+ strides = rewriter.getI64VectorAttr(vec);
+
+ Attribute dilations;
+ vec.clear();
+ if (op.getWindowDilations()) {
+ for (int i = 1; i < lastDim; ++i) {
+ vec.push_back(op.getWindowDilations().value().getValues<int64_t>()[i]);
+ }
+ } else {
+ vec.assign(rank - 2, 1);
+ }
+ dilations = rewriter.getI64VectorAttr(vec);
+
+ SmallVector<Value> poolingOps;
+
+ ValueRange operands = adaptor.getInputs();
+ ValueRange initValues = adaptor.getInitValues();
+ for (auto it : llvm::zip(op.getResults(), operands, initValues)) {
+ OpResult result = std::get<0>(it);
+ Value input = std::get<1>(it);
+ Value initValue = std::get<2>(it);
+ auto resultType = cast<ShapedType>(result.getType());
+ if (!cast<ShapedType>(input.getType()).getElementType().isF32()) {
+ return rewriter.notifyMatchFailure(op,
+ "expected element type to be f32");
+ }
+
+ // Create a fake window dimension.
+ auto fakeWindowDims = rewriter.create<tensor::EmptyOp>(
+ loc, fakeWindowShapes, resultType.getElementType());
+
+ SmallVector<Value> resultDynamicDims;
+ for (const auto &en : llvm::enumerate(resultType.getShape())) {
+ if (en.value() != ShapedType::kDynamic) continue;
+ Value dimSize = rewriter.create<tensor::DimOp>(loc, input, en.index());
+ if (en.index() == 0 || static_cast<int64_t>(en.index()) == rank - 1) {
+ // batch dims and channel dims can be derived from input dims
+ // directly.
+ resultDynamicDims.push_back(dimSize);
+ } else {
+ auto i = en.index() - 1;
+ auto stride =
+ strides.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
+ auto dilation =
+ dilations.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
+ // let j = i * stride
+ // output[i] = reduce( input[j, j + window_size * dilation) )
+ Value offset = rewriter.create<arith::ConstantIndexOp>(
+ loc, fakeWindowShapes[i] * dilation);
+ dimSize = rewriter.create<arith::SubIOp>(loc, dimSize, offset);
+ dimSize = rewriter.create<arith::DivUIOp>(
+ loc, dimSize,
+ rewriter.create<arith::ConstantIndexOp>(loc, stride));
+ dimSize = rewriter.create<arith::AddIOp>(
+ loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, 1));
+ resultDynamicDims.push_back(dimSize);
+ }
+ }
+ Value emptyTensor = rewriter.create<tensor::EmptyOp>(
+ loc, resultType.getShape(), resultType.getElementType(),
+ resultDynamicDims);
+
+ initValue = rewriter.create<tensor::ExtractOp>(loc, initValue);
+ Value filledInitTensor =
+ rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor)
+ .getResult(0);
+ auto createOp = [&](auto *typePtr) -> linalg::LinalgOp {
+ return cast<linalg::LinalgOp>(
+ rewriter
+ .create<std::remove_pointer_t<decltype(typePtr)>>(
+ loc, ArrayRef<Type>{resultType},
+ ValueRange{input, fakeWindowDims.getResult()},
+ filledInitTensor, strides, dilations,
+ linalg::getPrunedAttributeList(op))
+ .getOperation());
+ };
+ linalg::LinalgOp poolingOp;
+ PoolingType poolingType = getPoolingType(op, result.getResultNumber());
+ switch (poolingType) {
+ case PoolingType::k2DMin: {
+ poolingOp =
+ createOp(static_cast<linalg::PoolingNhwcMinOp *>(nullptr));
+ break;
+ }
+ case PoolingType::k3DMin: {
+ poolingOp =
+ createOp(static_cast<linalg::PoolingNdhwcMinOp *>(nullptr));
+ break;
+ }
+ case PoolingType::k2DMax: {
+ poolingOp =
+ createOp(static_cast<linalg::PoolingNhwcMaxOp *>(nullptr));
+ break;
+ }
+ case PoolingType::k3DMax: {
+ poolingOp =
+ createOp(static_cast<linalg::PoolingNdhwcMaxOp *>(nullptr));
+ break;
+ }
+ case PoolingType::k2DAdd: {
+ poolingOp =
+ createOp(static_cast<linalg::PoolingNhwcSumOp *>(nullptr));
+ break;
+ }
+ case PoolingType::k3DAdd: {
+ poolingOp =
+ createOp(static_cast<linalg::PoolingNdhwcSumOp *>(nullptr));
+ break;
+ }
+ case PoolingType::kInvalid:
+ return rewriter.notifyMatchFailure(op, "unknown reduction operation");
+ }
+ poolingOps.push_back(poolingOp->getResult(0));
+ }
+ rewriter.replaceOp(op, poolingOps);
+ return success();
+ }
+};
+
+} // namespace
+
+namespace detail {
+void populateStableHloReductionToLinalgConversionPatterns(
+ MLIRContext *context, TypeConverter &typeConverter,
+ RewritePatternSet *patterns, bool enablePrimitiveOps) {
+ if (enablePrimitiveOps) {
+ patterns->add<ReduceOpToReduceConverter>(typeConverter, context);
+ } else {
+ patterns->add<ReduceOpToGenericConverter>(typeConverter, context);
+ }
+ patterns->add<ReduceRegionReturnOpConversion,
+ ReduceWindowOpOnTensorsGenericConversion>(typeConverter,
+ context);
+
+ // Ensure specialized patterns are higher priority than their generic
+ // versions.
+ patterns->add<ReduceWindowOpConversion>(typeConverter, context,
+ PatternBenefit(2));
+}
+} // namespace detail
+} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
index 156011c..28bd8bb 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_linalg_reduce.mlir
@@ -355,3 +355,419 @@
// CHECK-PRIMITIVE-NEXT: %[[RES0:.*]] = arith.select %[[B0]], %[[RHS0]], %[[LHS0]] : f32
// CHECK-PRIMITIVE-NEXT: %[[RES1:.*]] = arith.select %[[B0]], %[[RHS1]], %[[LHS1]] : i32
// CHECK-PRIMITIVE-NEXT: linalg.yield %[[RES0]], %[[RES1]] : f32, i32
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_min_nhwc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_min_nhwc(%arg0: tensor<1x17x17x64xf32>,
+ %arg1: tensor<f32>) -> tensor<1x8x8x64xf32>{
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+ %1 = stablehlo.minimum %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
+ someattr} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+ func.return %0 : tensor<1x8x8x64xf32>
+}
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_min
+// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME: someattr,
+// CHECK-SAME: strides = dense<2> : vector<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_max_nhwc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_max_nhwc(%arg0: tensor<1x17x17x64xf32>,
+ %arg1: tensor<f32>) -> tensor<1x8x8x64xf32>{
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+ %1 = stablehlo.maximum %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+ func.return %0 : tensor<1x8x8x64xf32>
+}
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max
+// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME: strides = dense<2> : vector<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_sum_nhwc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_sum_nhwc(%arg0: tensor<1x17x17x64xf32>,
+ %arg1: tensor<f32>) -> tensor<1x8x8x64xf32>{
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+ func.return %0 : tensor<1x8x8x64xf32>
+}
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_sum
+// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME: strides = dense<2> : vector<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_max_nhwc_with_cst
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+func.func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x17x17x64xf32>) -> tensor<1x8x8x64xf32> {
+ %0 = arith.constant dense<0xFF800000> : tensor<f32>
+ %1 = "stablehlo.reduce_window"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<f32>, %arg2 : tensor<f32>):
+ %2 = stablehlo.maximum %arg1, %arg2 : tensor<f32>
+ "stablehlo.return"(%2) : (tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
+ func.return %1 : tensor<1x8x8x64xf32>
+}
+
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x64xf32
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max
+// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME: strides = dense<2> : vector<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_sum_max_nhwc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_sum_max_nhwc(%arg0: tensor<1x17x17x64xf32>,
+ %arg1: tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) {
+ %0:2 = "stablehlo.reduce_window"(%arg0, %arg0, %arg1, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>, %arg4: tensor<f32>, %arg5 : tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg4 : tensor<f32>
+ %2 = stablehlo.maximum %arg3, %arg5 : tensor<f32>
+ "stablehlo.return"(%1, %2) : (tensor<f32>, tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x17x17x64xf32>, tensor<1x17x17x64xf32>, tensor<f32>, tensor<f32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>)
+ func.return %0#0, %0#1 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>
+}
+
+// CHECK: %[[WINDOW0:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK: %[[INIT0:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL0:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES0:.+]] = linalg.pooling_nhwc_sum
+// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME: strides = dense<2> : vector<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW0]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL0]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK: %[[WINDOW1:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK: %[[INIT1:.+]] = tensor.empty() : tensor<1x8x8x64xf32>
+// CHECK: %[[INIT_VAL1:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL1:.+]] = linalg.fill ins(%[[INIT_VAL1]] : f32) outs(%[[INIT1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES1:.+]] = linalg.pooling_nhwc_max
+// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME: strides = dense<2> : vector<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW1]] : tensor<1x17x17x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL1]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+// CHECK: return %[[RES0]], %[[RES1]]
+
+// -----
+
+// Just check that this lowers successfully.
+// CHECK-LABEL: func @reduce_window_unsigned
+func.func @reduce_window_unsigned(%arg0: tensor<1x1xui32>) -> tensor<1x1xui32> {
+ %0 = stablehlo.constant dense<0> : tensor<ui32>
+ %1 = "stablehlo.reduce_window"(%arg0, %0) ({
+ ^bb0(%arg1: tensor<ui32>, %arg2: tensor<ui32>):
+ stablehlo.return %arg1 : tensor<ui32>
+ }) {
+ window_dimensions = dense<[1, 1]> : tensor<2xi64>,
+ window_strides = dense<[1, 1]> : tensor<2xi64>
+ } : (tensor<1x1xui32>, tensor<ui32>) -> tensor<1x1xui32>
+ return %1 : tensor<1x1xui32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dynamic_reduce_window_sum_nhwc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @dynamic_reduce_window_sum_nhwc(%arg0: tensor<?x?x?x?xf32>,
+ %arg1: tensor<f32>) -> tensor<?x?x?x?xf32>{
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?x?xf32>
+ func.return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3xf32>
+// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK: %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[C3]]
+// CHECK: %[[T3:.+]] = arith.divui %[[T2]], %[[C2]]
+// CHECK: %[[D1:.+]] = arith.addi %[[T3]], %[[C1]]
+// CHECK: %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[C3]]
+// CHECK: %[[T3:.+]] = arith.divui %[[T2]], %[[C2]]
+// CHECK: %[[D2:.+]] = arith.addi %[[T3]], %[[C1]]
+// CHECK: %[[D3:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?xf32>
+// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]], %[[D3]]) : tensor<?x?x?x?xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_sum
+// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME: strides = dense<2> : vector<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<?x?x?x?xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_min_ndhwc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_min_ndhwc(%arg0: tensor<1x17x17x17x64xf32>,
+ %arg1: tensor<f32>) -> tensor<1x8x8x8x64xf32>{
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+ %1 = stablehlo.minimum %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
+ window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
+ func.return %0 : tensor<1x8x8x8x64xf32>
+}
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_ndhwc_min
+// CHECK-SAME: {dilations = dense<1> : vector<3xi64>
+// CHECK-SAME: strides = dense<2> : vector<3xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_max_ndhwc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_max_ndhwc(%arg0: tensor<1x17x17x17x64xf32>,
+ %arg1: tensor<f32>) -> tensor<1x8x8x8x64xf32>{
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+ %1 = stablehlo.maximum %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
+ window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
+ func.return %0 : tensor<1x8x8x8x64xf32>
+}
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_ndhwc_max
+// CHECK-SAME: {dilations = dense<1> : vector<3xi64>
+// CHECK-SAME: strides = dense<2> : vector<3xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_sum_ndhwc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_sum_ndhwc(%arg0: tensor<1x17x17x17x64xf32>,
+ %arg1: tensor<f32>) -> tensor<1x8x8x8x64xf32>{
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
+ window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x8x64xf32>
+ func.return %0 : tensor<1x8x8x8x64xf32>
+}
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<3x3x3xf32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x8x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[INIT_VAL]] : f32) outs(%[[INIT]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_ndhwc_sum
+// CHECK-SAME: {dilations = dense<1> : vector<3xi64>
+// CHECK-SAME: strides = dense<2> : vector<3xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x17x17x17x64xf32>, tensor<3x3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x8x64xf32>) -> tensor<1x8x8x8x64xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_sum_ndhwc_dilated_base
+// CHECK: linalg.generic
+func.func @reduce_window_sum_ndhwc_dilated_base(
+ %arg0: tensor<1x17x17x17x64xf32>,
+ %arg1: tensor<f32>) -> tensor<1x8x8x16x64xf32>{
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3 : tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {base_dilations = dense<[1, 1, 1, 2, 1]> : tensor<5xi64>,
+ window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
+ window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<1x17x17x17x64xf32>, tensor<f32>) -> tensor<1x8x8x16x64xf32>
+ func.return %0 : tensor<1x8x8x16x64xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0 * 2, d1 + d2 * 2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: func @reduce_window_generic
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_generic(%arg0: tensor<4x6xf32>, %arg1: tensor<f32>) -> tensor<4x7xf32> {
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+ func.return %0 : tensor<4x7xf32>
+}
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<4x7xf32>
+// CHECK: %[[FILL:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<f32>) outs(%[[INIT]] : tensor<4x7xf32>)
+// CHECK: ^{{[a-z0-9_]*}}
+// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32
+// CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32
+// CHECK: linalg.yield %[[IN]] : f32
+
+// CHECK: %[[PADVAL:.+]] = tensor.extract %arg1[] : tensor<f32>
+// CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1] high[3, 2]
+// CHECK: ^{{[a-z0-9_]*}}
+// CHECK-SAME: %{{[a-zA-Z0-9_]*}}: index
+// CHECK-SAME: %{{[a-zA-Z0-9_]*}}: index
+// CHECK: tensor.yield %[[PADVAL]] : f32
+
+// CHECK: %[[WINDOW:.+]] = tensor.empty() : tensor<2xf32>
+// CHECK: %[[REDUCE:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[PAD]], %[[WINDOW]] : tensor<7x9xf32>, tensor<2xf32>) outs(%[[FILL]] : tensor<4x7xf32>) {
+// CHECK: ^{{[a-z0-9_]*}}
+// CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32
+// CHECK-SAME: %[[IN2:[a-zA-Z0-9_]*]]: f32
+// CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[IN]] : f32
+// CHECK: linalg.yield %[[ADD]]
+
+// CHECK: return %[[REDUCE]]
+// -----
+
+// CHECK-LABEL: func @reduce_window_generic_captured_constant
+func.func @reduce_window_generic_captured_constant(%arg0: tensor<4x6xf32>, %arg1: tensor<f32>) -> tensor<4x7xf32> {
+ %c2 = stablehlo.constant dense<2.0> : tensor<f32>
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ %2 = stablehlo.multiply %1, %c2 : tensor<f32>
+ "stablehlo.return"(%2) : (tensor<f32>) -> ()
+ }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+ func.return %0 : tensor<4x7xf32>
+}
+
+// CHECK: %[[C2:.*]] = arith.constant 2.0
+// CHECK: linalg.generic
+// CHECK: %[[SUM:.*]] = arith.addf
+// CHECK: %[[PROD:.*]] = arith.mulf %[[SUM]], %[[C2]]
+// CHECK: linalg.yield %[[PROD]]
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_generic_padding
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_generic_padding(%arg0: tensor<3x6xf32>, %arg1: tensor<f32>) -> tensor<3x7xf32> {
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<3x7xf32>
+ func.return %0 : tensor<3x7xf32>
+}
+// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[0, 1] high[3, 2]
+// CHECK: tensor.yield %[[PADVAL]] : f32
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_generic_base_dilation
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_generic_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor<f32>) -> tensor<3x4xf32> {
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {base_dilations = dense<[2, 1]> : tensor<2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<3x4xf32>
+ func.return %0 : tensor<3x4xf32>
+}
+// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<5x6xf32>) -> tensor<5x6xf32>
+// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 0] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<5x6xf32>
+
+// -----
+
+// CHECK-LABEL: func @reduce_window_generic_padding_base_dilation
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+func.func @reduce_window_generic_padding_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor<f32>) -> tensor<4x7xf32> {
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {base_dilations = dense<[2, 1]> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+ func.return %0 : tensor<4x7xf32>
+}
+// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<8x9xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<8x9xf32>) -> tensor<8x9xf32>
+// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 1] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<8x9xf32>
+
+// -----
+
+// CHECK: #[[MAP:.+]] = affine_map<() -> ()>
+// CHECK: func @reduce_window_generic_scalar
+func.func @reduce_window_generic_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+ %1 = stablehlo.add %arg2, %arg3 : tensor<f32>
+ "stablehlo.return"(%1) : (tensor<f32>) -> ()
+ }) {base_dilations = dense<> : tensor<0xi64>, padding = dense<> : tensor<0x2xi64>, window_dilations = dense<> : tensor<0xi64>, window_dimensions = dense<> : tensor<0xi64>, window_strides = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ func.return %0 : tensor<f32>
+}
+// CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]