[LLVMCPU] Drop unit dims on memory transfers (#13340)

diff --git a/compiler/src/iree/compiler/Codegen/Common/CommonPasses.h b/compiler/src/iree/compiler/Codegen/Common/CommonPasses.h
index ac8dc27..f2189b2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CommonPasses.h
+++ b/compiler/src/iree/compiler/Codegen/Common/CommonPasses.h
@@ -153,7 +153,7 @@
 
 /// Pass to optimize vector transfer_read and transfer_write.
 std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeVectorTransferPass(
-    bool flatten = false);
+    bool flatten = false, bool dropUnitDims = true);
 
 /// Pad dynamic alloc op to convert them into static one.
 std::unique_ptr<OperationPass<func::FuncOp>> createPadDynamicAlloc();
diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
index 3b02ad4..165620b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
@@ -88,7 +88,8 @@
 
 struct OptimizeVectorTransferPass
     : public OptimizeVectorTransferBase<OptimizeVectorTransferPass> {
-  OptimizeVectorTransferPass(bool flatten) : flatten(flatten) {}
+  OptimizeVectorTransferPass(bool flatten, bool dropUnitDims)
+      : flatten(flatten), dropUnitDims(dropUnitDims) {}
   void runOnOperation() override {
     func::FuncOp funcOp = getOperation();
     // Generate vector.shape_cast for dropping leading one dimensions in vector
@@ -125,10 +126,20 @@
       }
     }
 
+    // TODO(#14191): SPIR-V can't handle the vector.shape_cast created for
+    // dropping unit dims so this option is disabled in SPIR-V pipeline.
+    // This option should go away after all backend issues have been resolved.
+    if (dropUnitDims) {
+      RewritePatternSet patterns(&getContext());
+      mlir::vector::populateVectorTransferDropUnitDimsPatterns(patterns);
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+        return signalPassFailure();
+      }
+    }
+
     // Second stage of patterns to flatten transfer ops.
     if (flatten) {
       RewritePatternSet patterns(&getContext());
-      mlir::vector::populateVectorTransferDropUnitDimsPatterns(patterns);
       mlir::vector::populateFlattenVectorTransferPatterns(patterns);
       if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
         return signalPassFailure();
@@ -151,13 +162,14 @@
 
  private:
   bool flatten;
+  bool dropUnitDims;
 };
 
 }  // namespace
 
 std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeVectorTransferPass(
-    bool flatten) {
-  return std::make_unique<OptimizeVectorTransferPass>(flatten);
+    bool flatten, bool dropUnitDims) {
+  return std::make_unique<OptimizeVectorTransferPass>(flatten, dropUnitDims);
 }
 
 }  // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 481fc33..22e5601 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -552,12 +552,16 @@
     nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
   }
 
-  nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
-  nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
+  // Eliminate redundant transfer_read/write to avoid stack allocations.
   nestedModulePM.addNestedPass<func::FuncOp>(
       createOptimizeVectorTransferPass(/*flatten=*/true));
+
   addBufferizePasses(nestedModulePM);
 
