Remove padding preprocessing for mhlo.reduce_window. (#8503)
diff --git a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
index b20d1b5..ceff92c 100644
--- a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
+++ b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
@@ -272,70 +272,6 @@
}
};
-class ExtractReduceWindowOpPaddingAttributes
- : public OpRewritePattern<mhlo::ReduceWindowOp> {
- public:
- using OpRewritePattern<mhlo::ReduceWindowOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mhlo::ReduceWindowOp op,
- PatternRewriter &rewriter) const override {
- if (!op.padding()) return failure();
-
- if ((op.base_dilations() && !isSplatValue(*op.base_dilations(), 1)) ||
- (op.window_dilations() && !isSplatValue(*op.window_dilations(), 1))) {
- return failure();
- }
- if (isAllZero(op.paddingAttr())) return failure();
-
- // All inputs must be of the same static shape, since
- // mhlo.pad doesn't support dynamic shape.
- for (Type inputType : op.inputs().getType()) {
- if (!inputType.cast<ShapedType>().hasStaticShape()) return failure();
- }
- ArrayRef<int64_t> inputShape =
- op.inputs()[0].getType().cast<ShapedType>().getShape();
-
- int rank = inputShape.size();
- SmallVector<int64_t, 4> paddingLow, paddingHigh, interiorPadding, shape;
- for (unsigned i = 0; i < rank; ++i) {
- interiorPadding.push_back(0);
- paddingLow.push_back(op.paddingAttr().getValues<int64_t>()[{i, 0}]);
- paddingHigh.push_back(op.paddingAttr().getValues<int64_t>()[{i, 1}]);
- int size = inputShape[i];
- shape.push_back(size + paddingLow.back() + paddingHigh.back());
- }
-
- auto toDenseAttr = [&rewriter](ArrayRef<int64_t> elements) {
- return DenseIntElementsAttr::get(
- RankedTensorType::get(elements.size(), rewriter.getIntegerType(64)),
- elements);
- };
-
- SmallVector<Value> padOps;
- padOps.reserve(op.inputs().size());
- auto loc = op.getLoc();
- for (auto it : llvm::zip(op.inputs(), op.init_values())) {
- Value input = std::get<0>(it);
- Value initValue = std::get<1>(it);
- auto inputType = input.getType().cast<ShapedType>();
- auto padResultType =
- RankedTensorType::get(shape, inputType.getElementType());
- auto padOp = rewriter.create<mhlo::PadOp>(
- loc, padResultType, input, initValue, toDenseAttr(paddingLow),
- toDenseAttr(paddingHigh), toDenseAttr(interiorPadding));
- padOps.push_back(padOp);
- }
- auto newOp = rewriter.create<mhlo::ReduceWindowOp>(
- loc, op.getResultTypes(), padOps, op.init_values(),
- op.window_dimensions(), op.window_stridesAttr(),
- op.base_dilationsAttr(), op.window_dilationsAttr(),
- /*padding=*/nullptr);
- rewriter.inlineRegionBefore(op.body(), newOp.body(), newOp.body().begin());
- rewriter.replaceOp(op, newOp.getResults());
- return success();
- }
-};
-
// Adjust the shape of depthwise_conv filter where is applied by mhlo.
class AdjustDepthwiseFilterShape : public OpRewritePattern<mhlo::ConvOp> {
public:
@@ -871,8 +807,7 @@
mhlo::PopulateUnfuseBatchNormPatterns(context, &patterns);
mhlo::PopulateComplexLoweringPatterns(context, &patterns);
mhlo::PopulateGatherToTorchIndexSelectPatterns(context, &patterns);
- patterns.insert<ExtractReduceWindowOpPaddingAttributes,
- AdjustDepthwiseFilterShape, ScatterRank0Value,
+ patterns.insert<AdjustDepthwiseFilterShape, ScatterRank0Value,
ExpandRngNormal, MulCastOfBool>(context);
// dot_general canoncalization patterns.
diff --git a/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir b/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
index 75a1b3b..e72a219 100644
--- a/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
@@ -58,59 +58,6 @@
// -----
-// CHECK-LABEL: @reduce_window
-func @reduce_window(%input: tensor<1x16x16x64xf32>) -> tensor<1x8x8x64xf32> {
- // CHECK: %[[INITVAL:.+]] = mhlo.constant dense<0xFF800000> : tensor<f32>
- %initval = mhlo.constant dense<0xFF800000> : tensor<f32>
- // CHECK: %[[PAD:.+]] = "mhlo.pad"(%{{.+}}, %[[INITVAL]])
- // CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>
- // CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>
- // CHECK: "mhlo.reduce_window"(%[[PAD]], %[[INITVAL]])
- // CHECK-NOT: padding
- %0 = "mhlo.reduce_window"(%input, %initval) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- %3 = mhlo.maximum %arg1, %arg2 : tensor<f32>
- "mhlo.return"(%3) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
- window_dilations = dense<1> : tensor<4xi64>,
- base_dilations = dense<1> : tensor<4xi64>,
- padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
- } : (tensor<1x16x16x64xf32>, tensor<f32>) -> tensor<1x8x8x64xf32>
- return %0 : tensor<1x8x8x64xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @reduce_window_variadic
-func @reduce_window_variadic(%input0: tensor<1x16x16x64xf32>, %input1: tensor<1x16x16x64xi32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>) {
- // CHECK-DAG: %[[INITVAL0:.+]] = mhlo.constant dense<0xFF800000> : tensor<f32>
- // CHECK-DAG: %[[INITVAL1:.+]] = mhlo.constant dense<3> : tensor<i32>
- %initval0 = mhlo.constant dense<0xFF800000> : tensor<f32>
- %initval1 = mhlo.constant dense<3> : tensor<i32>
-
- // CHECK: %[[PAD0:.+]] = "mhlo.pad"(%{{.+}}, %[[INITVAL0]])
- // CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>
- // CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>
- // CHECK: %[[PAD1:.+]] = "mhlo.pad"(%{{.+}}, %[[INITVAL1]])
- // CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>
- // CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>
- // CHECK: "mhlo.reduce_window"(%[[PAD0]], %[[PAD1]], %[[INITVAL0]], %[[INITVAL1]])
- // CHECK-NOT: padding
- %0:2 = "mhlo.reduce_window"(%input0, %input1, %initval0, %initval1) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<i32>): // no predecessors
- %3 = mhlo.maximum %arg1, %arg3 : tensor<f32>
- %4 = mhlo.add %arg2, %arg4 : tensor<i32>
- "mhlo.return"(%3, %4) : (tensor<f32>, tensor<i32>) -> ()
- }) {window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
- window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
- padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
- } : (tensor<1x16x16x64xf32>, tensor<1x16x16x64xi32>, tensor<f32>, tensor<i32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>)
- return %0#0, %0#1 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>
-}
-
-// -----
-
// CHECK: @reorder_broadcast_in_dim_scalar_binary(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>)
func @reorder_broadcast_in_dim_scalar_binary(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>) {
// CHECK: %[[ADD:.*]] = mhlo.add %[[ARG0]], %[[ARG1]] : tensor<f32>
diff --git a/iree/test/e2e/xla_ops/reduce_window.mlir b/iree/test/e2e/xla_ops/reduce_window.mlir
index 08efbc2..24923d6 100644
--- a/iree/test/e2e/xla_ops/reduce_window.mlir
+++ b/iree/test/e2e/xla_ops/reduce_window.mlir
@@ -64,3 +64,20 @@
check.expect_almost_eq_const(%res, dense<[[[[1.0], [4.0]], [[13.0], [14.0]]]]> : tensor<1x2x2x1xf32>) : tensor<1x2x2x1xf32>
return
}
+
+func @reduce_window_max_with_padding_4x6xf32() {
+ %0 = util.unfoldable_constant dense<[[[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0], [ 6.0]],
+ [[ 7.0], [ 8.0], [ 9.0], [10.0], [11.0], [12.0]],
+ [[13.0], [14.0], [15.0], [16.0], [17.0], [18.0]],
+ [[19.0], [20.0], [21.0], [22.0], [23.0], [24.0]]]]> : tensor<1x4x6x1xf32>
+ %1 = util.unfoldable_constant dense<0.0> : tensor<f32>
+ %res = "mhlo.reduce_window"(%0, %1) ( {
+ ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): // no predecessors
+ %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ "mhlo.return"(%3) : (tensor<f32>) -> ()
+ }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
+ window_strides = dense<[1, 2, 3, 1]> : tensor<4xi64>,
+ padding = dense<[[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi64>} : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x3x2x1xf32>
+ check.expect_almost_eq_const(%res, dense<[[[[3.0], [6.0]], [[15.0], [18.0]], [[21.0], [24.0]]]]> : tensor<1x3x2x1xf32>) : tensor<1x3x2x1xf32>
+ return
+}