[spirv] NFC: Restructure TileAndVectorizeToCooperativeOps pass (#15138)
This commit lifted the general vectorization and hoisting steps
out of the `TileAndVectorizeToCooperativeOps` pass, and changed
to use the `GenericVectorization` and `HoistRedundantVectorTransfers`
common passes for such functionalities.
Progress towards https://github.com/openxla/iree/issues/15083
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index f3efc4d..35c8aa0 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -353,9 +353,10 @@
nestedModulePM.addPass(createCSEPass());
// Multi-buffer depending on pipeline depth and distribute to shared memory.
- if (pipelineDepth > 0)
+ if (pipelineDepth > 0) {
nestedModulePM.addNestedPass<func::FuncOp>(
createGPUMultiBuffering(pipelineDepth + 1));
+ }
nestedModulePM.addNestedPass<func::FuncOp>(createMemrefCopyToLinalgPass());
nestedModulePM.addNestedPass<func::FuncOp>(
createGPUDistributeSharedMemoryCopy());
@@ -365,10 +366,20 @@
createGPUReduceSharedMemoryBankConflicts(
detail::bankConflictReductionPaddingBits));
+ // Performs high-level n-D mechanical vectorization. This does not perform
+ // unrolling or lowering, which is done later.
+ {
+ GenericVectorizationPassOptions options;
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ createGenericVectorizationPass(options));
+ }
+
// Vectorize to cooperative ops.
nestedModulePM.addNestedPass<func::FuncOp>(
createSPIRVVectorizeToCooperativeOpsPass());
nestedModulePM.addNestedPass<func::FuncOp>(
+ createHoistRedundantVectorTransfersPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(
createRemoveSingleIterationLoopPass());
// Run canonicalization patterns to propagate constant shape sizes after
@@ -420,7 +431,8 @@
nestedPM.addNestedPass<func::FuncOp>(
createGPUTensorTile(/*distributeToWarp=*/false));
- // High-level n-D vectorization.
+ // Performs high-level n-D mechanical vectorization. This does not perform
+ // unrolling or lowering, which is done later.
{
GenericVectorizationPassOptions options;
options.vectorizePadding = true;
@@ -531,8 +543,8 @@
nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
nestedModulePM.addNestedPass<func::FuncOp>(createCSEPass());
- // Performs mechanical vectorization. This does not perform unrolling or
- // lowering, which is done later.
+ // Performs high-level n-D mechanical vectorization. This does not perform
+ // unrolling or lowering, which is done later.
{
GenericVectorizationPassOptions options;
options.vectorizePadding = true;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
index 3f3e52d..58fc4ef 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
@@ -54,8 +54,16 @@
namespace iree_compiler {
namespace {
+void debugPrint(func::FuncOp funcOp, const char *message) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "//--- " << message << " ---//\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+}
+
//===----------------------------------------------------------------------===//
-// Subgroup tiling patterns
+// Cooperative matrix shape utilities
//===----------------------------------------------------------------------===//
/// Gets the chosen hardware cooperative op size attached to the given `op`
@@ -64,6 +72,35 @@
return getTileSizes(op, 3); // For native vector sizes
}
+constexpr char coopMatShapeAttrName[] = "iree.spirv.coop_mat_shape";
+
+/// Sets the chosen cooperative matrix shape for CodeGen onto the
+/// hal.executable.export op for the given `funcOp`.
+void setSPIRVCooperativeMatrixShape(func::FuncOp funcOp,
+ ArrayRef<int64_t> shape) {
+ auto moduleOp = funcOp->getParentOfType<ModuleOp>();
+ auto exportOp = getAllEntryPoints(moduleOp).lookup(funcOp.getName());
+
+ Builder b(funcOp.getContext());
+ exportOp->setAttr(coopMatShapeAttrName, b.getDenseI64ArrayAttr(shape));
+}
+
+/// Returns the chosen cooperative matrix shape for CodeGen from the
+/// hal.executable.export op for the given `funcOp`. Returns an empty
+/// ArrayRef if cannot query.
+ArrayRef<int64_t> getSPIRVCooperativeMatrixShape(func::FuncOp funcOp) {
+ auto moduleOp = funcOp->getParentOfType<ModuleOp>();
+ auto exportOp = getAllEntryPoints(moduleOp).lookup(funcOp.getName());
+ auto attr = exportOp->getAttrOfType<DenseI64ArrayAttr>(coopMatShapeAttrName);
+ if (!attr)
+ return {};
+ return attr.asArrayRef();
+}
+
+//===----------------------------------------------------------------------===//
+// Subgroup tiling patterns
+//===----------------------------------------------------------------------===//
+
/// Deduces required subgroup counts along all workgroup tiled dimensions.
///
/// `op` should be an operation with a `lowering_config` attribute to specify
@@ -335,7 +372,12 @@
return signalPassFailure();
}
+ // Transfer the cooperative matrix shape to an attribute on the export op,
+ // given that after tiling and vectorization we won't have the root Linalg
+ // op anymore.
SmallVector<int64_t> cooperativeOpSize = getTargetCooperativeOpSize(rootOp);
+ setSPIRVCooperativeMatrixShape(funcOp, cooperativeOpSize);
+
SmallVector<int64_t> subgroupCounts = deduceSubgroupCounts(rootOp);
// Then tile and distribute to subgroups.
@@ -367,11 +409,7 @@
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After tiling to subgroups ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after tiling to subgroups");
}
};
@@ -388,94 +426,53 @@
MLIRContext *context = &getContext();
func::FuncOp funcOp = getOperation();
- // First we need to discover the CodeGen lowering configuration. It was
- // decided earlier and attached to a linalg op as an attribute.
-
- linalg::LinalgOp rootOp;
- funcOp.walk([&](linalg::LinalgOp linalgOp) {
- if (isMatmulOrBatchMatmul(linalgOp) && getLoweringConfig(linalgOp)) {
- rootOp = linalgOp;
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- });
- if (!rootOp) {
- funcOp.emitError("expected lowering confg on a (batch) matmul op");
+ // First discover the chosen cooperative matrix shape. It was decided
+ // earlier and attached to the export op as an attribute.
+ ArrayRef<int64_t> cooperativeOpSize =
+ getSPIRVCooperativeMatrixShape(funcOp);
+ if (cooperativeOpSize.empty()) {
+ funcOp->emitError(
+ "expected attribute for chosen cooperative matrix shape");
return signalPassFailure();
}
- SmallVector<int64_t> cooperativeOpSize = getTargetCooperativeOpSize(rootOp);
- SmallVector<int64_t> subgroupCounts = deduceSubgroupCounts(rootOp);
-
- // Now vectorize and unroll to native cooperative sizes.
+ // Now prepare and unroll to native cooperative sizes.
{
- RewritePatternSet vectorizationPatterns(context);
- populateVectorizationPatterns(context, vectorizationPatterns);
- if (failed(applyPatternsAndFoldGreedily(
- funcOp, std::move(vectorizationPatterns)))) {
- return signalPassFailure();
- }
-
- RewritePatternSet canonicalizationPatterns(context);
- vector::ContractionOp::getCanonicalizationPatterns(
- canonicalizationPatterns, context);
- populateCombineVectorTransferReadBroadcastPatterns(
- canonicalizationPatterns);
- populatePrepareVectorToMMAPatterns(canonicalizationPatterns,
+ RewritePatternSet patterns(context);
+ vector::ContractionOp::getCanonicalizationPatterns(patterns, context);
+ populateCombineVectorTransferReadBroadcastPatterns(patterns);
+ populatePrepareVectorToMMAPatterns(patterns,
/*useNvGPU=*/false);
- if (failed(applyPatternsAndFoldGreedily(
- funcOp, std::move(canonicalizationPatterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After vectorization ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after preparing vector ops");
{
- RewritePatternSet vectorUnrollPatterns(context);
- populateVectorUnrollPatterns(cooperativeOpSize, vectorUnrollPatterns);
- if (failed(applyPatternsAndFoldGreedily(
- funcOp, std::move(vectorUnrollPatterns)))) {
+ RewritePatternSet patterns(context);
+ populateVectorUnrollPatterns(cooperativeOpSize, patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After unrolling vector ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
-
- // At the last perform various canonicalization and cleanups.
-
- linalg::hoistRedundantVectorTransfers(funcOp);
-
- LLVM_DEBUG({
- llvm::dbgs() << "--- After hoisting vector transfers ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after unrolling vector ops");
// When using cooperative matrix we don't want to lower the contract,
// instead we want to merge contract and transpose so that they can be
// converted to cooperative matrix matmul op.
- RewritePatternSet combineTransposePatterns(context);
- combineTransposePatterns.add<CombineContractTranspose>(context);
- if (failed(applyPatternsAndFoldGreedily(
- funcOp, std::move(combineTransposePatterns)))) {
- return signalPassFailure();
+ {
+ RewritePatternSet patterns(context);
+ patterns.add<CombineContractTranspose>(context);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
}
- LLVM_DEBUG({
- llvm::dbgs() << "--- After handling transposes ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
+ debugPrint(funcOp, "after combining transpose ops");
}
};
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
index 32f010b..39c2050 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
@@ -4,11 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//===- Utils.cpp - Utility functions used in Linalg to SPIR-V lowering ----===//
-//
-// Implementaiton of utility functions used while lowering from Linalg to SPIRV.
-//
-//===----------------------------------------------------------------------===//
+//===- Utils.cpp - Utility functions for SPIR-V CodeGen -------------------===//
#include "iree/compiler/Codegen/SPIRV/Utils.h"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
index 1c57a2a..2bbeda7 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.h
@@ -21,6 +21,9 @@
namespace mlir {
namespace iree_compiler {
+/// Returns the attribute name carrying information about distribution.
+const char *getSPIRVDistributeAttrName();
+
/// Given an operation, returns the `spirv.target_env` attribute.
spirv::TargetEnvAttr getSPIRVTargetEnvAttr(Operation *op);
@@ -29,9 +32,6 @@
/// environment. Returns std::nullopt on failures.
std::optional<int> getSPIRVSubgroupSize(func::FuncOp funcOp);
-/// Returns the attribute name carrying information about distribution.
-const char *getSPIRVDistributeAttrName();
-
/// Returns the tile sizes at the given `tilingLevel` for compute ops in
/// `funcOp`.
FailureOr<SmallVector<int64_t>> getSPIRVTileSize(func::FuncOp funcOp,
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
index 4299982..e62e05a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-to-cooperative-ops, iree-spirv-vectorize-to-cooperative-ops, canonicalize, cse)))))' \
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile-to-cooperative-ops, iree-codegen-generic-vectorization, iree-spirv-vectorize-to-cooperative-ops, iree-codegen-hoist-redundant-vector-transfers, canonicalize, cse)))))' \
// RUN: %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[32, 32], [16, 16], [0, 0, 32], [16, 16, 16]]>