[NFC] Retire filter based vectorization patterns. (#15865)
This revision deletes filter based vectorization patterns. All the
vectorization goes through upstream API directly.
ci-extra: test_a100
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
index c8ade3b..04c8c95 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
@@ -29,9 +29,6 @@
#define DEBUG_TYPE "iree-codegen-gpu-distribute-shared-memory-copy"
-using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
-using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;
-
/// Prints the given `funcOp` after a leading `step` comment header.
void debugPrint(mlir::func::FuncOp funcOp, const char *step) {
LLVM_DEBUG({
@@ -273,14 +270,21 @@
StringAttr::get(patterns.getContext(), kCopyDistributed)));
}
-static void populateVectorizationPatterns(RewritePatternSet &patterns) {
- VectorizationPatterns<linalg::GenericOp>::insert(
- patterns, IREE::LinalgExt::LinalgVectorizationOptions(),
- IREE::LinalgExt::LinalgTransformationFilter(
- {StringAttr::get(patterns.getContext(),
- getCopyToWorkgroupMemoryMarker()),
- StringAttr::get(patterns.getContext(), kCopyDistributed)},
- std::nullopt));
+/// Vectorizes generic ops that have CopyToWorkgroupMemoryMarker or
+// `kCopyDistributed` marker.
+static void vectorizeCopyToWorkgroupMemoryOps(func::FuncOp funcOp) {
+ MLIRContext *context = funcOp.getContext();
+ IRRewriter rewriter(context);
+ auto filter = IREE::LinalgExt::LinalgTransformationFilter(
+ {StringAttr::get(context, getCopyToWorkgroupMemoryMarker()),
+ StringAttr::get(context, kCopyDistributed)},
+ std::nullopt);
+
+ funcOp.walk([&](linalg::GenericOp op) {
+ if (succeeded(filter.checkAndNotify(rewriter, op))) {
+ (void)linalg::vectorize(rewriter, op);
+ }
+ });
}
/// Return a flattened Id Value by combining the 3D gpu thread IDs.
@@ -435,12 +439,7 @@
debugPrint(funcOp, "After step 2: thread distribution");
// Step 3. Vectorize the distributed copies.
- RewritePatternSet vectorizationPatterns(context);
- populateVectorizationPatterns(vectorizationPatterns);
- if (failed(applyPatternsAndFoldGreedily(
- funcOp, std::move(vectorizationPatterns)))) {
- return signalPassFailure();
- }
+ vectorizeCopyToWorkgroupMemoryOps(funcOp);
debugPrint(funcOp, "After step 3: vectorization");
// Step4. Finally unroll all the loop created
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
index 845a4b8..144a99e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -24,26 +25,27 @@
#define DEBUG_TYPE "iree-codegen-gpu-tensorcore-vectorization"
-using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
-using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;
-
namespace mlir::iree_compiler {
//====---------------------------------------------------------------------===//
// Patterns for vectorization
//====---------------------------------------------------------------------===//
-static void populateVectorizationPatterns(RewritePatternSet &patterns) {
+static void vectorizeLinalgOps(func::FuncOp funcOp) {
+ MLIRContext *context = funcOp.getContext();
+ IRRewriter rewriter(context);
IREE::LinalgExt::LinalgTransformationFilter f(
- StringAttr::get(patterns.getContext(), getVectorizeMarker()));
- IREE::LinalgExt::LinalgVectorizationOptions vectorizationOptions;
- VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(
- patterns, vectorizationOptions, f);
- patterns.add<LinalgVectorizationPattern>(
- patterns.getContext(), vectorizationOptions,
- f.addOpFilter<linalg::ContractionOpInterface>());
- vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
- vector::populateVectorReductionToContractPatterns(patterns);
+ StringAttr::get(context, getVectorizeMarker()));
+
+ funcOp.walk([&](Operation *op) {
+ if (failed(f.checkAndNotify(rewriter, op)) ||
+ !isa<linalg::FillOp, linalg::GenericOp, linalg::ContractionOpInterface>(
+ op)) {
+ return WalkResult::advance();
+ }
+ (void)linalg::vectorize(rewriter, op);
+ return WalkResult::advance();
+ });
}
static void populateVectorUnrollPatterns(RewritePatternSet &patterns,
@@ -84,10 +86,13 @@
MLIRContext *context = &getContext();
{
// Step 1(a). Vectorize (linalg to vector).
- RewritePatternSet vectorizationPatterns(context);
- populateVectorizationPatterns(vectorizationPatterns);
+ vectorizeLinalgOps(funcOp);
+ RewritePatternSet contractionPatterns(context);
+ vector::populateVectorTransferPermutationMapLoweringPatterns(
+ contractionPatterns);
+ vector::populateVectorReductionToContractPatterns(contractionPatterns);
if (failed(applyPatternsAndFoldGreedily(
- funcOp, std::move(vectorizationPatterns)))) {
+ funcOp, std::move(contractionPatterns)))) {
return signalPassFailure();
}
LLVM_DEBUG({
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index 3ccacb8..9627486 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -47,8 +47,8 @@
LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
- LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
- void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
+ LogicalResult checkAndNotify(RewriterBase &rewriter, Operation *op) const;
+ void replaceLinalgTransformationFilter(RewriterBase &rewriter,
Operation *op) const;
bool hasReplacementFilter(Operation *op) const;
@@ -140,55 +140,6 @@
//===---------------------------------------------------------------------===//
// Codegen Strategy passes that are moved into IREE.
//===---------------------------------------------------------------------===//
-using VectorSizeComputationFunction =
- std::function<SmallVector<int64_t>(linalg::LinalgOp, ArrayRef<int64_t>)>;
-struct LinalgVectorizationOptions {
- /// Enable vector masking during vectorization.
- bool enableVectorMasking = false;
-
- LinalgVectorizationOptions &setEnableVectorMasking(bool val) {
- enableVectorMasking = val;
- return *this;
- }
-
- /// Canonical vector sizes for the vector iteration space (i.e., vectorization
- /// factors). They are optional for input code with full static shapes.
- SmallVector<int64_t> canonicalVectorSizes;
-
- LinalgVectorizationOptions &
- setCanonicalVectorSizes(ArrayRef<int64_t> vecSizes) {
- assert(canonicalVectorSizes.empty() &&
- "Canonical vector sizes are already set");
- canonicalVectorSizes.append(vecSizes.begin(), vecSizes.end());
- return *this;
- }
-
- /// Computation function that returns the vector sizes to vectorize a given
- /// Linalg operation and the canonical vector sizes of the iteration space.
- VectorSizeComputationFunction vectorSizeComputationFunction = nullptr;
-
- LinalgVectorizationOptions &
- setVectorSizeComputationFunction(VectorSizeComputationFunction fun) {
- vectorSizeComputationFunction = std::move(fun);
- return *this;
- }
-
- /// Enable vectorization of padding operations.
- bool vectorizePadding = false;
-
- LinalgVectorizationOptions &setVectorizePadding(bool vecPad) {
- vectorizePadding = vecPad;
- return *this;
- }
-
- /// Enable vectorization of gather accesses.
- bool vectorizeGatherAccesses = false;
-
- LinalgVectorizationOptions &setVectorizeGatherAccesses(bool vecGather) {
- vectorizeGatherAccesses = vecGather;
- return *this;
- }
-};
void registerPasses();
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 70aba5c..d6ec010 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -170,59 +170,6 @@
};
///
-/// Linalg vectorization patterns.
-///
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `vectorizeLinalgOp` for more details.
-struct LinalgVectorizationPattern
- : public OpInterfaceRewritePattern<linalg::LinalgOp> {
- /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
- LinalgVectorizationPattern(
- MLIRContext *context,
- LinalgVectorizationOptions opts = LinalgVectorizationOptions(),
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
-
- /// Construct a pattern specifically applied to `opName`.
- LinalgVectorizationPattern(
- StringRef opName, MLIRContext *context,
- LinalgVectorizationOptions opts = LinalgVectorizationOptions(),
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
-
- LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
- PatternRewriter &rewriter) const override;
-
-private:
- /// LinalgTransformMarker handles special attribute manipulations.
- LinalgVectorizationOptions options;
- LinalgTransformationFilter filter;
-};
-
-template <typename... OpTypes>
-class VectorizationPatterns;
-
-template <>
-class VectorizationPatterns<> {
-public:
- static void insert(RewritePatternSet &patterns,
- const LinalgVectorizationOptions &opts,
- const LinalgTransformationFilter &f) {}
-};
-
-template <typename OpTy, typename... OpTypes>
-class VectorizationPatterns<OpTy, OpTypes...> {
-public:
- static void insert(RewritePatternSet &patterns,
- const LinalgVectorizationOptions &opts,
- const LinalgTransformationFilter &f) {
- patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
- patterns.getContext(), opts, f);
- VectorizationPatterns<OpTypes...>::insert(patterns, opts, f);
- }
-};
-
-///
/// Linalg promotion patterns.
///
/// Apply the `promoteSubViews` transformation as a pattern.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp
index ba7408a..116f317 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp
@@ -38,9 +38,8 @@
filters.push_back(f);
}
-LogicalResult
-LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
- Operation *op) const {
+LogicalResult LinalgTransformationFilter::checkAndNotify(RewriterBase &rewriter,
+ Operation *op) const {
if (llvm::any_of(filters,
[&](const FilterFunction &f) { return failed(f(op)); }))
return failure();
@@ -73,7 +72,7 @@
}
void LinalgTransformationFilter::replaceLinalgTransformationFilter(
- PatternRewriter &rewriter, Operation *op) const {
+ RewriterBase &rewriter, Operation *op) const {
if (replacement.has_value())
op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value());
else
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index a344d54..cd63756 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -100,32 +100,6 @@
return res;
}
-LinalgVectorizationPattern::LinalgVectorizationPattern(
- MLIRContext *context, LinalgVectorizationOptions opts,
- LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
- : OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
- options(std::move(opts)), filter(std::move(f)) {}
-
-LinalgVectorizationPattern::LinalgVectorizationPattern(
- StringRef opName, MLIRContext *context, LinalgVectorizationOptions opts,
- LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
- : OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
- options(std::move(opts)), filter(f.addOpNameFilter(opName)) {}
-
-LogicalResult
-LinalgVectorizationPattern::matchAndRewrite(linalg::LinalgOp linalgOp,
- PatternRewriter &rewriter) const {
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
- SmallVector<int64_t> vectorSizes;
- if (options.enableVectorMasking)
- vectorSizes.append(options.vectorSizeComputationFunction(
- linalgOp, options.canonicalVectorSizes));
- SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
- return vectorize(rewriter, linalgOp, vectorSizes, scalableVecDims,
- options.vectorizeGatherAccesses);
-}
-
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler