Adapt pooling ops lowering to Linalg on tensors. (#5024)

This also introduces the concept of bufferization of shape-only operand.
If an init_tensor op can not map to a buffer, a fake/temp memref will be
allocated. And we rely on later passes to delete the alloc op. This is
similar to what we did to lower mhlo.reduce_window to linalg.pooling_*.
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 800768c..da5d8a0 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -340,146 +340,6 @@
 }  // namespace
 
 //===----------------------------------------------------------------------===//
-// mhlo.reduce_window conversion patterns and utility functions.
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-/// Returns the constant value associated with the init value if the defining
-/// operation is a constant.
-static Attribute GetInitValueAsConst(Value init) {
-  DenseElementsAttr attr;
-  if (!matchPattern(init, m_Constant(&attr))) return {};
-  auto type = attr.getType().dyn_cast<ShapedType>();
-  if (!type || type.getRank() != 0) return {};
-  return attr.getValue({});
-}
-
-/// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
-/// the pooling is determined based on the body of the reduce window
-/// operation. This class enumerates the different variants.
-enum class PoolingType {
-  kMin,
-  kMax,
-  kAdd,
-};
-
-struct ReduceWindowOpConversion
-    : public ConvertToLinalgBufferOp<ReduceWindowOpConversion,
-                                     mhlo::ReduceWindowOp> {
-  using ConvertToLinalgBufferOp<ReduceWindowOpConversion,
-                                mhlo::ReduceWindowOp>::ConvertToLinalgBufferOp;
-
-  LogicalResult apply(mhlo::ReduceWindowOp op, ArrayRef<Value> inputBuffers,
-                      ArrayRef<Value> resultBuffers,
-                      ConversionPatternRewriter &rewriter) const;
-};
-}  // namespace
-
-static PoolingType getPoolingType(Region &region) {
-  assert(region.getBlocks().size() == 1 &&
-         "expected the region has exactlly one block");
-  Block &block = region.front();
-  assert(block.getOperations().size() == 2 &&
-         "expected the block has exactlly two operations");
-  auto op = block.begin();
-  if (isa<mhlo::MinOp>(op)) return PoolingType::kMin;
-  if (isa<mhlo::MaxOp>(op)) return PoolingType::kMax;
-  if (isa<mhlo::AddOp>(op)) return PoolingType::kAdd;
-
-  llvm_unreachable("unknown pooling type");
-}
-
-LogicalResult ReduceWindowOpConversion::apply(
-    mhlo::ReduceWindowOp op, ArrayRef<Value> inputBuffers,
-    ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
-  auto loc = op.getLoc();
-  auto resultType = op.getResult().getType().cast<ShapedType>();
-  if (resultType.getRank() != 4) {
-    return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op");
-  }
-
-  // Create a fake window dimension.
-  SmallVector<int64_t, 4> shapes;
-  shapes.push_back(op.window_dimensions().getValue<int64_t>(1));
-  shapes.push_back(op.window_dimensions().getValue<int64_t>(2));
-  auto memrefType = MemRefType::get(shapes, rewriter.getF32Type());
-  auto fakeWindowDims = rewriter.create<AllocOp>(loc, memrefType);
-
-  if (op.window_strides() &&
-      (op.window_strides().getValue().getValue<int64_t>(0) != 1 ||
-       op.window_strides().getValue().getValue<int64_t>(3) != 1)) {
-    return rewriter.notifyMatchFailure(
-        op, "expected window_strides to be [1,x,y,1]");
-  }
-  if (op.window_dimensions() &&
-      (op.window_dimensions().getValue<int64_t>(0) != 1 ||
-       op.window_dimensions().getValue<int64_t>(3) != 1)) {
-    return rewriter.notifyMatchFailure(
-        op, "expected window_dimensions to be [1,x,y,1]");
-  }
-
-  if (!inputBuffers[0].getType().cast<ShapedType>().getElementType().isF32()) {
-    return rewriter.notifyMatchFailure(op, "expected element type to be f32");
-  }
-
-  Attribute strides;
-  if (op.window_stridesAttr()) {
-    strides = rewriter.getI64VectorAttr(
-        {op.window_strides().getValue().getValue<int64_t>(1),
-         op.window_strides().getValue().getValue<int64_t>(2)});
-  } else {
-    strides = rewriter.getI64VectorAttr({1, 1});
-  }
-  Attribute dilations;
-  if (op.window_dilations()) {
-    dilations = rewriter.getI64VectorAttr(
-        {op.window_dilations().getValue().getValue<int64_t>(1),
-         op.window_dilations().getValue().getValue<int64_t>(2)});
-  } else {
-    dilations = rewriter.getI64VectorAttr({1, 1});
-  }
-  auto createOp = [&](auto *type_ptr) -> linalg::LinalgOp {
-    return cast<linalg::LinalgOp>(
-        rewriter
-            .create<std::remove_pointer_t<decltype(type_ptr)>>(
-                loc, ArrayRef<Type>{},
-                ValueRange{inputBuffers[0], fakeWindowDims.getResult()},
-                resultBuffers[0], dilations, strides)
-            .getOperation());
-  };
-  linalg::LinalgOp poolingOp;
-  PoolingType poolingType = getPoolingType(op.body());
-
-  Value initValue = inputBuffers[1];
-  Attribute initConstVal = GetInitValueAsConst(initValue);
-  if (initConstVal) {
-    initValue = rewriter.create<ConstantOp>(initValue.getDefiningOp()->getLoc(),
-                                            initConstVal);
-  } else {
-    initValue = rewriter.create<LoadOp>(loc, initValue);
-  }
-  rewriter.create<linalg::FillOp>(loc, resultBuffers[0], initValue);
-
-  switch (poolingType) {
-    case PoolingType::kMin: {
-      poolingOp = createOp(static_cast<linalg::PoolingNHWCMinOp *>(nullptr));
-      break;
-    }
-    case PoolingType::kMax: {
-      poolingOp = createOp(static_cast<linalg::PoolingNHWCMaxOp *>(nullptr));
-      break;
-    }
-    case PoolingType::kAdd: {
-      poolingOp = createOp(static_cast<linalg::PoolingNHWCSumOp *>(nullptr));
-      break;
-    }
-  }
-  rewriter.create<DeallocOp>(loc, fakeWindowDims);
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
 // Linalg op on tensors to linalg op on buffers conversion base class.
 //===----------------------------------------------------------------------===//
 
@@ -635,8 +495,15 @@
       linalg::InitTensorOp op, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
     Value outputBuffer = resultTensorToBufferMap.lookup(op.result());
-    if (!outputBuffer) return failure();
-    rewriter.replaceOp(op, outputBuffer);
+    if (!outputBuffer) {
+      // If the outputBuffer does not exist, this is a shape-only operand.
+      // Allocate a temp buffer and it will get deleted after lowering to loops.
+      RankedTensorType type = op.getType();
+      auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
+      rewriter.replaceOpWithNewOp<AllocOp>(op, memrefType);
+    } else {
+      rewriter.replaceOp(op, outputBuffer);
+    }
     return success();
   }
 
