[spirv] NFC: Restructure the vector lowering pass (#15134)

This commit breaks the `SPIRVVectorLowering` pass into two parts, a
`SPIRVInitialVectorLowering` pass and a `SPIRVFinalVectorLowering` pass.
In the meanwhile, use the common hoisting pass for the hoisting.

Progress towards https://github.com/openxla/iree/issues/15083
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
index edb009b..23ac1dc 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
@@ -61,14 +61,15 @@
         "SPIRVDistribute.cpp",
         "SPIRVEmulateI64.cpp",
         "SPIRVEraseStorageBufferStaticShape.cpp",
+        "SPIRVFinalVectorLowering.cpp",
         "SPIRVGeneralizeNamedOps.cpp",
+        "SPIRVInitialVectorLowering.cpp",
         "SPIRVLowerExecutableTargetPass.cpp",
         "SPIRVMapMemRefStorageClass.cpp",
         "SPIRVTile.cpp",
         "SPIRVTileAndDistribute.cpp",
         "SPIRVTileAndPromote.cpp",
         "SPIRVTileAndVectorizeToCooperativeOps.cpp",
-        "SPIRVVectorLowering.cpp",
         "SPIRVVectorToGPUSubgroupMMAOps.cpp",
         "SPIRVVectorizeLoadStore.cpp",
         "Utils.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index e212853..e722a23 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -60,14 +60,15 @@
     "SPIRVDistribute.cpp"
     "SPIRVEmulateI64.cpp"
     "SPIRVEraseStorageBufferStaticShape.cpp"
+    "SPIRVFinalVectorLowering.cpp"
     "SPIRVGeneralizeNamedOps.cpp"
+    "SPIRVInitialVectorLowering.cpp"
     "SPIRVLowerExecutableTargetPass.cpp"
     "SPIRVMapMemRefStorageClass.cpp"
     "SPIRVTile.cpp"
     "SPIRVTileAndDistribute.cpp"
     "SPIRVTileAndPromote.cpp"
     "SPIRVTileAndVectorizeToCooperativeOps.cpp"
-    "SPIRVVectorLowering.cpp"
     "SPIRVVectorToGPUSubgroupMMAOps.cpp"
     "SPIRVVectorizeLoadStore.cpp"
     "Utils.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 129e429..f3efc4d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -127,6 +127,14 @@
   nestedModulePM.addPass(createCSEPass());
 }
 
+/// Adds passes to lower vector ops to meet SPIR-V requirements.
+static void addSPIRVVectorLoweringPasses(OpPassManager &modulePM) {
+  modulePM.addNestedPass<func::FuncOp>(createSPIRVInitialVectorLoweringPass());
+  modulePM.addNestedPass<func::FuncOp>(
+      createHoistRedundantVectorTransfersPass());
+  modulePM.addNestedPass<func::FuncOp>(createSPIRVFinalVectorLoweringPass());
+}
+
 static void
 addSPIRVBufferizePasses(OpPassManager &passManager,
                         BufferizationOptions::AllocationFn allocationFn) {
@@ -298,7 +306,7 @@
     nestedModulePM.addNestedPass<func::FuncOp>(
         createGenericVectorizationPass(options));
   }
-  nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+  addSPIRVVectorLoweringPasses(nestedModulePM);
   nestedModulePM.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
   nestedModulePM.addPass(createCanonicalizerPass());
   nestedModulePM.addPass(createCSEPass());
@@ -384,7 +392,7 @@
       createSPIRVVectorToGPUSubgroupMMAOpsPass());
   nestedModulePM.addPass(createCanonicalizerPass());
   nestedModulePM.addPass(createCSEPass());
-  nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+  addSPIRVVectorLoweringPasses(nestedModulePM);
 
   if (pipelineDepth > 0) {
     PipeliningSchedulingStrategy schedule =
@@ -446,7 +454,7 @@
   nestedPM.addNestedPass<func::FuncOp>(createGPUReduceSharedMemoryBankConflicts(
       detail::bankConflictReductionPaddingBits));
 
-  nestedPM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+  addSPIRVVectorLoweringPasses(nestedPM);
   nestedPM.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
   nestedPM.addPass(createCanonicalizerPass());
   nestedPM.addPass(createCSEPass());
@@ -574,7 +582,7 @@
       createConvertVectorReductionToGPUPass(getWarpSize));
   // Perform normal vector unrolling and lowering transformations. This breaks
   // vectors down to native machine size.
