Properly use TrackingListener in transforms that do not consume their operands (#12766)

#12681 simplified the effects of IREE transforms to avoid always
consuming the input handle and thus improve usage.
This was however lacking proper tracking now that the nested handles
were not automatically invalidated anymore.
The revision fixes this oversight.

Partially fixes #12759
diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index f11cf14..c81298b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -156,13 +156,13 @@
 }
 
 LogicalResult eliminateEmptyTensors(
-    Operation *op, const OneShotBufferizationOptions &options) {
+    RewriterBase &rewriter, Operation *op,
+    const OneShotBufferizationOptions &options) {
   // Analyze IR.
   OneShotAnalysisState state(op, options);
   if (failed(analyzeOp(op, state))) return failure();
 
   // Rewrite tensor.empty ops that are anchored on specific ops.
-  IRRewriter rewriter(op->getContext());
   if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
           rewriter, op, state)))
     return failure();
@@ -176,7 +176,9 @@
 void EliminateEmptyTensorsPass::runOnOperation() {
   ModuleOp moduleOp = getOperation();
   OneShotBufferizationOptions options = getBufferizationOptions();
-  if (failed(eliminateEmptyTensors(moduleOp, options)))
+
+  IRRewriter rewriter(moduleOp->getContext());
+  if (failed(eliminateEmptyTensors(rewriter, moduleOp, options)))
     return signalPassFailure();
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
index 4f1c2a1..1897a6a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeVectorTransferPass.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index de9c6c7..e0c347c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -92,9 +92,7 @@
       opToErase.push_back(op.getOperation());
     }
   });
-  for (Operation *op : opToErase) {
-    op->erase();
-  }
+  for (Operation *op : opToErase) op->erase();
 }
 
 //===---------------------------------------------------------------------===//
@@ -107,7 +105,6 @@
   // Apply store to load forwarding and dead store elimination.
   vector::transferOpflowOpt(target);
   eraseDeadAllocAndStores(target);
-
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -385,6 +382,7 @@
     addUnrollVectorsGpuMmaSyncPatterns(patterns);
   if (getUnrollVectorsGpuWmma()) addUnrollVectorsGpuWmmaPatterns(patterns);
 
+  Location loc = target->getLoc();
   TrackingListener listener(state);
   GreedyRewriteConfig config;
   config.listener = &listener;
@@ -397,11 +395,12 @@
   LogicalResult result =
       applyOpPatternsAndFold(ops, std::move(patterns), config);
   if (failed(result)) {
-    return mlir::emitDefiniteFailure(target, "greedy patterns failed");
+    return listener.check(
+        loc, mlir::emitDefiniteFailure(target, "greedy patterns failed"));
   }
-  LogicalResult listenerResult = listener.checkErrorState();
-  if (failed(listenerResult))
-    return mlir::emitDefiniteFailure(target, "pattern listener tracker fail");
+
+  auto diag = listener.check(loc);
+  if (!diag.succeeded()) return diag;
 
   if (getLicm()) {
     target->walk([&](func::FuncOp funcOp) {
@@ -432,8 +431,7 @@
       result =
           eliminateCommonSubexpressions(funcOp, /*domInfo=*/nullptr, &listener);
       if (failed(result)) return WalkResult::interrupt();
-      listenerResult = listener.checkErrorState();
-      if (failed(listenerResult)) return WalkResult::interrupt();
+      if (failed(listener.checkErrorState())) return WalkResult::interrupt();
       return WalkResult::advance();
     });
     if (walkResult.wasInterrupted()) {
@@ -441,14 +439,13 @@
         return mlir::emitDefiniteFailure(lastFuncVisited,
                                          "greedy patterns failed");
       }
-      LogicalResult listenerResult = listener.checkErrorState();
-      if (failed(listenerResult))
+      if (failed(listener.checkErrorState()))
         return mlir::emitDefiniteFailure(lastFuncVisited,
                                          "pattern listener tracker fail");
     }
   }
 
