Fix handling of no-op subviews for rank-reducing cases. (#8449)
The current check to avoid generation of no-op subviews, would also
not generate subviews when the rank of the result was lesser than the
source. That still needs to be a subview.
Also add some lit tests.
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index 9eed987..bcb1b79 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -159,6 +159,10 @@
/// that implements the `ShapeAwareOpInterface` (like
/// `hal.interface.binding.subspan`) then we can use that to check dynamic
/// equality.
+/// Note: This could be written as a canonicalizer, but the subview formed
+/// when there are dynamic shapes involved will have affine maps
+/// that shouldnt be there. Resolving that is a pain. So dont generate the
+/// subview to begin with.
static bool generatesNoOpSubView(Value src, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
@@ -223,11 +227,12 @@
Value src, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
- if (generatesNoOpSubView(src, offsets, sizes, strides)) {
+ MemRefType srcType = src.getType().cast<MemRefType>();
+ if (srcType.getRank() == resultRank &&
+ generatesNoOpSubView(src, offsets, sizes, strides)) {
return src;
}
MemRefType resultType;
- MemRefType srcType = src.getType().cast<MemRefType>();
if (srcType.getRank() != resultRank) {
resultType = memref::SubViewOp::inferRankReducedResultType(
resultRank, srcType, offsets, sizes, strides)
diff --git a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
index 130f705..ae53a94 100644
--- a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
@@ -2654,3 +2654,47 @@
flow.dispatch.tensor.store %5, %2, offsets = [0], sizes = [4], strides = [1] : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:4xf32>
return
}
+
+// -----
+
+func @no_op_subview() {
+ %c0 = arith.constant 0 : index
+ %d0 = hal.interface.constant.load[0] : index
+ %d1 = hal.interface.constant.load[1] : index
+ %src_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:?x?xf32>{%d0, %d1}
+ %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:?x?xf32>{%d0, %d1}
+ %src = flow.dispatch.tensor.load %src_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:?x?xf32>{%d0, %d1} -> tensor<?x?xf32>
+ %slice = tensor.extract_slice %src[0, 0] [%d0, %d1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ flow.dispatch.tensor.store %slice, %dest_binding, offsets = [0, 0], sizes = [%d0, %d1], strides = [1, 1]
+ : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>{%d0, %d1}
+ return
+}
+// CHECK-LABEL: func @no_op_subview()
+// CHECK-DAG: %[[SRC:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[DEST:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SRC]] :
+// CHECK-SAME: outs(%[[DEST]] :
+
+// -----
+
+func @rank_reducing_no_op_subview() {
+ %c0 = arith.constant 0 : index
+ %d0 = hal.interface.constant.load[0] : index
+ %src_binding = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<readonly:1x?xf32>{%d0}
+ %dest_binding = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(32) : !flow.dispatch.tensor<writeonly:?xf32>{%d0}
+ %src = flow.dispatch.tensor.load %src_binding, offsets = [0, 0], sizes = [1, %d0], strides = [1, 1]
+ : !flow.dispatch.tensor<readonly:1x?xf32>{%d0} -> tensor<1x?xf32>
+ %slice = tensor.extract_slice %src[0, 0] [1, %d0] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
+ flow.dispatch.tensor.store %slice, %dest_binding, offsets = [0], sizes = [%d0], strides = [1]
+ : tensor<?xf32> -> !flow.dispatch.tensor<writeonly:?xf32>{%d0}
+ return
+}
+// CHECK-LABEL: func @rank_reducing_no_op_subview()
+// CHECK-DAG: %[[SRC:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[DEST:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][0, 0] [1, %{{.+}}]
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SUBVIEW]] :
+// CHECK-SAME: outs(%[[DEST]] :