-  nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+  addSPIRVVectorLoweringPasses(nestedModulePM);
   nestedModulePM.addPass(createCanonicalizerPass());
   nestedModulePM.addPass(createCSEPass());
 }
@@ -605,7 +613,7 @@
     nestedModulePM.addNestedPass<func::FuncOp>(
         createGenericVectorizationPass(options));
   }
-  nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+  addSPIRVVectorLoweringPasses(nestedModulePM);
   nestedModulePM.addNestedPass<func::FuncOp>(createForOpCanonicalizationPass());
   nestedModulePM.addPass(createCanonicalizerPass());
   nestedModulePM.addPass(createCSEPass());
@@ -629,7 +637,7 @@
   // for SPIR-V.
   auto &nestedModulePM = pm.nest<ModuleOp>();
   nestedModulePM.addNestedPass<func::FuncOp>(createGenericVectorizationPass());
-  nestedModulePM.addNestedPass<func::FuncOp>(createSPIRVVectorLoweringPass());
+  addSPIRVVectorLoweringPasses(nestedModulePM);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index faf2d9f..e762e11 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -136,8 +136,11 @@
 /// having pointer bitcast.
 std::unique_ptr<OperationPass<ModuleOp>> createSPIRVVectorizeLoadStore();
 
-/// Pass to vectorize Linalg ops with buffer semantics.
-std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVVectorLoweringPass();
+/// Pass to lower vector ops to meet SPIR-V requirements.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createSPIRVInitialVectorLoweringPass();
+std::unique_ptr<OperationPass<func::FuncOp>>
+createSPIRVFinalVectorLoweringPass();
 
 /// Pass to do vectorization suitable for lowering to SPIR-V cooperative ops.
 std::unique_ptr<OperationPass<func::FuncOp>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
index 05ed844..e409e8b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
@@ -112,9 +112,16 @@
     "mlir::iree_compiler::createSPIRVTileToCooperativeOpsPass()";
 }
 
