[spirv] NFC: Restructure TileAndVectorizeToCooperativeOps pass (#15138)

This commit lifted the general vectorization and hoisting steps
out of the `TileAndVectorizeToCooperativeOps` pass, and changed
to use the `GenericVectorization` and `HoistRedundantVectorTransfers`
common passes for such functionalities.

Progress towards https://github.com/openxla/iree/issues/15083
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index f3efc4d..35c8aa0 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -353,9 +353,10 @@
   nestedModulePM.addPass(createCSEPass());
 
   // Multi-buffer depending on pipeline depth and distribute to shared memory.
-  if (pipelineDepth > 0)
+  if (pipelineDepth > 0) {
     nestedModulePM.addNestedPass<func::FuncOp>(
         createGPUMultiBuffering(pipelineDepth + 1));
+  }
   nestedModulePM.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
   nestedModulePM.addNestedPass<func::FuncOp>(
       createGPUDistributeSharedMemoryCopy());
@@ -365,10 +366,20 @@
       createGPUReduceSharedMemoryBankConflicts(
           detail::bankConflictReductionPaddingBits));
 
+  // Performs high-level n-D mechanical vectorization. This does not perform
+  // unrolling or lowering, which is done later.
+  {
+    GenericVectorizationPassOptions options;
+    nestedModulePM.addNestedPass<func::FuncOp>(
+        createGenericVectorizationPass(options));
+  }
+
   // Vectorize to cooperative ops.
   nestedModulePM.addNestedPass<func::FuncOp>(
       createSPIRVVectorizeToCooperativeOpsPass());
   nestedModulePM.addNestedPass<func::FuncOp>(
+      createHoistRedundantVectorTransfersPass());
+  nestedModulePM.addNestedPass<func::FuncOp>(
       createRemoveSingleIterationLoopPass());
 
   // Run canonicalization patterns to propagate constant shape sizes after
@@ -420,7 +431,8 @@
   nestedPM.addNestedPass<func::FuncOp>(
       createGPUTensorTile(/*distributeToWarp=*/false));
 
-  // High-level n-D vectorization.
+  // Performs high-level n-D mechanical vectorization. This does not perform
+  // unrolling or lowering, which is done later.
   {
     GenericVectorizationPassOptions options;
     options.vectorizePadding = true;
@@ -531,8 +543,8 @@
   nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
   nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
 
-  // Performs mechanical vectorization. This does not perform unrolling or
-  // lowering, which is done later.
+  // Performs high-level n-D mechanical vectorization. This does not perform
+  // unrolling or lowering, which is done later.
   {
     GenericVectorizationPassOptions options;
     options.vectorizePadding = true;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
index 3f3e52d..58fc4ef 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
@@ -54,8 +54,16 @@
 namespace iree_compiler {
 namespace {
 
+void debugPrint(func::FuncOp funcOp, const char *message) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "//--- " << message << " ---//\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+}
+
 //===----------------------------------------------------------------------===//
-// Subgroup tiling patterns
+// Cooperative matrix shape utilities
 //===----------------------------------------------------------------------===//
 
 /// Gets the chosen hardware cooperative op size attached to the given `op`
@@ -64,6 +72,35 @@
   return getTileSizes(op, 3); // For native vector sizes
 }
 
+constexpr char coopMatShapeAttrName[] = "iree.spirv.coop_mat_shape";
+
+/// Sets the chosen cooperative matrix shape for CodeGen onto the
+/// hal.executable.export op for the given `funcOp`.
+void setSPIRVCooperativeMatrixShape(func::FuncOp funcOp,
+                                    ArrayRef<int64_t> shape) {
+  auto moduleOp = funcOp->getParentOfType<ModuleOp>();
+  auto exportOp = getAllEntryPoints(moduleOp).lookup(funcOp.getName());
+
+  Builder b(funcOp.getContext());
+  exportOp->setAttr(coopMatShapeAttrName, b.getDenseI64ArrayAttr(shape));
+}
+
+/// Returns the chosen cooperative matrix shape for CodeGen from the
+/// hal.executable.export op for the given `funcOp`. Returns an empty
+/// ArrayRef if cannot query.
+ArrayRef<int64_t> getSPIRVCooperativeMatrixShape(func::FuncOp funcOp) {
+  auto moduleOp = funcOp->getParentOfType<ModuleOp>();
+  auto exportOp = getAllEntryPoints(moduleOp).lookup(funcOp.getName());
+  auto attr = exportOp->getAttrOfType<DenseI64ArrayAttr>(coopMatShapeAttrName);
+  if (!attr)
+    return {};
+  return attr.asArrayRef();
+}
+
+//===----------------------------------------------------------------------===//
+// Subgroup tiling patterns
+//===----------------------------------------------------------------------===//
+
 /// Deduces required subgroup counts along all workgroup tiled dimensions.
 ///
 /// `op` should be an operation with a `lowering_config` attribute to specify
@@ -335,7 +372,12 @@
       return signalPassFailure();
     }
 
+    // Transfer the cooperative matrix shape to an attribute on the export op,
+    // given that after tiling and vectorization we won't have the root Linalg
+    // op anymore.
     SmallVector<int64_t> cooperativeOpSize = getTargetCooperativeOpSize(rootOp);
+    setSPIRVCooperativeMatrixShape(funcOp, cooperativeOpSize);
+
     SmallVector<int64_t> subgroupCounts = deduceSubgroupCounts(rootOp);
 
     // Then tile and distribute to subgroups.
@@ -367,11 +409,7 @@
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After tiling to subgroups ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after tiling to subgroups");
   }
 };
 
