Cherry-pick and use subview folding pattern (#11582)

This removes a self-copy in reduction_v3 and cleans up the IR in
general.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
index 2d6abb8..ec3eee2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.cpp
@@ -254,6 +254,7 @@
                                                   Value variantH, Value funcH,
                                                   int64_t warpSize) {
   ApplyPatternsOpPatterns patterns;
+  patterns.foldMemrefAliases = true;
   patterns.rankReducing = true;
   funcH = b.create<ApplyPatternsOp>(funcH, patterns);
   Value ifH = b.create<MatchOp>(funcH, scf::IfOp::getOperationName());
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index c4ce5e2..dc2da5f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -70,6 +70,7 @@
   ADD_PATTERN(canonicalization, getCanonicalizationAttrName)
   ADD_PATTERN(eraseUnnecessaryTensorOperands,
               getEraseUnnecessaryTensorOperandsAttrName)
+  ADD_PATTERN(foldMemrefAliases, getFoldMemrefAliasesAttrName)
   ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName)
   ADD_PATTERN(promoteForeachThreadCaptureToShared,
               getPromoteForeachThreadCaptureToSharedAttrName)
@@ -162,6 +163,10 @@
 };
 }  // namespace
 
+static void addFoldMemrefAliasPatterns(RewritePatternSet &patterns) {
+  memref::populateFoldMemRefAliasOpPatterns(patterns);
+}
+
 static void addForeachThreadCapturePromotionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<PromoteCaptureToSharedOut>(patterns.getContext());
@@ -224,6 +229,7 @@
   if (getCanonicalization()) addAllRegisteredCanonicalizationPatterns(patterns);
   if (getEraseUnnecessaryTensorOperands())
     addEraseUnnecessaryTensorOperandsPatterns(patterns);
+  if (getFoldMemrefAliases()) addFoldMemrefAliasPatterns(patterns);
   if (getFoldReassociativeReshapes()) addReassociativeReshapePatterns(patterns);
   if (getPromoteForeachThreadCaptureToShared())
     addForeachThreadCapturePromotionPatterns(patterns);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
index b6747bf..9e671f2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
@@ -39,6 +39,7 @@
   bool bubbleCollapseExpand = false;
   bool canonicalization = false;
   bool eraseUnnecessaryTensorOperands = false;
+  bool foldMemrefAliases = false;
   bool foldReassociativeReshapes = false;
   bool promoteForeachThreadCaptureToShared = false;
   bool rankReducing = false;
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index b5edbc9..0bc6db3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -40,6 +40,8 @@
       registered dialects and ops.
       - erase_unnecessary_tensor_operands: add patterns that erase unnecessary
       tensor operands.
+      - fold_memref_aliases: adds patterns for folding ops such as
+      memref.subview.
       - fold_reassociative_reshapes: adds patterns that fold insert_slice/
       extract_slice ops with reassociative reshape ops.
       - promote_foreach_thread_capture_to_shared: adds patterns that rewrite
@@ -80,6 +82,7 @@
                        UnitAttr:$bubble_collapse_expand,
                        UnitAttr:$canonicalization,
                        UnitAttr:$erase_unnecessary_tensor_operands,
+                       UnitAttr:$fold_memref_aliases,
                        UnitAttr:$fold_reassociative_reshapes,
                        UnitAttr:$promote_foreach_thread_capture_to_shared,
                        UnitAttr:$rank_reducing,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir
index 8b55d88..7605753 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir
@@ -48,7 +48,7 @@
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
-//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: vector.transfer_read %{{.*}}[%{{.*}}]
 //         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[TIDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
@@ -115,7 +115,7 @@
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
-//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: vector.transfer_read %{{.*}}[%{{.*}}]
 //         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[TIDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
@@ -183,7 +183,7 @@
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
-//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: vector.transfer_read %{{.*}}[%{{.*}}]
 //         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[TIDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
@@ -254,7 +254,7 @@
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
-//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: vector.transfer_read %{{.*}}[%{{.*}}]
 //         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[TIDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
@@ -318,7 +318,7 @@
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
-//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: vector.transfer_read %{{.*}}[%{{.*}}]
 //         CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
 //         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
index baf04a3..172b805 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
@@ -48,7 +48,7 @@
 //         CHECK:   transform.structured.match ops{["func.func"]} in %{{.*}}
 //         CHECK:   transform.iree.foreach_thread_to_workgroup
 //         CHECK:   transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [32, 1, 1]}
