[Cleanup] Retire filter-based vectorization patterns. (#15185)
The last usage is in GPUDistributeSharedMemoryCopy. The revision
replaces it with calling vectorize method in function walk.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
index 3de3216..10d4265 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
@@ -7,7 +7,6 @@
#include <algorithm>
#include <numeric>
-#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Common/GPU/PassDetail.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
@@ -29,9 +28,6 @@
#define DEBUG_TYPE "iree-codegen-gpu-distribute-shared-memory-copy"
-using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern;
-using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns;
-
/// Prints the given `funcOp` after a leading `step` comment header.
void debugPrint(mlir::func::FuncOp funcOp, const char *step) {
LLVM_DEBUG({
@@ -274,14 +270,17 @@
StringAttr::get(patterns.getContext(), kCopyDistributed)));
}
-static void populateVectorizationPatterns(RewritePatternSet &patterns) {
- VectorizationPatterns<linalg::GenericOp>::insert(
- patterns, IREE::LinalgExt::LinalgVectorizationOptions(),
- IREE::LinalgExt::LinalgTransformationFilter(
- {StringAttr::get(patterns.getContext(),
- getCopyToWorkgroupMemoryMarker()),
- StringAttr::get(patterns.getContext(), kCopyDistributed)},
- std::nullopt));
+static void vectorizeDistributedCopies(func::FuncOp funcOp) {
+ IRRewriter rewriter(funcOp.getContext());
+ SmallVector<linalg::GenericOp> candidates;
+ funcOp.walk([&](linalg::GenericOp op) { candidates.push_back(op); });
+ for (auto op : candidates) {
+ SmallVector<int64_t> vectorSizes;
+ SmallVector<bool> scalableVecDims;
+ scalableVecDims.resize(vectorSizes.size());
+ (void)linalg::vectorize(rewriter, op, vectorSizes, scalableVecDims,
+ /*vectorizeGatherAccesses=*/true);
+ };
}
/// Return a flattened Id Value by combining the 3D gpu thread IDs.
@@ -436,12 +435,7 @@
debugPrint(funcOp, "After step 2: thread distribution");
// Step 3. Vectorize the distributed copies.
- RewritePatternSet vectorizationPatterns(context);
- populateVectorizationPatterns(vectorizationPatterns);
- if (failed(applyPatternsAndFoldGreedily(
- funcOp, std::move(vectorizationPatterns)))) {
- return signalPassFailure();
- }
+ vectorizeDistributedCopies(funcOp);
debugPrint(funcOp, "After step 3: vectorization");
// Step4. Finally unroll all the loop created
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 70aba5c..d6ec010 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -170,59 +170,6 @@
};
///
-/// Linalg vectorization patterns.
-///
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `vectorizeLinalgOp` for more details.
-struct LinalgVectorizationPattern
- : public OpInterfaceRewritePattern<linalg::LinalgOp> {
- /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
- LinalgVectorizationPattern(
- MLIRContext *context,
- LinalgVectorizationOptions opts = LinalgVectorizationOptions(),
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
-
- /// Construct a pattern specifically applied to `opName`.
- LinalgVectorizationPattern(
- StringRef opName, MLIRContext *context,
- LinalgVectorizationOptions opts = LinalgVectorizationOptions(),
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1);
-
- LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
- PatternRewriter &rewriter) const override;
-
-private:
- /// LinalgTransformMarker handles special attribute manipulations.
- LinalgVectorizationOptions options;
- LinalgTransformationFilter filter;
-};
-
-template <typename... OpTypes>
-class VectorizationPatterns;
-
-template <>
-class VectorizationPatterns<> {
-public:
- static void insert(RewritePatternSet &patterns,
- const LinalgVectorizationOptions &opts,
- const LinalgTransformationFilter &f) {}
-};
-
-template <typename OpTy, typename... OpTypes>
-class VectorizationPatterns<OpTy, OpTypes...> {
-public:
- static void insert(RewritePatternSet &patterns,
- const LinalgVectorizationOptions &opts,
- const LinalgTransformationFilter &f) {
- patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
- patterns.getContext(), opts, f);
- VectorizationPatterns<OpTypes...>::insert(patterns, opts, f);
- }
-};
-
-///
/// Linalg promotion patterns.
///
/// Apply the `promoteSubViews` transformation as a pattern.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
index a344d54..cd63756 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp
@@ -100,32 +100,6 @@
return res;
}
-LinalgVectorizationPattern::LinalgVectorizationPattern(
- MLIRContext *context, LinalgVectorizationOptions opts,
- LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
- : OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
- options(std::move(opts)), filter(std::move(f)) {}
-
-LinalgVectorizationPattern::LinalgVectorizationPattern(
- StringRef opName, MLIRContext *context, LinalgVectorizationOptions opts,
- LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit)
- : OpInterfaceRewritePattern<linalg::LinalgOp>(context, benefit),
- options(std::move(opts)), filter(f.addOpNameFilter(opName)) {}
-
-LogicalResult
-LinalgVectorizationPattern::matchAndRewrite(linalg::LinalgOp linalgOp,
- PatternRewriter &rewriter) const {
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
- SmallVector<int64_t> vectorSizes;
- if (options.enableVectorMasking)
- vectorSizes.append(options.vectorSizeComputationFunction(
- linalgOp, options.canonicalVectorSizes));
- SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
- return vectorize(rewriter, linalgOp, vectorSizes, scalableVecDims,
- options.vectorizeGatherAccesses);
-}
-
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler