[Codegen] Add support for memref.expand_shape to propagation util (#18202)

Similar to `memref.subview`, `memref.expand_shape` needs to have its
type updated when propagating type changes. This adds support for expand
shape to the propagation util so that passes like GPUReduceBankConflicts
can handle `memref.expand_shape`.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir
index 1e9d647..befb244 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/reduce_bank_conflicts.mlir
@@ -25,6 +25,30 @@
 
 // -----
 
+// CHECK-LABEL: func.func @pad_alloc_expand_shape
+// CHECK:         %[[A:.*]] = memref.alloc() : memref<4x32x66xf32, #gpu.address_space<workgroup>>
+// CHECK:         %[[S1:.*]] = memref.subview %[[A]][0, 0, 0] [4, 32, 64] [1, 1, 1] :
+// CHECK-SAME:      memref<4x32x66xf32, #gpu.address_space<workgroup>> to memref<4x32x64xf32, strided<[2112, 66, 1]>, #gpu.address_space<workgroup>>
+// CHECK:         %[[E:.*]] = memref.expand_shape %[[S1]] {{\[}}[0], [1, 2], [3, 4]] output_shape [4, 2, 16, 8, 8]
+// CHECK-SAME:      memref<4x32x64xf32, strided<[2112, 66, 1]>, #gpu.address_space<workgroup>> into
+// CHECK-SAME:      memref<4x2x16x8x8xf32, strided<[2112, 1056, 66, 8, 1]>, #gpu.address_space<workgroup>>
+// CHECK:           vector.transfer_write %{{.*}}, %[[E]][%{{.*}}, %{{.*}}, %{{.*}}] {in_bounds = [true]} :
+// CHECK-SAME:      vector<4xf32>, memref<4x2x16x8x8xf32, strided<[2112, 1056, 66, 8, 1]>, #gpu.address_space<workgroup>
+func.func @pad_alloc_expand_shape(%a: memref<1024x1024xf32>) {
+  %0 = memref.alloc() : memref<4x32x64xf32, #gpu.address_space<workgroup>>
+  %1 = memref.expand_shape %0 [[0], [1, 2], [3, 4]] output_shape [4, 2, 16, 8, 8]
+    : memref<4x32x64xf32, #gpu.address_space<workgroup>> into memref<4x2x16x8x8xf32, #gpu.address_space<workgroup>>
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %3 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} :
+    memref<1024x1024xf32>, vector<4xf32>
+  vector.transfer_write %3, %1[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true]} :
+    vector<4xf32>, memref<4x2x16x8x8xf32, #gpu.address_space<workgroup>>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func.func @pad_alloc_negative
 // CHECK:         memref.alloc(%{{.*}}) : memref<?x32x64xf32, #gpu.address_space<workgroup>
 func.func @pad_alloc_negative(%a: memref<1024x1024xf32>, %i: index, %v: vector<4xf32>) {
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index 18ba760..2812fc4 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -988,8 +988,30 @@
       newSubviewOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
       llvm::dbgs() << "\n";
     });
-    return SmallVector<Value>(newSubviewOp->result_begin(),
-                              newSubviewOp->result_end());
+    return llvm::to_vector_of<Value>(newSubviewOp->getResults());
+  }
+  if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(user)) {
+    auto currResultType =
+        llvm::cast<MemRefType>(expandOp.getResult().getType());
+    auto newSourceType = llvm::cast<MemRefType>(replacement.getType());
+
+    FailureOr<MemRefType> newResultType =
+        memref::ExpandShapeOp::computeExpandedType(
+            newSourceType, currResultType.getShape(),
+            expandOp.getReassociationIndices());
+    if (failed(newResultType)) {
+      return std::nullopt;
+    }
+
+    auto newExpandOp = rewriter.create<memref::ExpandShapeOp>(
+        loc, *newResultType, replacement, expandOp.getReassociation(),
+        expandOp.getOutputShape(), expandOp.getStaticOutputShape());
+    LLVM_DEBUG({
+      llvm::dbgs() << "\t\tNew user : ";
+      newExpandOp->print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
+      llvm::dbgs() << "\n";
+    });
+    return llvm::to_vector_of<Value>(newExpandOp->getResults());
   }
   return std::nullopt;
 }