[Codegen] Allow memref type propagation through collapse_shape (#19400)

This PR adds support for propagating memref type changes through
memref.collapse_shape ops in the `replaceMemrefUsesAndPropagateType`
util function. This propagation is used in allocation padding, since the
strides of the memref type change after padding.

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir
index e9d4d7b..4c08ebd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/pad_dynamic_alloc.mlir
@@ -48,3 +48,22 @@
 }
 // CHECK-LABEL: func @dynamic_bound_alloca(
 //       CHECK:   memref.alloca() : memref<4088xf32, 3>
+
+// -----
+
+func.func @dynamic_alloc_collapse_consumer(%id : index) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = util.assume.int %id<umin = 0, umax =  32> : index
+  %1 = memref.alloc(%0, %0) : memref<?x?xf32, 3>
+  %2 = memref.collapse_shape %1 [[0, 1]] : memref<?x?xf32, 3> into memref<?xf32, 3>
+  memref.store %cst, %2[%c0] : memref<?xf32, 3>
+  return
+}
+// CHECK-LABEL: func @dynamic_alloc_collapse_consumer(
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<32x32xf32, 3>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]]
+//  CHECK-SAME:     [0, 0] [{{.*}}] [1, 1] : memref<32x32xf32, 3> to memref<?x?xf32, strided<[32, 1]>, 3>
+//       CHECK:   %[[COLLAPSE:.+]] = memref.collapse_shape %[[SUBVIEW]] {{\[}}[0, 1]]
+//  CHECK-SAME:     : memref<?x?xf32, strided<[32, 1]>, 3> into memref<?xf32, strided<[?]>, 3>
+//       CHECK:   memref.store {{.*}} %[[COLLAPSE]]{{.*}} : memref<?xf32, strided<[?]>, 3>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index 37d061a..ddc0b9a 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -1342,6 +1342,24 @@
     });
     return llvm::to_vector_of<Value>(newExpandOp->getResults());
   }
+  if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(user)) {
+    auto newSourceType = llvm::cast<MemRefType>(replacement.getType());
+    FailureOr<MemRefType> newResultType =
+        memref::CollapseShapeOp::computeCollapsedType(
+            newSourceType, collapseOp.getReassociationIndices());
+    if (failed(newResultType)) {
+      return std::nullopt;
+    }
+
+    auto newCollapseOp = rewriter.create<memref::CollapseShapeOp>(
+        loc, *newResultType, replacement, collapseOp.getReassociation());
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t\tNew user : ";
+      newCollapseOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
+      llvm::dbgs() << "\n";
+    });
+    return llvm::to_vector_of<Value>(newCollapseOp->getResults());
+  }
   return std::nullopt;
 }