Remove ListenerGreedyPatternRewriteDriver (#12358)

The upstream GreedyPatternRewriteDriver now supports listeners.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 9983bd5..3ffb549 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -9,7 +9,6 @@
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
 #include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
 #include "iree-dialects/Transforms/TransformMatchers.h"
 #include "iree/compiler/Codegen/Common/Transforms.h"
 #include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
@@ -37,6 +36,7 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 using namespace mlir::iree_compiler;
@@ -352,8 +352,15 @@
 
   TrackingListener listener(state);
   GreedyRewriteConfig config;
-  LogicalResult result = applyPatternsAndFoldGreedily(
-      target, std::move(patterns), config, &listener);
+  config.listener = &listener;
+  // Manually gather list of ops because the other GreedyPatternRewriteDriver
+  // overloads only accepts ops that are isolated from above.
+  SmallVector<Operation *> ops;
+  target->walk([&](Operation *nestedOp) {
+    if (target != nestedOp) ops.push_back(nestedOp);
+  });
+  LogicalResult result =
+      applyOpPatternsAndFold(ops, std::move(patterns), config);
   LogicalResult listenerResult = listener.checkErrorState();
   if (failed(result)) {
     return mlir::emitDefiniteFailure(target,
@@ -1131,8 +1138,15 @@
     patterns.add<EmptyTensorLoweringPattern>(patterns.getContext());
     TrackingListener listener(state);
     GreedyRewriteConfig config;
-    LogicalResult result = applyPatternsAndFoldGreedily(
-        state.getTopLevel(), std::move(patterns), config, &listener);
+    config.listener = &listener;
+    // Manually gather list of ops because the other GreedyPatternRewriteDriver
+    // overloads only accepts ops that are isolated from above.
+    SmallVector<Operation *> ops;
+    state.getTopLevel()->walk([&](Operation *nestedOp) {
+      if (state.getTopLevel() != nestedOp) ops.push_back(nestedOp);
+    });
+    LogicalResult result =
+        applyOpPatternsAndFold(ops, std::move(patterns), config);
     LogicalResult listenerResult = listener.checkErrorState();
     if (failed(result)) {
       return mlir::emitDefiniteFailure(state.getTopLevel(),
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
index 3151aa3..bb711cd 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
@@ -7,7 +7,6 @@
 #ifndef IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
 #define IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
 
-#include "iree-dialects/Transforms/Listener.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
@@ -46,7 +45,7 @@
   return std::tuple_cat(a);
 }
 
-class TrackingListener : public RewriteListener,
+class TrackingListener : public RewriterBase::Listener,
                          public transform::TransformState::Extension {
 public:
   explicit TrackingListener(transform::TransformState &state)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Listener.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Listener.h
deleted file mode 100644
index 562c99d..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Listener.h
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H
-#define IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-
-using RewriteListener = RewriterBase::Listener;
-
-//===----------------------------------------------------------------------===//
-// ListenerList
-//===----------------------------------------------------------------------===//
-
-/// This class contains multiple listeners to which rewrite events can be sent.
-class ListenerList : public RewriteListener {
-public:
-  /// Add a listener to the list.
-  void addListener(RewriteListener *listener) { listeners.push_back(listener); }
-
-  /// Send notification of an operation being inserted to all listeners.
-  void notifyOperationInserted(Operation *op) override;
-  /// Send notification of a block being created to all listeners.
-  void notifyBlockCreated(Block *block) override;
-  /// Send notification that an operation has been replaced to all listeners.
-  void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
-  /// Send notification that an operation was modified in-place.
-  void notifyOperationModified(Operation *op) override;
-  /// Send notification that an operation is about to be deleted to all
-  /// listeners.
-  void notifyOperationRemoved(Operation *op) override;
-  /// Notify all listeners that a pattern match failed.
-  LogicalResult
-  notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback) override;
-
-private:
-  /// The list of listeners to send events to.
-  SmallVector<RewriteListener *, 1> listeners;
-};
-
-//===----------------------------------------------------------------------===//
-// PatternRewriterListener
-//===----------------------------------------------------------------------===//
-
-/// This class implements a pattern rewriter with a rewrite listener. Rewrite
-/// events are forwarded to the provided rewrite listener.
-class PatternRewriterListener : public PatternRewriter, public ListenerList {
-public:
-  PatternRewriterListener(MLIRContext *context) : PatternRewriter(context) {
-    setListener(this);
-  }
-
-  /// When an operation is about to be replaced, send out an event to all
-  /// attached listeners.
-  void replaceOp(Operation *op, ValueRange newValues) override {
-    ListenerList::notifyOperationReplaced(op, newValues);
-    PatternRewriter::replaceOp(op, newValues);
-  }
-
-  void notifyOperationModified(Operation *op) override {
-    ListenerList::notifyOperationModified(op);
-  }
-  void notifyOperationInserted(Operation *op) override {
-    ListenerList::notifyOperationInserted(op);
-  }
-  void notifyBlockCreated(Block *block) override {
-    ListenerList::notifyBlockCreated(block);
-  }
-  void notifyOperationRemoved(Operation *op) override {
-    ListenerList::notifyOperationRemoved(op);
-  }
-};
-
-} // namespace mlir
-
-#endif // IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h
index a062f2a..8b81236 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h
@@ -7,7 +7,7 @@
 #ifndef LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H
 #define LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H
 
-#include "iree-dialects/Transforms/Listener.h"
+#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 class DominanceInfo;
@@ -15,7 +15,7 @@
 
 LogicalResult eliminateCommonSubexpressions(Operation *op,
                                             DominanceInfo *domInfo,
-                                            RewriteListener *listener);
+                                            RewriterBase::Listener *listener);
 } // namespace mlir
 
 #endif // LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h
deleted file mode 100644
index a8c0912..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Transforms/Listener.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Rewrite/FrozenRewritePatternSet.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-// The following are iree-dialects extensions to MLIR.
-namespace mlir {
-struct GreedyRewriteConfig;
-
-/// Applies the specified patterns on `op` alone while also trying to fold it,
-/// by selecting the highest benefits patterns in a greedy manner. Returns
-/// success if no more patterns can be matched. `erased` is set to true if `op`
-/// was folded away or erased as a result of becoming dead. Note: This does not
-/// apply any patterns recursively to the regions of `op`. Accepts a listener
-/// so the caller can be notified of rewrite events.
-LogicalResult applyPatternsAndFoldGreedily(
-    Operation *op, const FrozenRewritePatternSet &patterns,
-    const GreedyRewriteConfig &config, RewriteListener *listener);
-
-/// Apply the given list of transformations to the regions of the
-/// isolated-from-above operation `root` greedily until convergence. Update
-/// Linalg operations in values of `trackedOperations` if they are replaced by
-/// other Linalg operations during the rewriting process. Tracked operations
-/// must be replaced with Linalg operations and must not be erased in the
-/// patterns.
-static inline LogicalResult applyPatternsTrackAndFoldGreedily(
-    Operation *root, RewriteListener &listener,
-    const FrozenRewritePatternSet &patterns,
-    GreedyRewriteConfig config = GreedyRewriteConfig()) {
-  return applyPatternsAndFoldGreedily(root, patterns, config, &listener);
-}
-
-} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
index 95e83c8..a64f64e 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
@@ -11,8 +11,6 @@
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
 #include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h"
-#include "iree-dialects/Transforms/Listener.h"
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index 59057a0..9a6d66b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -9,9 +9,7 @@
 #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
 #include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
 #include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h"
-#include "iree-dialects/Transforms/Listener.h"
 #include "iree-dialects/Transforms/ListenerCSE.h"
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
 #include "iree-dialects/Transforms/TransformMatchers.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
@@ -409,37 +407,45 @@
 }
 
 /// Find the op that defines all values in the range.