+  // Perform memref-based transfer_read/write optimizations.
+  nestedModulePM.addNestedPass<func::FuncOp>(
+      createOptimizeVectorTransferPass(/*flatten=*/false));
+
   // Run IREE specific passes before vector lowering expert.
   nestedModulePM.addNestedPass<func::FuncOp>(
       createRemoveSingleIterationLoopPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir
index 182021c..1c89a6a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pad_pipeline_tests.mlir
@@ -217,4 +217,5 @@
 //         CHECK:               scf.yield
 //         CHECK:             scf.yield
 //         CHECK:           scf.yield
-// CHECK-COUNT-7:         vector.store %{{.+}}, %[[OUTPUT_SUBVIEW_0]]
+//         CHECK:         %[[OUTPUT_SUBVIEW_1:.+]] = memref.subview %[[OUTPUT_SUBVIEW_0]]
+// CHECK-COUNT-7:         vector.store %{{.+}}, %[[OUTPUT_SUBVIEW_1]]
diff --git a/compiler/src/iree/compiler/Codegen/Passes.td b/compiler/src/iree/compiler/Codegen/Passes.td
index a4bfa03..08f6c79 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Passes.td
@@ -236,7 +236,9 @@
   let constructor = "mlir::iree_compiler::createOptimizeVectorTransferPass()";
   let options = [
     Option<"optionFlatten", "flatten", "bool", "false",
-           "Flatten the vector type of vector transfers where possible (contiguous row-major data).">
+           "Flatten the vector type of vector transfers where possible (contiguous row-major data).">,
+    Option<"optionDropUnitDims", "drop-unit-dims", "bool", /*default=*/"true",
+           "Drop unit dims in vector transfers where possible (might generate vector.shape_cast).">,
   ];
   let dependentDialects = [
       "memref::MemRefDialect"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index b220e3e..16e2321 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -208,14 +208,16 @@
   pm.addPass(createSPIRVVectorizeLoadStore());
   // Perform various vector-level cross-op optimizations like load-store
   // forwarding, shape casting and casting op cancelling.
-  pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+  pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+      /*flatten=*/false, /*dropUnitDims=*/false));
   pm.addNestedPass<func::FuncOp>(createSPIRVBreakDownLargeVectorPass());
 
   // Perform optimizations that need to across the scf.for region boundary.
   pm.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
-  pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+  pm.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+      /*flatten=*/false, /*dropUnitDims=*/false));
 
   // Turn multi-dimension memref into one-dimension. This is needed for SPIR-V
   // because we don't use upstream memref descriptors.
@@ -311,8 +313,8 @@
 
   // Perform various vector-level cross-op optimizations like load-store
   // forwarding, shape casting and casting op cancelling.
-  nestedModulePM.addNestedPass<func::FuncOp>(
-      createOptimizeVectorTransferPass());
+  nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+      /*flatten=*/false, /*dropUnitDims=*/false));
 }
 
 void addSPIRVCooperativeMatrixVectorizePassPipeline(OpPassManager &pm,
@@ -370,8 +372,8 @@
 
   // Perform various vector-level cross-op optimizations like load-store
   // forwarding, shape casting and casting op cancelling.
-  nestedModulePM.addNestedPass<func::FuncOp>(
-      createOptimizeVectorTransferPass());
+  nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+      /*flatten=*/false, /*dropUnitDims=*/false));
 
   // Fold subview ops is reqiured for converting vector transfer ops into SPIR-V
   // cooperative ops in the next step.
@@ -445,10 +447,12 @@
   // to hoisting. Because this is before folding all memref subview ops away, we
   // still have subview ops using the same indices, which allows for transfer
   // read/write forwarding.
-  nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+  nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+      /*flatten=*/false, /*dropUnitDims=*/false));
 
   nestedPM.addNestedPass<func::FuncOp>(memref::createFoldMemRefAliasOpsPass());
-  nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass());
+  nestedPM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+      /*flatten=*/false, /*dropUnitDims=*/false));
 
   // Hoist loop invariant code to avoid pipelining it.
   nestedPM.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
@@ -506,8 +510,8 @@
 
   // Perform various vector-level cross-op optimizations like load-store
   // forwarding, shape casting and casting op cancelling.
-  nestedModulePM.addNestedPass<func::FuncOp>(
-      createOptimizeVectorTransferPass());
+  nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+      /*flatten=*/false, /*dropUnitDims=*/false));
 
   // Simplify the IR for vector distribution.
   nestedModulePM.addNestedPass<func::FuncOp>(
@@ -567,8 +571,8 @@
 
   // Perform various vector-level cross-op optimizations like load-store
   // forwarding, shape casting and casting op cancelling.
-  nestedModulePM.addNestedPass<func::FuncOp>(
-      createOptimizeVectorTransferPass());
+  nestedModulePM.addNestedPass<func::FuncOp>(createOptimizeVectorTransferPass(
+      /*flatten=*/false, /*dropUnitDims=*/false));
 }
 
 void addSPIRVTransformDialectPassPipeline(OpPassManager &pm) {