-  return DiagnosedSilenceableFailure::success();
+  return listener.check(loc);
 }
 
 void transform_dialect::ApplyPatternsOp::getEffects(
@@ -462,12 +459,15 @@
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure transform_dialect::HoistStaticAllocOp::applyToOne(
-    func::FuncOp funcOp, transform::ApplyToEachResultList &results,
+    func::FuncOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  IRRewriter rewriter(funcOp->getContext());
+  Location loc = target->getLoc();
+  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state);
+  rewriter.setListener(&listener);
   mlir::iree_compiler::hoistStaticallyBoundAllocationsInFunc<memref::AllocOp>(
-      rewriter, funcOp);
-  return DiagnosedSilenceableFailure::success();
+      rewriter, target);
+  return listener.check(loc);
 }
 
 void transform_dialect::HoistStaticAllocOp::getEffects(
@@ -554,7 +554,7 @@
 /// operands. Assumes the HAL::ExecutableExportOp is built with an empty
 /// region.
 static LogicalResult populateWorkgroupCountComputingRegion(
-    PatternRewriter &rewriter, scf::ForallOp forallOp,
+    RewriterBase &rewriter, scf::ForallOp forallOp,
     HAL::ExecutableExportOp exportOp) {
   Location loc = forallOp.getLoc();
   OpBuilder::InsertionGuard g(rewriter);
@@ -587,9 +587,9 @@
 // Patterns for ForallToWorkgroup rewrite.
 //===---------------------------------------------------------------------===//
 
-LogicalResult rewriteForallToWorkgroup(scf::ForallOp forallOp,
-                                       IREE::HAL::ExecutableExportOp exportOp,
-                                       PatternRewriter &rewriter) {
+LogicalResult rewriteForallToWorkgroup(RewriterBase &rewriter,
+                                       scf::ForallOp forallOp,
+                                       IREE::HAL::ExecutableExportOp exportOp) {
   // Step 0. Target-specific verifications. There is no good place to anchor
   // those right now: the ForallOp is target-independent and the
   // transform op does not apply to individual ForallOp.
@@ -754,12 +754,17 @@
         target, "could not find a unique topLevel scf.forall");
   }
 
-  SimplePatternRewriter rewriter(topLevelForallOp);
-  if (failed(rewriteForallToWorkgroup(topLevelForallOp, exportOp, rewriter))) {
-    return mlir::emitDefiniteFailure(target, "rewriteForallToWorkgroup failed");
+  Location loc = target->getLoc();
+  IRRewriter rewriter(topLevelForallOp->getContext());
+  rewriter.setInsertionPoint(topLevelForallOp);
+  TrackingListener listener(state);
+  rewriter.setListener(&listener);
+  if (failed(rewriteForallToWorkgroup(rewriter, topLevelForallOp, exportOp))) {
+    return listener.check(loc, mlir::emitDefiniteFailure(
+                                   target, "rewriteForallToWorkgroup failed"));
   }
 
-  return DiagnosedSilenceableFailure::success();
+  return listener.check(loc);
 }
 
 void transform_dialect::ForallToWorkgroupOp::getEffects(
@@ -1221,11 +1226,13 @@
     memCpyFn = gpuComprehensiveBufferizeCopyFn;
   }
 
+  Operation *target = payload.front();
+  Location loc = target->getLoc();
+  TrackingListener listener(state);
   //   1. Rewrite tensor.empty to tensor.alloc, without the pass baggage.
   {
     RewritePatternSet patterns(getContext());
     patterns.add<EmptyTensorLoweringPattern>(patterns.getContext());
-    TrackingListener listener(state);
     GreedyRewriteConfig config;
     config.listener = &listener;
     // Manually gather list of ops because the other GreedyPatternRewriteDriver
@@ -1238,12 +1245,15 @@
         applyOpPatternsAndFold(ops, std::move(patterns), config);
     LogicalResult listenerResult = listener.checkErrorState();
     if (failed(result)) {
-      return mlir::emitDefiniteFailure(state.getTopLevel(),
-                                       "greedy pattern application failed");
+      return listener.check(
+          loc, mlir::emitDefiniteFailure(state.getTopLevel(),
+                                         "greedy pattern application failed"));
     }
-    if (failed(listenerResult))
-      return mlir::emitDefiniteFailure(state.getTopLevel(),
-                                       "listener tracking failed");
+    if (failed(listenerResult)) {
+      return listener.check(
+          loc, mlir::emitDefiniteFailure(state.getTopLevel(),
+                                         "listener tracking failed"));
+    }
   }
 
   //   2. Run one-shot-bufferize, without the pass baggage.
@@ -1254,12 +1264,12 @@
   options.testAnalysisOnly = getTestAnalysisOnly();
   options.printConflicts = getPrintConflicts();
   if (failed(runIREEOneShotBufferize(state.getTopLevel(), options)))
-    return DiagnosedSilenceableFailure::definiteFailure();
+    return listener.check(loc, emitDefaultDefiniteFailure(target));
 
   // Early exit if test_analysis_only is set.
   if (getTestAnalysisOnly()) {
     results.set(getOperation()->getOpResult(0), payload.front());
-    return DiagnosedSilenceableFailure::success();
+    return listener.check(loc);
   }
 
   //   3. Post-bufferization passes are fine.
@@ -1276,10 +1286,10 @@
     return WalkResult::advance();
   });
   if (res.wasInterrupted())
