[GlobalOpt] Switch to new pass generation tablegen definitions. (#18163)
This is mostly an NFC change. The revision applies a little cleanups:
- Remove `enableQuantizedMatmulReassociation` option from
FuseDequantizationMatmulPass. It should be controled by pipeline.
- Move testing options to tablegen definitions for
PropagateLinalgTransposePass
- Switch a couple of passes to follow `create.*Pass` naming convention.
- Switch namespaces to the new single-line syntax for
FuseSiluHorizontalMatmulPass
---------
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
index efd7758..bf0aef8 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -30,7 +30,6 @@
iree_compiler_cc_library(
name = "PassHeaders",
hdrs = [
- "PassDetail.h",
"Passes.h",
"Passes.h.inc",
],
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
index 7d6435a..51a1c5d 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -23,7 +23,6 @@
NAME
PassHeaders
HDRS
- "PassDetail.h"
"Passes.h"
"Passes.h.inc"
DEPS
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CleanupNumericNarrowing.cpp b/compiler/src/iree/compiler/GlobalOptimization/CleanupNumericNarrowing.cpp
index b55dd29..d06d1bd 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CleanupNumericNarrowing.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/CleanupNumericNarrowing.cpp
@@ -5,15 +5,18 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_CLEANUPNUMERICNARROWINGPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
class CleanupNumericNarrowingPass
- : public CleanupNumericNarrowingBase<CleanupNumericNarrowingPass> {
+ : public impl::CleanupNumericNarrowingPassBase<
+ CleanupNumericNarrowingPass> {
void runOnOperation() override {
getOperation()->walk([](IREE::Util::NumericOptionalNarrowOp op) {
op.getResult().replaceAllUsesWith(op.getOperand());
@@ -23,9 +26,4 @@
};
} // namespace
-
-std::unique_ptr<Pass> createCleanupNumericNarrowingPass() {
- return std::make_unique<CleanupNumericNarrowingPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp
index 1d7b3a0..a8b4bec 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Convert1X1FilterConv2DToMatmul.cpp
@@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -14,6 +13,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_CONVERT1X1FILTERCONV2DTOMATMULPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
// Converts linalg.conv_2d_input_nhwc_filter_nhwc op to linalg.matmul
@@ -157,7 +159,7 @@
};
struct Convert1X1FilterConv2DToMatmulPass
- : public Convert1X1FilterConv2DToMatmulBase<
+ : public impl::Convert1X1FilterConv2DToMatmulPassBase<
Convert1X1FilterConv2DToMatmulPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg::LinalgDialect>();
@@ -176,9 +178,4 @@
}
};
} // namespace
-
-std::unique_ptr<Pass> createConvert1X1FilterConv2DToMatmulPass() {
- return std::make_unique<Convert1X1FilterConv2DToMatmulPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp b/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp
index c56e708..44172d3 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp
@@ -4,20 +4,20 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-using namespace mlir;
-
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_DATALAYOUTPROPAGATIONPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
struct DataLayoutPropagationPass
- : public DataLayoutPropagationBase<DataLayoutPropagationPass> {
+ : public impl::DataLayoutPropagationPassBase<DataLayoutPropagationPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
FunctionOpInterface funcOp = getOperation();
@@ -43,9 +43,4 @@
};
} // namespace
-
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDataLayoutPropagationPass() {
- return std::make_unique<DataLayoutPropagationPass>();
-}
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp b/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp
index 445ea4f..eae755f 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DecomposeConcat.cpp
@@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -16,6 +15,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_DECOMPOSECONCATPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
static Value createTranspose(OpBuilder &builder, Value source,
@@ -78,13 +80,16 @@
}
};
-struct DecomposeConcatPass : public DecomposeConcatBase<DecomposeConcatPass> {
+struct DecomposeConcatPass
+ : public impl::DecomposeConcatPassBase<DecomposeConcatPass> {
+ using impl::DecomposeConcatPassBase<
+ DecomposeConcatPass>::DecomposeConcatPassBase;
+ explicit DecomposeConcatPass(bool enableConcatTransposition) {
+ this->enableConcatTransposition = enableConcatTransposition;
+ }
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg::LinalgDialect>();
}
- DecomposeConcatPass(bool enableConcatTransposition) {
- this->enableConcatTransposition = enableConcatTransposition;
- }
DecomposeConcatPass(const DecomposeConcatPass &pass)
: DecomposeConcatPass(pass.enableConcatTransposition) {}
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
index 9df9578..5f72ff5 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
@@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -19,6 +18,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_DEMOTECONTRACTIONINPUTSTOBF16PASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
// For narrowable inputs, selects
@@ -133,12 +135,13 @@
};
class DemoteContractionInputsToBF16Pass
- : public DemoteContractionInputsToBF16Base<
+ : public impl::DemoteContractionInputsToBF16PassBase<
DemoteContractionInputsToBF16Pass> {
-
public:
+ using impl::DemoteContractionInputsToBF16PassBase<
+ DemoteContractionInputsToBF16Pass>::DemoteContractionInputsToBF16PassBase;
explicit DemoteContractionInputsToBF16Pass(const DemotionOption &option) {
- this->demoteOnly.setValue(option);
+ this->demoteOnly = option;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp
index c90a0f8..524f111 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp
@@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -27,6 +26,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_DETACHELEMENTWISEFROMNAMEDOPSPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
struct DetachElementwisePattern
@@ -185,7 +187,7 @@
};
struct DetachElementwiseFromNamedOpsPass
- : public DetachElementwiseFromNamedOpsBase<
+ : public impl::DetachElementwiseFromNamedOpsPassBase<
DetachElementwiseFromNamedOpsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<arith::ArithDialect, linalg::LinalgDialect,
@@ -206,9 +208,4 @@
};
} // namespace
-
-std::unique_ptr<Pass> createDetachElementwiseFromNamedOpsPass() {
- return std::make_unique<DetachElementwiseFromNamedOpsPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/EraseUnusedLinalgOperands.cpp b/compiler/src/iree/compiler/GlobalOptimization/EraseUnusedLinalgOperands.cpp
index a995382..3396e46 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/EraseUnusedLinalgOperands.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/EraseUnusedLinalgOperands.cpp
@@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -14,9 +13,13 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_ERASEUNUSEDLINALGOPERANDSPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
struct EraseUnusedLinalgOperandsPass
- : public EraseUnusedLinalgOperandsBase<EraseUnusedLinalgOperandsPass> {
+ : public impl::EraseUnusedLinalgOperandsPassBase<
+ EraseUnusedLinalgOperandsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
@@ -28,10 +31,4 @@
}
};
} // namespace
-
-std::unique_ptr<OperationPass<mlir::ModuleOp>>
-createEraseUnusedLinalgOperands() {
- return std::make_unique<EraseUnusedLinalgOperandsPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp
index 5120999..c9974d1 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp
@@ -11,7 +11,6 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/Utils/IntegerSet.h"
#include "llvm/ADT/BreadthFirstIterator.h"
@@ -29,6 +28,10 @@
#define DEBUG_TYPE "iree-global-opt-expand-tensor-shapes"
namespace mlir::iree_compiler::GlobalOptimization {
+
+#define GEN_PASS_DEF_EXPANDTENSORSHAPESPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
// TODO(benvanik): factor out into a generic util pass base that lets us share
@@ -624,10 +627,8 @@
// results are always wrapped in a flow.tensor.tie_shape, with the
// elision/deduplication/etc left until cleanup.
class ExpandTensorShapesPass
- : public ExpandTensorShapesBase<ExpandTensorShapesPass> {
+ : public impl::ExpandTensorShapesPassBase<ExpandTensorShapesPass> {
public:
- ExpandTensorShapesPass() = default;
-
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<mlir::arith::ArithDialect>();
registry.insert<IREE::Flow::FlowDialect>();
@@ -661,9 +662,4 @@
};
} // namespace
-
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createExpandTensorShapesPass() {
- return std::make_unique<ExpandTensorShapesPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp
index 832d9fa..af3cff9 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -24,6 +23,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_FUSEDEQUANTIZATIONMATMULPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
//----------------------------------------------------------------------------//
@@ -767,19 +769,12 @@
}
struct FuseDequantizationMatmulPass
- : public FuseDequantizationMatmulBase<FuseDequantizationMatmulPass> {
-
+ : public impl::FuseDequantizationMatmulPassBase<
+ FuseDequantizationMatmulPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg::LinalgDialect, IREE::Flow::FlowDialect,
math::MathDialect>();
}
- FuseDequantizationMatmulPass(bool enableQuantizedMatmulReassociation) {
- this->enableQuantizedMatmulReassociation =
- enableQuantizedMatmulReassociation;
- }
- FuseDequantizationMatmulPass(const FuseDequantizationMatmulPass &pass)
- : FuseDequantizationMatmulPass(pass.enableQuantizedMatmulReassociation) {}
-
void runOnOperation() override;
};
@@ -789,56 +784,45 @@
MLIRContext *context = &getContext();
auto funcOp = getOperation();
- // Perform reassociation if enabled
- if (this->enableQuantizedMatmulReassociation) {
- int quantizeBitWidth = 16;
- SmallVector<std::pair<linalg::GenericOp, linalg::GenericOp>> candidates;
- for (auto genericOp :
- funcOp.getFunctionBody().getOps<linalg::GenericOp>()) {
- if (failed(isContractionWithTwoReductions(genericOp))) {
- continue;
- }
+ int quantizeBitWidth = 16;
+ SmallVector<std::pair<linalg::GenericOp, linalg::GenericOp>> candidates;
+ for (auto genericOp : funcOp.getFunctionBody().getOps<linalg::GenericOp>()) {
+ if (failed(isContractionWithTwoReductions(genericOp))) {
+ continue;
+ }
- OpOperand *lhs = genericOp.getDpsInputOperand(0);
- OpOperand *rhs = genericOp.getDpsInputOperand(1);
- auto lhsOp = lhs->get().getDefiningOp<linalg::GenericOp>();
- auto rhsOp = rhs->get().getDefiningOp<linalg::GenericOp>();
- if (!llvm::cast<ShapedType>(genericOp.getInputs()[0].getType())
- .hasStaticShape() ||
- !llvm::cast<ShapedType>(genericOp.getInputs()[1].getType())
- .hasStaticShape() ||
- !llvm::cast<ShapedType>(genericOp.getResults()[0].getType())
- .hasStaticShape()) {
- // Codegen can't handle the dynamic case yet.
+ OpOperand *lhs = genericOp.getDpsInputOperand(0);
+ OpOperand *rhs = genericOp.getDpsInputOperand(1);
+ auto lhsOp = lhs->get().getDefiningOp<linalg::GenericOp>();
+ auto rhsOp = rhs->get().getDefiningOp<linalg::GenericOp>();
+ if (!llvm::cast<ShapedType>(genericOp.getInputs()[0].getType())
+ .hasStaticShape() ||
+ !llvm::cast<ShapedType>(genericOp.getInputs()[1].getType())
+ .hasStaticShape() ||
+ !llvm::cast<ShapedType>(genericOp.getResults()[0].getType())
+ .hasStaticShape()) {
+ // Codegen can't handle the dynamic case yet.
+ continue;
+ }
+ if (lhsOp) {
+ if (!failed(isGroupedDequantizationOp(lhsOp))) {
+ candidates.push_back(std::make_pair(lhsOp, genericOp));
continue;
}
- if (lhsOp) {
- if (!failed(isGroupedDequantizationOp(lhsOp))) {
- candidates.push_back(std::make_pair(lhsOp, genericOp));
- continue;
- }
- }
- if (rhsOp) {
- if (!failed(isGroupedDequantizationOp(rhsOp))) {
- candidates.push_back(std::make_pair(rhsOp, genericOp));
- }
- }
}
- IRRewriter rewriter(context);
- for (auto candidate : candidates) {
- rewriter.setInsertionPointAfter(candidate.second);
- if (failed(reassociateDequantMatmul(
- rewriter, candidate.first, candidate.second, quantizeBitWidth))) {
- return signalPassFailure();
+ if (rhsOp) {
+ if (!failed(isGroupedDequantizationOp(rhsOp))) {
+ candidates.push_back(std::make_pair(rhsOp, genericOp));
}
}
}
+ IRRewriter rewriter(context);
+ for (auto candidate : candidates) {
+ rewriter.setInsertionPointAfter(candidate.second);
+ if (failed(reassociateDequantMatmul(rewriter, candidate.first,
+ candidate.second, quantizeBitWidth))) {
+ return signalPassFailure();
+ }
+ }
}
-
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createFuseDequantizationMatmulPass(bool enableQuantizedMatmulReassociation) {
- return std::make_unique<FuseDequantizationMatmulPass>(
- enableQuantizedMatmulReassociation);
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseHorizontalContractions.cpp
index ca9a490..b71912b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/FuseHorizontalContractions.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/FuseHorizontalContractions.cpp
@@ -6,11 +6,8 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/GlobalOptimization/Utils.h"
-#include "mlir/IR/Dominance.h"
-
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -21,6 +18,7 @@
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -28,14 +26,18 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-global-opt-fuse-horizontal-contraction"
+#define DEBUG_TYPE "iree-global-opt-fuse-horizontal-contractions"
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_FUSEHORIZONTALCONTRACTIONSPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
struct FuseHorizontalContractionsPass
- : public FuseHorizontalContractionsBase<FuseHorizontalContractionsPass> {
+ : public impl::FuseHorizontalContractionsPassBase<
+ FuseHorizontalContractionsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<arith::ArithDialect, tensor::TensorDialect>();
@@ -449,10 +451,4 @@
return signalPassFailure();
}
}
-
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createFuseHorizontalContractionsPass() {
- return std::make_unique<FuseHorizontalContractionsPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseSiluHorizontalMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseSiluHorizontalMatmul.cpp
index 3da1d32..b4a2711 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/FuseSiluHorizontalMatmul.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/FuseSiluHorizontalMatmul.cpp
@@ -6,7 +6,6 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/GlobalOptimization/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -23,9 +22,10 @@
#define DEBUG_TYPE "iree-global-opt-fuse-dequantization-matmul"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-namespace mlir {
-namespace iree_compiler {
-namespace GlobalOptimization {
+namespace mlir::iree_compiler::GlobalOptimization {
+
+#define GEN_PASS_DEF_FUSESILUHORIZONTALMATMULPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
namespace {
@@ -169,16 +169,12 @@
};
struct FuseSiluHorizontalMatmulPass
- : public FuseSiluHorizontalMatmulBase<FuseSiluHorizontalMatmulPass> {
-
+ : public impl::FuseSiluHorizontalMatmulPassBase<
+ FuseSiluHorizontalMatmulPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg::LinalgDialect, IREE::Flow::FlowDialect,
math::MathDialect>();
}
- FuseSiluHorizontalMatmulPass() {}
- FuseSiluHorizontalMatmulPass(const FuseSiluHorizontalMatmulPass &pass)
- : FuseSiluHorizontalMatmulPass() {}
-
void runOnOperation() override;
};
@@ -195,11 +191,4 @@
}
}
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createFuseSiluHorizontalMatmulPass() {
- return std::make_unique<FuseSiluHorizontalMatmulPass>();
-}
-
-} // namespace GlobalOptimization
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
index 5f30902..92293bc 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/GeneralizeLinalgNamedOps.cpp
@@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -20,10 +19,13 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_GENERALIZELINALGNAMEDOPSPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
struct GeneralizeLinalgNamedOpsPass
- : public GeneralizeLinalgNamedOpsBase<GeneralizeLinalgNamedOpsPass> {
-
+ : public impl::GeneralizeLinalgNamedOpsPassBase<
+ GeneralizeLinalgNamedOpsPass> {
void runOnOperation() override;
};
} // namespace
@@ -59,9 +61,4 @@
}
}
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createGeneralizeLinalgNamedOpsPass() {
- return std::make_unique<GeneralizeLinalgNamedOpsPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp b/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp
index 9a30e21..0489448 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/GlobalLoopInvariantCodeMotion.cpp
@@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -101,10 +100,13 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_GLOBALLOOPINVARIANTCODEMOTIONPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
struct GlobalLoopInvariantCodeMotionPass
- : public GlobalLoopInvariantCodeMotionBase<
+ : public impl::GlobalLoopInvariantCodeMotionPassBase<
GlobalLoopInvariantCodeMotionPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
@@ -129,9 +131,4 @@
};
} // namespace
-
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createGlobalLoopInvariantCodeMotionPass() {
- return std::make_unique<GlobalLoopInvariantCodeMotionPass>();
-}
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/InferNumericNarrowing.cpp b/compiler/src/iree/compiler/GlobalOptimization/InferNumericNarrowing.cpp
index 0edb016..c4d48aa 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/InferNumericNarrowing.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/InferNumericNarrowing.cpp
@@ -10,7 +10,6 @@
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
@@ -20,6 +19,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_INFERNUMERICNARROWINGPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
IntegerType deriveIntegerTypeFromRange(MLIRContext *context, int64_t minValue,
@@ -47,7 +49,7 @@
}
class InferNumericNarrowingPass
- : public InferNumericNarrowingBase<InferNumericNarrowingPass> {
+ : public impl::InferNumericNarrowingPassBase<InferNumericNarrowingPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::Util::UtilDialect>();
}
@@ -128,9 +130,4 @@
};
} // namespace
-
-std::unique_ptr<Pass> createInferNumericNarrowingPass() {
- return std::make_unique<InferNumericNarrowingPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
index 5c264be..750ae78 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/MaterializeHomogeneousEncodings.cpp
@@ -9,7 +9,6 @@
#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "llvm/ADT/STLExtras.h"
@@ -24,15 +23,17 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_MATERIALIZEHOMOGENEOUSENCODINGSPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
using FunctionLikeNest =
MultiOpNest<IREE::Util::InitializerOp, IREE::Util::FuncOp>;
+namespace {
class MaterializeHomogeneousEncodingsPass
- : public MaterializeHomogeneousEncodingsBase<
+ : public impl::MaterializeHomogeneousEncodingsPassBase<
MaterializeHomogeneousEncodingsPass> {
public:
- MaterializeHomogeneousEncodingsPass() = default;
-
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::HAL::HALDialect, tensor::TensorDialect>();
}
@@ -78,10 +79,5 @@
}
}
};
-
-std::unique_ptr<OperationPass<ModuleOp>>
-createMaterializeHomogeneousEncodingsPass() {
- return std::make_unique<MaterializeHomogeneousEncodingsPass>();
-}
-
+} // namespace
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp b/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp
index 329354e..38b8749 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/OptimizeNumerics.cpp
@@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -16,6 +15,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_OPTIMIZENUMERICSPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
int getNextPotBitWidth(int bitWidth, int minBitWidth = 8) {
@@ -254,7 +256,8 @@
}
};
-class OptimizeNumericsPass : public OptimizeNumericsBase<OptimizeNumericsPass> {
+class OptimizeNumericsPass
+ : public impl::OptimizeNumericsPassBase<OptimizeNumericsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
@@ -274,9 +277,4 @@
};
} // namespace
-
-std::unique_ptr<Pass> createOptimizeNumericsPass() {
- return std::make_unique<OptimizeNumericsPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h b/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h
deleted file mode 100644
index b0b79e7..0000000
--- a/compiler/src/iree/compiler/GlobalOptimization/PassDetail.h
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2023 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_
-#define IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_
-
-#include "iree/compiler/GlobalOptimization/Passes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir::iree_compiler::GlobalOptimization {
-
-#define GEN_PASS_CLASSES
-#include "iree/compiler/GlobalOptimization/Passes.h.inc"
-
-} // namespace mlir::iree_compiler::GlobalOptimization
-
-#endif // IREE_COMPILER_GLOBALOPTIMIZATION_PASSDETAIL_H_
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index e095d78..877f976 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -101,7 +101,7 @@
.addPass(createDetachElementwiseFromNamedOpsPass)
.addPass(mlir::createLinalgNamedOpConversionPass)
.addPass(createConvert1X1FilterConv2DToMatmulPass);
- mainPassManager.addPass(createEraseUnusedLinalgOperands());
+ mainPassManager.addPass(createEraseUnusedLinalgOperandsPass());
// Expand tensor shapes into SSA values and optimize the whole program.
// The more we are able to equate shape dimensions at this level the
@@ -116,7 +116,7 @@
// RaiseSpecialOps, by virtue of implementing various peephole
// optimizations, is sensitive to surrounding IR structure. Thus we run
// this pass both before unit dim folding + consteval, as well as after.
- .addPass(createRaiseSpecialOps)
+ .addPass(createRaiseSpecialOpsPass)
// We decompose and transpose concatenations immediately before folding
// unit extent dims because this allows decoupling unit dims in the
// concatenation from the transposes that are introduced.
@@ -138,10 +138,8 @@
return createDemoteContractionInputsToBF16Pass(
clDemoteContractionInputsToBF16Strategy);
})
- .addPass([&]() {
- return createFuseDequantizationMatmulPass(
- clEnableQuantizedMatmulReassociation);
- })
+ .addPredicatedPass(clEnableQuantizedMatmulReassociation,
+ createFuseDequantizationMatmulPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
// Propagate transposes immediately before set encoding/data tiling
@@ -226,7 +224,7 @@
FunctionLikeNest(mainPassManager)
// After running const-eval to a fixed point and folding unit extent dims,
// try any new raising opportunities.
- .addPass(createRaiseSpecialOps)
+ .addPass(createRaiseSpecialOpsPass)
// Strip std.assert & co after we perform optimizations; prior to this we
// may use the assertions to derive information during analysis.
.addPredicatedPass(transformOptions.options.stripAssertions,
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
index 7643737..6225224 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
@@ -36,93 +36,27 @@
void buildGlobalOptimizationPassPipeline(
OpPassManager &mainPassManager, const TransformOptions &transformOptions);
-//===----------------------------------------------------------------------===//
-// Input canonicalization and legalization
-//===----------------------------------------------------------------------===//
+//------------------------------------------------------------------------------
+// Wrappers that not use tablegen options.
+//------------------------------------------------------------------------------
-/// Cleans up any numeric narrowing ops inserted by
-/// iree-global-opt-infer-numeric-narrowing.
-std::unique_ptr<Pass> createCleanupNumericNarrowingPass();
-
-/// Converts linalg convolution ops with 1x1 kernels into linalg.matmul.
-std::unique_ptr<Pass> createConvert1X1FilterConv2DToMatmulPass();
-
-/// Fuses dequantization and matmul linalg.generic ops
-std::unique_ptr<Pass>
-createDecomposeConcatPass(bool enableConcatTransposition = false);
+std::unique_ptr<Pass> createDecomposeConcatPass(bool enableConcatTransposition);
// Used by the demoteContractionInputsToBF16 pass to determine which op inputs
// to demote.
enum class DemotionOption { All, Conv, Matmul, None };
+std::unique_ptr<Pass>
+createDemoteContractionInputsToBF16Pass(DemotionOption option);
-/// Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16.
-std::unique_ptr<Pass> createDemoteContractionInputsToBF16Pass(
- DemotionOption option = DemotionOption::None);
-
-/// Detaches elementwise ops from named Linalg ops.
-std::unique_ptr<Pass> createDetachElementwiseFromNamedOpsPass();
-
-/// Applies patterns to erase unused linalg operands and remove dead code
-/// associated.
-std::unique_ptr<OperationPass<mlir::ModuleOp>>
-createEraseUnusedLinalgOperands();
-
-/// Expands tensor shape dimensions into SSA values across the program.
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createExpandTensorShapesPass();
-
-/// Fuses dequantization and matmul linalg.generic ops
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createFuseDequantizationMatmulPass(
- bool enableQuantizedMatmulReassociation = false);
+createPropagateLinalgTransposePass(bool enableAggressivePropagation);
-/// Horizontally fuses multiple contraction ops.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createFuseHorizontalContractionsPass();
+//----------------------------------------------------------------------------//
+// Register GlobalOptimization Passes
+//----------------------------------------------------------------------------//
-/// Fuses two matmul ops and a linalg.generic Silu op
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createFuseSiluHorizontalMatmulPass();
-
-/// Generalizes some named Linalg ops into `linalg.generic` operations since the
-/// compiler can handle that better.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createGeneralizeLinalgNamedOpsPass();
-
-/// Infers and inserts util.numeric.optional_narrow ops at points that may be
-/// beneficial.
-std::unique_ptr<Pass> createInferNumericNarrowingPass();
-
-/// Materializes logical encodings to physical encodings if there is a single
-/// device target.
-std::unique_ptr<OperationPass<mlir::ModuleOp>>
-createMaterializeHomogeneousEncodingsPass();
-
-/// Optimizes numerics given annotations added via
-/// iree-global-opt-infer-numeric-narrowing.
-std::unique_ptr<Pass> createOptimizeNumericsPass();
-
-/// Propagates linalg.transpose ops to a restricted set of operations.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createPropagateLinalgTransposePass(bool enableAggressivePropagation = false);
-
-/// Performs specialized raisings of various sequences of ops to a
-/// representation easier for the compiler to handle.
-std::unique_ptr<Pass> createRaiseSpecialOps();
-
-/// Removes tensors that have 0-extents.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createRemoveZeroExtentTensorsPass();
-
-/// Simplifies tensor pack/unpack ops to reshape ops.
-std::unique_ptr<Pass> createSimplifyPackUnpackPass();
-
-/// Hoist loop invariants out of loops with zero-trip-check.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createGlobalLoopInvariantCodeMotionPass();
-
-/// Propagate pack/unpack ops across other ops to improve fusion.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createDataLayoutPropagationPass();
+#define GEN_PASS_DECL
+#include "iree/compiler/GlobalOptimization/Passes.h.inc" // IWYU pragma: keep
void registerGlobalOptimizationPipeline();
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
index 141fcf3..c919161 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
@@ -9,22 +9,19 @@
include "mlir/Pass/PassBase.td"
-def CleanupNumericNarrowing :
+def CleanupNumericNarrowingPass :
Pass<"iree-global-opt-cleanup-numeric-narrowing", ""> {
let summary = "Cleans up any numeric narrowing ops inserted by iree-global-opt-infer-numeric-narrowing.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createCleanupNumericNarrowingPass()";
}
-def Convert1X1FilterConv2DToMatmul:
+def Convert1X1FilterConv2DToMatmulPass:
Pass<"iree-global-opt-convert-1x1-filter-conv2d-to-matmul", ""> {
let summary = "Convert linalg convolution ops with 1x1 kernels into linalg matrix multiplication ops.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createConvert1X1FilterConv2DToMatmulPass()";
}
-def DecomposeConcat :
+def DecomposeConcatPass :
Pass<"iree-global-opt-decompose-concat", ""> {
let summary = "Decomposes concatenations into a destination and a sequence of slice inserts.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createDecomposeConcatPass()";
let options = [
Option<"enableConcatTransposition", "enable-concat-transposition", "bool",
/*default=*/"false", "Allows transposing concatenations such that "
@@ -32,12 +29,10 @@
];
}
-def DemoteContractionInputsToBF16
+def DemoteContractionInputsToBF16Pass
: Pass<"iree-global-opt-demote-contraction-inputs-to-bf16", ""> {
let summary =
"Demotes inputs (LHS, RHS) of linalg matmul-like ops from f32 to bf16.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::"
- "createDemoteContractionInputsToBF16Pass()";
let options =
[Option<"demoteOnly", "demote-only",
"mlir::iree_compiler::GlobalOptimization::DemotionOption",
@@ -61,106 +56,89 @@
];
}
-def DetachElementwiseFromNamedOps :
+def DetachElementwiseFromNamedOpsPass :
Pass<"iree-global-opt-detach-elementwise-from-named-ops", ""> {
let summary = "Detaches elementwise ops from named Linalg ops.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createDetachElementwiseFromNamedOpsPass()";
}
-def EraseUnusedLinalgOperands :
+def EraseUnusedLinalgOperandsPass :
Pass<"iree-global-opt-erase-unused-linalg-operands", "mlir::ModuleOp"> {
let summary = "Erases unused linalg operand and remove dead code.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createEraseUnusedLinalgOperands()";
}
-def ExpandTensorShapes :
+def ExpandTensorShapesPass :
Pass<"iree-global-opt-expand-tensor-shapes", "mlir::ModuleOp"> {
let summary = "Expands tensor shape dimensions into SSA values across the program.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createExpandTensorShapesPass()";
}
-def FuseDequantizationMatmul:
+def FuseDequantizationMatmulPass:
InterfacePass<"iree-global-opt-fuse-dequantization-matmul", "mlir::FunctionOpInterface"> {
let summary = "Fuses dequantization and matmul linalg.generic ops.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createFuseDequantizationMatmulPass()";
- let options = [
- Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", "bool",
- /*default=*/"false", "Allow reassociation of quantized matmuls (experimental).">,
- ];
}
-def FuseHorizontalContractions:
+def FuseHorizontalContractionsPass:
InterfacePass<"iree-global-opt-fuse-horizontal-contractions", "mlir::FunctionOpInterface"> {
let summary = "Fuses horizontal contraction ops without fusions";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createFuseHorizontalContractionsPass()";
}
-def FuseSiluHorizontalMatmul:
+def FuseSiluHorizontalMatmulPass:
InterfacePass<"iree-global-opt-fuse-silu-horizontal-matmul", "mlir::FunctionOpInterface"> {
let summary = "Fuses matmul ops and silu linalg.generic op.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createFuseSiluHorizontalMatmulPass()";
}
-def GeneralizeLinalgNamedOps :
+def GeneralizeLinalgNamedOpsPass :
InterfacePass<"iree-global-opt-generalize-linalg-named-ops", "mlir::FunctionOpInterface"> {
let summary = "Convert some Linalg named ops into linalg.generics.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createGeneralizeLinalgNamedOpsPass()";
}
-def InferNumericNarrowing :
+def InferNumericNarrowingPass :
Pass<"iree-global-opt-infer-numeric-narrowing", ""> {
let summary = "Infers and inserts util.numeric.optional_narrow ops at points that may be beneficial.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createInferNumericNarrowingPass()";
}
-def MaterializeHomogeneousEncodings :
+def MaterializeHomogeneousEncodingsPass :
Pass<"iree-global-opt-materialize-homogeneous-encodings", "mlir::ModuleOp"> {
let summary = "Materializes logical encodings to physical encodings if there is a single device target.";
- let constructor =
- "mlir::iree_compiler::GlobalOptimization::createMaterializeHomogeneousEncodingsPass()";
}
-def OptimizeNumerics :
+def OptimizeNumericsPass :
Pass<"iree-global-opt-optimize-numerics", ""> {
let summary = "Optimizes numerics given annotations added via iree-global-opt-infer-numeric-narrowing.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createOptimizeNumericsPass()";
}
-def PropagateLinalgTranspose :
+def PropagateLinalgTransposePass :
InterfacePass<"iree-global-opt-propagate-linalg-transpose", "mlir::FunctionOpInterface"> {
let summary = "Propagates linalg.transpose through a restricted set of ops.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createPropagateLinalgTransposePass()";
let options = [
Option<"enableAggressivePropagation", "enable-aggressive-propagation", "bool",
/*default=*/"false", "Enable aggressive propagation to named ops.">,
+ Option<"testSinkingOnly", "test-sinking-only", "bool", /*default=*/"false",
+ "Flag used for lit-testing sinking patterns only. Not for general usage">,
+ Option<"testBubblingOnly", "test-bubbling-only", "bool", /*default=*/"false",
+ "Flag used for lit-testing bubbling patterns only. Not for general usage">,
];
}
-def RaiseSpecialOps :
+def RaiseSpecialOpsPass :
Pass<"iree-global-opt-raise-special-ops", ""> {
let summary = "Raises special ops like softmax to the high level linalg.ext representation.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createRaiseSpecialOps()";
}
-def RemoveZeroExtentTensors :
+def RemoveZeroExtentTensorsPass :
InterfacePass<"iree-global-opt-remove-zero-extent-tensors", "mlir::FunctionOpInterface"> {
let summary = "Removes tensors that have 0-extents.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createRemoveZeroExtentTensorsPass()";
}
-def SimplifyPackUnpack : Pass<"iree-global-opt-simplify-pack-unpack", ""> {
+def SimplifyPackUnpackPass : Pass<"iree-global-opt-simplify-pack-unpack", ""> {
let summary = "Simplifies tensor pack and unpack ops.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createSimplifyPackUnpackPass()";
}
-def GlobalLoopInvariantCodeMotion : InterfacePass<"iree-global-opt-loop-invariant-code-motion", "mlir::FunctionOpInterface"> {
+def GlobalLoopInvariantCodeMotionPass : InterfacePass<"iree-global-opt-loop-invariant-code-motion", "mlir::FunctionOpInterface"> {
let summary = "Hoist loop invariants out of loops with zero-trip-check.";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createGlobalLoopInvariantCodeMotionPass()";
}
-def DataLayoutPropagation : InterfacePass<"iree-global-opt-data-layout-propagation", "mlir::FunctionOpInterface"> {
+def DataLayoutPropagationPass : InterfacePass<"iree-global-opt-data-layout-propagation", "mlir::FunctionOpInterface"> {
let summary = "Propagate pack/unpack ops across other ops to improve fusion";
- let constructor = "mlir::iree_compiler::GlobalOptimization::createDataLayoutPropagationPass()";
}
#endif // IREE_COMPILER_GLOBALOPTIMIZATION_PASSES
diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
index ac2c356..2dea4ad 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -34,6 +33,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_PROPAGATELINALGTRANSPOSEPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
//===----------------------------------------------------------------------===//
// Transpose permutation helpers
//===----------------------------------------------------------------------===//
@@ -801,29 +803,18 @@
namespace {
struct PropagateLinalgTransposePass
- : public PropagateLinalgTransposeBase<PropagateLinalgTransposePass> {
+ : public impl::PropagateLinalgTransposePassBase<
+ PropagateLinalgTransposePass> {
+ using impl::PropagateLinalgTransposePassBase<
+ PropagateLinalgTransposePass>::PropagateLinalgTransposePassBase;
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
}
- PropagateLinalgTransposePass(bool enableAggressivePropagation) {
+ explicit PropagateLinalgTransposePass(bool enableAggressivePropagation) {
this->enableAggressivePropagation = enableAggressivePropagation;
}
- PropagateLinalgTransposePass(const PropagateLinalgTransposePass &pass)
- : PropagateLinalgTransposePass(pass.enableAggressivePropagation) {}
void runOnOperation() override;
-
-private:
- Option<bool> testSinkingOnly{
- *this, "test-sinking-only",
- llvm::cl::desc("Flag used for lit-testing sinking patterns only. "
- "Not for general usage"),
- llvm::cl::init(false)};
- Option<bool> testBubblingOnly{
- *this, "test-bubbling-only",
- llvm::cl::desc("Flag used for lit-testing bubbling patterns only. "
- "Not for general usage"),
- llvm::cl::init(false)};
};
} // namespace
diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp
index 8c267e2..b700869 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp
@@ -9,7 +9,6 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/GlobalOptimization/Utils.h"
#include "llvm/ADT/STLExtras.h"
@@ -32,6 +31,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_RAISESPECIALOPSPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
//===----------------------------------------------------------------------===//
@@ -998,7 +1000,8 @@
// Pass Implementation
//===----------------------------------------------------------------------===//
-struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
+struct RaiseSpecialOpsPass
+ : public impl::RaiseSpecialOpsPassBase<RaiseSpecialOpsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::LinalgExt::IREELinalgExtDialect>();
}
@@ -1054,9 +1057,4 @@
};
} // namespace
-
-std::unique_ptr<Pass> createRaiseSpecialOps() {
- return std::make_unique<RaiseSpecialOpsPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/RemoveZeroExtentTensors.cpp b/compiler/src/iree/compiler/GlobalOptimization/RemoveZeroExtentTensors.cpp
index 9cb5929..b6d82c7 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/RemoveZeroExtentTensors.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/RemoveZeroExtentTensors.cpp
@@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -14,6 +13,9 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_REMOVEZEROEXTENTTENSORSPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
/// Check if a `t` is a `tensor` with zero extents.
static std::optional<RankedTensorType> isZeroExtent(Type t) {
auto operandType = dyn_cast<RankedTensorType>(t);
@@ -77,7 +79,7 @@
namespace {
struct RemoveZeroExtentTensorsPass
- : RemoveZeroExtentTensorsBase<RemoveZeroExtentTensorsPass> {
+ : impl::RemoveZeroExtentTensorsPassBase<RemoveZeroExtentTensorsPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<tensor::TensorDialect>();
}
@@ -101,9 +103,4 @@
}
}
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createRemoveZeroExtentTensorsPass() {
- return std::make_unique<RemoveZeroExtentTensorsPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/SimplifyPackUnpack.cpp b/compiler/src/iree/compiler/GlobalOptimization/SimplifyPackUnpack.cpp
index 9333f9c..3bd113f 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/SimplifyPackUnpack.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/SimplifyPackUnpack.cpp
@@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/GlobalOptimization/PassDetail.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
@@ -12,9 +11,12 @@
namespace mlir::iree_compiler::GlobalOptimization {
+#define GEN_PASS_DEF_SIMPLIFYPACKUNPACKPASS
+#include "iree/compiler/GlobalOptimization/Passes.h.inc"
+
namespace {
struct SimplifyPackUnpackPass
- : public SimplifyPackUnpackBase<SimplifyPackUnpackPass> {
+ : public impl::SimplifyPackUnpackPassBase<SimplifyPackUnpackPass> {
void runOnOperation() override;
};
@@ -30,8 +32,4 @@
}
}
-std::unique_ptr<Pass> createSimplifyPackUnpackPass() {
- return std::make_unique<SimplifyPackUnpackPass>();
-}
-
} // namespace mlir::iree_compiler::GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir
index 4037295..3ad7805 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-global-opt-fuse-dequantization-matmul{enable-quantized-matmul-reassociation=true},iree-flow-canonicalize))" %s | FileCheck %s
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-global-opt-fuse-dequantization-matmul,iree-flow-canonicalize))" %s | FileCheck %s
util.func public @grouped_quantized_matmul_reassociate(%arg0: tensor<11008x32x128xi4>, %arg1: tensor<32x128xf32>, %arg2: tensor<11008x32xf32>, %arg3: tensor<11008x32xf32>) -> tensor<11008xf32> {
%cst = arith.constant 0.000000e+00 : f32