[Preprocessing] NFC: Finish migrating passes to use new tablegen (#17047)
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp
index b1a3e0c..8f30e19 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/ApplyPDLPatterns.cpp
@@ -31,7 +31,7 @@
namespace mlir::iree_compiler::Preprocessing {
-#define GEN_PASS_DEF_APPLYPDLPATTERNS
+#define GEN_PASS_DEF_APPLYPDLPATTERNSPASS
#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
} // namespace mlir::iree_compiler::Preprocessing
@@ -451,19 +451,11 @@
namespace {
class ApplyPDLPatternsPass
- : public iree_compiler::Preprocessing::impl::ApplyPDLPatternsBase<
+ : public iree_compiler::Preprocessing::impl::ApplyPDLPatternsPassBase<
ApplyPDLPatternsPass> {
-
public:
- using Base::Base;
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect, iree_compiler::IREE::Flow::FlowDialect,
- iree_compiler::IREE::Stream::StreamDialect,
- iree_compiler::IREE::Util::UtilDialect,
- memref::MemRefDialect, pdl::PDLDialect,
- pdl_interp::PDLInterpDialect, tensor::TensorDialect>();
- }
+ using iree_compiler::Preprocessing::impl::ApplyPDLPatternsPassBase<
+ ApplyPDLPatternsPass>::ApplyPDLPatternsPassBase;
LogicalResult initialize(MLIRContext *context) override {
if (patternsFile.empty()) {
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index b5a9090..160ce9d 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -36,7 +36,6 @@
"InterpreterPass.cpp",
"MakeSingleDispatchForFunction.cpp",
"PadLinalgOps.cpp",
- "PassDetail.h",
"Passes.cpp",
],
hdrs = [
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index a86d659..7e25d00 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -32,7 +32,6 @@
"InterpreterPass.cpp"
"MakeSingleDispatchForFunction.cpp"
"PadLinalgOps.cpp"
- "PassDetail.h"
"Passes.cpp"
DEPS
::PassesIncGen
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp
index 77faa23..045a9b6 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConv2DToImg2Col.cpp
@@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Preprocessing/Common/PassDetail.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -21,6 +20,9 @@
namespace mlir::iree_compiler::Preprocessing {
+#define GEN_PASS_DEF_CONVERTCONV2DTOIMG2COLPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
+
static bool hasAllOneValues(DenseIntElementsAttr attr) {
return llvm::all_of(
attr, [](APInt element) { return element.getSExtValue() == 1; });
@@ -550,8 +552,9 @@
}
};
-struct ConvertConv2DToImg2ColPass
- : ConvertConv2DToImg2ColBase<ConvertConv2DToImg2ColPass> {
+class ConvertConv2DToImg2ColPass
+ : public iree_compiler::Preprocessing::impl::ConvertConv2DToImg2ColPassBase<
+ ConvertConv2DToImg2ColPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
@@ -566,8 +569,4 @@
} // namespace
-std::unique_ptr<Pass> createConvertConv2DToImg2ColPass() {
- return std::make_unique<ConvertConv2DToImg2ColPass>();
-}
-
} // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp
index 1be1cbd..788deb9 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp
@@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Preprocessing/Common/PassDetail.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
@@ -28,6 +27,9 @@
namespace mlir::iree_compiler::Preprocessing {
+#define GEN_PASS_DEF_CONVERTCONVTOCHANNELSLASTPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
+
using ConvBuilderFn = std::function<Value(
OpBuilder &b, Location loc, linalg::LinalgOp srcConv, Value input,
Value filter, Value output, AffineMap inputMap, AffineMap filterMap,
@@ -645,23 +647,12 @@
}
};
-struct ConvertConvToChannelsLastPass
- : public ConvertConvToChannelsLastBase<ConvertConvToChannelsLastPass> {
+class ConvertConvToChannelsLastPass
+ : public iree_compiler::Preprocessing::impl::
+ ConvertConvToChannelsLastPassBase<ConvertConvToChannelsLastPass> {
public:
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect>();
- registry.insert<tensor::TensorDialect>();
- }
- LogicalResult initializeOptions(
- StringRef options,
- function_ref<LogicalResult(const Twine &)> errorHandler) override {
- if (failed(Pass::initializeOptions(options, errorHandler))) {
- return failure();
- }
- tilingFactor = tileSize;
- return success();
- }
-
+ using iree_compiler::Preprocessing::impl::ConvertConvToChannelsLastPassBase<
+ ConvertConvToChannelsLastPass>::ConvertConvToChannelsLastPassBase;
void runOnOperation() override {
auto op = getOperation();
MLIRContext *context = &getContext();
@@ -722,15 +713,8 @@
LDBG("after generalizing all remaining packs/unpacks\n" << *op);
}
-
-private:
- int64_t tilingFactor;
};
} // namespace
-std::unique_ptr<Pass> createConvertConvToChannelsLastPass() {
- return std::make_unique<ConvertConvToChannelsLastPass>();
-}
-
} // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp
index a4cc665..eaf1684 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/InterpreterPass.cpp
@@ -21,7 +21,8 @@
: public iree_compiler::Preprocessing::impl::InterpreterPassBase<
InterpreterPass> {
public:
- using Base::Base;
+ using iree_compiler::Preprocessing::impl::InterpreterPassBase<
+ InterpreterPass>::InterpreterPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
index 120c464..693b8d6 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/MakeSingleDispatchForFunction.cpp
@@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Preprocessing/Common/PassDetail.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -14,11 +13,15 @@
namespace mlir::iree_compiler::Preprocessing {
+#define GEN_PASS_DEF_MAKESINGLEDISPATCHFORFUNCTIONPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
+
namespace {
struct MakeSingleDispatchForFunctionPass
- : public MakeSingleDispatchForFunctionBase<
- MakeSingleDispatchForFunctionPass> {
+ : public iree_compiler::Preprocessing::impl::
+ MakeSingleDispatchForFunctionPassBase<
+ MakeSingleDispatchForFunctionPass> {
void runOnOperation() override;
};
} // namespace
@@ -83,9 +86,4 @@
rewriter.create<func::ReturnOp>(loc, dispatchRegionOp.getResults());
}
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createMakeSingleDispatchForFunctionPass() {
- return std::make_unique<MakeSingleDispatchForFunctionPass>();
-}
-
} // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp
index 68125a4..b1e5a9c 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PadLinalgOps.cpp
@@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree/compiler/Preprocessing/Common/PassDetail.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -16,6 +15,9 @@
namespace mlir::iree_compiler::Preprocessing {
+#define GEN_PASS_DEF_PADLINALGOPSPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
+
namespace {
/// A pattern to pad statically shaped matmul operands to the next integer
/// multiple of padSize.
@@ -153,8 +155,12 @@
int paddingSize;
};
-class PadLinalgOpsPass : public PadLinalgOpsBase<PadLinalgOpsPass> {
+class PadLinalgOpsPass
+ : public iree_compiler::Preprocessing::impl::PadLinalgOpsPassBase<
+ PadLinalgOpsPass> {
public:
+ using iree_compiler::Preprocessing::impl::PadLinalgOpsPassBase<
+ PadLinalgOpsPass>::PadLinalgOpsPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
@@ -168,8 +174,4 @@
} // namespace
-std::unique_ptr<Pass> createPadLinalgOpsToIntegerMultiplePass() {
- return std::make_unique<PadLinalgOpsPass>();
-}
-
} // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h b/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
deleted file mode 100644
index 447b597..0000000
--- a/compiler/src/iree/compiler/Preprocessing/Common/PassDetail.h
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
-#define IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
-
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Transform/IR/TransformDialect.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir::iree_compiler::Preprocessing {
-
-#define GEN_PASS_CLASSES
-#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: keep
-
-} // namespace mlir::iree_compiler::Preprocessing
-
-#endif // IREE_COMPILER_PREPROCESSING_COMMON_PASS_DETAIL_H_
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
index a78cd16..54654f5 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.h
@@ -15,20 +15,6 @@
namespace mlir::iree_compiler::Preprocessing {
-/// Creates a pass to convert linalg convolution ops into linalg.matmul ops
-/// using im2col tranformation.
-std::unique_ptr<Pass> createConvertConv2DToImg2ColPass();
-
-/// A pass to convert convolutions to channels last and propagate.
-std::unique_ptr<Pass> createConvertConvToChannelsLastPass();
-
-/// Moves the body of the entire function into a single dispatch.
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createMakeSingleDispatchForFunctionPass();
-
-/// A pass to pad linalg ops to the next integer multiple of `paddingSize`.
-std::unique_ptr<Pass> createPadLinalgOpsToIntegerMultiplePass();
-
//===----------------------------------------------------------------------===//
// Register all Passes
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index 708d4c7..889632c 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -9,34 +9,45 @@
include "mlir/Pass/PassBase.td"
-def ApplyPDLPatterns : Pass<"iree-preprocessing-apply-pdl-patterns", "ModuleOp"> {
+def ApplyPDLPatternsPass : Pass<"iree-preprocessing-apply-pdl-patterns", "ModuleOp"> {
let summary = "Parse an input file containing PDL patterns and apply them as patterns";
let options = [
Option<"patternsFile", "patterns-file", "std::string",
/*default=*/"", "File path to file containing PDL patterns to use.">,
];
-}
-
-def ConvertConv2DToImg2Col :
- Pass<"iree-preprocessing-convert-conv2d-to-img2col", ""> {
- let summary = "Convert linalg convolution ops to matmul img2col based implementation";
- let constructor = "mlir::iree_compiler::Preprocessing::createConvertConv2DToImg2ColPass()";
let dependentDialects = [
- "mlir::linalg::LinalgDialect",
+ "iree_compiler::IREE::Flow::FlowDialect",
+ "iree_compiler::IREE::Stream::StreamDialect",
+ "iree_compiler::IREE::Util::UtilDialect",
+ "mlir::arith::ArithDialect",
+ "mlir::memref::MemRefDialect, pdl::PDLDialect",
+ "mlir::pdl_interp::PDLInterpDialect",
+ "mlir::tensor::TensorDialect",
];
}
-def ConvertConvToChannelsLast :
+def ConvertConv2DToImg2ColPass :
+ Pass<"iree-preprocessing-convert-conv2d-to-img2col", ""> {
+ let summary = "Convert linalg convolution ops to matmul img2col based implementation";
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ "mlir::tensor::TensorDialect",
+ ];
+}
+
+def ConvertConvToChannelsLastPass :
Pass<"iree-preprocessing-convert-conv-to-channels-last", ""> {
let summary = "Convert linalg convolutions to channels last.";
- let constructor =
- "mlir::iree_compiler::Preprocessing::createConvertConvToChannelsLastPass()";
let options = [
- Option<"tileSize", "tile-size", "int",
+ Option<"tilingFactor", "tiling-factor", "int",
/*default=*/"0",
"Tiling factor for the channel dimension of NCHW-like convolutions. "
"Defaults to fully transposing all channel-like dimensions">,
];
+ let dependentDialects = [
+ "mlir::linalg::LinalgDialect",
+ "mlir::tensor::TensorDialect",
+ ];
}
def InterpreterPass : Pass<"iree-preprocessing-transform-interpreter"> {
@@ -59,19 +70,17 @@
];
}
-def MakeSingleDispatchForFunction :
+def MakeSingleDispatchForFunctionPass :
InterfacePass<"iree-preprocessing-make-single-dispatch-for-function", "mlir::FunctionOpInterface"> {
let summary = "Convert entire function into a single dispatch";
- let constructor = "mlir::iree_compiler::Preprocessing::createMakeSingleDispatchForFunctionPass()";
let dependentDialects = [
"IREE::Flow::FlowDialect",
];
}
-def PadLinalgOps :
+def PadLinalgOpsPass :
Pass<"iree-preprocessing-pad-linalg-ops", ""> {
let summary = "Pad linalg ops to the next integer multiple of paddingSize.";
- let constructor = "mlir::iree_compiler::Preprocessing::createPadLinalgOpsToIntegerMultiplePass()";
let options = [
Option<"paddingSize", "pad-size", "int",
/*default=*/"4",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/conv_to_channels_last.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/conv_to_channels_last.mlir
index fffb480..1bee668 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/conv_to_channels_last.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/conv_to_channels_last.mlir
@@ -1,6 +1,6 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-preprocessing-convert-conv-to-channels-last))" %s | \
// RUN: FileCheck %s
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-preprocessing-convert-conv-to-channels-last{tile-size=16}))" %s | \
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-preprocessing-convert-conv-to-channels-last{tiling-factor=16}))" %s | \
// RUN: FileCheck %s --check-prefix=TILE16
util.func @conv_nhwc_hwcf_no_transpose(%arg0: tensor<1x16x16x256xf32>, %arg1: tensor<3x3x256x128xf32>, %arg2: tensor<1x14x14x128xf32>) -> tensor<1x14x14x128xf32> {
diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.cpp b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
index f3e209d..77f6da5 100644
--- a/compiler/src/iree/compiler/Preprocessing/Passes.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Passes.cpp
@@ -85,11 +85,11 @@
}
if (!preprocessingOptions.preprocessingPDLSpecFilename.empty()) {
- Preprocessing::ApplyPDLPatternsOptions applyPDLPatternsOptions;
+ Preprocessing::ApplyPDLPatternsPassOptions applyPDLPatternsOptions;
applyPDLPatternsOptions.patternsFile =
preprocessingOptions.preprocessingPDLSpecFilename;
passManager.addPass(
- Preprocessing::createApplyPDLPatterns(applyPDLPatternsOptions));
+ Preprocessing::createApplyPDLPatternsPass(applyPDLPatternsOptions));
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
}
diff --git a/tests/compiler_driver/preprocessing_flags.mlir b/tests/compiler_driver/preprocessing_flags.mlir
index 6c4ad99..3313988 100644
--- a/tests/compiler_driver/preprocessing_flags.mlir
+++ b/tests/compiler_driver/preprocessing_flags.mlir
@@ -10,8 +10,8 @@
}
// Just check that the pass runs, and that the compilation finishes
-// CHECK: ConvertConv2DToImg2Col (iree-preprocessing-convert-conv2d-to-img2col)
-// CHECK: PadLinalgOps (iree-preprocessing-pad-linalg-ops)
+// CHECK: ConvertConv2DToImg2ColPass (iree-preprocessing-convert-conv2d-to-img2col)
+// CHECK: PadLinalgOpsPass (iree-preprocessing-pad-linalg-ops)
// CHECK-LABEL: module
// CHECK-NEXT: util.func public @test(
// CHECK-DAG: %[[ARG0:.+]] = hal.tensor.import %{{[a-zA-Z0-9]+}} "input0" : !hal.buffer_view -> tensor<10x20xf32>