-//         CHECK:   transform.iree.apply_patterns %{{.*}} {rank_reducing}
+//         CHECK:   transform.iree.apply_patterns %{{.*}} {fold_memref_aliases, rank_reducing}
 //         CHECK:   transform.structured.match ops{["scf.if"]} in %{{.*}}
 //         CHECK:   sequence {{.*}} failures(suppress) {
 //         CHECK:     transform.iree.vector.to_warp_execute_on_lane_0 %{{.*}} {warp_size = 32 : i64}
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
index e4c4358..947a721 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
@@ -85,7 +85,7 @@
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing }
+  %func_7 = transform.iree.apply_patterns %func_6 { rank_reducing, fold_memref_aliases }
   %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
diff --git a/tests/transform_dialect/cuda/reduction.mlir b/tests/transform_dialect/cuda/reduction.mlir
index 29fc449..c45f5c8 100644
--- a/tests/transform_dialect/cuda/reduction.mlir
+++ b/tests/transform_dialect/cuda/reduction.mlir
@@ -50,7 +50,6 @@
   //     CHECK-DAG: %[[TIDY:.]] = gpu.thread_id  y
   //     CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id  z
 
-  //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
   //         CHECK: %[[ADDED:.*]] = arith.addi %[[TIDZ]], %[[workgroup_id_x]]
 
   // Distributed reduction: everyone loads then 5 xor + addf expected
@@ -62,7 +61,7 @@
   //         CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
   //         CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
   //         CHECK: scf.if %[[CONDXIS0]]
-  //         CHECK:   vector.transfer_write %[[RES_VEC]], %[[SHMEM_VIEW_EXPANDED]][]
+  //         CHECK:   vector.transfer_write %[[RES_VEC]], %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]
   //         CHECK: gpu.barrier
 
   // Last part is not distributed atm and is only ran by threadIdx.x == 0 and threadIdx.y == 0.
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
index 6649bc9..927753c 100644
--- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -58,7 +58,7 @@
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_8 = transform.iree.apply_patterns %func_7 { rank_reducing }
+  %func_8 = transform.iree.apply_patterns %func_7 { rank_reducing, fold_memref_aliases }
   %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
diff --git a/tests/transform_dialect/cuda/reduction_eltwise.mlir b/tests/transform_dialect/cuda/reduction_eltwise.mlir
index 571f8fd..0a32c50 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise.mlir
@@ -58,7 +58,6 @@
   //     CHECK-DAG: %[[TIDY:.]] = gpu.thread_id  y
   //     CHECK-DAG: %[[TIDZ:.]] = gpu.thread_id  z
 
-  //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]{{.*}}to memref<f32, {{.*}}, 3>
   //         CHECK: %[[ADDED:.*]] = arith.addi %[[TIDZ]], %[[workgroup_id_x]]
 
   // Distributed reduction: everyone loads then 5 xor + addf expected
@@ -70,7 +69,7 @@
   //         CHECK: %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
   //         CHECK: %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
   //         CHECK: scf.if %[[CONDXIS0]]
-  //         CHECK:   vector.transfer_write %[[RES_VEC]], %[[SHMEM_VIEW_EXPANDED]][]
+  //         CHECK:   vector.transfer_write %[[RES_VEC]], %[[SHMEM_ALLOC]][%[[TIDZ]], %[[TIDY]]]
   //         CHECK: gpu.barrier
 
   // Last part is not distributed atm and is only ran by threadIdx.x == 0 and threadIdx.y == 0.
diff --git a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
index 5afa0a8..61dbaaa 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
@@ -62,7 +62,7 @@
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_8 = transform.iree.apply_patterns %func_7 { rank_reducing }
+  %func_8 = transform.iree.apply_patterns %func_7 { rank_reducing, fold_memref_aliases }
   %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   // Don't complain about unsupported if (threadIdx.x == 0 && threadIdx.y == 0)
   // at this point.
diff --git a/tests/transform_dialect/cuda/reduction_v2.mlir b/tests/transform_dialect/cuda/reduction_v2.mlir
index 6372a66..931ddf5 100644
--- a/tests/transform_dialect/cuda/reduction_v2.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2.mlir
@@ -46,14 +46,13 @@
   
   //         CHECK: %[[TIDX:.]] = gpu.thread_id  x
   //         CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
-  //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<4xf32, strided<[1], offset: ?>, 3>
   //         CHECK: gpu.barrier
   // Local per-thread scf.for-based reduction.
   //         CHECK: scf.for
   //         CHECK:   vector.transfer_read 
-  //         CHECK:   vector.transfer_read 
+  //         CHECK:   vector.transfer_read %[[SHMEM_ALLOC]][%[[C0]], %[[IDX]]]
   //         CHECK:   arith.addf %{{.*}}, %{{.*}} : vector<4xf32>
