Convert generic FunctionOpInterface and CallableOpInterface passes to InterfacePass<>. (#8595)
* Convert flow passes to InterfacePass.
* Deprivilege FuncOp in StripSignednessPass
* Convert streams to InterfacePass
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
index 4473ac2..da943c5 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlow.cpp
@@ -104,9 +104,8 @@
mlir::arith::ArithmeticDialect, mlir::math::MathDialect>();
}
void runOnOperation() override {
- auto funcOp = getOperation();
- MLIRContext *context = funcOp->getContext();
- RewritePatternSet patterns(&getContext());
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
patterns
.insert<LinalgTensorReshapeToFlowTensorReshape<tensor::CollapseShapeOp>,
@@ -116,7 +115,8 @@
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
return signalPassFailure();
}
}
@@ -131,16 +131,16 @@
mlir::arith::ArithmeticDialect, mlir::math::MathDialect>();
}
void runOnOperation() override {
- auto funcOp = getOperation();
- MLIRContext *context = funcOp->getContext();
- RewritePatternSet patterns(&getContext());
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
patterns.insert<LinalgFillToFlowTensorSplat>(context);
populateTensorToFlowPatternsAfterDispatchFormation(context, patterns);
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
return signalPassFailure();
}
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 9b92be2..2476b40 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -993,8 +993,8 @@
}
void DispatchLinalgOnTensorsPass::runOnOperation() {
- auto funcOp = llvm::cast<FunctionOpInterface>(getOperation());
- MLIRContext *context = funcOp->getContext();
+ auto funcOp = getOperation();
+ MLIRContext *context = &getContext();
unsigned numRoots = decideFusableLinalgOps(funcOp);
LLVM_DEBUG({
@@ -1081,7 +1081,8 @@
});
}
-std::unique_ptr<Pass> createDispatchLinalgOnTensorsPass() {
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createDispatchLinalgOnTensorsPass() {
return std::make_unique<DispatchLinalgOnTensorsPass>();
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp b/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp
index d61d3e3..2f660cf 100644
--- a/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/InjectDispatchTracing.cpp
@@ -34,7 +34,7 @@
InjectDispatchTracingPass() = default;
void runOnOperation() override {
- auto funcOp = llvm::cast<FunctionOpInterface>(getOperation());
+ auto funcOp = getOperation();
for (auto dispatchOp : funcOp.getBody().getOps<DispatchOp>()) {
std::string entryPointName =
dispatchOp.entry_point().getRootReference().getValue().str();
@@ -60,7 +60,8 @@
}
};
-std::unique_ptr<Pass> createInjectDispatchTracingPass() {
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createInjectDispatchTracingPass() {
return std::make_unique<InjectDispatchTracingPass>();
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index d45d6d5..810b1db 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -113,7 +113,8 @@
std::unique_ptr<Pass> createOptimizeNumericsPass();
// Strips the signed/unsigned portion off of tensors.
-std::unique_ptr<OperationPass<mlir::FuncOp>> createStripSignednessPass();
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createStripSignednessPass();
// Verifies that the input to the Flow transformation pipeline is legal.
// This includes checking for operations from dialects that are expected
@@ -126,7 +127,8 @@
// Pass to perform dispatch of Linalg on tensor ops by tiling and distribution.
// A dispatch region is created for each tiled loop nest.
-std::unique_ptr<Pass> createDispatchLinalgOnTensorsPass();
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createDispatchLinalgOnTensorsPass();
// Captures dynamic shape dimensions required by dispatch operands.
std::unique_ptr<Pass> createCaptureDispatchDynamicDimsPass();
@@ -136,7 +138,8 @@
createOutlineDispatchRegionsPass();
// Injects tracing markers for dispatch operation tensor inputs and outputs.
-std::unique_ptr<Pass> createInjectDispatchTracingPass();
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createInjectDispatchTracingPass();
// Exports all functions and dispatch executables as `() -> ()` benchmark funcs.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createExportBenchmarkFuncsPass();
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 0760ed9..d88efed 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -57,7 +57,7 @@
}
def DispatchLinalgOnTensors :
- Pass<"iree-flow-dispatch-linalg-on-tensors-pass", ""> {
+ InterfacePass<"iree-flow-dispatch-linalg-on-tensors-pass", "mlir::FunctionOpInterface"> {
let summary = "Dispatch Linalg operations on tensors by using tile and distribute";
let constructor = "mlir::iree_compiler::IREE::Flow::createDispatchLinalgOnTensorsPass()";
}
@@ -87,7 +87,7 @@
}
def InjectDispatchTracing :
- Pass<"iree-flow-inject-dispatch-tracing", ""> {
+ InterfacePass<"iree-flow-inject-dispatch-tracing", "mlir::FunctionOpInterface"> {
let summary = "Injects dispatch region tracing.";
let constructor = "mlir::iree_compiler::IREE::Flow::createInjectDispatchTracingPass()";
}
@@ -140,7 +140,7 @@
}
def StripSignedness :
- Pass<"iree-flow-strip-signedness", "mlir::FuncOp"> {
+ InterfacePass<"iree-flow-strip-signedness", "mlir::FunctionOpInterface"> {
let summary = "Legalizes ui tensors constants to uis";
let constructor = "mlir::iree_compiler::IREE::Flow::createStripSignednessPass()";
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/StripSignednessPass.cpp b/iree/compiler/Dialect/Flow/Transforms/StripSignednessPass.cpp
index edeb4f3..cee8b77 100644
--- a/iree/compiler/Dialect/Flow/Transforms/StripSignednessPass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/StripSignednessPass.cpp
@@ -6,8 +6,8 @@
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -58,7 +58,7 @@
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override {
llvm::SmallVector<Type, 4> newResults;
- if (isa<FuncOp>(op)) {
+ if (isa<FunctionOpInterface>(op)) {
return failure();
}
@@ -93,11 +93,11 @@
// Operations are legal if they don't contain any illegal type.
target.markUnknownOpDynamicallyLegal([](Operation* op) {
- if (auto funcOp = dyn_cast<FuncOp>(op)) {
- for (Type type : funcOp.getType().getInputs()) {
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ for (Type type : funcOp.getArgumentTypes()) {
if (isIllegalType(type)) return false;
}
- for (Type type : funcOp.getType().getResults()) {
+ for (Type type : funcOp.getResultTypes()) {
if (isIllegalType(type)) return false;
}
}
@@ -111,20 +111,22 @@
});
auto* ctx = &getContext();
- auto func = getOperation();
RewritePatternSet patterns(&getContext());
patterns.insert<GenericTypeConvert>(ctx, converter);
- populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, converter);
+ populateFunctionOpInterfaceTypeConversionPattern(
+ getOperation()->getName().getStringRef(), patterns, converter);
- if (failed(applyFullConversion(func, target, std::move(patterns)))) {
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
}
}
} // namespace
-std::unique_ptr<OperationPass<mlir::FuncOp>> createStripSignednessPass() {
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createStripSignednessPass() {
return std::make_unique<StripSignednessPass>();
}
diff --git a/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp b/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp
index 970ee9f..70b5c0f 100644
--- a/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp
@@ -254,8 +254,8 @@
}
void runOnOperation() override {
- auto parentOp = dyn_cast<CallableOpInterface>(getOperation());
- if (!parentOp || !parentOp.getCallableRegion() ||
+ auto parentOp = getOperation();
+ if (!parentOp.getCallableRegion() ||
parentOp.getCallableRegion()->empty()) {
return;
}
@@ -318,7 +318,7 @@
} // namespace
-std::unique_ptr<OperationPass<>> createLayoutSlicesPass() {
+std::unique_ptr<InterfacePass<CallableOpInterface>> createLayoutSlicesPass() {
return std::make_unique<LayoutSlicesPass>();
}
diff --git a/iree/compiler/Dialect/Stream/Transforms/PackAllocations.cpp b/iree/compiler/Dialect/Stream/Transforms/PackAllocations.cpp
index b24bc17..e719a24 100644
--- a/iree/compiler/Dialect/Stream/Transforms/PackAllocations.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/PackAllocations.cpp
@@ -42,8 +42,8 @@
}
void runOnOperation() override {
- auto parentOp = dyn_cast<CallableOpInterface>(getOperation());
- if (!parentOp || !parentOp.getCallableRegion() ||
+ auto parentOp = getOperation();
+ if (!parentOp.getCallableRegion() ||
parentOp.getCallableRegion()->empty()) {
return;
}
@@ -109,7 +109,8 @@
} // namespace
-std::unique_ptr<OperationPass<>> createPackAllocationsPass() {
+std::unique_ptr<InterfacePass<CallableOpInterface>>
+createPackAllocationsPass() {
return std::make_unique<PackAllocationsPass>();
}
diff --git a/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp b/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
index 1f6fe3d..ef2f689 100644
--- a/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
@@ -453,7 +453,7 @@
}
void runOnOperation() override {
- auto parentOp = dyn_cast<CallableOpInterface>(getOperation());
+ auto parentOp = getOperation();
if (!parentOp || !parentOp.getCallableRegion() ||
parentOp.getCallableRegion()->empty()) {
return;
@@ -543,7 +543,7 @@
} // namespace
-std::unique_ptr<OperationPass<>> createPackConstantsPass() {
+std::unique_ptr<InterfacePass<CallableOpInterface>> createPackConstantsPass() {
return std::make_unique<PackConstantsPass>();
}
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.h b/iree/compiler/Dialect/Stream/Transforms/Passes.h
index 1d5100b..9b54449 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.h
@@ -118,8 +118,10 @@
// Stream formation and scheduling
//===----------------------------------------------------------------------===//
-std::unique_ptr<OperationPass<>> createScheduleExecutionPass();
-std::unique_ptr<OperationPass<>> createScheduleConcurrencyPass();
+std::unique_ptr<InterfacePass<CallableOpInterface>>
+createScheduleExecutionPass();
+std::unique_ptr<InterfacePass<CallableOpInterface>>
+createScheduleConcurrencyPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createPropagateTimepointsPass();
@@ -127,11 +129,12 @@
// Allocation and command issuing
//===----------------------------------------------------------------------===//
-std::unique_ptr<OperationPass<>> createScheduleAllocationPass();
+std::unique_ptr<InterfacePass<CallableOpInterface>>
+createScheduleAllocationPass();
-std::unique_ptr<OperationPass<>> createPackConstantsPass();
-std::unique_ptr<OperationPass<>> createPackAllocationsPass();
-std::unique_ptr<OperationPass<>> createLayoutSlicesPass();
+std::unique_ptr<InterfacePass<CallableOpInterface>> createPackConstantsPass();
+std::unique_ptr<InterfacePass<CallableOpInterface>> createPackAllocationsPass();
+std::unique_ptr<InterfacePass<CallableOpInterface>> createLayoutSlicesPass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createPropagateSubviewsPass();
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.td b/iree/compiler/Dialect/Stream/Transforms/Passes.td
index 794041c..d4908f4 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -90,7 +90,7 @@
//===----------------------------------------------------------------------===//
def ScheduleExecution :
- Pass<"iree-stream-schedule-execution", ""> {
+ InterfacePass<"iree-stream-schedule-execution", "mlir::CallableOpInterface"> {
let summary = "Identifies and groups asynchronous operations into executable regions within function-like regions.";
let constructor = [{
mlir::iree_compiler::IREE::Stream::createScheduleExecutionPass()
@@ -98,7 +98,7 @@
}
def ScheduleConcurrency :
- Pass<"iree-stream-schedule-concurrency", ""> {
+ InterfacePass<"iree-stream-schedule-concurrency", "mlir::CallableOpInterface"> {
let summary = "Identifies and groups asynchronous operations within executable regions that can run concurrently and groups them into streams.";
let constructor = [{
mlir::iree_compiler::IREE::Stream::createScheduleConcurrencyPass()
@@ -118,7 +118,7 @@
//===----------------------------------------------------------------------===//
def ScheduleAllocation :
- Pass<"iree-stream-schedule-allocation", ""> {
+ InterfacePass<"iree-stream-schedule-allocation", "mlir::CallableOpInterface"> {
let summary = "Allocates resources and converts to explicit stream commands.";
let constructor = [{
mlir::iree_compiler::IREE::Stream::createScheduleAllocationPass()
@@ -126,7 +126,7 @@
}
def PackConstants :
- Pass<"iree-stream-pack-constants", ""> {
+ InterfacePass<"iree-stream-pack-constants", "mlir::CallableOpInterface"> {
let summary = "Packs and allocate backing storage for fused constant resources.";
let constructor = [{
mlir::iree_compiler::IREE::Stream::createPackConstantsPass()
@@ -134,7 +134,7 @@
}
def PackAllocations :
- Pass<"iree-stream-pack-allocations", ""> {
+ InterfacePass<"iree-stream-pack-allocations", "mlir::CallableOpInterface"> {
let summary = "Packs fused allocations based on lifetime.";
let constructor = [{
mlir::iree_compiler::IREE::Stream::createPackAllocationsPass()
@@ -142,7 +142,7 @@
}
def LayoutSlices :
- Pass<"iree-stream-layout-slices", ""> {
+ InterfacePass<"iree-stream-layout-slices", "mlir::CallableOpInterface"> {
let summary = "Lays out packed slices and produces arithmetic required for all offsets.";
let constructor = [{
mlir::iree_compiler::IREE::Stream::createLayoutSlicesPass()
diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
index 448e170..48687a3 100644
--- a/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
@@ -1406,8 +1406,8 @@
}
void runOnOperation() override {
- auto parentOp = dyn_cast<CallableOpInterface>(getOperation());
- if (!parentOp || !parentOp.getCallableRegion() ||
+ auto parentOp = getOperation();
+ if (!parentOp.getCallableRegion() ||
parentOp.getCallableRegion()->empty()) {
return;
}
@@ -1433,7 +1433,8 @@
} // namespace
-std::unique_ptr<OperationPass<>> createScheduleAllocationPass() {
+std::unique_ptr<InterfacePass<CallableOpInterface>>
+createScheduleAllocationPass() {
return std::make_unique<ScheduleAllocationPass>();
}
diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp
index bf7b1f2..ba52623 100644
--- a/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleConcurrency.cpp
@@ -176,8 +176,8 @@
}
void runOnOperation() override {
- auto parentOp = dyn_cast<CallableOpInterface>(getOperation());
- if (!parentOp || !parentOp.getCallableRegion() ||
+ auto parentOp = getOperation();
+ if (!parentOp.getCallableRegion() ||
parentOp.getCallableRegion()->empty()) {
return;
}
@@ -264,7 +264,8 @@
} // namespace
-std::unique_ptr<OperationPass<>> createScheduleConcurrencyPass() {
+std::unique_ptr<InterfacePass<CallableOpInterface>>
+createScheduleConcurrencyPass() {
return std::make_unique<ScheduleConcurrencyPass>();
}
diff --git a/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
index 7727a49..75aac70 100644
--- a/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/ScheduleExecution.cpp
@@ -206,8 +206,8 @@
void runOnOperation() override {
auto *context = &getContext();
- auto parentOp = dyn_cast<CallableOpInterface>(getOperation());
- if (!parentOp || !parentOp.getCallableRegion() ||
+ auto parentOp = getOperation();
+ if (!parentOp.getCallableRegion() ||
parentOp.getCallableRegion()->empty()) {
return;
}
@@ -319,7 +319,8 @@
} // namespace
-std::unique_ptr<OperationPass<>> createScheduleExecutionPass() {
+std::unique_ptr<InterfacePass<CallableOpInterface>>
+createScheduleExecutionPass() {
return std::make_unique<ScheduleExecutionPass>();
}
diff --git a/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
index 4bf9ee4..e83286c 100644
--- a/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
+++ b/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp
@@ -223,7 +223,8 @@
namespace {
class SimplifyGlobalAccessesPass
- : public PassWrapper<SimplifyGlobalAccessesPass, OperationPass<void>> {
+ : public PassWrapper<SimplifyGlobalAccessesPass,
+ InterfacePass<CallableOpInterface>> {
public:
StringRef getArgument() const override {
return "iree-util-simplify-global-accesses";
@@ -235,8 +236,8 @@
}
void runOnOperation() override {
- auto callableOp = dyn_cast<CallableOpInterface>(getOperation());
- if (!callableOp || !callableOp.getCallableRegion() ||
+ auto callableOp = getOperation();
+ if (!callableOp.getCallableRegion() ||
callableOp.getCallableRegion()->empty()) {
return;
}