Remove ListenerGreedyPatternRewriteDriver (#12358)
The upstream GreedyPatternRewriteDriver now supports listeners.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 9983bd5..3ffb549 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -9,7 +9,6 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
#include "iree-dialects/Transforms/TransformMatchers.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
@@ -37,6 +36,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::iree_compiler;
@@ -352,8 +352,15 @@
TrackingListener listener(state);
GreedyRewriteConfig config;
- LogicalResult result = applyPatternsAndFoldGreedily(
- target, std::move(patterns), config, &listener);
+ config.listener = &listener;
+ // Manually gather list of ops because the other GreedyPatternRewriteDriver
+ // overloads only accepts ops that are isolated from above.
+ SmallVector<Operation *> ops;
+ target->walk([&](Operation *nestedOp) {
+ if (target != nestedOp) ops.push_back(nestedOp);
+ });
+ LogicalResult result =
+ applyOpPatternsAndFold(ops, std::move(patterns), config);
LogicalResult listenerResult = listener.checkErrorState();
if (failed(result)) {
return mlir::emitDefiniteFailure(target,
@@ -1131,8 +1138,15 @@
patterns.add<EmptyTensorLoweringPattern>(patterns.getContext());
TrackingListener listener(state);
GreedyRewriteConfig config;
- LogicalResult result = applyPatternsAndFoldGreedily(
- state.getTopLevel(), std::move(patterns), config, &listener);
+ config.listener = &listener;
+ // Manually gather list of ops because the other GreedyPatternRewriteDriver
+ // overloads only accepts ops that are isolated from above.
+ SmallVector<Operation *> ops;
+ state.getTopLevel()->walk([&](Operation *nestedOp) {
+ if (state.getTopLevel() != nestedOp) ops.push_back(nestedOp);
+ });
+ LogicalResult result =
+ applyOpPatternsAndFold(ops, std::move(patterns), config);
LogicalResult listenerResult = listener.checkErrorState();
if (failed(result)) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
index 3151aa3..bb711cd 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
@@ -7,7 +7,6 @@
#ifndef IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
#define IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
-#include "iree-dialects/Transforms/Listener.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
@@ -46,7 +45,7 @@
return std::tuple_cat(a);
}
-class TrackingListener : public RewriteListener,
+class TrackingListener : public RewriterBase::Listener,
public transform::TransformState::Extension {
public:
explicit TrackingListener(transform::TransformState &state)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Listener.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Listener.h
deleted file mode 100644
index 562c99d..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Listener.h
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H
-#define IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-
-using RewriteListener = RewriterBase::Listener;
-
-//===----------------------------------------------------------------------===//
-// ListenerList
-//===----------------------------------------------------------------------===//
-
-/// This class contains multiple listeners to which rewrite events can be sent.
-class ListenerList : public RewriteListener {
-public:
- /// Add a listener to the list.
- void addListener(RewriteListener *listener) { listeners.push_back(listener); }
-
- /// Send notification of an operation being inserted to all listeners.
- void notifyOperationInserted(Operation *op) override;
- /// Send notification of a block being created to all listeners.
- void notifyBlockCreated(Block *block) override;
- /// Send notification that an operation has been replaced to all listeners.
- void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
- /// Send notification that an operation was modified in-place.
- void notifyOperationModified(Operation *op) override;
- /// Send notification that an operation is about to be deleted to all
- /// listeners.
- void notifyOperationRemoved(Operation *op) override;
- /// Notify all listeners that a pattern match failed.
- LogicalResult
- notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback) override;
-
-private:
- /// The list of listeners to send events to.
- SmallVector<RewriteListener *, 1> listeners;
-};
-
-//===----------------------------------------------------------------------===//
-// PatternRewriterListener
-//===----------------------------------------------------------------------===//
-
-/// This class implements a pattern rewriter with a rewrite listener. Rewrite
-/// events are forwarded to the provided rewrite listener.
-class PatternRewriterListener : public PatternRewriter, public ListenerList {
-public:
- PatternRewriterListener(MLIRContext *context) : PatternRewriter(context) {
- setListener(this);
- }
-
- /// When an operation is about to be replaced, send out an event to all
- /// attached listeners.
- void replaceOp(Operation *op, ValueRange newValues) override {
- ListenerList::notifyOperationReplaced(op, newValues);
- PatternRewriter::replaceOp(op, newValues);
- }
-
- void notifyOperationModified(Operation *op) override {
- ListenerList::notifyOperationModified(op);
- }
- void notifyOperationInserted(Operation *op) override {
- ListenerList::notifyOperationInserted(op);
- }
- void notifyBlockCreated(Block *block) override {
- ListenerList::notifyBlockCreated(block);
- }
- void notifyOperationRemoved(Operation *op) override {
- ListenerList::notifyOperationRemoved(op);
- }
-};
-
-} // namespace mlir
-
-#endif // IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h
index a062f2a..8b81236 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h
@@ -7,7 +7,7 @@
#ifndef LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H
#define LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H
-#include "iree-dialects/Transforms/Listener.h"
+#include "mlir/IR/PatternMatch.h"
namespace mlir {
class DominanceInfo;
@@ -15,7 +15,7 @@
LogicalResult eliminateCommonSubexpressions(Operation *op,
DominanceInfo *domInfo,
- RewriteListener *listener);
+ RewriterBase::Listener *listener);
} // namespace mlir
#endif // LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h
deleted file mode 100644
index a8c0912..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Transforms/Listener.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Rewrite/FrozenRewritePatternSet.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-// The following are iree-dialects extensions to MLIR.
-namespace mlir {
-struct GreedyRewriteConfig;
-
-/// Applies the specified patterns on `op` alone while also trying to fold it,
-/// by selecting the highest benefits patterns in a greedy manner. Returns
-/// success if no more patterns can be matched. `erased` is set to true if `op`
-/// was folded away or erased as a result of becoming dead. Note: This does not
-/// apply any patterns recursively to the regions of `op`. Accepts a listener
-/// so the caller can be notified of rewrite events.
-LogicalResult applyPatternsAndFoldGreedily(
- Operation *op, const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config, RewriteListener *listener);
-
-/// Apply the given list of transformations to the regions of the
-/// isolated-from-above operation `root` greedily until convergence. Update
-/// Linalg operations in values of `trackedOperations` if they are replaced by
-/// other Linalg operations during the rewriting process. Tracked operations
-/// must be replaced with Linalg operations and must not be erased in the
-/// patterns.
-static inline LogicalResult applyPatternsTrackAndFoldGreedily(
- Operation *root, RewriteListener &listener,
- const FrozenRewritePatternSet &patterns,
- GreedyRewriteConfig config = GreedyRewriteConfig()) {
- return applyPatternsAndFoldGreedily(root, patterns, config, &listener);
-}
-
-} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
index 95e83c8..a64f64e 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
@@ -11,8 +11,6 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h"
-#include "iree-dialects/Transforms/Listener.h"
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index 59057a0..9a6d66b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -9,9 +9,7 @@
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h"
-#include "iree-dialects/Transforms/Listener.h"
#include "iree-dialects/Transforms/ListenerCSE.h"
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
#include "iree-dialects/Transforms/TransformMatchers.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
@@ -409,37 +407,45 @@
}
/// Find the op that defines all values in the range.
-static Operation *findSingleOpDefiningAll(ValueRange range) {
+static FailureOr<Operation *> findSingleOpDefiningAll(ValueRange range) {
Operation *op = nullptr;
for (Value value : range) {
- if (auto currentSourceOp = value.getDefiningOp()) {
- if (!op || op == currentSourceOp) {
- op = currentSourceOp;
- continue;
- }
- LLVM_DEBUG(DBGS() << "different source op when replacing one op\n");
- return nullptr;
+ // Block arguments are just dropped.
+ auto currentSourceOp = value.getDefiningOp();
+ if (!currentSourceOp) {
+ LLVM_DEBUG(DBGS() << "replacing tracked op with bbarg\n");
+ continue;
}
- LLVM_DEBUG(
- DBGS() << "could not find a source op when replacing another op\n");
- return nullptr;
+ if (!op || op == currentSourceOp) {
+ op = currentSourceOp;
+ continue;
+ }
+
+ LLVM_DEBUG(DBGS() << "different source op when replacing one op\n");
+ return failure();
}
return op;
}
// Find a single op that defines all values in the range, optionally
// transitively through other operations in an op-specific way.
-static Operation *findSingleDefiningOp(Operation *replacedOp,
- ValueRange range) {
- return llvm::TypeSwitch<Operation *, Operation *>(replacedOp)
- .Case<linalg::LinalgOp>([&](linalg::LinalgOp) -> Operation * {
- return findSingleLinalgOpDefiningAll(range);
+static FailureOr<Operation *> findSingleDefiningOp(Operation *replacedOp,
+ ValueRange range) {
+ return llvm::TypeSwitch<Operation *, FailureOr<Operation *>>(replacedOp)
+ .Case<linalg::LinalgOp>([&](linalg::LinalgOp) -> FailureOr<Operation *> {
+ auto op = findSingleLinalgOpDefiningAll(range);
+ if (!op)
+ return failure();
+ return op.getOperation();
})
- .Case<scf::ForOp>([&](scf::ForOp) -> Operation * {
- return findSingleForOpDefiningAll(range);
+ .Case<scf::ForOp>([&](scf::ForOp) -> FailureOr<Operation *> {
+ auto op = findSingleForOpDefiningAll(range);
+ if (!op)
+ return failure();
+ return op.getOperation();
})
- .Default([&](Operation *) -> Operation * {
+ .Default([&](Operation *) -> FailureOr<Operation *> {
return findSingleOpDefiningAll(range);
});
}
@@ -455,15 +461,21 @@
if (failed(getTransformState().getHandlesForPayloadOp(op, handles)))
return;
- Operation *replacement = findSingleDefiningOp(op, newValues);
- if (!replacement) {
+ FailureOr<Operation *> replacement = findSingleDefiningOp(op, newValues);
+ if (failed(replacement)) {
emitError(op) << "could not find replacement for tracked op";
return;
}
- LLVM_DEBUG(DBGS() << "replacing tracked @" << op << " : " << *op << " with "
- << *replacement << "\n");
- mayFail(replacePayloadOp(op, replacement));
+ if (*replacement == nullptr) {
+ // TODO: Check if the handle is dead. Otherwise, the op should not be
+ // dropped. This needs a change in the transform dialect interpreter.
+ LLVM_DEBUG(DBGS() << "removing tracked @" << op << " : " << *op << "\n");
+ } else {
+ LLVM_DEBUG(DBGS() << "replacing tracked @" << op << " : " << *op << " with "
+ << **replacement << "\n");
+ }
+ mayFail(replacePayloadOp(op, *replacement));
}
void mlir::TrackingListener::notifyOperationRemoved(Operation *op) {
@@ -516,14 +528,15 @@
/// Run enabling transformations (LICM and its variants, single-iteration loop
/// removal, CSE) on the given function.
static LogicalResult performEnablerTransformations(
- func::FuncOp func, RewriteListener &listener,
+ func::FuncOp func, RewriterBase::Listener &listener,
LinalgEnablingOptions options = LinalgEnablingOptions()) {
MLIRContext *ctx = func->getContext();
RewritePatternSet patterns(ctx);
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
- if (failed(applyPatternsTrackAndFoldGreedily(func, listener,
- std::move(patterns))))
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+ if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config)))
return failure();
// This assumes LICM never removes operations so we don't need tracking.
@@ -550,7 +563,7 @@
/// Run enabling transformations on the given `containerOp` while preserving the
/// operation tracking information.
static LogicalResult performEnablerTransformations(
- Operation *containerOp, RewriteListener &listener,
+ Operation *containerOp, RewriterBase::Listener &listener,
LinalgEnablingOptions options = LinalgEnablingOptions()) {
auto res = containerOp->walk([&](func::FuncOp func) {
if (failed(performEnablerTransformations(func, listener, options)))
@@ -661,7 +674,7 @@
return DiagnosedSilenceableFailure::definiteFailure();
auto checkedListenerTransform =
- [&](function_ref<LogicalResult(Operation *, RewriteListener &)>
+ [&](function_ref<LogicalResult(Operation *, RewriterBase::Listener &)>
transform) {
SmallVector<Operation *> roots;
if (Value root = getRoot())
@@ -684,7 +697,7 @@
return success();
};
- auto performCSE = [](Operation *root, RewriteListener &listener) {
+ auto performCSE = [](Operation *root, RewriterBase::Listener &listener) {
LogicalResult result =
eliminateCommonSubexpressions(root, /*domInfo=*/nullptr, &listener);
LLVM_DEBUG(
@@ -692,7 +705,7 @@
<< " CSE\n");
return result;
};
- auto performEnabler = [](Operation *root, RewriteListener &listener) {
+ auto performEnabler = [](Operation *root, RewriterBase::Listener &listener) {
LogicalResult result = performEnablerTransformations(root, listener);
LLVM_DEBUG(
DBGS() << (succeeded(result) ? "successfully performed" : "failed")
@@ -700,9 +713,17 @@
return result;
};
auto performCanonicalization = [&patterns](Operation *root,
- RewriteListener &listener) {
- LogicalResult result =
- applyPatternsTrackAndFoldGreedily(root, listener, patterns);
+ RewriterBase::Listener &listener) {
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+ // Manually gather list of ops because the other GreedyPatternRewriteDriver
+ // overloads only accepts ops that are isolated from above.
+ SmallVector<Operation *> ops;
+ root->walk([&](Operation *op) {
+ if (op != root)
+ ops.push_back(op);
+ });
+ LogicalResult result = applyOpPatternsAndFold(ops, patterns, config);
LLVM_DEBUG(
DBGS() << (succeeded(result) ? "successfully performed" : "failed")
<< " canonicalization\n");
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
index 5b233cf..44a3159 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
@@ -1,8 +1,6 @@
add_mlir_library(IREEDialectsTransforms
- Listener.cpp
ListenerCSE.cpp
- ListenerGreedyPatternRewriteDriver.cpp
TransformMatchers.cpp
LINK_LIBS PRIVATE
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/Listener.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/Listener.cpp
deleted file mode 100644
index b71ef2e..0000000
--- a/llvm-external-projects/iree-dialects/lib/Transforms/Listener.cpp
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Transforms/Listener.h"
-
-namespace mlir {
-
-//===----------------------------------------------------------------------===//
-// ListenerList
-//===----------------------------------------------------------------------===//
-
-void ListenerList::notifyOperationInserted(Operation *op) {
- for (RewriteListener *listener : listeners)
- listener->notifyOperationInserted(op);
-}
-
-void ListenerList::notifyBlockCreated(Block *block) {
- for (RewriteListener *listener : listeners)
- listener->notifyBlockCreated(block);
-}
-
-void ListenerList::notifyOperationReplaced(Operation *op,
- ValueRange newValues) {
- for (RewriteListener *listener : listeners)
- listener->notifyOperationReplaced(op, newValues);
-}
-
-void ListenerList::notifyOperationModified(Operation *op) {
- for (RewriteListener *listener : listeners)
- listener->notifyOperationModified(op);
-}
-
-void ListenerList::notifyOperationRemoved(Operation *op) {
- for (RewriteListener *listener : listeners)
- listener->notifyOperationRemoved(op);
-}
-
-LogicalResult ListenerList::notifyMatchFailure(
- Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
- bool failed = false;
- for (RewriteListener *listener : listeners)
- failed |= listener->notifyMatchFailure(loc, reasonCallback).failed();
- return failure(failed);
-}
-
-} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp
index b8fc3e1..e8e6943 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp
@@ -155,7 +155,7 @@
// void runOnOperation() override;
void doItOnOperation(Operation *rootOp, DominanceInfo *domInfo,
- RewriteListener *listener);
+ RewriterBase::Listener *listener);
private:
void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
@@ -173,7 +173,7 @@
// END copied from mlir/lib/Transforms/CSE.cpp
//===----------------------------------------------------------------------===//
/// An optional listener to notify of replaced or erased operations.
- RewriteListener *listener;
+ RewriterBase::Listener *listener;
int64_t numDCE = 0, numCSE = 0;
//===----------------------------------------------------------------------===//
@@ -412,7 +412,7 @@
/// Copy of CSE::runOnOperation, without the pass baggage.
void CSE::doItOnOperation(Operation *rootOp, DominanceInfo *domInfo,
- RewriteListener *listener) {
+ RewriterBase::Listener *listener) {
/// A scoped hash table of defining operations within a region.
ScopedMapTy knownValues;
this->domInfo = domInfo;
@@ -431,9 +431,9 @@
}
/// Run CSE on the provided operation
-LogicalResult mlir::eliminateCommonSubexpressions(Operation *op,
- DominanceInfo *domInfo,
- RewriteListener *listener) {
+LogicalResult
+mlir::eliminateCommonSubexpressions(Operation *op, DominanceInfo *domInfo,
+ RewriterBase::Listener *listener) {
assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"can only do CSE on isolated-from-above ops");
Optional<DominanceInfo> defaultDomInfo;
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp
deleted file mode 100644
index a3055a7..0000000
--- a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp
+++ /dev/null
@@ -1,452 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
-
-#include "iree-dialects/Transforms/Listener.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Rewrite/PatternApplicator.h"
-#include "mlir/Transforms/FoldUtils.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ScopedPrinter.h"
-#include "llvm/Support/raw_ostream.h"
-
-using namespace mlir;
-
-#define DEBUG_TYPE "listener-greedy-rewriter"
-
-//===----------------------------------------------------------------------===//
-// GreedyPatternRewriteDriver
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
-/// applies the locally optimal patterns in a roughly "bottom up" way.
-class GreedyPatternRewriteDriver : public RewriteListener {
-public:
- explicit GreedyPatternRewriteDriver(Operation *rootOp,
- const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config,
- RewriteListener *listener);
- //===--------------------------------------------------------------------===//
- // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- //===--------------------------------------------------------------------===//
-
- /// Simplify the operations within the given regions.
- bool simplify(MutableArrayRef<Region> regions);
-
- /// Add the given operation to the worklist.
- void addToWorklist(Operation *op);
-
- /// Pop the next operation from the worklist.
- Operation *popFromWorklist();
-
- /// If the specified operation is in the worklist, remove it.
- void removeFromWorklist(Operation *op);
-
-protected:
- // Implement the hook for inserting operations, and make sure that newly
- // inserted ops are added to the worklist for processing.
- void notifyOperationInserted(Operation *op) override;
-
- void notifyOperationModified(Operation *op) override;
-
- // Look over the provided operands for any defining operations that should
- // be re-added to the worklist. This function should be called when an
- // operation is modified or removed, as it may trigger further
- // simplifications.
- void addOperandsToWorklist(ValueRange operands);
-
- // If an operation is about to be removed, make sure it is not in our
- // worklist anymore because we'd get dangling references to it.
- void notifyOperationRemoved(Operation *op) override;
-
- // When the root of a pattern is about to be replaced, it can trigger
- // simplifications to its users - make sure to add them to the worklist
- // before the root is changed.
- void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
-
- //===--------------------------------------------------------------------===//
- // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- //===--------------------------------------------------------------------===//
- // This seems unused
- /// PatternRewriter hook for erasing a dead operation.
- // void eraseOp(Operation *op) override;
- // //===-----------------------------------------------------------------===//
- // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- //===--------------------------------------------------------------------===//
-
- /// PatternRewriter hook for notifying match failure reasons.
- LogicalResult
- notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback) override;
-
- /// The low-level pattern applicator.
- PatternApplicator matcher;
-
- /// The worklist for this transformation keeps track of the operations that
- /// need to be revisited, plus their index in the worklist. This allows us to
- /// efficiently remove operations from the worklist when they are erased, even
- /// if they aren't the root of a pattern.
- std::vector<Operation *> worklist;
- DenseMap<Operation *, unsigned> worklistMap;
-
- /// Non-pattern based folder for operations.
- OperationFolder folder;
-
-private:
- /// Configuration information for how to simplify.
- GreedyRewriteConfig config;
-
- //===--------------------------------------------------------------------===//
- // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- //===--------------------------------------------------------------------===//
- /// The pattern rewriter to use.
- PatternRewriterListener rewriter;
- /// The operation under which all processed ops must be nested.
- Operation *rootOp;
- //===--------------------------------------------------------------------===//
- // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- //===--------------------------------------------------------------------===//
-
-#ifndef NDEBUG
- /// A logger used to emit information during the application process.
- llvm::ScopedPrinter logger{llvm::dbgs()};
-#endif
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-//===----------------------------------------------------------------------===//
-GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
- Operation *rootOp, const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config, RewriteListener *listener)
- : matcher(patterns), folder(rootOp->getContext(), this), config(config),
- rewriter(rootOp->getContext()), rootOp(rootOp) {
- // Add self as a listener and the user-provided listener.
- rewriter.addListener(this);
- if (listener)
- rewriter.addListener(listener);
- //===--------------------------------------------------------------------===//
- // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- //===--------------------------------------------------------------------===//
-
- worklist.reserve(64);
-
- // Apply a simple cost model based solely on pattern benefit.
- matcher.applyDefaultCostModel();
-}
-
-bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
-#ifndef NDEBUG
- const char *logLineComment =
- "//===-------------------------------------------===//\n";
-
- /// A utility function to log a process result for the given reason.
- auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
- logger.unindent();
- logger.startLine() << "} -> " << result;
- if (!msg.isTriviallyEmpty())
- logger.getOStream() << " : " << msg;
- logger.getOStream() << "\n";
- };
- auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
- logResult(result, msg);
- logger.startLine() << logLineComment;
- };
-#endif
-
- auto insertKnownConstant = [&](Operation *op) {
- // Check for existing constants when populating the worklist. This avoids
- // accidentally reversing the constant order during processing.
- Attribute constValue;
- if (matchPattern(op, m_Constant(&constValue)))
- if (!folder.insertKnownConstant(op, constValue))
- return true;
- return false;
- };
-
- bool changed = false;
- unsigned iteration = 0;
- do {
- worklist.clear();
- worklistMap.clear();
-
- if (!config.useTopDownTraversal) {
- // Add operations to the worklist in postorder.
- for (auto ®ion : regions) {
- region.walk([&](Operation *op) {
- if (!insertKnownConstant(op))
- addToWorklist(op);
- });
- }
- } else {
- // Add all nested operations to the worklist in preorder.
- for (auto ®ion : regions) {
- region.walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (!insertKnownConstant(op)) {
- addToWorklist(op);
- return WalkResult::advance();
- }
- return WalkResult::skip();
- });
- }
-
- // Reverse the list so our pop-back loop processes them in-order.
- std::reverse(worklist.begin(), worklist.end());
- // Remember the reverse index.
- for (size_t i = 0, e = worklist.size(); i != e; ++i)
- worklistMap[worklist[i]] = i;
- }
-
- changed = false;
- while (!worklist.empty()) {
- auto *op = popFromWorklist();
-
- // Nulls get added to the worklist when operations are removed, ignore
- // them.
- if (op == nullptr)
- continue;
-
- LLVM_DEBUG({
- logger.getOStream() << "\n";
- logger.startLine() << logLineComment;
- logger.startLine() << "Processing operation : '" << op->getName()
- << "'(" << op << ") {\n";
- logger.indent();
-
- // If the operation has no regions, just print it here.
- if (op->getNumRegions() == 0) {
- op->print(
- logger.startLine(),
- OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
- logger.getOStream() << "\n\n";
- }
- });
-
- // If the operation is trivially dead - remove it.
- if (isOpTriviallyDead(op)) {
- notifyOperationRemoved(op);
- op->erase();
- changed = true;
-
- LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
- continue;
- }
-
- // Try to fold this op.
- if (succeeded(folder.tryToFold(op))) {
- LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
- changed = true;
- continue;
- }
-
- // Try to match one of the patterns. The rewriter is automatically
- // notified of any necessary changes, so there is nothing else to do
- // here.
-#ifndef NDEBUG
- auto canApply = [&](const Pattern &pattern) {
- LLVM_DEBUG({
- logger.getOStream() << "\n";
- logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
- << op->getName() << " -> (";
- llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
- logger.getOStream() << ")' {\n";
- logger.indent();
- });
- return true;
- };
- auto onFailure = [&](const Pattern &pattern) {
- LLVM_DEBUG(logResult("failure", "pattern failed to match"));
- };
- auto onSuccess = [&](const Pattern &pattern) {
- LLVM_DEBUG(logResult("success", "pattern applied successfully"));
- return success();
- };
-
- LogicalResult matchResult =
- //===------------------------------------------------------------===//
- // BEGIN single line change from
- // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- // END
- //===------------------------------------------------------------===//
- matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
- if (succeeded(matchResult))
- LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
- else
- LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
-#else
- //===----------------------------------------------------------------===//
- // BEGIN single line change from
- // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- // END
- //===----------------------------------------------------------------===//
- LogicalResult matchResult = matcher.matchAndRewrite(op, rewriter);
-#endif
- changed |= succeeded(matchResult);
- }
-
- // After applying patterns, make sure that the CFG of each of the regions
- // is kept up to date.
- if (config.enableRegionSimplification)
- //===----------------------------------------------------------------===//
- // BEGIN single line change from
- // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- // END
- //===----------------------------------------------------------------===//
- changed |= succeeded(simplifyRegions(rewriter, regions));
- } while (changed && (iteration++ < config.maxIterations ||
- config.maxIterations == GreedyRewriteConfig::kNoLimit));
-
- // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
- return !changed;
-}
-
-void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
- // Check to see if the worklist already contains this op.
- if (worklistMap.count(op))
- return;
- //===--------------------------------------------------------------------===//
- // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- //===--------------------------------------------------------------------===//
- // Enforce nested under constraint before adding to worklist.
- if (!rootOp->isProperAncestor(op))
- return;
- //===--------------------------------------------------------------------===//
- // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- //===--------------------------------------------------------------------===//
-
- worklistMap[op] = worklist.size();
- worklist.push_back(op);
-}
-
-Operation *GreedyPatternRewriteDriver::popFromWorklist() {
- auto *op = worklist.back();
- worklist.pop_back();
-
- // This operation is no longer in the worklist, keep worklistMap up to date.
- if (op)
- worklistMap.erase(op);
- return op;
-}
-
-void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
- auto it = worklistMap.find(op);
- if (it != worklistMap.end()) {
- assert(worklist[it->second] == op && "malformed worklist data structure");
- worklist[it->second] = nullptr;
- worklistMap.erase(it);
- }
-}
-
-void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
- LLVM_DEBUG({
- logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
- << ")\n";
- });
- addToWorklist(op);
-}
-
-void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
- LLVM_DEBUG({
- logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
- << ")\n";
- });
- addToWorklist(op);
-}
-
-void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
- for (Value operand : operands) {
- // If the use count of this operand is now < 2, we re-add the defining
- // operation to the worklist.
- // TODO: This is based on the fact that zero use operations
- // may be deleted, and that single use values often have more
- // canonicalization opportunities.
- if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
- continue;
- if (auto *defOp = operand.getDefiningOp())
- addToWorklist(defOp);
- }
-}
-
-void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
- addOperandsToWorklist(op->getOperands());
- op->walk([this](Operation *operation) {
- removeFromWorklist(operation);
- folder.notifyRemoval(operation);
- });
-}
-
-void GreedyPatternRewriteDriver::notifyOperationReplaced(
- Operation *op, ValueRange replacement) {
- LLVM_DEBUG({
- logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
- << ")\n";
- });
- for (auto result : op->getResults())
- for (auto *user : result.getUsers())
- addToWorklist(user);
-}
-
-//===----------------------------------------------------------------------===//
-// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-//===----------------------------------------------------------------------===//
-// This seems unused
-// void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
-// LLVM_DEBUG({
-// logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
-// << ")\n";
-// });
-// PatternRewriter::eraseOp(op);
-// }
-//===----------------------------------------------------------------------===//
-// BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
- Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
- LLVM_DEBUG({
- Diagnostic diag(loc, DiagnosticSeverity::Remark);
- reasonCallback(diag);
- logger.startLine() << "** Failure : " << diag.str() << "\n";
- });
- return failure();
-}
-
-/// Rewrite the regions of the specified operation, which must be isolated from
-/// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner. Return success if no more patterns can be matched
-/// in the result operation regions. Note: This does not apply patterns to the
-/// top-level operation itself.
-///
-//===----------------------------------------------------------------------===//
-// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-//===----------------------------------------------------------------------===//
-LogicalResult mlir::applyPatternsAndFoldGreedily(
- Operation *op, const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config, RewriteListener *listener) {
- if (op->getRegions().empty())
- return success();
-
- // Start the pattern driver.
- GreedyPatternRewriteDriver driver(op, patterns, config, listener);
- auto regions = op->getRegions();
- //===--------------------------------------------------------------------===//
- // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
- //===--------------------------------------------------------------------===//
- bool converged = driver.simplify(regions);
- LLVM_DEBUG(if (!converged) {
- llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
- << config.maxIterations << " times\n";
- });
- return success(converged);
-}
diff --git a/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp b/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp
index 04fc13a..debf3c0 100644
--- a/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp
+++ b/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp
@@ -4,9 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Transforms/Listener.h"
#include "iree-dialects/Transforms/ListenerCSE.h"
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -16,7 +14,7 @@
/// The test listener prints stuff to `stdout` so that it can be checked by lit
/// tests.
-struct TestListener : public RewriteListener {
+struct TestListener : public RewriterBase::Listener {
void notifyOperationReplaced(Operation *op, ValueRange newValues) override {
llvm::outs() << "REPLACED " << op->getName() << "\n";
}
@@ -41,7 +39,7 @@
void runOnOperation() override {
TestListener listener;
- RewriteListener *listenerToUse = nullptr;
+ RewriterBase::Listener *listenerToUse = nullptr;
if (withListener)
listenerToUse = &listener;
@@ -51,9 +49,16 @@
for (RegisteredOperationName op : getContext().getRegisteredOperations())
op.getCanonicalizationPatterns(patterns, &getContext());
- if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
- GreedyRewriteConfig(),
- listenerToUse)))
+ GreedyRewriteConfig config;
+ config.listener = listenerToUse;
+ // Manually gather list of ops because the other GreedyPatternRewriteDriver
+ // overloads only accepts ops that are isolated from above.
+ SmallVector<Operation *> ops;
+ getOperation()->walk([&](Operation *nestedOp) {
+ if (this->getOperation() != nestedOp)
+ ops.push_back(nestedOp);
+ });
+ if (failed(applyOpPatternsAndFold(ops, std::move(patterns), config)))
signalPassFailure();
}
@@ -76,7 +81,7 @@
void runOnOperation() override {
TestListener listener;
- RewriteListener *listenerToUse = nullptr;
+ RewriterBase::Listener *listenerToUse = nullptr;
if (withListener)
listenerToUse = &listener;