-    return DiagnosedSilenceableFailure::definiteFailure();
+    return listener.check(loc, emitDefaultDefiniteFailure(target));
 
   results.set(getOperation()->getOpResult(0), payload.front());
-  return DiagnosedSilenceableFailure::success();
+  return listener.check(loc);
 }
 
 //===---------------------------------------------------------------------===//
@@ -1291,11 +1301,16 @@
     ::mlir::Operation *target,
     ::mlir::transform::ApplyToEachResultList &results,
     ::mlir::transform::TransformState &state) {
-  if (failed(eliminateEmptyTensors(target, getBufferizationOptions()))) {
+  Location loc = target->getLoc();
+  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state);
+  rewriter.setListener(&listener);
+  if (failed(
+          eliminateEmptyTensors(rewriter, target, getBufferizationOptions()))) {
     getOperation()->emitError() << "failed to eliminate tensor.empty ops";
-    return DiagnosedSilenceableFailure::definiteFailure();
+    return listener.check(loc, emitDefaultDefiniteFailure(target));
   }
-  return DiagnosedSilenceableFailure::success();
+  return listener.check(loc);
 }
 
 void transform_dialect::IREEEliminateEmptyTensorsOp::getEffects(
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.h b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
index ca99cb5..ca22522 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.h
@@ -21,7 +21,8 @@
 
 /// Eliminates tensor.empty ops to avoid buffer allocations.
 LogicalResult eliminateEmptyTensors(
-    Operation *op, const bufferization::OneShotBufferizationOptions &options);
+    RewriterBase &rewriter, Operation *op,
+    const bufferization::OneShotBufferizationOptions &options);
 
 /// Bufferizes the given op with One-Shot Bufferize.
 LogicalResult runIREEOneShotBufferize(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp
index c555ee5..560c5d6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorToGPU.cpp
@@ -86,7 +86,7 @@
         return signalPassFailure();
       }
     }
-    createAsyncGroups(funcOp, targetMmaSync);
+    createAsyncGroups(rewriter, funcOp, targetMmaSync);
 
     if (targetMmaSync) {
       swizzleSharedMemory(funcOp);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index f52c5a6a..f26fc59 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -7,6 +7,7 @@
 #include "LLVMGPUExtensions.h"
 
 #include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
 #include "iree/compiler/Codegen/Common/Transforms.h"
 #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
@@ -94,21 +95,22 @@
 
   auto transformOp = cast<transform::TransformOpInterface>(getOperation());
 
+  Location loc = target->getLoc();
   IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state);
+  rewriter.setListener(&listener);
   rewriter.setInsertionPointToStart(&target.getBody().front());
   DiagnosedSilenceableFailure diag =
       mlir::transform::gpu::mapNestedForallToThreadsImpl(
           rewriter, transformOp, target, getWorkgroupDims(), getWarpDims(),
           true);
-
   if (diag.succeeded()) {
     auto newAttr = rewriter.getIndexArrayAttr(getWorkgroupDims());
     rewriter.startRootUpdate(exportOp);
     exportOp->setAttr(exportOp.getWorkgroupSizeAttrName(), newAttr);
     rewriter.finalizeRootUpdate(exportOp);
   }
-
-  return diag;
+  return listener.check(loc, std::move(diag));
 }
 
 void transform_dialect::MapNestedForallToGpuThreadsOp::getEffects(
@@ -210,7 +212,7 @@
 };
 
 static FailureOr<VectorDistributionResult> rewriteScfIfAsWarpExecuteOnLane0(
-    PatternRewriter &rewriter, Location loc, scf::IfOp ifOp,
+    RewriterBase &rewriter, Location loc, scf::IfOp ifOp,
     int64_t workgroupSizeX, int64_t warpSize) {
   // Bail if cond is not `if (threadIdx.x == 0)`.
   FailureOr<gpu::ThreadIdOp> maybeThreadIdxxOp =
@@ -336,21 +338,27 @@
            << warpSize << " --- the transform is not applied";
   }
 