@@ -388,94 +426,53 @@
     MLIRContext *context = &getContext();
     func::FuncOp funcOp = getOperation();
 
-    // First we need to discover the CodeGen lowering configuration. It was
-    // decided earlier and attached to a linalg op as an attribute.
-
-    linalg::LinalgOp rootOp;
-    funcOp.walk([&](linalg::LinalgOp linalgOp) {
-      if (isMatmulOrBatchMatmul(linalgOp) && getLoweringConfig(linalgOp)) {
-        rootOp = linalgOp;
-        return WalkResult::interrupt();
-      }
-      return WalkResult::advance();
-    });
-    if (!rootOp) {
-      funcOp.emitError("expected lowering confg on a (batch) matmul op");
+    // First discover the chosen cooperative matrix shape. It was decided
+    // earlier and attached to the export op as an attribute.
+    ArrayRef<int64_t> cooperativeOpSize =
+        getSPIRVCooperativeMatrixShape(funcOp);
+    if (cooperativeOpSize.empty()) {
+      funcOp->emitError(
+          "expected attribute for chosen cooperative matrix shape");
       return signalPassFailure();
     }
 
-    SmallVector<int64_t> cooperativeOpSize = getTargetCooperativeOpSize(rootOp);
-    SmallVector<int64_t> subgroupCounts = deduceSubgroupCounts(rootOp);
-
-    // Now vectorize and unroll to native cooperative sizes.
+    // Now prepare and unroll to native cooperative sizes.
 
     {
-      RewritePatternSet vectorizationPatterns(context);
-      populateVectorizationPatterns(context, vectorizationPatterns);
-      if (failed(applyPatternsAndFoldGreedily(
-              funcOp, std::move(vectorizationPatterns)))) {
-        return signalPassFailure();
-      }
-
-      RewritePatternSet canonicalizationPatterns(context);
-      vector::ContractionOp::getCanonicalizationPatterns(
-          canonicalizationPatterns, context);
-      populateCombineVectorTransferReadBroadcastPatterns(
-          canonicalizationPatterns);
-      populatePrepareVectorToMMAPatterns(canonicalizationPatterns,
+      RewritePatternSet patterns(context);
+      vector::ContractionOp::getCanonicalizationPatterns(patterns, context);
+      populateCombineVectorTransferReadBroadcastPatterns(patterns);
+      populatePrepareVectorToMMAPatterns(patterns,
                                          /*useNvGPU=*/false);
-      if (failed(applyPatternsAndFoldGreedily(
-              funcOp, std::move(canonicalizationPatterns)))) {
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
         return signalPassFailure();
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After vectorization ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after preparing vector ops");
 
     {
-      RewritePatternSet vectorUnrollPatterns(context);
-      populateVectorUnrollPatterns(cooperativeOpSize, vectorUnrollPatterns);
-      if (failed(applyPatternsAndFoldGreedily(
-              funcOp, std::move(vectorUnrollPatterns)))) {
+      RewritePatternSet patterns(context);
+      populateVectorUnrollPatterns(cooperativeOpSize, patterns);
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
         return signalPassFailure();
       }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After unrolling vector ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
-
-    // At the last perform various canonicalization and cleanups.
-
-    linalg::hoistRedundantVectorTransfers(funcOp);
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After hoisting vector transfers ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after unrolling vector ops");
 
     // When using cooperative matrix we don't want to lower the contract,
     // instead we want to merge contract and transpose so that they can be
     // converted to cooperative matrix matmul op.
-    RewritePatternSet combineTransposePatterns(context);
-    combineTransposePatterns.add<CombineContractTranspose>(context);
-    if (failed(applyPatternsAndFoldGreedily(
-            funcOp, std::move(combineTransposePatterns)))) {
-      return signalPassFailure();
+    {
+      RewritePatternSet patterns(context);
+      patterns.add<CombineContractTranspose>(context);
+      if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+        return signalPassFailure();
+      }
     }
 
-    LLVM_DEBUG({
-      llvm::dbgs() << "--- After handling transposes ---\n";
-      funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
-      llvm::dbgs() << "\n\n";
-    });
+    debugPrint(funcOp, "after combining transpose ops");
   }
 };
 
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
index 32f010b..39c2050 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
@@ -4,11 +4,7 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-//===- Utils.cpp - Utility functions used in Linalg to SPIR-V lowering ----===//
-//
-// Implementaiton of utility functions used while lowering from Linalg to SPIRV.
-//
-//===----------------------------------------------------------------------===//
+//===- Utils.cpp - Utility functions for SPIR-V CodeGen -------------------===//
 
 #include "iree/compiler/Codegen/SPIRV/Utils.h"
 
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
index 1c57a2a..2bbeda7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
@@ -21,6 +21,9 @@
 namespace mlir {
 namespace iree_compiler {
 
+/// Returns the attribute name carrying information about distribution.
+const char *getSPIRVDistributeAttrName();
+
 /// Given an operation, returns the `spirv.target_env` attribute.
 spirv::TargetEnvAttr getSPIRVTargetEnvAttr(Operation *op);
 
@@ -29,9 +32,6 @@
 /// environment. Returns std::nullopt on failures.
 std::optional<int> getSPIRVSubgroupSize(func::FuncOp funcOp);
 
-/// Returns the attribute name carrying information about distribution.
-const char *getSPIRVDistributeAttrName();
-
 /// Returns the tile sizes at the given `tilingLevel` for compute ops in
 /// `funcOp`.
 FailureOr<SmallVector<int64_t>> getSPIRVTileSize(func::FuncOp funcOp,
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
index 4299982..e62e05a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file \
-// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-to-cooperative-ops, iree-spirv-vectorize-to-cooperative-ops, canonicalize, cse)))))' \
+// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-to-cooperative-ops, iree-codegen-generic-vectorization, iree-spirv-vectorize-to-cooperative-ops, iree-codegen-hoist-redundant-vector-transfers, canonicalize, cse)))))' \
 // RUN:   %s | FileCheck %s
 
 #config = #iree_codegen.lowering_config<tile_sizes = [[32, 32], [16, 16], [0, 0, 32], [16, 16, 16]]>