-def SPIRVVectorLowering : Pass<"iree-spirv-vector-lowering", "func::FuncOp"> {
-  let summary = "Vectorize Linalg ops with buffer semantics";
-  let constructor = "mlir::iree_compiler::createSPIRVVectorLoweringPass()";
+def SPIRVInitialVectorLowering : Pass<
+    "iree-spirv-initial-vector-lowering", "func::FuncOp"> {
+  let summary = "Perform initial lowering of vectors ops to fit SPIR-V";
+  let constructor = "mlir::iree_compiler::createSPIRVInitialVectorLoweringPass()";
+}
+
+def SPIRVFinalVectorLowering : Pass<
+    "iree-spirv-final-vector-lowering", "func::FuncOp"> {
+  let summary = "Perform final lowering of vectors ops to fit SPIR-V";
+  let constructor = "mlir::iree_compiler::createSPIRVFinalVectorLoweringPass()";
 }
 
 def SPIRVVectorizeLoadStore :
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVFinalVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVFinalVectorLowering.cpp
new file mode 100644
index 0000000..a6e0e81
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVFinalVectorLowering.cpp
@@ -0,0 +1,115 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//===- SPIRVFinalVectorLowering.cpp ---------------------------------------===//
+//
+// This pass hosts final steps towards lowering vectors ops to meet SPIR-V
+// requirements--it applies vector lowering patterns to convert vector ops
+// to more basic forms.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
+#include "iree/compiler/Codegen/SPIRV/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/VectorInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-spirv-final-vector-lowering"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+void debugPrint(func::FuncOp funcOp, const char *message) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "//--- " << message << " ---//\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+}
+
+class SPIRVFinalVectorLoweringPass
+    : public SPIRVFinalVectorLoweringBase<SPIRVFinalVectorLoweringPass> {
+public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    // vector.gather lowering patterns target scf ops.
+    registry.insert<scf::SCFDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    func::FuncOp funcOp = getOperation();
+
+    // Lower vector transfer permutation map.
+    {
+      RewritePatternSet patterns(context);
+      vector::ExtractStridedSliceOp::getCanonicalizationPatterns(patterns,
+                                                                 context);
+      vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+        return signalPassFailure();
+      }
+    }
+
+    debugPrint(funcOp, "after lowering transfer ops");
+
+    // Lower vector broadcast/transpose and contraction.
+    {
+      RewritePatternSet patterns(context);
+      auto options = vector::VectorTransformsOptions()
+                         .setVectorTransformsOptions(
+                             vector::VectorContractLowering::OuterProduct)
+                         .setVectorTransposeLowering(
+                             vector::VectorTransposeLowering::EltWise);
+      vector::populateVectorBroadcastLoweringPatterns(patterns);
+      vector::populateVectorContractLoweringPatterns(patterns, options);
+      vector::populateVectorMultiReductionLoweringPatterns(
+          patterns, vector::VectorMultiReductionLowering::InnerParallel);
+      vector::populateVectorTransposeLoweringPatterns(patterns, options);
+      vector::populateVectorGatherLoweringPatterns(patterns);
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+        return signalPassFailure();
+      }
+    }
+
+    debugPrint(funcOp, "after lowering various vector ops");
+
+    // Run all sorts of canonicalization patterns to clean up again.
+    {
+      RewritePatternSet patterns(context);
+      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+      vector::InsertOp::getCanonicalizationPatterns(patterns, context);
+      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
+      vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
+      vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
+      populateVectorTransferTensorSliceTransforms(patterns);
+      vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
+      scf::IfOp::getCanonicalizationPatterns(patterns, context);
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+        return signalPassFailure();
+      }
+    }
+  }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createSPIRVFinalVectorLoweringPass() {
+  return std::make_unique<SPIRVFinalVectorLoweringPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
similarity index 75%
rename from compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorLowering.cpp
rename to compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
index 3f6dff6..d25ec01 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
@@ -4,29 +4,23 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-//===- SPIRVVectorLoweringPass.cpp
-//-------------------------------------------------===//
+//===- SPIRVInitialVectorLowering.cpp -------------------------------------===//
 //
-// This pass vectorizes Linalg ops with buffer semantics.
+// This pass hosts initial steps towards lowering vectors ops to meet SPIR-V
+// requirements--it applies vector lowering patterns to unroll large n-D vectors
+// to 1-D ones that are directly in SPIR-V.
 //
 //===----------------------------------------------------------------------===//
 
-#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Common/Passes.h"
 #include "iree/compiler/Codegen/SPIRV/PassDetail.h"
 #include "iree/compiler/Codegen/SPIRV/Passes.h"
 #include "iree/compiler/Codegen/SPIRV/Utils.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
-#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -35,19 +29,24 @@
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-#define DEBUG_TYPE "iree-spirv-vector-lowering"
+#define DEBUG_TYPE "iree-spirv-initial-vector-lowering"
 
 namespace mlir {
 namespace iree_compiler {
 namespace {
 
+void debugPrint(func::FuncOp funcOp, const char *message) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "//--- " << message << " ---//\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+}
+
 int getComputeVectorSize(int64_t size) {
   for (int i : {4, 3, 2}) {
     if (size % i == 0)
@@ -286,13 +285,9 @@
   return true;
 }
 
-/// Vectorizes Linalg ops on buffer semantics.
-class SPIRVVectorLoweringPass
-    : public SPIRVVectorLoweringBase<SPIRVVectorLoweringPass> {
+class SPIRVInitialLoweringPass
+    : public SPIRVInitialVectorLoweringBase<SPIRVInitialLoweringPass> {
 public:
-  SPIRVVectorLoweringPass() = default;
-  SPIRVVectorLoweringPass(const SPIRVVectorLoweringPass &pass) = default;
-
   void getDependentDialects(DialectRegistry &registry) const override {
     // vector.gather lowering patterns target scf ops.
     registry.insert<linalg::LinalgDialect, vector::VectorDialect,
@@ -316,11 +311,7 @@
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After IREE tensor.pad vectorization ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after vectorizing tensor.pad");
 
     // Special peephole optimizations to clean up IR before further processing.
     {
@@ -342,11 +333,7 @@
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After peephole optimization ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after peephole optimization");
 
     // High dimension contraction can appear after vectorizing ops like 1-D
     // convolution. Those 1-D convolution ops typically have a leading unit
@@ -362,11 +349,7 @@
       (void)vector::castAwayContractionLeadingOneDim(op, rewriter);
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After trimming contract leading unit dims ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after trimming contract leading unit dims");
 
     // Fold tensor.extract_slice/insert_slice ops into transfer ops. This helps
     // to remove those tensor slice ops so that we can enable further vector op
@@ -382,11 +365,7 @@
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After folding tensor extract/insert slice ops ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after folding tensor extract/insert slice ops");
 
     // Lower vector.multi_dimension early if any operand is a transpose op.
     // The lowering itself generates transpose ops. This helps to cancel
@@ -411,11 +390,7 @@
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After lowering multi_reduction ops ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after lowering multi reduction ops");
 
     // Prepare for SPIR-V integer dot product lowering.
     if (emitIntegerDotProdOps) {
@@ -426,12 +401,7 @@
         return signalPassFailure();
       }
 
-      LLVM_DEBUG({
-        llvm::dbgs()
-            << "--- After prepare for SPIR-V dot product lowering ---\n";
-        funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-        llvm::dbgs() << "\n\n";
-      });
+      debugPrint(funcOp, "after preparing for SPIR-V dot product lowering");
     }
 
     // Then unroll vectors to native vector size. We try to use 128-bit
@@ -444,11 +414,7 @@
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After unrolling vector ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after unrolling vector ops");
 
     // Lower reduction-unrolled vector contract ops. Such contract ops have
     // their reduction dimensions all be one, so we can convert them into
@@ -470,11 +436,7 @@
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After lowering size-1 reduction contract ops ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after lowering size-1 reduction contract ops");
 
     // Now lower vector transpose given we have handled vector patterns that may
     // generate transpose ops in previous steps. This converts transpose ops
@@ -491,11 +453,7 @@
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After lowering transpose ops ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after lowering transpose ops");
 
     // Next run canonicalization to cast away leading size-1 dimensions. They
     // can be generated from vector unrolling and generally cause issues to
@@ -531,11 +489,7 @@
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After trimming leading unit dims ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after trimming leading unit dims");
 
     // Lower vector reduction to SPIR-V integer dot product.
     if (emitIntegerDotProdOps) {
@@ -545,89 +499,16 @@
         return signalPassFailure();
       }
 
-      LLVM_DEBUG({
-        llvm::dbgs() << "--- After lowering to SPIR-V dot product ---\n";
-        funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-        llvm::dbgs() << "\n\n";
-      });
-    }
-
-    // Next perform hoisting. This would analyze transfer read/write ops into
-    // tensors and hoist them out of loop nests. So after it we have
-    // loop-carried vectors, not loop-carried tensors anymore.
-    linalg::hoistRedundantVectorTransfersOnTensor(funcOp);
-    linalg::hoistRedundantVectorTransfers(funcOp);
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After hoisting transfers ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
-
-    // Lower vector transfer permutation map.
-    {
-      RewritePatternSet patterns(context);
-      vector::ExtractStridedSliceOp::getCanonicalizationPatterns(patterns,
-                                                                 context);
-      vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
-      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
-        return signalPassFailure();
-      }
-    }
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After lowering transfer ops ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
-
-    // Lower vector broadcast/transpose and contraction.
-    {
-      RewritePatternSet patterns(context);
-      auto options = vector::VectorTransformsOptions()
-                         .setVectorTransformsOptions(
-                             vector::VectorContractLowering::OuterProduct)
-                         .setVectorTransposeLowering(
-                             vector::VectorTransposeLowering::EltWise);
-      vector::populateVectorBroadcastLoweringPatterns(patterns);
-      vector::populateVectorContractLoweringPatterns(patterns, options);
-      vector::populateVectorMultiReductionLoweringPatterns(
-          patterns, vector::VectorMultiReductionLowering::InnerParallel);
-      vector::populateVectorTransposeLoweringPatterns(patterns, options);
-      vector::populateVectorGatherLoweringPatterns(patterns);
-      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
-        return signalPassFailure();
-      }
-    }
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After lowering various vector ops ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
-
-    // Run all sorts of canonicalization patterns to clean up again.
-    {
-      RewritePatternSet patterns(context);
-      vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
-      vector::InsertOp::getCanonicalizationPatterns(patterns, context);
-      vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
-      vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
-      vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
-      populateVectorTransferTensorSliceTransforms(patterns);
-      vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
-      scf::IfOp::getCanonicalizationPatterns(patterns, context);
-      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
-        return signalPassFailure();
-      }
+      debugPrint(funcOp, "after lowering to SPIR-V dot product");
     }
   }
 };
 
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>> createSPIRVVectorLoweringPass() {
-  return std::make_unique<SPIRVVectorLoweringPass>();
+std::unique_ptr<OperationPass<func::FuncOp>>
+createSPIRVInitialVectorLoweringPass() {
+  return std::make_unique<SPIRVInitialLoweringPass>();
 }
 
 } // namespace iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
index 1fca89c..76efc45 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' \
+// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
 // RUN:   %s | FileCheck %s
 
 #config = #iree_codegen.lowering_config<tile_sizes = [[1, 8, 64], [1, 8, 4], [0, 0, 0, 4]]>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
index f460f90..dcf863b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-create-fast-slow-path,iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' \
+// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-create-fast-slow-path,iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
 // RUN:   %s | FileCheck %s
 
 #config = #iree_codegen.lowering_config<tile_sizes = [[0, 4, 4, 16], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
index 81e14fb..e82a5c3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' %s | FileCheck %s
+// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering,canonicalize,cse)))))' %s | FileCheck %s
 
 #config = #iree_codegen.lowering_config<tile_sizes = [[8, 64], [8, 4], [0, 0, 4]]>
 #translation = #iree_codegen.translation_info<SPIRVBaseVectorize>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir
index dcd40fd..6587c3a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-vector-lowering,canonicalize,cse)))))' \
+// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
 // RUN:   %s | FileCheck %s
 
 #config = #iree_codegen.lowering_config<tile_sizes = [[0, 2, 2, 8], [0, 1, 1, 4], [0, 0, 0, 0, 1, 1], [0, 1, 0, 0]]>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
