[bufferize] check if the memref types are same for subspans (#9652)
Subspans are different when the result types are different.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index 2bb7438..ea984b8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -2325,3 +2325,39 @@
// CHECK: iree_linalg_ext.topk
// CHECK-SAME: ins(%[[INPUT_VALUES]], %[[INPUT_INDICES]]
// CHECK-SAME: outs(%[[OUTPUT_VALUES]], %[[OUTPUT_INDICES]]
+
+// -----
+module {
+ func.func @reduction_ew() {
+ %c5120 = arith.constant 5120 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_0 = arith.constant 1.000000e+00 : f32
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c5120) alignment(64) : !flow.dispatch.tensor<readonly:1001xf32>
+ %1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c5120) alignment(64) : !flow.dispatch.tensor<readonly:1x1001xf32>
+ %2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:1x1001xf32>
+ %3 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [1, 1001], strides = [1, 1] : !flow.dispatch.tensor<writeonly:1x1001xf32> -> tensor<1x1001xf32>
+ %4 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [1001], strides = [1] : !flow.dispatch.tensor<readonly:1001xf32> -> tensor<1001xf32>
+ %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1, 1001], strides = [1, 1] : !flow.dispatch.tensor<readonly:1x1001xf32> -> tensor<1x1001xf32>
+ %6 = bufferization.alloc_tensor() : tensor<f32>
+ %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<f32>) -> tensor<f32>
+ %8 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%4 : tensor<1001xf32>) outs(%7 : tensor<f32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0]]>} {
+ ^bb0(%arg0: f32, %arg1: f32):
+ %10 = arith.addf %arg0, %arg1 : f32
+ linalg.yield %10 : f32
+ } -> tensor<f32>
+ %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %8 : tensor<1x1001xf32>, tensor<f32>) outs(%3 : tensor<1x1001xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+ %10 = arith.divf %cst_0, %arg1 : f32
+ %11 = arith.mulf %arg0, %10 : f32
+ linalg.yield %11 : f32
+ } -> tensor<1x1001xf32>
+ flow.dispatch.tensor.store %9, %2, offsets = [0, 0], sizes = [1, 1001], strides = [1, 1] : tensor<1x1001xf32> -> !flow.dispatch.tensor<writeonly:1x1001xf32>
+ return
+ }
+}
+
+// CHECK: func.func @reduction_ew
+// CHECK: hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c5120) alignment(64) : memref<1001xf32>
+// CHECK: hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c5120) alignment(64) : memref<1x1001xf32>
+// CHECK: hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<1x1001xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
index 706948b..cab413a 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
@@ -60,19 +60,27 @@
.dyn_cast<IREE::Flow::DispatchTensorType>();
assert(shapedType && shapedType.hasRank());
+ auto memRefType = getMemrefTypeForTensor(shapedType);
+
// Look for an existing op.
Block *block = subspanOp->getBlock();
for (Operation &op : *block) {
if (&op == subspanOp.getOperation()) break;
auto bufferSubspanOp = dyn_cast<IREE::HAL::InterfaceBindingSubspanOp>(&op);
if (!bufferSubspanOp) continue;
+
+ auto bufferMemrefType =
+ bufferSubspanOp.getResult().getType().dyn_cast<MemRefType>();
+ if (!bufferMemrefType) continue;
+
if (bufferSubspanOp.set() != subspanOp.set() ||
bufferSubspanOp.binding() != subspanOp.binding() ||
bufferSubspanOp.type() != subspanOp.type() ||
bufferSubspanOp.byte_offset() != subspanOp.byte_offset() ||
!llvm::equal(bufferSubspanOp.dynamic_dims(),
subspanOp.dynamic_dims()) ||
- bufferSubspanOp.alignment() != subspanOp.alignment())
+ bufferSubspanOp.alignment() != subspanOp.alignment() ||
+ memRefType != bufferMemrefType)
continue;
return bufferSubspanOp.getResult();
}
@@ -81,7 +89,6 @@
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(subspanOp);
// Just change the result type of the InterfaceBindingSubspanOp.
- auto memRefType = getMemrefTypeForTensor(shapedType);
Value buffer = b.create<IREE::HAL::InterfaceBindingSubspanOp>(
subspanOp->getLoc(), memRefType, subspanOp.set(), subspanOp.binding(),
subspanOp.type(), subspanOp.byte_offset(), subspanOp.dynamic_dims(),