@@ -1007,8 +874,10 @@
       NamedOpConversion<linalg::MatmulI16I16I32Op>,
       NamedOpConversion<linalg::MatmulI32I32I32Op>,
       NamedOpConversion<linalg::BatchMatmulOp>,
+      NamedOpConversion<linalg::PoolingNHWCMaxOp>,
+      NamedOpConversion<linalg::PoolingNHWCMinOp>,
+      NamedOpConversion<linalg::PoolingNHWCSumOp>,
       PadTensorOpConversion,
-      ReduceWindowOpConversion,
       SubTensorOpConversion,
       SubTensorInsertOpConversion,
       TensorReshapeOpConversion
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index ac4ea51..2c5b278 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -152,6 +152,127 @@
 }  // namespace
 
 //===----------------------------------------------------------------------===//
+// mhlo.reduce_window conversion patterns.
+//===----------------------------------------------------------------------===//
+
+/// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
+/// the pooling is determined based on the body of the reduce window
+/// operation. This class enumerates the different variants.
+enum class PoolingType {
+  kMin,
+  kMax,
+  kAdd,
+};
+
+struct ReduceWindowOpConversion
+    : public OpConversionPattern<mhlo::ReduceWindowOp> {
+  using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      mhlo::ReduceWindowOp op, ArrayRef<Value> args,
+      ConversionPatternRewriter &rewriter) const override;
+};
+
+static PoolingType getPoolingType(Region &region) {
+  assert(region.getBlocks().size() == 1 &&
+         "expected the region has exactlly one block");
+  Block &block = region.front();
+  assert(block.getOperations().size() == 2 &&
+         "expected the block has exactlly two operations");
+  auto op = block.begin();
+  if (isa<mhlo::MinOp>(op)) return PoolingType::kMin;
+  if (isa<mhlo::MaxOp>(op)) return PoolingType::kMax;
+  if (isa<mhlo::AddOp>(op)) return PoolingType::kAdd;
+
+  llvm_unreachable("unknown pooling type");
+}
+
+LogicalResult ReduceWindowOpConversion::matchAndRewrite(
+    mhlo::ReduceWindowOp op, ArrayRef<Value> args,
+    ConversionPatternRewriter &rewriter) const {
+  auto loc = op.getLoc();
+  auto resultType = op.getResult().getType().cast<ShapedType>();
+  if (resultType.getRank() != 4) {
+    return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op");
+  }
+
+  // Create a fake window dimension.
+  SmallVector<int64_t, 4> shapes;
+  shapes.push_back(op.window_dimensions().getValue<int64_t>(1));
+  shapes.push_back(op.window_dimensions().getValue<int64_t>(2));
+  auto fakeWindowDims = rewriter.create<linalg::InitTensorOp>(
+      loc, shapes, resultType.getElementType());
+
+  if (op.window_strides() &&
+      (op.window_strides().getValue().getValue<int64_t>(0) != 1 ||
+       op.window_strides().getValue().getValue<int64_t>(3) != 1)) {
+    return rewriter.notifyMatchFailure(
+        op, "expected window_strides to be [1,x,y,1]");
+  }
+  if (op.window_dimensions() &&
+      (op.window_dimensions().getValue<int64_t>(0) != 1 ||
+       op.window_dimensions().getValue<int64_t>(3) != 1)) {
+    return rewriter.notifyMatchFailure(
+        op, "expected window_dimensions to be [1,x,y,1]");
+  }
+
+  if (!args[0].getType().cast<ShapedType>().getElementType().isF32()) {
+    return rewriter.notifyMatchFailure(op, "expected element type to be f32");
+  }
+
+  Attribute strides;
+  if (op.window_stridesAttr()) {
+    strides = rewriter.getI64VectorAttr(
+        {op.window_strides().getValue().getValue<int64_t>(1),
+         op.window_strides().getValue().getValue<int64_t>(2)});
+  } else {
+    strides = rewriter.getI64VectorAttr({1, 1});
+  }
+  Attribute dilations;
+  if (op.window_dilations()) {
+    dilations = rewriter.getI64VectorAttr(
+        {op.window_dilations().getValue().getValue<int64_t>(1),
+         op.window_dilations().getValue().getValue<int64_t>(2)});
+  } else {
+    dilations = rewriter.getI64VectorAttr({1, 1});
+  }
+  linalg::LinalgOp poolingOp;
+  PoolingType poolingType = getPoolingType(op.body());
+
+  Value initTensor = rewriter.create<linalg::InitTensorOp>(
+      loc, resultType.getShape(), resultType.getElementType());
+  Value initValue = args[1];
+  initValue = rewriter.create<tensor::ExtractOp>(loc, initValue);
+  Value filledInitTensor =
+      rewriter.create<linalg::FillOp>(loc, initTensor, initValue).getResult(0);
+  auto createOp = [&](auto *type_ptr) -> linalg::LinalgOp {
+    return cast<linalg::LinalgOp>(
+        rewriter
+            .create<std::remove_pointer_t<decltype(type_ptr)>>(
+                loc, ArrayRef<Type>{resultType},
+                ValueRange{args[0], fakeWindowDims.getResult()},
+                filledInitTensor, dilations, strides)
+            .getOperation());
+  };
+  switch (poolingType) {
+    case PoolingType::kMin: {
+      poolingOp = createOp(static_cast<linalg::PoolingNHWCMinOp *>(nullptr));
+      break;
+    }
+    case PoolingType::kMax: {
+      poolingOp = createOp(static_cast<linalg::PoolingNHWCMaxOp *>(nullptr));
+      break;
+    }
+    case PoolingType::kAdd: {
+      poolingOp = createOp(static_cast<linalg::PoolingNHWCSumOp *>(nullptr));
+      break;
+    }
+  }
+  rewriter.replaceOp(op, poolingOp->getResult(0));
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // mhlo.conv conversion patterns.
 //===----------------------------------------------------------------------===//
 
@@ -506,13 +627,7 @@
     }
 
     ConversionTarget target(getContext());