-  SimplePatternRewriter rewriter(target);
+  Location loc = target->getLoc();
+  IRRewriter rewriter(target->getContext());
+  rewriter.setInsertionPoint(target);
+  TrackingListener listener(state);
+  rewriter.setListener(&listener);
   FailureOr<VectorDistributionResult> vectorDistributionResult =
-      rewriteScfIfAsWarpExecuteOnLane0(rewriter, target->getLoc(), target,
-                                       workgroupSizeX, warpSize);
+      rewriteScfIfAsWarpExecuteOnLane0(rewriter, loc, target, workgroupSizeX,
+                                       warpSize);
   if (failed(vectorDistributionResult)) {
     // Return a silenceable failure and set the expected 1 result to
     // nullptr.
     results.assign(1, nullptr);
-    return emitDefaultSilenceableFailure(target)
-           << "scf::ifOp needs to be predicated on threadIdx.x == 0 "
-              "--- the "
-              "transform is not applied";
+    return listener.check(
+        loc, emitDefaultSilenceableFailure(target)
+                 << "scf::ifOp needs to be predicated on threadIdx.x == 0 "
+                    "--- the "
+                    "transform is not applied");
   }
+
   results.push_back(vectorDistributionResult->warpOp);
-  return DiagnosedSilenceableFailure::success();
+  return listener.check(loc);
 }
 
 //===---------------------------------------------------------------------===//
@@ -405,9 +413,8 @@
 }
 
 namespace {
-
-/// Pattern to convert InsertElement to broadcast, this is a workaround until
-/// MultiDimReduction distribution is supported.
+/// Pattern to convert InsertElement to broadcast, this is a workaround
+/// until MultiDimReduction distribution is supported.
 class InsertElementToBroadcast final
     : public OpRewritePattern<vector::InsertElementOp> {
  public:
@@ -682,18 +689,20 @@
     return emitDefaultDefiniteFailure(target);
   }
 
-  IRRewriter rewriter(getContext());
+  Location loc = target->getLoc();
+  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state);
+  rewriter.setListener(&listener);
+  auto diag = DiagnosedSilenceableFailure::success();
   if (getUseWmma()) {
-    if (failed(convertVectorToMMAOps(rewriter, target))) {
-      target->emitOpError("vector to wmma patterns failed to apply");
-      return emitDefaultDefiniteFailure(target);
-    }
-    return DiagnosedSilenceableFailure::success();
+    if (failed(convertVectorToMMAOps(rewriter, target)))
+      diag = emitDefiniteFailure("vector to wmma patterns failed to apply");
+    return listener.check(loc, std::move(diag));
   }
 
   if (failed(convertVectorToNVVMCompatibleMMASync(rewriter, funcOp))) {
     target->emitOpError("vector to mma patterns failed to apply");
-    return emitDefaultDefiniteFailure(target);
+    return listener.check(loc, emitDefaultDefiniteFailure(target));
   }
   // Using TF32 for Float.
   RewritePatternSet f32ToTF32patterns(funcOp.getContext());
@@ -702,9 +711,9 @@
   if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                           std::move(f32ToTF32patterns)))) {
     target->emitOpError("vector to mma F32ToTF32 patterns failed to apply");
-    return emitDefaultDefiniteFailure(target);
+    return listener.check(loc, emitDefaultDefiniteFailure(target));
   }
-  return DiagnosedSilenceableFailure::success();
+  return listener.check(loc, std::move(diag));
 }
 
 //===----------------------------------------------------------------------===//