-static Operation *findSingleOpDefiningAll(ValueRange range) {
+static FailureOr<Operation *> findSingleOpDefiningAll(ValueRange range) {
   Operation *op = nullptr;
   for (Value value : range) {
-    if (auto currentSourceOp = value.getDefiningOp()) {
-      if (!op || op == currentSourceOp) {
-        op = currentSourceOp;
-        continue;
-      }
-      LLVM_DEBUG(DBGS() << "different source op when replacing one op\n");
-      return nullptr;
+    // Block arguments are just dropped.
+    auto currentSourceOp = value.getDefiningOp();
+    if (!currentSourceOp) {
+      LLVM_DEBUG(DBGS() << "replacing tracked op with bbarg\n");
+      continue;
     }
 
-    LLVM_DEBUG(
-        DBGS() << "could not find a source op when replacing another op\n");
-    return nullptr;
+    if (!op || op == currentSourceOp) {
+      op = currentSourceOp;
+      continue;
+    }
+
+    LLVM_DEBUG(DBGS() << "different source op when replacing one op\n");
+    return failure();
   }
   return op;
 }
 
 // Find a single op that defines all values in the range, optionally
 // transitively through other operations in an op-specific way.
-static Operation *findSingleDefiningOp(Operation *replacedOp,
-                                       ValueRange range) {
-  return llvm::TypeSwitch<Operation *, Operation *>(replacedOp)
-      .Case<linalg::LinalgOp>([&](linalg::LinalgOp) -> Operation * {
-        return findSingleLinalgOpDefiningAll(range);
+static FailureOr<Operation *> findSingleDefiningOp(Operation *replacedOp,
+                                                   ValueRange range) {
+  return llvm::TypeSwitch<Operation *, FailureOr<Operation *>>(replacedOp)
+      .Case<linalg::LinalgOp>([&](linalg::LinalgOp) -> FailureOr<Operation *> {
+        auto op = findSingleLinalgOpDefiningAll(range);
+        if (!op)
+          return failure();
+        return op.getOperation();
       })
-      .Case<scf::ForOp>([&](scf::ForOp) -> Operation * {
-        return findSingleForOpDefiningAll(range);
+      .Case<scf::ForOp>([&](scf::ForOp) -> FailureOr<Operation *> {
+        auto op = findSingleForOpDefiningAll(range);
+        if (!op)
+          return failure();
+        return op.getOperation();
       })
-      .Default([&](Operation *) -> Operation * {
+      .Default([&](Operation *) -> FailureOr<Operation *> {
         return findSingleOpDefiningAll(range);
       });
 }
@@ -455,15 +461,21 @@
   if (failed(getTransformState().getHandlesForPayloadOp(op, handles)))
     return;
 
-  Operation *replacement = findSingleDefiningOp(op, newValues);
-  if (!replacement) {
+  FailureOr<Operation *> replacement = findSingleDefiningOp(op, newValues);
+  if (failed(replacement)) {
     emitError(op) << "could not find replacement for tracked op";
     return;
   }
 
-  LLVM_DEBUG(DBGS() << "replacing tracked @" << op << " : " << *op << " with "
-                    << *replacement << "\n");
-  mayFail(replacePayloadOp(op, replacement));
+  if (*replacement == nullptr) {
+    // TODO: Check if the handle is dead. Otherwise, the op should not be
+    // dropped. This needs a change in the transform dialect interpreter.
+    LLVM_DEBUG(DBGS() << "removing tracked @" << op << " : " << *op << "\n");
+  } else {
+    LLVM_DEBUG(DBGS() << "replacing tracked @" << op << " : " << *op << " with "
+                      << **replacement << "\n");
+  }
+  mayFail(replacePayloadOp(op, *replacement));
 }
 
 void mlir::TrackingListener::notifyOperationRemoved(Operation *op) {
@@ -516,14 +528,15 @@
 /// Run enabling transformations (LICM and its variants, single-iteration loop
 /// removal, CSE) on the given function.
 static LogicalResult performEnablerTransformations(
-    func::FuncOp func, RewriteListener &listener,
+    func::FuncOp func, RewriterBase::Listener &listener,
     LinalgEnablingOptions options = LinalgEnablingOptions()) {
   MLIRContext *ctx = func->getContext();
   RewritePatternSet patterns(ctx);
   linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
   scf::populateSCFForLoopCanonicalizationPatterns(patterns);
-  if (failed(applyPatternsTrackAndFoldGreedily(func, listener,
-                                               std::move(patterns))))
+  GreedyRewriteConfig config;
+  config.listener = &listener;
+  if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config)))
     return failure();
 
   // This assumes LICM never removes operations so we don't need tracking.
@@ -550,7 +563,7 @@
 /// Run enabling transformations on the given `containerOp` while preserving the
 /// operation tracking information.
 static LogicalResult performEnablerTransformations(
-    Operation *containerOp, RewriteListener &listener,
+    Operation *containerOp, RewriterBase::Listener &listener,
     LinalgEnablingOptions options = LinalgEnablingOptions()) {
   auto res = containerOp->walk([&](func::FuncOp func) {
     if (failed(performEnablerTransformations(func, listener, options)))
@@ -661,7 +674,7 @@
     return DiagnosedSilenceableFailure::definiteFailure();
 
   auto checkedListenerTransform =
-      [&](function_ref<LogicalResult(Operation *, RewriteListener &)>
+      [&](function_ref<LogicalResult(Operation *, RewriterBase::Listener &)>
               transform) {
         SmallVector<Operation *> roots;
         if (Value root = getRoot())
@@ -684,7 +697,7 @@
         return success();
       };
 
-  auto performCSE = [](Operation *root, RewriteListener &listener) {
+  auto performCSE = [](Operation *root, RewriterBase::Listener &listener) {
     LogicalResult result =
         eliminateCommonSubexpressions(root, /*domInfo=*/nullptr, &listener);
     LLVM_DEBUG(
@@ -692,7 +705,7 @@
                << " CSE\n");
     return result;
   };
-  auto performEnabler = [](Operation *root, RewriteListener &listener) {
+  auto performEnabler = [](Operation *root, RewriterBase::Listener &listener) {
     LogicalResult result = performEnablerTransformations(root, listener);
     LLVM_DEBUG(
         DBGS() << (succeeded(result) ? "successfully performed" : "failed")
@@ -700,9 +713,17 @@
     return result;
   };
   auto performCanonicalization = [&patterns](Operation *root,
-                                             RewriteListener &listener) {
-    LogicalResult result =
-        applyPatternsTrackAndFoldGreedily(root, listener, patterns);
+                                             RewriterBase::Listener &listener) {
+    GreedyRewriteConfig config;
+    config.listener = &listener;
+    // Manually gather list of ops because the other GreedyPatternRewriteDriver
+    // overloads only accepts ops that are isolated from above.
+    SmallVector<Operation *> ops;
+    root->walk([&](Operation *op) {
+      if (op != root)
+        ops.push_back(op);
+    });
+    LogicalResult result = applyOpPatternsAndFold(ops, patterns, config);
     LLVM_DEBUG(
         DBGS() << (succeeded(result) ? "successfully performed" : "failed")
                << " canonicalization\n");
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
index 5b233cf..44a3159 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
@@ -1,8 +1,6 @@
 
 add_mlir_library(IREEDialectsTransforms
-  Listener.cpp
   ListenerCSE.cpp
-  ListenerGreedyPatternRewriteDriver.cpp
   TransformMatchers.cpp
 
   LINK_LIBS PRIVATE
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/Listener.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/Listener.cpp
deleted file mode 100644
index b71ef2e..0000000
--- a/llvm-external-projects/iree-dialects/lib/Transforms/Listener.cpp
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Transforms/Listener.h"
-
-namespace mlir {
-
-//===----------------------------------------------------------------------===//
-// ListenerList
-//===----------------------------------------------------------------------===//
-
-void ListenerList::notifyOperationInserted(Operation *op) {
-  for (RewriteListener *listener : listeners)
-    listener->notifyOperationInserted(op);
-}
-
-void ListenerList::notifyBlockCreated(Block *block) {
-  for (RewriteListener *listener : listeners)
-    listener->notifyBlockCreated(block);
-}
-
-void ListenerList::notifyOperationReplaced(Operation *op,
-                                           ValueRange newValues) {
-  for (RewriteListener *listener : listeners)
-    listener->notifyOperationReplaced(op, newValues);
-}
-
-void ListenerList::notifyOperationModified(Operation *op) {
-  for (RewriteListener *listener : listeners)
-    listener->notifyOperationModified(op);
-}
-
-void ListenerList::notifyOperationRemoved(Operation *op) {
-  for (RewriteListener *listener : listeners)
-    listener->notifyOperationRemoved(op);
-}
-
-LogicalResult ListenerList::notifyMatchFailure(
-    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
-  bool failed = false;
-  for (RewriteListener *listener : listeners)
-    failed |= listener->notifyMatchFailure(loc, reasonCallback).failed();
-  return failure(failed);
-}
-
-} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp
index b8fc3e1..e8e6943 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp
@@ -155,7 +155,7 @@
 
   // void runOnOperation() override;
   void doItOnOperation(Operation *rootOp, DominanceInfo *domInfo,
-                       RewriteListener *listener);
+                       RewriterBase::Listener *listener);
 
 private:
   void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
@@ -173,7 +173,7 @@
   // END copied from mlir/lib/Transforms/CSE.cpp
   //===----------------------------------------------------------------------===//
   /// An optional listener to notify of replaced or erased operations.
-  RewriteListener *listener;
+  RewriterBase::Listener *listener;
   int64_t numDCE = 0, numCSE = 0;
 
   //===----------------------------------------------------------------------===//
@@ -412,7 +412,7 @@
 
 /// Copy of CSE::runOnOperation, without the pass baggage.
 void CSE::doItOnOperation(Operation *rootOp, DominanceInfo *domInfo,
-                          RewriteListener *listener) {
+                          RewriterBase::Listener *listener) {
   /// A scoped hash table of defining operations within a region.
   ScopedMapTy knownValues;
   this->domInfo = domInfo;
@@ -431,9 +431,9 @@
 }
 
 /// Run CSE on the provided operation
-LogicalResult mlir::eliminateCommonSubexpressions(Operation *op,
-                                                  DominanceInfo *domInfo,
-                                                  RewriteListener *listener) {
+LogicalResult
+mlir::eliminateCommonSubexpressions(Operation *op, DominanceInfo *domInfo,
+                                    RewriterBase::Listener *listener) {
   assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
          "can only do CSE on isolated-from-above ops");
   Optional<DominanceInfo> defaultDomInfo;
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp
deleted file mode 100644
index a3055a7..0000000
--- a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp
+++ /dev/null
@@ -1,452 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
-
-#include "iree-dialects/Transforms/Listener.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Rewrite/PatternApplicator.h"
-#include "mlir/Transforms/FoldUtils.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/ScopedPrinter.h"
-#include "llvm/Support/raw_ostream.h"
-
-using namespace mlir;
-
-#define DEBUG_TYPE "listener-greedy-rewriter"
-
-//===----------------------------------------------------------------------===//
-// GreedyPatternRewriteDriver
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
-/// applies the locally optimal patterns in a roughly "bottom up" way.
-class GreedyPatternRewriteDriver : public RewriteListener {
-public:
-  explicit GreedyPatternRewriteDriver(Operation *rootOp,
-                                      const FrozenRewritePatternSet &patterns,
-                                      const GreedyRewriteConfig &config,
-                                      RewriteListener *listener);
-  //===--------------------------------------------------------------------===//
-  // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-  //===--------------------------------------------------------------------===//
-
-  /// Simplify the operations within the given regions.
-  bool simplify(MutableArrayRef<Region> regions);
-
-  /// Add the given operation to the worklist.
-  void addToWorklist(Operation *op);
-
-  /// Pop the next operation from the worklist.
-  Operation *popFromWorklist();
-
-  /// If the specified operation is in the worklist, remove it.
-  void removeFromWorklist(Operation *op);
-
-protected:
-  // Implement the hook for inserting operations, and make sure that newly
-  // inserted ops are added to the worklist for processing.
-  void notifyOperationInserted(Operation *op) override;
-
-  void notifyOperationModified(Operation *op) override;
-
-  // Look over the provided operands for any defining operations that should
-  // be re-added to the worklist. This function should be called when an
-  // operation is modified or removed, as it may trigger further
-  // simplifications.
-  void addOperandsToWorklist(ValueRange operands);
-
-  // If an operation is about to be removed, make sure it is not in our
-  // worklist anymore because we'd get dangling references to it.
-  void notifyOperationRemoved(Operation *op) override;
-
-  // When the root of a pattern is about to be replaced, it can trigger
-  // simplifications to its users - make sure to add them to the worklist
-  // before the root is changed.
-  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
-
-  //===--------------------------------------------------------------------===//
-  // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-  //===--------------------------------------------------------------------===//
-  // This seems unused
-  /// PatternRewriter hook for erasing a dead operation.
-  // void eraseOp(Operation *op) override;
-  // //===-----------------------------------------------------------------===//
-  // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-  //===--------------------------------------------------------------------===//
-
-  /// PatternRewriter hook for notifying match failure reasons.
-  LogicalResult
-  notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback) override;
-
-  /// The low-level pattern applicator.
-  PatternApplicator matcher;
-
-  /// The worklist for this transformation keeps track of the operations that
-  /// need to be revisited, plus their index in the worklist.  This allows us to
-  /// efficiently remove operations from the worklist when they are erased, even
-  /// if they aren't the root of a pattern.
-  std::vector<Operation *> worklist;
-  DenseMap<Operation *, unsigned> worklistMap;
-
-  /// Non-pattern based folder for operations.
-  OperationFolder folder;
-
-private:
-  /// Configuration information for how to simplify.
-  GreedyRewriteConfig config;
-
-  //===--------------------------------------------------------------------===//
-  // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-  //===--------------------------------------------------------------------===//
-  /// The pattern rewriter to use.
-  PatternRewriterListener rewriter;
-  /// The operation under which all processed ops must be nested.
-  Operation *rootOp;
-  //===--------------------------------------------------------------------===//
-  // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-  //===--------------------------------------------------------------------===//
-
-#ifndef NDEBUG
-  /// A logger used to emit information during the application process.
-  llvm::ScopedPrinter logger{llvm::dbgs()};
-#endif
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-//===----------------------------------------------------------------------===//
-GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
-    Operation *rootOp, const FrozenRewritePatternSet &patterns,
-    const GreedyRewriteConfig &config, RewriteListener *listener)
-    : matcher(patterns), folder(rootOp->getContext(), this), config(config),
-      rewriter(rootOp->getContext()), rootOp(rootOp) {
-  // Add self as a listener and the user-provided listener.
-  rewriter.addListener(this);
-  if (listener)
-    rewriter.addListener(listener);
-  //===--------------------------------------------------------------------===//
-  // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-  //===--------------------------------------------------------------------===//
-
-  worklist.reserve(64);
-
-  // Apply a simple cost model based solely on pattern benefit.
-  matcher.applyDefaultCostModel();
-}
-
-bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
-#ifndef NDEBUG
-  const char *logLineComment =
-      "//===-------------------------------------------===//\n";
-
-  /// A utility function to log a process result for the given reason.
-  auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
-    logger.unindent();
-    logger.startLine() << "} -> " << result;
-    if (!msg.isTriviallyEmpty())
-      logger.getOStream() << " : " << msg;
-    logger.getOStream() << "\n";
-  };
-  auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
-    logResult(result, msg);
-    logger.startLine() << logLineComment;
-  };
-#endif
-
-  auto insertKnownConstant = [&](Operation *op) {
-    // Check for existing constants when populating the worklist. This avoids
-    // accidentally reversing the constant order during processing.
-    Attribute constValue;
-    if (matchPattern(op, m_Constant(&constValue)))
-      if (!folder.insertKnownConstant(op, constValue))
-        return true;
-    return false;
-  };
-
-  bool changed = false;
-  unsigned iteration = 0;
-  do {
-    worklist.clear();
-    worklistMap.clear();
-
-    if (!config.useTopDownTraversal) {
-      // Add operations to the worklist in postorder.
-      for (auto &region : regions) {
-        region.walk([&](Operation *op) {
-          if (!insertKnownConstant(op))
-            addToWorklist(op);
-        });
-      }
-    } else {
-      // Add all nested operations to the worklist in preorder.
-      for (auto &region : regions) {
-        region.walk<WalkOrder::PreOrder>([&](Operation *op) {
-          if (!insertKnownConstant(op)) {
-            addToWorklist(op);
-            return WalkResult::advance();
-          }
-          return WalkResult::skip();
-        });
-      }
-
-      // Reverse the list so our pop-back loop processes them in-order.
-      std::reverse(worklist.begin(), worklist.end());
-      // Remember the reverse index.
-      for (size_t i = 0, e = worklist.size(); i != e; ++i)
-        worklistMap[worklist[i]] = i;
-    }
-
-    changed = false;
-    while (!worklist.empty()) {
-      auto *op = popFromWorklist();
-
-      // Nulls get added to the worklist when operations are removed, ignore
-      // them.
-      if (op == nullptr)
-        continue;
-
-      LLVM_DEBUG({
-        logger.getOStream() << "\n";
-        logger.startLine() << logLineComment;
-        logger.startLine() << "Processing operation : '" << op->getName()
-                           << "'(" << op << ") {\n";
-        logger.indent();
-
-        // If the operation has no regions, just print it here.
-        if (op->getNumRegions() == 0) {
-          op->print(
-              logger.startLine(),
-              OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
-          logger.getOStream() << "\n\n";
-        }
-      });
-
-      // If the operation is trivially dead - remove it.
-      if (isOpTriviallyDead(op)) {
-        notifyOperationRemoved(op);
-        op->erase();
-        changed = true;
-
-        LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
-        continue;
-      }
-
-      // Try to fold this op.
-      if (succeeded(folder.tryToFold(op))) {
-        LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-        changed = true;
-        continue;
-      }
-
-      // Try to match one of the patterns. The rewriter is automatically
-      // notified of any necessary changes, so there is nothing else to do
-      // here.
-#ifndef NDEBUG
-      auto canApply = [&](const Pattern &pattern) {
-        LLVM_DEBUG({
-          logger.getOStream() << "\n";
-          logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
-                             << op->getName() << " -> (";
-          llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
-          logger.getOStream() << ")' {\n";
-          logger.indent();
-        });
-        return true;
-      };
-      auto onFailure = [&](const Pattern &pattern) {
-        LLVM_DEBUG(logResult("failure", "pattern failed to match"));
-      };
-      auto onSuccess = [&](const Pattern &pattern) {
-        LLVM_DEBUG(logResult("success", "pattern applied successfully"));
-        return success();
-      };
-
-      LogicalResult matchResult =
-          //===------------------------------------------------------------===//
-          // BEGIN single line change from
-          // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-          // END
-          //===------------------------------------------------------------===//
-          matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
-      if (succeeded(matchResult))
-        LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
-      else
-        LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
-#else
-      //===----------------------------------------------------------------===//
-      // BEGIN single line change from
-      // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-      // END
-      //===----------------------------------------------------------------===//
-      LogicalResult matchResult = matcher.matchAndRewrite(op, rewriter);
-#endif
-      changed |= succeeded(matchResult);
-    }
-
-    // After applying patterns, make sure that the CFG of each of the regions
-    // is kept up to date.
-    if (config.enableRegionSimplification)
-      //===----------------------------------------------------------------===//
-      // BEGIN single line change from
-      // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-      // END
-      //===----------------------------------------------------------------===//
-      changed |= succeeded(simplifyRegions(rewriter, regions));
-  } while (changed && (iteration++ < config.maxIterations ||
-                       config.maxIterations == GreedyRewriteConfig::kNoLimit));
-
-  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
-  return !changed;
-}
-
-void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
-  // Check to see if the worklist already contains this op.
-  if (worklistMap.count(op))
-    return;
-  //===--------------------------------------------------------------------===//
-  // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-  //===--------------------------------------------------------------------===//
-  // Enforce nested under constraint before adding to worklist.
-  if (!rootOp->isProperAncestor(op))
-    return;
-  //===--------------------------------------------------------------------===//
-  // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-  //===--------------------------------------------------------------------===//
-
-  worklistMap[op] = worklist.size();
-  worklist.push_back(op);
-}
-
-Operation *GreedyPatternRewriteDriver::popFromWorklist() {
-  auto *op = worklist.back();
-  worklist.pop_back();
-
-  // This operation is no longer in the worklist, keep worklistMap up to date.
-  if (op)
-    worklistMap.erase(op);
-  return op;
-}
-
-void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
-  auto it = worklistMap.find(op);
-  if (it != worklistMap.end()) {
-    assert(worklist[it->second] == op && "malformed worklist data structure");
-    worklist[it->second] = nullptr;
-    worklistMap.erase(it);
-  }
-}
-
-void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
-  LLVM_DEBUG({
-    logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
-                       << ")\n";
-  });
-  addToWorklist(op);
-}
-
-void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
-  LLVM_DEBUG({
-    logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
-                       << ")\n";
-  });
-  addToWorklist(op);
-}
-
-void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
-  for (Value operand : operands) {
-    // If the use count of this operand is now < 2, we re-add the defining
-    // operation to the worklist.
-    // TODO: This is based on the fact that zero use operations
-    // may be deleted, and that single use values often have more
-    // canonicalization opportunities.
-    if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
-      continue;
-    if (auto *defOp = operand.getDefiningOp())
-      addToWorklist(defOp);
-  }
-}
-
-void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
-  addOperandsToWorklist(op->getOperands());
-  op->walk([this](Operation *operation) {
-    removeFromWorklist(operation);
-    folder.notifyRemoval(operation);
-  });
-}
-
-void GreedyPatternRewriteDriver::notifyOperationReplaced(
-    Operation *op, ValueRange replacement) {
-  LLVM_DEBUG({
-    logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
-                       << ")\n";
-  });
-  for (auto result : op->getResults())
-    for (auto *user : result.getUsers())
-      addToWorklist(user);
-}
-
-//===----------------------------------------------------------------------===//
-// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-//===----------------------------------------------------------------------===//
-// This seems unused
-// void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
-//   LLVM_DEBUG({
-//     logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
-//                        << ")\n";
-//   });
-//   PatternRewriter::eraseOp(op);
-// }
-//===----------------------------------------------------------------------===//
-// BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-//===----------------------------------------------------------------------===//
-
-LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
-    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
-  LLVM_DEBUG({
-    Diagnostic diag(loc, DiagnosticSeverity::Remark);
-    reasonCallback(diag);
-    logger.startLine() << "** Failure : " << diag.str() << "\n";
-  });
-  return failure();
-}
-
-/// Rewrite the regions of the specified operation, which must be isolated from
-/// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner. Return success if no more patterns can be matched
-/// in the result operation regions. Note: This does not apply patterns to the
-/// top-level operation itself.
-///
-//===----------------------------------------------------------------------===//
-// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-//===----------------------------------------------------------------------===//
-LogicalResult mlir::applyPatternsAndFoldGreedily(
-    Operation *op, const FrozenRewritePatternSet &patterns,
-    const GreedyRewriteConfig &config, RewriteListener *listener) {
-  if (op->getRegions().empty())
-    return success();
-
-  // Start the pattern driver.
-  GreedyPatternRewriteDriver driver(op, patterns, config, listener);
-  auto regions = op->getRegions();
-  //===--------------------------------------------------------------------===//
-  // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
-  //===--------------------------------------------------------------------===//
-  bool converged = driver.simplify(regions);
-  LLVM_DEBUG(if (!converged) {
-    llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
-                 << config.maxIterations << " times\n";
-  });
-  return success(converged);
-}
diff --git a/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp b/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp
index 04fc13a..debf3c0 100644
--- a/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp
+++ b/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp
@@ -4,9 +4,7 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree-dialects/Transforms/Listener.h"
 #include "iree-dialects/Transforms/ListenerCSE.h"
-#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -16,7 +14,7 @@
 
 /// The test listener prints stuff to `stdout` so that it can be checked by lit
 /// tests.
-struct TestListener : public RewriteListener {
+struct TestListener : public RewriterBase::Listener {
   void notifyOperationReplaced(Operation *op, ValueRange newValues) override {
     llvm::outs() << "REPLACED " << op->getName() << "\n";
   }
@@ -41,7 +39,7 @@
 
   void runOnOperation() override {
     TestListener listener;
-    RewriteListener *listenerToUse = nullptr;
+    RewriterBase::Listener *listenerToUse = nullptr;
     if (withListener)
       listenerToUse = &listener;
 
@@ -51,9 +49,16 @@
     for (RegisteredOperationName op : getContext().getRegisteredOperations())
       op.getCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
-                                            GreedyRewriteConfig(),
-                                            listenerToUse)))
+    GreedyRewriteConfig config;
+    config.listener = listenerToUse;
+    // Manually gather list of ops because the other GreedyPatternRewriteDriver
+    // overloads only accepts ops that are isolated from above.
+    SmallVector<Operation *> ops;
+    getOperation()->walk([&](Operation *nestedOp) {
+      if (this->getOperation() != nestedOp)
+        ops.push_back(nestedOp);
+    });
+    if (failed(applyOpPatternsAndFold(ops, std::move(patterns), config)))
       signalPassFailure();
   }
 
@@ -76,7 +81,7 @@
 
   void runOnOperation() override {
     TestListener listener;
-    RewriteListener *listenerToUse = nullptr;
+    RewriterBase::Listener *listenerToUse = nullptr;
     if (withListener)
       listenerToUse = &listener;