-    // Don't convert the body of reduction ops.
-    target.addDynamicallyLegalDialect<mhlo::MhloDialect>(
-        Optional<ConversionTarget::DynamicLegalityCallbackFn>(
-            [](Operation *op) {
-              auto parentOp = op->getParentRegion()->getParentOp();
-              return isa<mhlo::ReduceWindowOp>(parentOp);
-            }));
+    target.addIllegalDialect<mhlo::MhloDialect>();
     // Let the rest fall through.
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
     if (useLinalgOnTensorsPath) {
@@ -554,7 +669,8 @@
     MLIRContext *context, OwningRewritePatternList &patterns) {
   mhlo::populateHLOToLinalgConversionPattern(context, &patterns);
   patterns.insert<TorchIndexSelectOpConversion, ConstOpConversion,
-                  ConcatenateOpConversion, DepthwiseConvOpConversion>(context);
+                  ReduceWindowOpConversion, ConcatenateOpConversion,
+                  DepthwiseConvOpConversion>(context);
 }
 
 static llvm::cl::opt<bool> clUseLinalgOnTensorsPath(
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
index 0d8ab1e..9e9e46b 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
@@ -589,3 +589,40 @@
 //       CHECK:   linalg.batch_matmul
 //  CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : memref<2x2x3xf32>, memref<2x3x4xf32>)
 //  CHECK-SAME:    outs(%[[RET]] : memref<2x2x4xf32>)
+
+// -----
+
+module {
+  func @reduce_window_sum_nhwc() {
+    %c0 = constant 0 : index
+    %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+    %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+    %2 = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+    %3 = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
+    %4 = tensor.extract %1[] : tensor<f32>
+    %5 = linalg.fill(%3, %4) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+    %6 = linalg.pooling_nhwc_sum {
+        dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+      ins(%0, %2 : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+      outs(%5 : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+    hal.interface.store.tensor %6, @legacy_io::@ret0, offset = %c0 : tensor<1x8x8x64xf32>
+    return
+  }
+  hal.interface @legacy_io attributes {sym_visibility = "private"} {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+  }
+}
+// CHECK-LABEL: func @reduce_window_sum_nhwc
+// CHECK-DAG:     %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
+// CHECK-DAG:     %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
+// CHECK-DAG:     %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
+// CHECK:         %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
+// CHECK:         %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
+// CHECK:         linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
+// CHECK:         linalg.pooling_nhwc_sum
+// CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME:       strides = dense<2> : vector<2xi64>}
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
+// CHECK-SAME:      outs(%[[RES]] : memref<1x8x8x64xf32>)
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/reduce_window.mlir b/iree/compiler/Conversion/HLOToLinalg/test/reduce_window.mlir
index 5d3e6cc..2784636 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/reduce_window.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/reduce_window.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-tensors %s | IreeFileCheck %s
 
 module {
   func @reduce_window_min_nhwc() {
@@ -21,17 +21,17 @@
   }
 }
 // CHECK-LABEL: func @reduce_window_min_nhwc
-// CHECK-DAG:     %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
-// CHECK-DAG:     %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-// CHECK-DAG:     %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
-// CHECK:         %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
-// CHECK:         %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
-// CHECK:         linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
-// CHECK:         linalg.pooling_nhwc_min
+// CHECK-DAG:     %[[ARG0:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+// CHECK-DAG:     %[[ARG1:.+]] = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+// CHECK:         %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_nhwc_min
 // CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
 // CHECK-SAME:       strides = dense<2> : vector<2xi64>}
-// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
-// CHECK-SAME:      outs(%[[RES]] : memref<1x8x8x64xf32>)
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
 
 // -----
 
@@ -56,22 +56,22 @@
   }
 }
 // CHECK-LABEL: func @reduce_window_max_nhwc
-// CHECK-DAG:     %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
-// CHECK-DAG:     %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-// CHECK-DAG:     %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
-// CHECK:         %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
-// CHECK:         %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
-// CHECK:         linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
-// CHECK:         linalg.pooling_nhwc_max
+// CHECK-DAG:     %[[ARG0:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+// CHECK-DAG:     %[[ARG1:.+]] = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+// CHECK:         %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_nhwc_max
 // CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
 // CHECK-SAME:       strides = dense<2> : vector<2xi64>}
-// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
-// CHECK-SAME:      outs(%[[RES]] : memref<1x8x8x64xf32>)
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
 
 // -----
 
 module {
-  func @reduce_window_add_nhwc() {
+  func @reduce_window_sum_nhwc() {
     %c0 = constant 0 : index
     %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
     %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
@@ -90,18 +90,18 @@
     hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
   }
 }
-// CHECK-LABEL: func @reduce_window_add_nhwc
-// CHECK-DAG:     %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
-// CHECK-DAG:     %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-// CHECK-DAG:     %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
-// CHECK:         %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
-// CHECK:         %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
-// CHECK:         linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
-// CHECK:         linalg.pooling_nhwc_sum
+// CHECK-LABEL: func @reduce_window_sum_nhwc
+// CHECK-DAG:     %[[ARG0:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+// CHECK-DAG:     %[[ARG1:.+]] = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+// CHECK:         %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_nhwc_sum
 // CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
 // CHECK-SAME:       strides = dense<2> : vector<2xi64>}
-// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
-// CHECK-SAME:      outs(%[[RES]] : memref<1x8x8x64xf32>)
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
 
 // -----
 
@@ -125,13 +125,14 @@
   }
 }
 // CHECK-LABEL: func @reduce_window_max_nhwc
-// CHECK-DAG:     %[[INIT:.+]] = constant 0xFF800000 : f32
-// CHECK-DAG:     %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
-// CHECK-DAG:     %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
-// CHECK:         %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
-// CHECK:         linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
-// CHECK:         linalg.pooling_nhwc_max
+// CHECK-DAG:     %[[CST:.+]] = constant dense<0xFF800000> : tensor<f32>
+// CHECK-DAG:     %[[ARG0:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+// CHECK:         %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+// CHECK:         %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32
+// CHECK:         %[[INIT_VAL:.+]] = tensor.extract %[[CST]][] : tensor<f32>
+// CHECK:         %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+// CHECK:         %[[RES:.+]] = linalg.pooling_nhwc_max
 // CHECK-SAME:      {dilations = dense<1> : vector<2xi64>
 // CHECK-SAME:       strides = dense<2> : vector<2xi64>}
-// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
-// CHECK-SAME:      outs(%[[RES]] : memref<1x8x8x64xf32>)
+// CHECK-SAME:      ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>