[Codegen] Allow memref type propagation through collapse_shape (#19400)
This PR adds support for propagating memref type changes through
memref.collapse_shape ops in the `replaceMemrefUsesAndPropagateType`
util function. This propagation is used in allocation padding, since the
strides of the memref type change after padding.
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir
index e9d4d7b..4c08ebd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir
@@ -48,3 +48,22 @@
}
// CHECK-LABEL: func @dynamic_bound_alloca(
// CHECK: memref.alloca() : memref<4088xf32, 3>
+
+// -----
+
+func.func @dynamic_alloc_collapse_consumer(%id : index) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = util.assume.int %id<umin = 0, umax = 32> : index
+ %1 = memref.alloc(%0, %0) : memref<?x?xf32, 3>
+ %2 = memref.collapse_shape %1 [[0, 1]] : memref<?x?xf32, 3> into memref<?xf32, 3>
+ memref.store %cst, %2[%c0] : memref<?xf32, 3>
+ return
+}
+// CHECK-LABEL: func @dynamic_alloc_collapse_consumer(
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x32xf32, 3>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]]
+// CHECK-SAME: [0, 0] [{{.*}}] [1, 1] : memref<32x32xf32, 3> to memref<?x?xf32, strided<[32, 1]>, 3>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SUBVIEW]] {{\[}}[0, 1]]
+// CHECK-SAME: : memref<?x?xf32, strided<[32, 1]>, 3> into memref<?xf32, strided<[?]>, 3>
+// CHECK: memref.store {{.*}} %[[COLLAPSE]]{{.*}} : memref<?xf32, strided<[?]>, 3>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index 37d061a..ddc0b9a 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -1342,6 +1342,24 @@
});
return llvm::to_vector_of<Value>(newExpandOp->getResults());
}
+ if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(user)) {
+ auto newSourceType = llvm::cast<MemRefType>(replacement.getType());
+ FailureOr<MemRefType> newResultType =
+ memref::CollapseShapeOp::computeCollapsedType(
+ newSourceType, collapseOp.getReassociationIndices());
+ if (failed(newResultType)) {
+ return std::nullopt;
+ }
+
+ auto newCollapseOp = rewriter.create<memref::CollapseShapeOp>(
+ loc, *newResultType, replacement, collapseOp.getReassociation());
+ LLVM_DEBUG({
+ llvm::dbgs() << "\t\tNew user : ";
+ newCollapseOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
+ llvm::dbgs() << "\n";
+ });
+ return llvm::to_vector_of<Value>(newCollapseOp->getResults());
+ }
return std::nullopt;
}