-  //         CHECK:   vector.transfer_write
+  //         CHECK:   vector.transfer_write %{{.*}}, %[[SHMEM_ALLOC]][%[[C0]], %[[IDX]]]
   // TODO: remote unnecessary barrier within the loop
   //         CHECK:   gpu.barrier
 
diff --git a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
index 90c27c8..c1538af 100644
--- a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
@@ -57,7 +57,7 @@
 
   // Step 7. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_10 = transform.iree.apply_patterns %func_9 { rank_reducing }
+  %func_10 = transform.iree.apply_patterns %func_9 { rank_reducing, fold_memref_aliases }
   %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   transform.iree.vector.warp_distribute %func_10
diff --git a/tests/transform_dialect/cuda/reduction_v2_uneven.mlir b/tests/transform_dialect/cuda/reduction_v2_uneven.mlir
index 473ec18..83ed66e 100644
--- a/tests/transform_dialect/cuda/reduction_v2_uneven.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2_uneven.mlir
@@ -40,7 +40,6 @@
   
   //         CHECK: %[[TIDX:.]] = gpu.thread_id  x
   //         CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
-  //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<4xf32, strided<[1], offset: ?>, 3>
   //         CHECK: gpu.barrier
   // Local per-thread scf.for-based reduction.
   //         CHECK: scf.for
diff --git a/tests/transform_dialect/cuda/reduction_v3.mlir b/tests/transform_dialect/cuda/reduction_v3.mlir
index b17a8a8..f51818b 100644
--- a/tests/transform_dialect/cuda/reduction_v3.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3.mlir
@@ -44,13 +44,12 @@
   //     CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x1024xf32, 3>
   
   //         CHECK: %[[TIDX:.]] = gpu.thread_id  x
-  //         CHECK: %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[TIDX]]]{{.*}}to memref<1x1xf32, strided<[1024, 1], offset: ?>, 3>
   // Local per-thread scf.for-based reduction.
   //         CHECK: scf.for
-  //         CHECK:   vector.transfer_read %{{.*}} : memref<f32, strided<[], offset: ?>>, vector<f32>
-  //         CHECK:   vector.transfer_read %{{.*}} vector<f32>
+  //         CHECK:   vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<f32>
+  //         CHECK:   vector.transfer_read %[[SHMEM_ALLOC]][%[[C0]], %[[TIDX]]], %{{.*}} : memref<1x1024xf32, 3>, vector<f32>
   //         CHECK:   arith.addf {{.*}} : f32
-  //         CHECK:   vector.transfer_write {{.*}} vector<f32>
+  //         CHECK:   vector.transfer_write %{{.*}}, %[[SHMEM_ALLOC]][%[[C0]], %[[TIDX]]] : vector<f32>, memref<1x1024xf32, 3>
 
   //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
   // Distributed reduction: everyone loads then 5 xor + addf expected
diff --git a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
index 1b188b9..8322f3c 100644
--- a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
@@ -56,7 +56,7 @@
 
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===========================================================================
-  %func_10 = transform.iree.apply_patterns %func_9 { rank_reducing }
+  %func_10 = transform.iree.apply_patterns %func_9 { rank_reducing, fold_memref_aliases }
   %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   transform.iree.vector.warp_distribute %func_10
diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
index 1d849c6..cf5669d 100644
--- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir
@@ -90,7 +90,7 @@
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===================================================================
   %end_func = transform.structured.match ops{["func.func"]} in %variant_op_3
-  %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
+  %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing, fold_memref_aliases }
   %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   transform.iree.vector.warp_distribute %end_func_2
diff --git a/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
index ef712a2..25227b9 100644
--- a/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_partial_codegen_spec.mlir
@@ -74,7 +74,7 @@
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===================================================================
   %end_func = transform.structured.match ops{["func.func"]} in %variant_op_3
-  %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
+  %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing, fold_memref_aliases }
   %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   transform.iree.vector.warp_distribute %end_func_2
diff --git a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
index 52869f9..a810cb4 100644
--- a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
@@ -81,7 +81,7 @@
   // Step 6. Post-bufferization vector distribution with rank-reduction.
   // ===================================================================
   %end_func = transform.structured.match ops{["func.func"]} in %variant_op_3
-  %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing }
+  %end_func_2 = transform.iree.apply_patterns %end_func { rank_reducing, fold_memref_aliases }
   %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_3
   %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 }
   transform.iree.vector.warp_distribute %end_func_2
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 6425acd..397bba5 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 6425acd8fb8915ade15583eeaaa22a286b8d10eb
+Subproject commit 397bba5bd6ec168e50f7364ee9c4aa6a78ce17c3