@@ -714,6 +723,7 @@
 DiagnosedSilenceableFailure transform_dialect::PromoteOperandsOp::applyToOne(
     Operation *target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
+  Location loc = target->getLoc();
   IRRewriter rewriter(getContext());
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(target);
@@ -725,8 +735,7 @@
   for (int64_t index : indices) {
     if ((index >= 0) && (index < numOperands)) {
       FailureOr<Value> ret = bufferization::allocateTensorForShapedValue(
-          rewriter, target->getLoc(), target->getOperand(index), false, options,
-          true);
+          rewriter, loc, target->getOperand(index), false, options, true);
       if (failed(ret)) {
         return emitDefaultDefiniteFailure(target)
                << "failed to promote operand";
@@ -771,8 +780,13 @@
 DiagnosedSilenceableFailure transform_dialect::CreateAsyncGroupsOp::applyToOne(
     func::FuncOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  iree_compiler::createAsyncGroups(cast<func::FuncOp>(target), getUseMmaSync());
-  return DiagnosedSilenceableFailure::success();
+  Location loc = target->getLoc();
+  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state);
+  rewriter.setListener(&listener);
+  iree_compiler::createAsyncGroups(rewriter, cast<func::FuncOp>(target),
+                                   getUseMmaSync());
+  return listener.check(loc);
 }
 
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
index 1a7158f..a3c7460 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.cpp
@@ -87,7 +87,8 @@
   llvm_unreachable("unsupported op type");
 }
 
-void createAsyncGroups(func::FuncOp funcOp, bool useMMASync) {
+void createAsyncGroups(RewriterBase& rewriter, func::FuncOp funcOp,
+                       bool useMMASync) {
   LLVM_DEBUG(DBGS() << "Start asyncGroups: useMMASync=" << useMMASync << "\n");
   llvm::SmallSetVector<Operation*, 16> copyToSharedMem;
   // Look for all the copy that can be converted to async copy ops.
@@ -172,31 +173,30 @@
     }
     // emit the group.
     SmallVector<Value> tokens;
-    OpBuilder builder(funcOp.getContext());
     for (Operation* writeOp : group) {
-      builder.setInsertionPoint(writeOp);
+      rewriter.setInsertionPoint(writeOp);
       Value vectorVal = getValueStored(writeOp);
       Operation* readOp = vectorVal.getDefiningOp();
       Value storeBase = getMemrefOperand(writeOp);
       Value loadBase = getMemrefOperand(readOp);
-      Value token = builder.create<nvgpu::DeviceAsyncCopyOp>(
+      Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
           writeOp->getLoc(),
           nvgpu::DeviceAsyncTokenType::get(funcOp.getContext()), storeBase,
           getIndices(writeOp), loadBase, getIndices(readOp),
-          builder.getIndexAttr(
+          rewriter.getIndexAttr(
               vectorVal.getType().cast<VectorType>().getNumElements()),
           Value(),
-          /*bypassL1=*/useMMASync ? builder.getUnitAttr() : UnitAttr());
+          /*bypassL1=*/useMMASync ? rewriter.getUnitAttr() : UnitAttr());
       tokens.push_back(token);
     }
     // Create the group and wait for it right after.
-    Value groupToken = builder.create<nvgpu::DeviceAsyncCreateGroupOp>(
+    Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>(
         funcOp.getLoc(), nvgpu::DeviceAsyncTokenType::get(funcOp.getContext()),
         tokens);
-    builder.create<nvgpu::DeviceAsyncWaitOp>(funcOp.getLoc(), groupToken,
-                                             nullptr);
+    rewriter.create<nvgpu::DeviceAsyncWaitOp>(funcOp.getLoc(), groupToken,
+                                              nullptr);
     // Clean up old stores.
-    for (Operation* writeOp : group) writeOp->erase();
+    for (Operation* writeOp : group) rewriter.eraseOp(writeOp);
   }
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
index 1f04e32..b4244e1 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h
@@ -15,7 +15,8 @@
 
 /// Helper to convert copy to shared memory to async copy. This creates groups
 /// of consecutive copies and emit wait operation right after.
-void createAsyncGroups(func::FuncOp funcOp, bool useMMASync);
+void createAsyncGroups(RewriterBase &rewriter, func::FuncOp funcOp,
+                       bool useMMASync);
 
 /// Function to do layout analysis and distribution.
 void doLayoutAnalysisAndDistribution(IRRewriter &rewriter, func::FuncOp funcOp);
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
index bb711cd..6b8c51d 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
@@ -57,6 +57,26 @@
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
   }
 
+  DiagnosedSilenceableFailure check(Location loc) {
+    if (failed(checkErrorState()))
+      return emitDefiniteFailure(loc, "listener failed");
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  DiagnosedSilenceableFailure check(Location loc,
+                                    DiagnosedSilenceableFailure &&diag) {
+    if (failed(checkErrorState())) {
+      auto definite = emitDefiniteFailure(loc, "listener failed");
+      if (diag.isSilenceableFailure()) {
+        definite.attachNote()
+            << "was propagating silenceable error:" << diag.getMessage();
+        (void)diag.silence();
+      }
+      return definite;
+    }
+    return std::move(diag);
+  }
+
   void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
 
   void notifyOperationRemoved(Operation *op) override;