Adapt pooling ops lowering to Linalg on tensors. (#5024)
This also introduces the concept of bufferization of shape-only operand.
If an init_tensor op can not map to a buffer, a fake/temp memref will be
allocated. And we rely on later passes to delete the alloc op. This is
similar to what we did to lower mhlo.reduce_window to linalg.pooling_*.
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 800768c..da5d8a0 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -340,146 +340,6 @@
} // namespace
//===----------------------------------------------------------------------===//
-// mhlo.reduce_window conversion patterns and utility functions.
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-/// Returns the constant value associated with the init value if the defining
-/// operation is a constant.
-static Attribute GetInitValueAsConst(Value init) {
- DenseElementsAttr attr;
- if (!matchPattern(init, m_Constant(&attr))) return {};
- auto type = attr.getType().dyn_cast<ShapedType>();
- if (!type || type.getRank() != 0) return {};
- return attr.getValue({});
-}
-
-/// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
-/// the pooling is determined based on the body of the reduce window
-/// operation. This class enumerates the different variants.
-enum class PoolingType {
- kMin,
- kMax,
- kAdd,
-};
-
-struct ReduceWindowOpConversion
- : public ConvertToLinalgBufferOp<ReduceWindowOpConversion,
- mhlo::ReduceWindowOp> {
- using ConvertToLinalgBufferOp<ReduceWindowOpConversion,
- mhlo::ReduceWindowOp>::ConvertToLinalgBufferOp;
-
- LogicalResult apply(mhlo::ReduceWindowOp op, ArrayRef<Value> inputBuffers,
- ArrayRef<Value> resultBuffers,
- ConversionPatternRewriter &rewriter) const;
-};
-} // namespace
-
-static PoolingType getPoolingType(Region ®ion) {
- assert(region.getBlocks().size() == 1 &&
- "expected the region has exactlly one block");
- Block &block = region.front();
- assert(block.getOperations().size() == 2 &&
- "expected the block has exactlly two operations");
- auto op = block.begin();
- if (isa<mhlo::MinOp>(op)) return PoolingType::kMin;
- if (isa<mhlo::MaxOp>(op)) return PoolingType::kMax;
- if (isa<mhlo::AddOp>(op)) return PoolingType::kAdd;
-
- llvm_unreachable("unknown pooling type");
-}
-
-LogicalResult ReduceWindowOpConversion::apply(
- mhlo::ReduceWindowOp op, ArrayRef<Value> inputBuffers,
- ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
- auto loc = op.getLoc();
- auto resultType = op.getResult().getType().cast<ShapedType>();
- if (resultType.getRank() != 4) {
- return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op");
- }
-
- // Create a fake window dimension.
- SmallVector<int64_t, 4> shapes;
- shapes.push_back(op.window_dimensions().getValue<int64_t>(1));
- shapes.push_back(op.window_dimensions().getValue<int64_t>(2));
- auto memrefType = MemRefType::get(shapes, rewriter.getF32Type());
- auto fakeWindowDims = rewriter.create<AllocOp>(loc, memrefType);
-
- if (op.window_strides() &&
- (op.window_strides().getValue().getValue<int64_t>(0) != 1 ||
- op.window_strides().getValue().getValue<int64_t>(3) != 1)) {
- return rewriter.notifyMatchFailure(
- op, "expected window_strides to be [1,x,y,1]");
- }
- if (op.window_dimensions() &&
- (op.window_dimensions().getValue<int64_t>(0) != 1 ||
- op.window_dimensions().getValue<int64_t>(3) != 1)) {
- return rewriter.notifyMatchFailure(
- op, "expected window_dimensions to be [1,x,y,1]");
- }
-
- if (!inputBuffers[0].getType().cast<ShapedType>().getElementType().isF32()) {
- return rewriter.notifyMatchFailure(op, "expected element type to be f32");
- }
-
- Attribute strides;
- if (op.window_stridesAttr()) {
- strides = rewriter.getI64VectorAttr(
- {op.window_strides().getValue().getValue<int64_t>(1),
- op.window_strides().getValue().getValue<int64_t>(2)});
- } else {
- strides = rewriter.getI64VectorAttr({1, 1});
- }
- Attribute dilations;
- if (op.window_dilations()) {
- dilations = rewriter.getI64VectorAttr(
- {op.window_dilations().getValue().getValue<int64_t>(1),
- op.window_dilations().getValue().getValue<int64_t>(2)});
- } else {
- dilations = rewriter.getI64VectorAttr({1, 1});
- }
- auto createOp = [&](auto *type_ptr) -> linalg::LinalgOp {
- return cast<linalg::LinalgOp>(
- rewriter
- .create<std::remove_pointer_t<decltype(type_ptr)>>(
- loc, ArrayRef<Type>{},
- ValueRange{inputBuffers[0], fakeWindowDims.getResult()},
- resultBuffers[0], dilations, strides)
- .getOperation());
- };
- linalg::LinalgOp poolingOp;
- PoolingType poolingType = getPoolingType(op.body());
-
- Value initValue = inputBuffers[1];
- Attribute initConstVal = GetInitValueAsConst(initValue);
- if (initConstVal) {
- initValue = rewriter.create<ConstantOp>(initValue.getDefiningOp()->getLoc(),
- initConstVal);
- } else {
- initValue = rewriter.create<LoadOp>(loc, initValue);
- }
- rewriter.create<linalg::FillOp>(loc, resultBuffers[0], initValue);
-
- switch (poolingType) {
- case PoolingType::kMin: {
- poolingOp = createOp(static_cast<linalg::PoolingNHWCMinOp *>(nullptr));
- break;
- }
- case PoolingType::kMax: {
- poolingOp = createOp(static_cast<linalg::PoolingNHWCMaxOp *>(nullptr));
- break;
- }
- case PoolingType::kAdd: {
- poolingOp = createOp(static_cast<linalg::PoolingNHWCSumOp *>(nullptr));
- break;
- }
- }
- rewriter.create<DeallocOp>(loc, fakeWindowDims);
- return success();
-}
-
-//===----------------------------------------------------------------------===//
// Linalg op on tensors to linalg op on buffers conversion base class.
//===----------------------------------------------------------------------===//
@@ -635,8 +495,15 @@
linalg::InitTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Value outputBuffer = resultTensorToBufferMap.lookup(op.result());
- if (!outputBuffer) return failure();
- rewriter.replaceOp(op, outputBuffer);
+ if (!outputBuffer) {
+ // If the outputBuffer does not exist, this is a shape-only operand.
+ // Allocate a temp buffer and it will get deleted after lowering to loops.
+ RankedTensorType type = op.getType();
+ auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
+ rewriter.replaceOpWithNewOp<AllocOp>(op, memrefType);
+ } else {
+ rewriter.replaceOp(op, outputBuffer);
+ }
return success();
}
@@ -1007,8 +874,10 @@
NamedOpConversion<linalg::MatmulI16I16I32Op>,
NamedOpConversion<linalg::MatmulI32I32I32Op>,
NamedOpConversion<linalg::BatchMatmulOp>,
+ NamedOpConversion<linalg::PoolingNHWCMaxOp>,
+ NamedOpConversion<linalg::PoolingNHWCMinOp>,
+ NamedOpConversion<linalg::PoolingNHWCSumOp>,
PadTensorOpConversion,
- ReduceWindowOpConversion,
SubTensorOpConversion,
SubTensorInsertOpConversion,
TensorReshapeOpConversion
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index ac4ea51..2c5b278 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -152,6 +152,127 @@
} // namespace
//===----------------------------------------------------------------------===//
+// mhlo.reduce_window conversion patterns.
+//===----------------------------------------------------------------------===//
+
+/// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
+/// the pooling is determined based on the body of the reduce window
+/// operation. This class enumerates the different variants.
+enum class PoolingType {
+ kMin,
+ kMax,
+ kAdd,
+};
+
+struct ReduceWindowOpConversion
+ : public OpConversionPattern<mhlo::ReduceWindowOp> {
+ using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mhlo::ReduceWindowOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+static PoolingType getPoolingType(Region ®ion) {
+ assert(region.getBlocks().size() == 1 &&
+ "expected the region has exactlly one block");
+ Block &block = region.front();
+ assert(block.getOperations().size() == 2 &&
+ "expected the block has exactlly two operations");
+ auto op = block.begin();
+ if (isa<mhlo::MinOp>(op)) return PoolingType::kMin;
+ if (isa<mhlo::MaxOp>(op)) return PoolingType::kMax;
+ if (isa<mhlo::AddOp>(op)) return PoolingType::kAdd;
+
+ llvm_unreachable("unknown pooling type");
+}
+
+LogicalResult ReduceWindowOpConversion::matchAndRewrite(
+ mhlo::ReduceWindowOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const {
+ auto loc = op.getLoc();
+ auto resultType = op.getResult().getType().cast<ShapedType>();
+ if (resultType.getRank() != 4) {
+ return rewriter.notifyMatchFailure(op, "expected NHWC pooling-based op");
+ }
+
+ // Create a fake window dimension.
+ SmallVector<int64_t, 4> shapes;
+ shapes.push_back(op.window_dimensions().getValue<int64_t>(1));
+ shapes.push_back(op.window_dimensions().getValue<int64_t>(2));
+ auto fakeWindowDims = rewriter.create<linalg::InitTensorOp>(
+ loc, shapes, resultType.getElementType());
+
+ if (op.window_strides() &&
+ (op.window_strides().getValue().getValue<int64_t>(0) != 1 ||
+ op.window_strides().getValue().getValue<int64_t>(3) != 1)) {
+ return rewriter.notifyMatchFailure(
+ op, "expected window_strides to be [1,x,y,1]");
+ }
+ if (op.window_dimensions() &&
+ (op.window_dimensions().getValue<int64_t>(0) != 1 ||
+ op.window_dimensions().getValue<int64_t>(3) != 1)) {
+ return rewriter.notifyMatchFailure(
+ op, "expected window_dimensions to be [1,x,y,1]");
+ }
+
+ if (!args[0].getType().cast<ShapedType>().getElementType().isF32()) {
+ return rewriter.notifyMatchFailure(op, "expected element type to be f32");
+ }
+
+ Attribute strides;
+ if (op.window_stridesAttr()) {
+ strides = rewriter.getI64VectorAttr(
+ {op.window_strides().getValue().getValue<int64_t>(1),
+ op.window_strides().getValue().getValue<int64_t>(2)});
+ } else {
+ strides = rewriter.getI64VectorAttr({1, 1});
+ }
+ Attribute dilations;
+ if (op.window_dilations()) {
+ dilations = rewriter.getI64VectorAttr(
+ {op.window_dilations().getValue().getValue<int64_t>(1),
+ op.window_dilations().getValue().getValue<int64_t>(2)});
+ } else {
+ dilations = rewriter.getI64VectorAttr({1, 1});
+ }
+ linalg::LinalgOp poolingOp;
+ PoolingType poolingType = getPoolingType(op.body());
+
+ Value initTensor = rewriter.create<linalg::InitTensorOp>(
+ loc, resultType.getShape(), resultType.getElementType());
+ Value initValue = args[1];
+ initValue = rewriter.create<tensor::ExtractOp>(loc, initValue);
+ Value filledInitTensor =
+ rewriter.create<linalg::FillOp>(loc, initTensor, initValue).getResult(0);
+ auto createOp = [&](auto *type_ptr) -> linalg::LinalgOp {
+ return cast<linalg::LinalgOp>(
+ rewriter
+ .create<std::remove_pointer_t<decltype(type_ptr)>>(
+ loc, ArrayRef<Type>{resultType},
+ ValueRange{args[0], fakeWindowDims.getResult()},
+ filledInitTensor, dilations, strides)
+ .getOperation());
+ };
+ switch (poolingType) {
+ case PoolingType::kMin: {
+ poolingOp = createOp(static_cast<linalg::PoolingNHWCMinOp *>(nullptr));
+ break;
+ }
+ case PoolingType::kMax: {
+ poolingOp = createOp(static_cast<linalg::PoolingNHWCMaxOp *>(nullptr));
+ break;
+ }
+ case PoolingType::kAdd: {
+ poolingOp = createOp(static_cast<linalg::PoolingNHWCSumOp *>(nullptr));
+ break;
+ }
+ }
+ rewriter.replaceOp(op, poolingOp->getResult(0));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// mhlo.conv conversion patterns.
//===----------------------------------------------------------------------===//
@@ -506,13 +627,7 @@
}
ConversionTarget target(getContext());
- // Don't convert the body of reduction ops.
- target.addDynamicallyLegalDialect<mhlo::MhloDialect>(
- Optional<ConversionTarget::DynamicLegalityCallbackFn>(
- [](Operation *op) {
- auto parentOp = op->getParentRegion()->getParentOp();
- return isa<mhlo::ReduceWindowOp>(parentOp);
- }));
+ target.addIllegalDialect<mhlo::MhloDialect>();
// Let the rest fall through.
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (useLinalgOnTensorsPath) {
@@ -554,7 +669,8 @@
MLIRContext *context, OwningRewritePatternList &patterns) {
mhlo::populateHLOToLinalgConversionPattern(context, &patterns);
patterns.insert<TorchIndexSelectOpConversion, ConstOpConversion,
- ConcatenateOpConversion, DepthwiseConvOpConversion>(context);
+ ReduceWindowOpConversion, ConcatenateOpConversion,
+ DepthwiseConvOpConversion>(context);
}
static llvm::cl::opt<bool> clUseLinalgOnTensorsPath(
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
index 0d8ab1e..9e9e46b 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
@@ -589,3 +589,40 @@
// CHECK: linalg.batch_matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<2x2x3xf32>, memref<2x3x4xf32>)
// CHECK-SAME: outs(%[[RET]] : memref<2x2x4xf32>)
+
+// -----
+
+module {
+ func @reduce_window_sum_nhwc() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+ %2 = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+ %3 = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
+ %4 = tensor.extract %1[] : tensor<f32>
+ %5 = linalg.fill(%3, %4) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+ %6 = linalg.pooling_nhwc_sum {
+ dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+ ins(%0, %2 : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+ outs(%5 : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
+ hal.interface.store.tensor %6, @legacy_io::@ret0, offset = %c0 : tensor<1x8x8x64xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-LABEL: func @reduce_window_sum_nhwc
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
+// CHECK-DAG: %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
+// CHECK: %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
+// CHECK: %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
+// CHECK: linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
+// CHECK: linalg.pooling_nhwc_sum
+// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
+// CHECK-SAME: strides = dense<2> : vector<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
+// CHECK-SAME: outs(%[[RES]] : memref<1x8x8x64xf32>)
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/reduce_window.mlir b/iree/compiler/Conversion/HLOToLinalg/test/reduce_window.mlir
index 5d3e6cc..2784636 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/reduce_window.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/reduce_window.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-tensors %s | IreeFileCheck %s
module {
func @reduce_window_min_nhwc() {
@@ -21,17 +21,17 @@
}
}
// CHECK-LABEL: func @reduce_window_min_nhwc
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-// CHECK-DAG: %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
-// CHECK: %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
-// CHECK: %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
-// CHECK: linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
-// CHECK: linalg.pooling_nhwc_min
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+// CHECK-DAG: %[[ARG1:.+]] = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_min
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
-// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
-// CHECK-SAME: outs(%[[RES]] : memref<1x8x8x64xf32>)
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
// -----
@@ -56,22 +56,22 @@
}
}
// CHECK-LABEL: func @reduce_window_max_nhwc
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-// CHECK-DAG: %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
-// CHECK: %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
-// CHECK: %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
-// CHECK: linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
-// CHECK: linalg.pooling_nhwc_max
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+// CHECK-DAG: %[[ARG1:.+]] = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
-// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
-// CHECK-SAME: outs(%[[RES]] : memref<1x8x8x64xf32>)
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
// -----
module {
- func @reduce_window_add_nhwc() {
+ func @reduce_window_sum_nhwc() {
%c0 = constant 0 : index
%0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
%1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
@@ -90,18 +90,18 @@
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
}
}
-// CHECK-LABEL: func @reduce_window_add_nhwc
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<f32>
-// CHECK-DAG: %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
-// CHECK: %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
-// CHECK: %[[INIT:.+]] = load %[[ARG1]][] : memref<f32>
-// CHECK: linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
-// CHECK: linalg.pooling_nhwc_sum
+// CHECK-LABEL: func @reduce_window_sum_nhwc
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+// CHECK-DAG: %[[ARG1:.+]] = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<f32>
+// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32>
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_sum
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
-// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
-// CHECK-SAME: outs(%[[RES]] : memref<1x8x8x64xf32>)
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>
// -----
@@ -125,13 +125,14 @@
}
}
// CHECK-LABEL: func @reduce_window_max_nhwc
-// CHECK-DAG: %[[INIT:.+]] = constant 0xFF800000 : f32
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x18x18x64xf32>
-// CHECK-DAG: %[[RES:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x8x8x64xf32>
-// CHECK: %[[WINDOW:.+]] = alloc() : memref<3x3xf32>
-// CHECK: linalg.fill(%[[RES]], %[[INIT]]) : memref<1x8x8x64xf32>, f32
-// CHECK: linalg.pooling_nhwc_max
+// CHECK-DAG: %[[CST:.+]] = constant dense<0xFF800000> : tensor<f32>
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x18x18x64xf32>
+// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 8, 8, 64] : tensor<1x8x8x64xf32
+// CHECK: %[[INIT_VAL:.+]] = tensor.extract %[[CST]][] : tensor<f32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[INIT_VAL]]) : tensor<1x8x8x64xf32>, f32 -> tensor<1x8x8x64xf32>
+// CHECK: %[[RES:.+]] = linalg.pooling_nhwc_max
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>
// CHECK-SAME: strides = dense<2> : vector<2xi64>}
-// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : memref<1x18x18x64xf32>, memref<3x3xf32>)
-// CHECK-SAME: outs(%[[RES]] : memref<1x8x8x64xf32>)
+// CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32>