[LinalgExt] Switch to new pass generation tablegen definitions. (#18216)
The revision applies few cleanups:
- Remove outdated pass constructors in DecomposeAttention pass. The
`tileOnly` option is not used at all.
- Add dummy summery to ConvertAttentionToOnlineAttentionPass
- Switch namespaces to the single-line syntax for few passes.
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
index 50d22f5..3bba99a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/BUILD.bazel
@@ -38,7 +38,6 @@
"DecomposeIm2col.cpp",
"DecomposeWinogradPass.cpp",
"PadContractionToBlockSize.cpp",
- "PassDetail.h",
"Passes.cpp",
"SplitReduction.cpp",
"TileAttention.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index a1de8ce..aecda4b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -34,7 +34,6 @@
"DecomposeIm2col.cpp"
"DecomposeWinogradPass.cpp"
"PadContractionToBlockSize.cpp"
- "PassDetail.h"
"Passes.cpp"
"SplitReduction.cpp"
"TileAttention.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
index 0e7b3b7..31daf1f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -16,6 +15,9 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
+#define GEN_PASS_DEF_CONVERTCONV2DTOIM2COLOPPASS
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
+
static bool hasAllOneValues(DenseIntElementsAttr attr) {
return llvm::all_of(
attr, [](APInt element) { return element.getSExtValue() == 1; });
@@ -322,8 +324,8 @@
ControlFnTy controlFn;
};
-struct ConvertConv2DToIm2ColOpPass
- : ConvertConv2DToIm2ColOpBase<ConvertConv2DToIm2ColOpPass> {
+struct ConvertConv2DToIm2ColOpPass final
+ : impl::ConvertConv2DToIm2ColOpPassBase<ConvertConv2DToIm2ColOpPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
}
@@ -345,9 +347,4 @@
patterns.getContext(), std::move(controlFn));
}
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createConvertConv2DToIm2ColOpPass() {
- return std::make_unique<ConvertConv2DToIm2ColOpPass>();
-}
-
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp
index 4fc370e..e95c7dd 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToWinograd.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/WinogradConstants.h"
@@ -27,6 +26,9 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
+#define GEN_PASS_DEF_CONVERTCONV2DTOWINOGRADPASS
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
+
static const char kWinogradAttr[] = "__winograd_conv";
static bool hasAllOneValues(DenseIntElementsAttr attr) {
@@ -403,8 +405,11 @@
/// }
/// }
/// ```
-struct ConvertConv2DToWinogradPass
- : ConvertConv2DToWinogradBase<ConvertConv2DToWinogradPass> {
+struct ConvertConv2DToWinogradPass final
+ : impl::ConvertConv2DToWinogradPassBase<ConvertConv2DToWinogradPass> {
+ using impl::ConvertConv2DToWinogradPassBase<
+ ConvertConv2DToWinogradPass>::ConvertConv2DToWinogradPassBase;
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect>();
@@ -423,10 +428,4 @@
};
} // namespace
-
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createConvertConv2DToWinogradPass() {
- return std::make_unique<ConvertConv2DToWinogradPass>();
-}
-
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
index b134113..c6bc20b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -23,9 +22,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-using namespace mlir;
-namespace IREE = mlir::iree_compiler::IREE;
-using namespace IREE::LinalgExt;
+namespace mlir::iree_compiler::IREE::LinalgExt {
+
+#define GEN_PASS_DEF_LINALGEXTTOLOOPSPASS
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
/// Recursive method that lowers one dimension of the `TiledOpInterface` to
/// scalar loops at a time.
@@ -100,8 +100,8 @@
//===----------------------------------------------------------------------===//
namespace {
-struct LinalgExtToLoopsPass
- : public LinalgExtToLoopsBase<LinalgExtToLoopsPass> {
+struct LinalgExtToLoopsPass final
+ : impl::LinalgExtToLoopsPassBase<LinalgExtToLoopsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<linalg::LinalgDialect, mlir::arith::ArithDialect,
@@ -120,8 +120,4 @@
}
};
} // namespace
-
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-IREE::LinalgExt::createLinalgExtToLoopsPass() {
- return std::make_unique<LinalgExtToLoopsPass>();
-}
+} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
index daf50dc..b1ff512 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -16,6 +15,9 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
+#define GEN_PASS_DEF_DECOMPOSEATTENTIONPASS
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
+
namespace {
// Computes a reduction along the rows of a 2d tensor of shape MxN
@@ -337,20 +339,16 @@
}
namespace {
-struct DecomposeAttentionPass
- : public DecomposeAttentionBase<DecomposeAttentionPass> {
+struct DecomposeAttentionPass final
+ : impl::DecomposeAttentionPassBase<DecomposeAttentionPass> {
+ using impl::DecomposeAttentionPassBase<
+ DecomposeAttentionPass>::DecomposeAttentionPassBase;
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
}
- DecomposeAttentionPass() = default;
- DecomposeAttentionPass(bool onlyTile, uint64_t tileSize) {
- this->tileSize = tileSize;
- }
- DecomposeAttentionPass(const DecomposeAttentionPass &pass) {
- tileSize = pass.tileSize;
- }
void runOnOperation() override;
};
} // namespace
@@ -377,9 +375,4 @@
rewriter.replaceOp(onlineAtt, results.value());
});
}
-
-std::unique_ptr<Pass> createDecomposeAttentionPass() {
- return std::make_unique<DecomposeAttentionPass>();
-}
-
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp
index a8fc5f2..cb931ec 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeIm2col.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -16,6 +15,10 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler::IREE::LinalgExt {
+
+#define GEN_PASS_DEF_DECOMPOSEIM2COLPASS
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
+
namespace {
/// Pattern to decompose the tiled im2col op.
@@ -37,7 +40,8 @@
} // namespace
namespace {
-struct DecomposeIm2colPass : public DecomposeIm2colBase<DecomposeIm2colPass> {
+struct DecomposeIm2colPass final
+ : impl::DecomposeIm2colPassBase<DecomposeIm2colPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
@@ -58,10 +62,4 @@
return signalPassFailure();
}
}
-
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDecomposeIm2colPass() {
- return std::make_unique<DecomposeIm2colPass>();
-}
-
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeWinogradPass.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeWinogradPass.cpp
index 01feda5..97d9687 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeWinogradPass.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeWinogradPass.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/WinogradConstants.h"
@@ -24,6 +23,10 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler::IREE::LinalgExt {
+
+#define GEN_PASS_DEF_DECOMPOSEWINOGRADTRANSFORMPASS
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
+
namespace {
/// Pattern to remove unit dims from winograd ops after tililng. Tiling is
@@ -333,8 +336,8 @@
} // namespace
namespace {
-struct DecomposeWinogradTransformPass
- : public DecomposeWinogradTransformBase<DecomposeWinogradTransformPass> {
+struct DecomposeWinogradTransformPass final
+ : impl::DecomposeWinogradTransformPassBase<DecomposeWinogradTransformPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
@@ -363,9 +366,4 @@
}
}
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDecomposeWinogradTransformPass() {
- return std::make_unique<DecomposeWinogradTransformPass>();
-}
-
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/PadContractionToBlockSize.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/PadContractionToBlockSize.cpp
index fa015ce..e479d2c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/PadContractionToBlockSize.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/PadContractionToBlockSize.cpp
@@ -6,7 +6,6 @@
#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "iree-dialects/Dialect/Input/InputOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -16,9 +15,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-using namespace mlir;
-namespace IREE = mlir::iree_compiler::IREE;
-using namespace IREE::LinalgExt;
+namespace mlir::iree_compiler::IREE::LinalgExt {
+
+#define GEN_PASS_DEF_PADCONTRACTIONTOBLOCKSIZEPASS
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
static Operation *sliceTensor(Location loc, Value expanded, Value original,
OpBuilder &builder) {
@@ -87,8 +87,11 @@
namespace {
-struct PadContractionToBlockSizePass
- : public PadContractionToBlockSizeBase<PadContractionToBlockSizePass> {
+struct PadContractionToBlockSizePass final
+ : impl::PadContractionToBlockSizePassBase<PadContractionToBlockSizePass> {
+ using impl::PadContractionToBlockSizePassBase<
+ PadContractionToBlockSizePass>::PadContractionToBlockSizePassBase;
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::Input::IREEInputDialect>();
}
@@ -126,8 +129,4 @@
}
};
} // namespace
-
-std::unique_ptr<OperationPass<>>
-IREE::LinalgExt::createPadContractionToBlockSizePass() {
- return std::make_unique<PadContractionToBlockSizePass>();
-}
+} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h
deleted file mode 100644
index 2d93ecc..0000000
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h
+++ /dev/null
@@ -1,23 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
-#define IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
-
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir::iree_compiler::IREE::LinalgExt {
-
-#define GEN_PASS_CLASSES
-
-#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc" // IWYU pragma: keep
-
-} // namespace mlir::iree_compiler::IREE::LinalgExt
-
-#endif // IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index 04c1e8f..af7716c 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -16,11 +16,6 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createLinalgExtToLoopsPass();
-
-std::unique_ptr<OperationPass<>> createPadContractionToBlockSizePass();
-
/// Function signature to control reduction splitting. This returns the split
/// reduction ratio used to split the reduction dimension. The ratio is applied
/// to the reduction dimension of TopK. If the ratio value is less or equal to 1
@@ -33,34 +28,12 @@
splitReduction(RewriterBase &rewriter, LinalgExt::TopkOp topkOp,
const TopkSplitReductionControlFn &splitReductionFn);
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createTopkSplitReductionPass();
-
-/// Decompose im2col ops into a serial loop of insert and extract slice ops.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDecomposeIm2colPass();
-
-/// Decompose the winograd transform ops into a sequence of linalg ops.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDecomposeWinogradTransformPass();
-
-// Creates a pass to convert linalg convolution ops into a gemm with an im2col
-// op and reshapes on the inputs.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createConvertConv2DToIm2ColOpPass();
-
// Patterns to convert linalg convolution ops into a gemm with an im2col
// op and reshapes on the inputs.
void populateConv2DToIm2colOpPatterns(
RewritePatternSet &patterns,
std::optional<std::function<bool(Operation *)>> controlFn = std::nullopt);
-// Creates a pass to convert linalg convolution ops into a sequence of
-// linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd
-// tranformation.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createConvertConv2DToWinogradPass();
-
IREE::LinalgExt::AttentionOp
tileAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops, RewriterBase &rewriter,
@@ -75,18 +48,13 @@
SmallVectorImpl<Operation *> &ops,
RewriterBase &rewriter);
-// Creates a pass to tile the attention op along the reduction dim.
-std::unique_ptr<Pass> createTileAttentionPass();
-
-// Creates a pass to convert the attention op into a sequence of linalg ops.
-std::unique_ptr<Pass> createDecomposeAttentionPass();
-
-std::unique_ptr<Pass> createConvertAttentionToOnlineAttentionPass();
-
//===---------------------------------------------------------------------===//
-// Codegen Strategy passes that are moved into IREE.
+// Register LinalgExt Passes.
//===---------------------------------------------------------------------===//
+#define GEN_PASS_DECL
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc" // IWYU pragma: keep
+
void registerPasses();
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index 9597cc7..57b484f 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -9,13 +9,12 @@
include "mlir/Pass/PassBase.td"
-def LinalgExtToLoops :
+def LinalgExtToLoopsPass :
InterfacePass<"iree-linalg-ext-to-loops", "mlir::FunctionOpInterface"> {
let summary = "Convert LinalgExt ops to loops and Linalg ops.";
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::createLinalgExtToLoopsPass()";
}
-def PadContractionToBlockSize :
+def PadContractionToBlockSizePass :
Pass<"iree-linalg-pad-contraction-to-block-size", ""> {
let summary = "Pads contraction (matmul) ops to next multiple of block size";
let description = [{
@@ -29,7 +28,6 @@
of the dynamic case, applying this pass multiple times can result in
mutation on each run.
}];
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::createPadContractionToBlockSizePass()";
let options = [
Option<"rowAlignment", "rowAlignment", "int", /*default=*/"16",
"The row-wise output block size">,
@@ -38,7 +36,7 @@
];
}
-def TopkSplitReduction:
+def TopkSplitReductionPass:
InterfacePass<"iree-linalg-ext-topk-split-reduction", "mlir::FunctionOpInterface"> {
let summary = "Topk split reduction pass.";
let description = [{
@@ -46,40 +44,33 @@
into two, on containing reducitons in parallel and the other contianing the
combination of the parallel reductions into a final result.
}];
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::createTopkSplitReductionPass()";
let options = [
ListOption<"splitRatios", "split-ratios", "int",
"List of split reduction ratios">,
];
}
-def DecomposeIm2col :
+def DecomposeIm2colPass :
InterfacePass<"iree-linalg-ext-decompose-im2col", "mlir::FunctionOpInterface"> {
let summary =
"Decomposes im2col ops into insert and extract slice ops";
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
- "createDecomposeIm2colPass()";
}
-def DecomposeWinogradTransform :
+def DecomposeWinogradTransformPass :
InterfacePass<"iree-linalg-ext-decompose-winograd", "mlir::FunctionOpInterface"> {
let summary =
"Decomposes winograd transform ops into linalg ops";
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
- "createDecomposeWinogradTransformPass()";
}
-def ConvertConv2DToIm2ColOp :
+def ConvertConv2DToIm2ColOpPass :
InterfacePass<"iree-linalg-ext-convert-conv2d-to-im2col-op", "mlir::FunctionOpInterface"> {
let summary = "Convert linalg convolution ops to im2col gemm based implementation.";
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::createConvertConv2DToIm2ColOpPass()";
}
-def ConvertConv2DToWinograd :
+def ConvertConv2DToWinogradPass :
InterfacePass<"iree-linalg-ext-convert-conv2d-to-winograd", "mlir::FunctionOpInterface"> {
let summary = "Convert linalg convolution ops to winograd based implementation. By default, "
"only convs annotated with a `__winograd_conv` attribute will be rewritten.";
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::createConvertConv2DToWinogradPass()";
let options = [
Option<"replaceAllConvs", "replace-all-convs", "bool",
/*default=*/"false",
@@ -88,36 +79,30 @@
];
}
-def TileAttention :
+def TileAttentionPass :
InterfacePass<"iree-linalg-ext-tile-attention", "mlir::FunctionOpInterface"> {
let summary =
"Tile the attention op along the reduction dimension";
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
- "createTileAttentionPass()";
let options = [
Option<"tileSize", "tileSize", "uint64_t", /*default=*/"",
"Tile size for sequential for loop in attention">,
];
}
-def DecomposeAttention :
+def DecomposeAttentionPass :
InterfacePass<"iree-linalg-ext-decompose-attention", "mlir::FunctionOpInterface"> {
let summary =
"Decomposes attention op into a sequence of linalg ops";
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
- "createDecomposeAttentionPass()";
let options = [
Option<"tileSize", "tileSize", "uint64_t", /*default=*/"",
"Tile size for sequential for loop in attention">,
];
}
-def ConvertAttentionToOnlineAttention :
+def ConvertAttentionToOnlineAttentionPass :
InterfacePass<"iree-linalg-ext-convert-attention-to-online-attention",
"mlir::FunctionOpInterface"> {
- let summary = "";
- let constructor = "mlir::iree_compiler::IREE::LinalgExt::"
- "createConvertAttentionToOnlineAttentionPass()";
+ let summary = "Converts attention op to online_attention op";
}
#endif // IREE_DIALECT_LINALGEXT_PASSES
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp
index ba31883..0ccd7bd 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -25,6 +24,10 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::iree_compiler::IREE::LinalgExt {
+
+#define GEN_PASS_DEF_TOPKSPLITREDUCTIONPASS
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
+
namespace {
// Marker used as attribute the depth of the split reduction transformations.
@@ -289,8 +292,11 @@
//===----------------------------------------------------------------------===//
namespace {
-struct TopkSplitReductionPass
- : public TopkSplitReductionBase<TopkSplitReductionPass> {
+struct TopkSplitReductionPass final
+ : impl::TopkSplitReductionPassBase<TopkSplitReductionPass> {
+ using impl::TopkSplitReductionPassBase<
+ TopkSplitReductionPass>::TopkSplitReductionPassBase;
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<linalg::LinalgDialect, arith::ArithDialect, math::MathDialect,
@@ -397,8 +403,4 @@
return success();
}
-std::unique_ptr<InterfacePass<FunctionOpInterface>>
-createTopkSplitReductionPass() {
- return std::make_unique<TopkSplitReductionPass>();
-}
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
index 384ff55..e69c513 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -17,6 +16,10 @@
namespace mlir::iree_compiler::IREE::LinalgExt {
+#define GEN_PASS_DEF_TILEATTENTIONPASS
+#define GEN_PASS_DEF_CONVERTATTENTIONTOONLINEATTENTIONPASS
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"
+
namespace {
static Value truncateToF16(Value input, Value output,
@@ -185,7 +188,10 @@
return attentionMaps;
}
-struct TileAttentionPass : public TileAttentionBase<TileAttentionPass> {
+struct TileAttentionPass final
+ : impl::TileAttentionPassBase<TileAttentionPass> {
+ using impl::TileAttentionPassBase<TileAttentionPass>::TileAttentionPassBase;
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
@@ -200,7 +206,7 @@
};
struct ConvertAttentionToOnlineAttentionPass final
- : ConvertAttentionToOnlineAttentionBase<
+ : impl::ConvertAttentionToOnlineAttentionPassBase<
ConvertAttentionToOnlineAttentionPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry
@@ -465,12 +471,4 @@
});
}
-std::unique_ptr<Pass> createTileAttentionPass() {
- return std::make_unique<TileAttentionPass>();
-}
-
-std::unique_ptr<Pass> createConvertAttentionToOnlineAttentionPass() {
- return std::make_unique<ConvertAttentionToOnlineAttentionPass>();
-}
-
} // namespace mlir::iree_compiler::IREE::LinalgExt