index 2cf3d6b..60e1157 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_conv.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
+// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
 // RUN:   %s | FileCheck %s
 
 func.func @ncw_conv_1d(%input: tensor<2x4x4xf32>, %filter: tensor<4x4x1xf32>, %init: tensor<2x4x4xf32>) -> tensor<2x4x4xf32> {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
index e4e3dd2..81ad85d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
+// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
 // RUN:   %s | FileCheck %s
 
 func.func @add(%lhs: tensor<2x8xf32>, %rhs: tensor<2x8xf32>) -> tensor<2x8xf32> {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_gather.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_gather.mlir
index a3e3b92..a2336b3 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_gather.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_gather.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization{vectorize-gather-accesses},iree-spirv-vector-lowering))' \
+// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization{vectorize-gather-accesses},iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
 // RUN:   %s | FileCheck %s
 
 func.func @tensor_extract(%arg0: tensor<6x4xf32>, %arg1: tensor<6xi32>, %data: tensor<1x2x512xf32>, %init: tensor<6x4xf32>, %i : index, %j: index) -> tensor<6x4xf32> {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index ab2f8b2..0cd0efd 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
+// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
 // RUN:   %s | FileCheck %s
 
 func.func @matmul_1x4x4(%lhs: tensor<1x4xf32>, %rhs: tensor<4x4xf32>, %init: tensor<1x4xf32>) -> tensor<1x4xf32> {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_reduction.mlir
index 917560f..5e6a71c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_reduction.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-vector-lowering))' \
+// RUN:   --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-hoist-redundant-vector-transfers,iree-spirv-final-vector-lowering))' \
 // RUN:   %s | FileCheck %s
 
 func.func @reduce_outmost_dim(%input: tensor<4x1x4xf32>, %init: tensor<1x4xf32>) -> tensor<1x4xf32> {