Remove IREE apply_patterns op (#14054)

Switch entirely to the upstream `transform.apply_patterns` op. Add two
new IREE ops for CSE and LICM (to be upstreamed when listener support is
available for those transforms).
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 9d0468f..29f1dd1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -130,7 +130,42 @@
 }
 
 //===---------------------------------------------------------------------===//
-// Apply...PatternsOp
+// ApplyIreeLinalgElementwiseGreedyFusionPatternsOp
+//===---------------------------------------------------------------------===//
+
+static void addOperands(Operation *op, SetVector<Value> &operandSet) {
+  if (!op) return;
+  TypeSwitch<Operation *, void>(op)
+      .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
+        SmallVector<Value> inputOperands{linalgOp.getDpsInputOperands()};
+        operandSet.insert(inputOperands.begin(), inputOperands.end());
+      })
+      .Default([&](Operation *operation) {
+        operandSet.insert(operation->operand_begin(), operation->operand_end());
+      });
+}
+
+template <int limit = 3>
+static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
+  Operation *producer = fusedOperand->get().getDefiningOp();
+  if (!producer) return false;
+  Operation *consumer = fusedOperand->getOwner();
+  SetVector<Value> fusedOpOperands;
+  if (producer->getNumResults() != 1) return false;
+  addOperands(consumer, fusedOpOperands);
+  fusedOpOperands.remove(producer->getResult(0));
+  addOperands(producer, fusedOpOperands);
+  return fusedOpOperands.size() <= limit;
+}
+
+void transform_dialect::ApplyIreeLinalgElementwiseGreedyFusionPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  linalg::populateElementwiseOpsFusionPatterns(patterns,
+                                               setFusedOpOperandLimit<3>);
+}
+
+//===---------------------------------------------------------------------===//
+// ApplyFoldFillIntoPadPatternsOp
 //===---------------------------------------------------------------------===//
 
 namespace {
@@ -184,120 +219,17 @@
   patterns.insert<FoldFillIntoPad>(patterns.getContext());
 }
 
-void transform_dialect::ApplyFoldReshapeIntoTensorHalInterfacePatternsOp::
-    populatePatterns(RewritePatternSet &patterns) {
-  populateReshapeToInterfaceTensorPatterns(patterns);
-}
-
 //===---------------------------------------------------------------------===//
-// ApplyPatternsOp
+// ApplyUnrollVectorsGpuMmaSyncPatternsOp
 //===---------------------------------------------------------------------===//
 
-void transform_dialect::ApplyPatternsOp::build(
-    OpBuilder &builder, OperationState &result, Value target,
-    const ApplyPatternsOpPatterns &patterns) {
-  result.addOperands(target);
-
-  auto unitAttr = builder.getUnitAttr();
-
-#define ADD_PATTERN(NAME, ATTR) \
-  if (patterns.NAME)            \
-    result.addAttribute(ApplyPatternsOp::ATTR(result.name), unitAttr);
-  ///
-  /// When touching something here, do not forget to update CommonExtensions.h.
-  ///
-  ADD_PATTERN(bubbleCollapse, getBubbleCollapseAttrName)
-  ADD_PATTERN(bubbleExpand, getBubbleExpandAttrName)
-  ADD_PATTERN(bubblePackUnPack, getBubblePackUnPackAttrName)
-  ADD_PATTERN(canonicalization, getCanonicalizationAttrName)
-  ADD_PATTERN(cse, getCseAttrName)
-  ADD_PATTERN(eraseUnnecessaryTensorOperands,
-              getEraseUnnecessaryTensorOperandsAttrName)
-  ADD_PATTERN(expandMemrefStridedMetadata,
-              getExpandMemrefStridedMetadataAttrName)
-  ADD_PATTERN(extractAddressComputations, getExtractAddressComputationsAttrName)
-  ADD_PATTERN(foldReassociativeReshapes, getFoldReassociativeReshapesAttrName)
-  ADD_PATTERN(licm, getLicmAttrName)
-  ADD_PATTERN(linalgElementwiseGreedyFusion,
-              getLinalgElementwiseGreedyFusionAttrName)
-  ADD_PATTERN(lowerTransferOpPermutations,
-              getLowerTransferOpPermutationsAttrName)
-  ADD_PATTERN(lowerVectorMasks, getLowerVectorMasksAttrName)
-  ADD_PATTERN(prepareVectorToMma, getPrepareVectorToMmaAttrName)
-  ADD_PATTERN(swapPaddingElideConditional,
-              getSwapPaddingElideConditionalAttrName)
-  ADD_PATTERN(swappingPatterns, getSwappingPatternsAttrName)
-  ADD_PATTERN(unrollVectorsGpuMmaSync, getUnrollVectorsGpuMmaSyncAttrName)
-  ADD_PATTERN(unrollVectorsGpuWmma, getUnrollVectorsGpuWmmaAttrName)
-#undef ADD_PATTERN
-}
-
-static void addOperands(Operation *op, SetVector<Value> &operandSet) {
-  if (!op) return;
-  TypeSwitch<Operation *, void>(op)
-      .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
-        SmallVector<Value> inputOperands{linalgOp.getDpsInputOperands()};
-        operandSet.insert(inputOperands.begin(), inputOperands.end());
-      })
-      .Default([&](Operation *operation) {
-        operandSet.insert(operation->operand_begin(), operation->operand_end());
-      });
-}
-
-template <int limit = 3>
-static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
-  Operation *producer = fusedOperand->get().getDefiningOp();
-  if (!producer) return false;
-  Operation *consumer = fusedOperand->getOwner();
-  SetVector<Value> fusedOpOperands;
-  if (producer->getNumResults() != 1) return false;
-  addOperands(consumer, fusedOpOperands);
-  fusedOpOperands.remove(producer->getResult(0));
-  addOperands(producer, fusedOpOperands);
-  return fusedOpOperands.size() <= limit;
-}
-
-static void addLowerTransferOpPermutationsPatterns(
-    RewritePatternSet &patterns) {
-  vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
-}
-
-static void addLowerVectorMasksPatterns(RewritePatternSet &patterns) {
-  vector::populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
-}
-
-static void addExtractAddressComputationsPatterns(RewritePatternSet &patterns) {
-  memref::populateExtractAddressComputationsPatterns(patterns);
-}
-
-static void addReassociativeReshapePatterns(RewritePatternSet &patterns) {
-  tensor::populateReassociativeReshapeFoldingPatterns(patterns);
-}
-
-static void addEraseUnnecessaryTensorOperandsPatterns(
-    RewritePatternSet &patterns) {
-  linalg::populateEraseUnnecessaryInputsPatterns(patterns);
-}
-
-static void addPrepareVectorToMmaPatterns(RewritePatternSet &patterns) {
-  populatePrepareVectorToMMAPatterns(patterns, /*useNvGpu=*/true);
-}
-
-static void addSwappingPatterns(RewritePatternSet &patterns,
-                                bool swapPaddingElideCornerCase) {
-  patterns.add<linalg::ExtractSliceOfPadTensorSwapPattern>(
-      patterns.getContext(),
-      [&](tensor::ExtractSliceOp) -> std::optional<bool> {
-        return !swapPaddingElideCornerCase;
-      });
-}
-
 static std::optional<SmallVector<int64_t>>
 getGPUTensorCoreNativeMmaSyncVectorSize(Operation *op) {
   return getMmaNativeVectorSize(op);
 }
 
-static void addUnrollVectorsGpuMmaSyncPatterns(RewritePatternSet &patterns) {
+void transform_dialect::ApplyUnrollVectorsGpuMmaSyncPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
   auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
     auto contract = dyn_cast<vector::ContractionOp>(op);
     if (!contract) return std::nullopt;
@@ -309,12 +241,17 @@
                     .setUnrollTraversalOrderFn(unrollOrder));
 }
 
+//===---------------------------------------------------------------------===//
+// ApplyUnrollVectorsGpuWmmaSyncPatternsOp
+//===---------------------------------------------------------------------===//
+
 static std::optional<SmallVector<int64_t>> getGPUTensorCoreNativeWmmaVectorSize(
     Operation *op) {
   return getWmmaNativeVectorSize(op);
 }
 
-static void addUnrollVectorsGpuWmmaPatterns(RewritePatternSet &patterns) {
+void transform_dialect::ApplyUnrollVectorsGpuWmmaSyncPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
   auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
     auto contract = dyn_cast<vector::ContractionOp>(op);
     if (!contract) return std::nullopt;
@@ -326,120 +263,106 @@
                     .setUnrollTraversalOrderFn(unrollOrder));
 }
 
-static void addAllRegisteredCanonicalizationPatterns(
+//===---------------------------------------------------------------------===//
+// Remaining Apply...PatternsOp
+//===---------------------------------------------------------------------===//
+
+void transform_dialect::ApplyBubbleCollapsePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  MLIRContext *ctx = patterns.getContext();
-  for (Dialect *dialect : ctx->getLoadedDialects())
-    dialect->getCanonicalizationPatterns(patterns);
-  for (RegisteredOperationName op : ctx->getRegisteredOperations())
-    op.getCanonicalizationPatterns(patterns, ctx);
+  linalg::populateFoldReshapeOpsByCollapsingPatterns(
+      patterns, [](OpOperand *) { return true; });
 }
 
-DiagnosedSilenceableFailure transform_dialect::ApplyPatternsOp::applyToOne(
+void transform_dialect::ApplyBubbleExpandPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateFoldReshapeOpsByExpansionPatterns(
+      patterns, [](OpOperand *) { return true; });
+}
+
+void transform_dialect::ApplyBubblePackUnpackPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateDataLayoutPropagationPatterns(
+      patterns, [](Operation *op) { return true; });
+}
+
+void transform_dialect::ApplyFoldReshapeIntoTensorHalInterfacePatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  populateReshapeToInterfaceTensorPatterns(patterns);
+}
+
+void transform_dialect::ApplyPrepareVectorToMMAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  populatePrepareVectorToMMAPatterns(patterns, getUseNvGpu());
+}
+
+//===---------------------------------------------------------------------===//
+// ApplyCommonSubexpressionEliminationOp
+//===---------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform_dialect::ApplyCommonSubexpressionEliminationOp::applyToOne(
     Operation *target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
-    return mlir::emitDefiniteFailure(
-        target,
-        "applies only to isolated-from-above targets because it needs to apply "
-        "patterns greedily");
-  }
-  MLIRContext *ctx = target->getContext();
-  RewritePatternSet patterns(ctx);
-  if (getBubbleCollapse()) {
-    linalg::populateFoldReshapeOpsByCollapsingPatterns(
-        patterns, [](OpOperand *) { return true; });
-  }
-  if (getBubbleExpand()) {
-    linalg::populateFoldReshapeOpsByExpansionPatterns(
-        patterns, [](OpOperand *) { return true; });
-  }
-  if (getBubblePackUnPack())
-    linalg::populateDataLayoutPropagationPatterns(
-        patterns, [](Operation *op) { return true; });
-  if (getCanonicalization()) addAllRegisteredCanonicalizationPatterns(patterns);
-  if (getEraseUnnecessaryTensorOperands())
-    addEraseUnnecessaryTensorOperandsPatterns(patterns);
-  if (getExpandMemrefStridedMetadata())
-    memref::populateExpandStridedMetadataPatterns(patterns);
-  if (getExtractAddressComputations())
-    addExtractAddressComputationsPatterns(patterns);
-  if (getFoldReassociativeReshapes()) addReassociativeReshapePatterns(patterns);
-  if (getLinalgElementwiseGreedyFusion())
-    linalg::populateElementwiseOpsFusionPatterns(patterns,
-                                                 setFusedOpOperandLimit<3>);
-  if (getLowerTransferOpPermutations())
-    addLowerTransferOpPermutationsPatterns(patterns);
-  if (getLowerVectorMasks()) addLowerVectorMasksPatterns(patterns);
-  if (getPrepareVectorToMma()) addPrepareVectorToMmaPatterns(patterns);
-  if (getSwappingPatterns())
-    addSwappingPatterns(patterns, getSwapPaddingElideConditional());
-  if (getUnrollVectorsGpuMmaSync())
-    addUnrollVectorsGpuMmaSyncPatterns(patterns);
-  if (getUnrollVectorsGpuWmma()) addUnrollVectorsGpuWmmaPatterns(patterns);
-
   ErrorCheckingTrackingListener listener(state, *this);
-  GreedyRewriteConfig config;
-  config.listener = &listener;
-  // Manually gather list of ops because the other GreedyPatternRewriteDriver
-  // overloads only accepts ops that are isolated from above.
-  SmallVector<Operation *> ops;
-  target->walk([&](Operation *nestedOp) {
-    if (target != nestedOp) ops.push_back(nestedOp);
+  Operation *lastOpVisited = nullptr;
+
+  WalkResult status = target->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+      lastOpVisited = op;
+      if (failed(eliminateCommonSubexpressions(op, /*domInfo=*/nullptr,
+                                               &listener)))
+        return WalkResult::interrupt();
+      if (listener.failed()) return WalkResult::interrupt();
+      return WalkResult::skip();
+    }
+    return WalkResult::advance();
   });
-  LogicalResult result =
-      applyOpPatternsAndFold(ops, std::move(patterns), config);
-  if (failed(result))
-    return mlir::emitDefiniteFailure(target, "greedy patterns failed");
+
+  if (!status.wasInterrupted()) return DiagnosedSilenceableFailure::success();
 
   if (listener.failed()) return listener.checkAndResetError();
 
-  if (getLicm()) {
-    target->walk([&](func::FuncOp funcOp) {
-      // This assumes LICM never removes operations so we don't need tracking.
-      // TODO: confirm / revisit this assumption and plumb a rewriter through
-      // upstream moveLoopInvariantCode if necessary.
-      funcOp->walk([](LoopLikeOpInterface loopLike) {
-        moveLoopInvariantCode(loopLike);
-      });
-      // For now, put single loop promotion as part of licm. Underlying
-      // implementations perform splice operations which shouldn't need
-      // tracking.
-      // TODO: confirm / revisit this assumption and plumb a rewriter through
-      // upstream moveLoopInvariantCode if necessary.
-      funcOp->walk([](Operation *op) {
-        (void)llvm::TypeSwitch<Operation *, LogicalResult>(op)
-            .Case<affine::AffineForOp, scf::ForOp>(
-                [](auto loop) { return promoteIfSingleIteration(loop); })
-            .Default([](Operation *) { return success(); });
-      });
-    });
-  }
+  return mlir::emitDefiniteFailure(lastOpVisited, "CSE failed");
+}
 
-  if (getCse()) {
-    func::FuncOp lastFuncVisited;
-    auto walkResult = target->walk([&](func::FuncOp funcOp) -> WalkResult {
-      lastFuncVisited = funcOp;
-      result =
-          eliminateCommonSubexpressions(funcOp, /*domInfo=*/nullptr, &listener);
-      if (failed(result)) return WalkResult::interrupt();
-      if (listener.failed()) return WalkResult::interrupt();
-      return WalkResult::advance();
+void transform_dialect::ApplyCommonSubexpressionEliminationOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::modifiesPayload(effects);
+}
+
+//===---------------------------------------------------------------------===//
+// ApplyLoopIndependentCodeMotionOp
+//===---------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform_dialect::ApplyLoopIndependentCodeMotionOp::applyToOne(
+    Operation *target, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  ErrorCheckingTrackingListener listener(state, *this);
+  target->walk([&](func::FuncOp funcOp) {
+    // This assumes LICM never removes operations so we don't need tracking.
+    // TODO: confirm / revisit this assumption and plumb a rewriter through
+    // upstream moveLoopInvariantCode if necessary.
+    funcOp->walk(
+        [](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+    // For now, put single loop promotion as part of licm. Underlying
+    // implementations perform splice operations which shouldn't need
+    // tracking.
+    // TODO: confirm / revisit this assumption and plumb a rewriter through
+    // upstream moveLoopInvariantCode if necessary.
+    funcOp->walk([](Operation *op) {
+      (void)llvm::TypeSwitch<Operation *, LogicalResult>(op)
+          .Case<affine::AffineForOp, scf::ForOp>(
+              [](auto loop) { return promoteIfSingleIteration(loop); })
+          .Default([](Operation *) { return success(); });
     });
-    if (walkResult.wasInterrupted()) {
-      if (failed(result)) {
-        return mlir::emitDefiniteFailure(lastFuncVisited,
-                                         "greedy patterns failed");
-      }
-      if (listener.failed()) return listener.checkAndResetError();
-      llvm_unreachable("walk was interrupted for unknown reason");
-    }
-  }
+  });
 
   return listener.checkAndResetError();
 }
 
-void transform_dialect::ApplyPatternsOp::getEffects(
+void transform_dialect::ApplyLoopIndependentCodeMotionOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   transform::onlyReadsHandle(getTarget(), effects);
   transform::modifiesPayload(effects);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
index da6cada..718da72 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h
@@ -28,34 +28,6 @@
 struct NumThreadsSpec;
 class TransformTypeInterface;
 }  // namespace transform
-
-namespace iree_compiler {
-namespace IREE {
-namespace transform_dialect {
-/// Selected patterns for ApplyPatternOp.
-struct ApplyPatternsOpPatterns {
-  bool bubbleCollapse = false;
-  bool bubbleExpand = false;
-  bool bubblePackUnPack = false;
-  bool canonicalization = false;
-  bool cse = false;
-  bool eraseUnnecessaryTensorOperands = false;
-  bool expandMemrefStridedMetadata = false;
-  bool extractAddressComputations = false;
-  bool foldReassociativeReshapes = false;
-  bool licm = false;
-  bool linalgElementwiseGreedyFusion = false;
-  bool lowerTransferOpPermutations = false;
-  bool lowerVectorMasks = false;
-  bool prepareVectorToMma = false;
-  bool swapPaddingElideConditional = false;
-  bool swappingPatterns = false;
-  bool unrollVectorsGpuMmaSync = false;
-  bool unrollVectorsGpuWmma = false;
-};
-}  // namespace transform_dialect
-}  // namespace IREE
-}  // namespace iree_compiler
 }  // namespace mlir
 
 #define GET_OP_CLASSES
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
index 8103b01..9becc08 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td
@@ -51,6 +51,44 @@
   }];
 }
 
+def ApplyBubbleCollapsePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.iree.bubble_collapse",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Populate patterns to fold an expanding tensor.expand_shape operation with
+    its producer generic operation by collapsing the dimensions of the generic
+    op.
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyBubbleExpandPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.iree.bubble_expand",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Populate patterns to fold an expanding (collapsing) tensor_reshape
+    operation with its producer (consumer) generic operation by expanding
+    the dimensionality of the loop in the generic op.
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyBubblePackUnpackPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.iree.bubble_pack_unpack",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Populate patterns to bubble up or down data layout ops across other
+    operations.
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyFoldFillIntoPadPatternsOp : Op<Transform_Dialect,
     "apply_patterns.iree.fold_fill_into_pad",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
@@ -77,118 +115,101 @@
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyPatternsOp : Op<Transform_Dialect, "iree.apply_patterns",
+def ApplyIreeLinalgElementwiseGreedyFusionPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.iree.linalg_elementwise_greedy_fusion",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Populate patterns to fuse `linalg.generic` -> `linalg.generic` operations
+    when both operations are fusable elementwise operations.
+
+    Note: This pattern set is parameterized for usage in IREE, therefore
+    it is called "iree.linalg_elementwise_greedy_fusion".
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyPrepareVectorToMMAPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.iree.prepare_vector_to_mma",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Populate patterns that transform vector ops into a canonical form to
+    convert to MMA matrix operations. If `useNvGpu` is true, then the patterns
+    will populated will prepare for conversion to `nvgpu` mma operations
+    rather than the `gpu` dialect WMMA operations.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<BoolAttr, "true">:$useNvGpu);
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyUnrollVectorsGpuMmaSyncPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.iree.unroll_vectors_gpu_mma_sync",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Populate patterns that unroll vectors. TODO: better documentation.
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyUnrollVectorsGpuWmmaSyncPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.iree.unroll_vectors_gpu_wmma_sync",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Populate patterns that unroll vectors. TODO: better documentation.
+  }];
+
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyCommonSubexpressionEliminationOp : Op<Transform_Dialect, "iree.apply_cse",
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      TransformEachOpTrait,
      TransformOpInterface]> {
   let description = [{
-    Greedily applies patterns as specified by its attributes.
+    Apply common subexpression elimination. This transform is applied to all
+    ops within the target that are isolated from above.
 
-    Must be applied to an op with trait IsolatedFromAbove since the
-    GreedyPatternRewriter asserts those. Internally, uses the tracking rewriter
-    to preserve handles to payload operations nested within operations
-    associated with `target`. Fails if tracking cannot find replacement for a
-    payload operation. This may become controllable with an attribute in the
-    future.
-
-    Returns the IsolatedFromAbove op whose content it has modified for better
-    chaining APIs.
-
-    The following additive attributes can be set, they add patterns in an
-    unspecified order:
-      - bubble_collapse: bubble `collapse_shape` down across Linalg ops. This
-      must be applied separately from `bubble_expand` patterns because of some
-      upstream pattern interference issue atm.
-      - bubble_expand: bubble `expand_shape` down across Linalg ops. This
-      must be applied separately from `bubble_collapse` patterns because of some
-      upstream pattern interference issue atm.
-      - bubble_pack_un_pack: bubble `pack` up and `unpack` down across Linalg
-      ops.
-      - canonicalization: adds all the canonicalization patterns of all
-      registered dialects and ops.
-      - cse: additionally apply common subexpression elimination. This must
-      apply on a funcOp. This is not a set of patterns per se but is still very
-      convenient to apply it close to canonicalization and other greedy pattern
-      applications.
-      - erase_unnecessary_tensor_operands: add patterns that erase unnecessary
-      tensor operands.
-      - expand_memref_strided_metadata: adds patterns that expand memref
-      operations into extract_strided_metadata operations and a materialization
-      of their effect on the metadata (sizes, offset, strides).
-      - extract_address_computations: adds patterns for anchoring subview 
-      accessing operations at [0, ... 0].
-      - fold_reassociative_reshapes: adds patterns that fold insert_slice/
-      extract_slice ops with reassociative reshape ops.
-      - licm: additionally apply loop-independent code motion and single 
-      iteration loop promotion. This is not a set of patterns per se but is still
-      very convenient to apply it close to canonicalization and other greedy
-      pattern applications.
-      - linalg_elementwise_greedy_fusion: add linalg elementwise ops fusion
-      patterns using a naive default heuristic.
-      - lower_transfer_op_permutations: Lower transfer ops to transfer ops
-      with minor identity permutations.
-      - lower_vector_masks: Lower vector.mask ops away.
-      - prepare_vector_to_mma: pre-process vector.contract op to set it in a form
-      that can be mapped to nvgpu.mma operations. 
-      behavior on subset-based linalg operations using insert/extract slices.
-      - swapping_patterns: adds patterns that swap operations for a better outcome.
-      This is a catch all that can be refined further if/when needed.
-      - swap_padding_elide_conditional: refines the tensor.pad +
-      tensor.extract_slice swapping pattern. This injects static information
-      that guarantees padding is smaller than the window size which guarantees
-      we never see a tile comprised of padding-only.
-      - unroll_vectors_gpu_mma_sync: adds patterns that unroll vectors to a native tile
-      size for GPUs with mma operations. The size is currently hardcoded but
-      should be refactored upstream and made pluggable.
-      - unroll_vectors_gpu_wmma: adds patterns that unroll vectors to a native tile
-      size for GPUs with wmma operations. The size is currently hardcoded but
-      should be refactored upstream and made pluggable.
-
-
-    #### Return modes:
-
-    This operation applies a set of patterns specified by attributes. To apply
-    these patterns, this operation must target an operation that is isolated
-    from above, otherwise the transform definitely fails.
-
-    If the pattern application fails, or if the underlying listener fails to
-    capture op handles, the transformation definitely fails.
-
-    Otherwise the transformation is successful.
+    #### Return modes
 
     This operation does not consume the target handle and does not produce any
     handle.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target,
-                       UnitAttr:$bubble_collapse,
-                       UnitAttr:$bubble_expand,
-                       UnitAttr:$bubble_pack_un_pack,
-                       UnitAttr:$canonicalization,
-                       UnitAttr:$cse,
-                       UnitAttr:$erase_unnecessary_tensor_operands,
-                       UnitAttr:$expand_memref_strided_metadata,
-                       UnitAttr:$extract_address_computations,
-                       UnitAttr:$fold_reassociative_reshapes,
-                       UnitAttr:$licm,
-                       UnitAttr:$linalg_elementwise_greedy_fusion,
-                       UnitAttr:$lower_transfer_op_permutations,
-                       UnitAttr:$lower_vector_masks,
-                       UnitAttr:$prepare_vector_to_mma,
-                       UnitAttr:$swap_padding_elide_conditional,
-                       UnitAttr:$swapping_patterns,
-                       UnitAttr:$unroll_vectors_gpu_mma_sync,
-                       UnitAttr:$unroll_vectors_gpu_wmma);
-  let results = (outs);
-
-  let assemblyFormat = "$target attr-dict `:` functional-type($target, results)";
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let assemblyFormat = "$target attr-dict `:` type($target)";
   let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
 
-  let builders = [
-    // TODO: Some bitvector to scale better than n-bools.
-    OpBuilder<(ins "Value":$target,
-                   "const ApplyPatternsOpPatterns &":$patterns)>
-  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def ApplyLoopIndependentCodeMotionOp : Op<Transform_Dialect, "iree.apply_licm",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformEachOpTrait,
+     TransformOpInterface]> {
+  let description = [{
+    Apply loop-independent code motion and single iteration loop promotion.
+    This transform is applied to all FuncOps within the target.
+
+    #### Return modes
+
+    This operation does not consume the target handle and does not produce any
+    handle.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+  let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir b/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir
index 6d5a0fa..e004a26 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/reductions_codegen_spec.mlir
@@ -20,7 +20,9 @@
     : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
   
   %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func { bubble_expand } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.iree.bubble_expand
+  } : !transform.any_op
 
   // Excessively eager canonicalization results in `fill`s being "fused" due to
   // swapping with `extract_slice`, which confuses the fusion operation below.
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
index 0e777df..684e863 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
@@ -11,7 +11,9 @@
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %0 { canonicalization } : (!transform.any_op) -> ()
+  transform.apply_patterns to %0 {
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
 }
 
 // -----
@@ -84,7 +86,9 @@
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %0 { bubble_expand } : (!transform.any_op) -> ()
+  transform.apply_patterns to %0 {
+    transform.apply_patterns.iree.bubble_expand
+  } : !transform.any_op
 }
 
 // -----
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
index 97b2bd4..f758e27 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir
@@ -67,10 +67,15 @@
       transform.apply_patterns.linalg.tiling_canonicalization
       transform.apply_patterns.scf.for_loop_canonicalization
     } : !transform.any_op
-    transform.iree.apply_patterns %func_3
-      { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
+    transform.apply_patterns to %func_3 {
+      transform.apply_patterns.tensor.reassociative_reshape_folding
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.iree.apply_cse %func_3 : !transform.any_op
     transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
-    transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
+    transform.apply_patterns to %func_3 {
+      transform.apply_patterns.linalg.erase_unnecessary_inputs
+    } : !transform.any_op
     %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
     %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
     transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!transform.any_op) -> ()
@@ -83,8 +88,10 @@
 
     %func_8 = transform.structured.hoist_redundant_vector_transfers %memref_func
     : (!transform.any_op) -> !transform.any_op
-    transform.iree.apply_patterns %func_8 { canonicalization } : (!transform.any_op) -> ()
-    transform.iree.apply_patterns %func_8 { cse } : (!transform.any_op) -> ()
+    transform.apply_patterns to %func_8 {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.iree.apply_cse %func_8 : !transform.any_op
     transform.iree.apply_buffer_optimizations %func_8 : (!transform.any_op) -> ()
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/layout_analysis_and_distribution.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/layout_analysis_and_distribution.mlir
index 5b1ee29..b9b3dcb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/layout_analysis_and_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/layout_analysis_and_distribution.mlir
@@ -977,7 +977,7 @@
   ^bb1(%variant_op: !transform.any_op):
     %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
     %reordered_func = transform.iree.reorder_transpose %top_level_func : (!transform.any_op) -> !transform.any_op
-    transform.iree.apply_patterns %reordered_func { cse } : (!transform.any_op) -> ()
+     transform.iree.apply_cse %reordered_func : !transform.any_op
   }
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
index 3c5d8ed..c7bf062 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
@@ -94,8 +94,12 @@
 // CHECK: apply_patterns to %{{.*}} {
 // CHECK:   transform.apply_patterns.memref.fold_memref_alias_ops
 // CHECK: } : !transform.any_op
-// CHECK: transform.iree.apply_patterns %{{.*}} {extract_address_computations}
-// CHECK: transform.iree.apply_patterns %{{.*}} {unroll_vectors_gpu_mma_sync}
+// CHECK: apply_patterns to %{{.*}} {
+// CHECK:   transform.apply_patterns.memref.extract_address_computations
+// CHECK: } : !transform.any_op
+// CHECK: apply_patterns to %{{.*}} {
+// CHECK:   transform.apply_patterns.iree.unroll_vectors_gpu_mma_sync
+// CHECK: } : !transform.any_op
 // CHECK: transform.structured.match ops{["scf.for"]} in %{{.*}} 
 // CHECK: transform.iree.synchronize_loop %{{.*}}
 // CHECK: transform.structured.hoist_redundant_vector_transfers %{{.*}}
@@ -112,11 +116,12 @@
 // CHECK: transform.apply_patterns.vector.lower_masks
 // CHECK: transform.apply_patterns.vector.materialize_masks
 // CHECK: apply_patterns to %{{.*}} {
-// CHECK:   transform.apply_patterns.linalg.tiling_canonicalization
-// CHECK:   transform.apply_patterns.memref.fold_memref_alias_ops
+// CHECK-DAG:   transform.apply_patterns.linalg.tiling_canonicalization
+// CHECK-DAG:   transform.apply_patterns.memref.fold_memref_alias_ops
+// CHECK-DAG:   transform.apply_patterns.canonicalization
 // CHECK: } : !transform.any_op
-// CHECK: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
-
+// CHECK: transform.iree.apply_licm
+// CHECK: transform.iree.apply_cse
 
 // WITH_OPTIONS-LABEL: func @matmul
 
@@ -154,10 +159,14 @@
 // WITH_OPTIONS: apply_patterns to %{{.*}} {
 // WITH_OPTIONS:   transform.apply_patterns.memref.fold_memref_alias_ops
 // WITH_OPTIONS: } : !transform.any_op
-// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {extract_address_computations}
+// WITH_OPTIONS: apply_patterns to %{{.*}} {
+// WITH_OPTIONS:   transform.apply_patterns.memref.extract_address_computations
+// WITH_OPTIONS: } : !transform.any_op
 // The unroll attribute should match td-matmul-use-mma-sync, for true: mma_sync,
 // for false:_wmma.
-// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {unroll_vectors_gpu_mma_sync}
+// WITH_OPTIONS: apply_patterns to %{{.*}} {
+// WITH_OPTIONS:   transform.apply_patterns.iree.unroll_vectors_gpu_mma_sync
+// WITH_OPTIONS: }
 // WITH_OPTIONS: transform.structured.match ops{["scf.for"]} in %{{.*}} 
 // WITH_OPTIONS: transform.iree.synchronize_loop %{{.*}}
 // WITH_OPTIONS: transform.structured.hoist_redundant_vector_transfers %{{.*}}
@@ -182,7 +191,11 @@
 // WITH_OPTIONS:   transform.apply_patterns.linalg.tiling_canonicalization
 // WITH_OPTIONS:   transform.apply_patterns.memref.fold_memref_alias_ops
 // WITH_OPTIONS: } : !transform.any_op
-// WITH_OPTIONS: transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
+// WITH_OPTIONS: apply_patterns to %{{.*}} {
+// WITH_OPTIONS:   transform.apply_patterns.canonicalization
+// WITH_OPTIONS  }
+// WITH_OPTIONS: transform.iree.apply_licm
+// WITH_OPTIONS: transform.iree.apply_cse
 
 
 // WITH_OPTIONS_2-LABEL: func @matmul
@@ -314,7 +327,11 @@
 // CHECK-SAME:   pack_paddings = [1, 1, 1]
 // CHECK-SAME:   padding_dimensions = [0, 1, 2]
 // CHECK-SAME:   padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
-// CHECK:      transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
+// CHECK:      apply_patterns to %{{.*}} {
+// CHECK:        transform.apply_patterns.canonicalization
+// CHECK       }
+// CHECK:      transform.iree.apply_licm
+// CHECK:      transform.iree.apply_cse
 // CHECK:      %[[RES_PAD:.+]] = get_producer_of_operand %{{.*}}[2]
 // CHECK:      %[[RES_COPY:.+]] = transform.structured.rewrite_in_destination_passing_style %[[RES_PAD]]
 // CHECK:      %[[LHS_PAD:.+]] = get_producer_of_operand %{{.*}}[0]
@@ -327,7 +344,10 @@
 // CHECK:      transform.scf.take_assumed_branch %{{.*}} take_else_branch
 // CHECK:      transform.structured.tile_to_forall_op %{{.*}}   num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<x>, #gpu.warp<y>])
 // CHECK:      transform.structured.tile_to_forall_op %{{.*}}   num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<x>, #gpu.warp<y>])
-// CHECK:      transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
+// CHECK:        transform.apply_patterns.canonicalization
+// CHECK       }
+// CHECK:      transform.iree.apply_licm
+// CHECK:      transform.iree.apply_cse
 
 // alignLhs
 // CHECK:      transform.structured.masked_vectorize %[[TILED_LHS]] vector_sizes [4, 4]
@@ -380,7 +400,11 @@
 // CHECK-SAME:   padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
 
 // Canonicalization is currently required here to enable pad to dps to produce linalg.copy ops.
-// CHECK:      transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
+// CHECK:      apply_patterns to %{{.*}} {
+// CHECK:        transform.apply_patterns.canonicalization
+// CHECK       }
+// CHECK:      transform.iree.apply_licm
+// CHECK:      transform.iree.apply_cse
 // CHECK:      %[[RES_PAD:.+]] = get_producer_of_operand %{{.*}}[2]
 // CHECK:      %[[RES_COPY:.+]] = transform.structured.rewrite_in_destination_passing_style %[[RES_PAD]]
 // CHECK:      %[[LHS_PAD:.+]] = get_producer_of_operand %{{.*}}[0]
@@ -391,7 +415,10 @@
 // CHECK:      transform.structured.tile_to_forall_op %[[RHS_COPY]]   num_threads [4, 32] tile_sizes [](mapping = [#gpu.linear<y>, #gpu.linear<x>])
 // CHECK:      transform.structured.tile_to_forall_op %{{.*}}   num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<x>, #gpu.warp<y>])
 // CHECK:      transform.structured.tile_to_forall_op %{{.*}}   num_threads [2, 2] tile_sizes [](mapping = [#gpu.warp<x>, #gpu.warp<y>])
-// CHECK:      transform.iree.apply_patterns %{{.*}} {canonicalization, cse, licm}
+// CHECK:        transform.apply_patterns.canonicalization
+// CHECK       }
+// CHECK:      transform.iree.apply_licm
+// CHECK:      transform.iree.apply_cse
 
 // Verify we don't go down the path without the flag.
 // WITH_OPTIONS-LABEL: func @aligned_matmul
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir
index bb29693..3d7a00f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_pad.mlir
@@ -48,12 +48,20 @@
 //       CHECK:   transform.iree.register_match_callbacks
 //       CHECK:   {{.*}} = transform.iree.match_callback failures(propagate) "pad"({{.*}}) : (!transform.any_op) -> !transform.any_op
 //       CHECK:   transform.structured.tile_to_forall_op {{.*}}   num_threads [] tile_sizes [64, 64](mapping = [#gpu.block<y>, #gpu.block<x>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-//       CHECK:   transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm} : (!transform.any_op) -> ()
+//       CHECK:   apply_patterns to %{{.*}} {
+//       CHECK:     transform.apply_patterns.canonicalization
+//       CHECK    }
+//       CHECK:   transform.iree.apply_licm
+//       CHECK:   transform.iree.apply_cse
 //       CHECK:   {{.*}} = transform.structured.match ops{["scf.if"]} in {{.*}} : (!transform.any_op) -> !transform.any_op
 //       CHECK:   transform.scf.take_assumed_branch {{.*}} take_else_branch : (!transform.any_op) -> ()
 //       CHECK:   transform.iree.populate_workgroup_count_region_using_num_threads_slice {{.*}} : (!transform.any_op) -> ()
 //       CHECK:   {{.*}} = transform.structured.tile_to_forall_op {{.*}}   num_threads [16, 16] tile_sizes [](mapping = [#gpu.thread<y>, #gpu.thread<x>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-//       CHECK:   transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm} : (!transform.any_op) -> ()
+//       CHECK:   apply_patterns to %{{.*}} {
+//       CHECK:     transform.apply_patterns.canonicalization
+//       CHECK    }
+//       CHECK:   transform.iree.apply_licm
+//       CHECK:   transform.iree.apply_cse
 //       CHECK:   {{.*}} = transform.structured.match ops{["scf.if"]} in {{.*}} : (!transform.any_op) -> !transform.any_op
 //       CHECK:   transform.scf.take_assumed_branch {{.*}} take_else_branch : (!transform.any_op) -> ()
 //       CHECK:   transform.structured.masked_vectorize {{.*}} vector_sizes [4, 4] : !transform.any_op
@@ -65,7 +73,11 @@
 //   CHECK-DAG:     transform.apply_patterns.vector.cast_away_vector_leading_one_dim
 //       CHECK:   } : !transform.any_op
 //       CHECK:   {{.*}} = transform.structured.vectorize {{.*}} : (!transform.any_op) -> !transform.any_op
-//       CHECK:   transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm} : (!transform.any_op) -> ()
+//       CHECK:   apply_patterns to %{{.*}} {
+//       CHECK:     transform.apply_patterns.canonicalization
+//       CHECK    }
+//       CHECK:   transform.iree.apply_licm
+//       CHECK:   transform.iree.apply_cse
 //       CHECK:   transform.iree.eliminate_empty_tensors {{.*}} : (!transform.any_op) -> ()
 //       CHECK:   {{.*}} = transform.iree.bufferize {target_gpu} {{.*}} : (!transform.any_op) -> !transform.any_op
 //       CHECK:   {{.*}} = transform.structured.match ops{["func.func"]} in {{.*}} : (!transform.any_op) -> !transform.any_op
@@ -77,10 +89,12 @@
 //       CHECK:     transform.apply_patterns.vector.lower_masks
 //       CHECK:     transform.apply_patterns.vector.materialize_masks
 //       CHECK:   apply_patterns to %{{.*}} {
-//       CHECK:     transform.apply_patterns.linalg.tiling_canonicalization
-//       CHECK:     transform.apply_patterns.memref.fold_memref_alias_ops
+//   CHECK-DAG:     transform.apply_patterns.linalg.tiling_canonicalization
+//   CHECK-DAG:     transform.apply_patterns.memref.fold_memref_alias_ops
+//   CHECK-DAG:     transform.apply_patterns.canonicalization
 //       CHECK:   } : !transform.any_op
-//       CHECK:   transform.iree.apply_patterns {{.*}} {canonicalization, cse, licm} : (!transform.any_op) -> ()
+//       CHECK:   transform.iree.apply_licm
+//       CHECK:   transform.iree.apply_cse
 
 // WITH_OPTIONS-LABEL: func @pad
 //       WITH_OPTIONS:   transform.structured.tile_to_forall_op {{.*}}   num_threads [] tile_sizes [32, 16](mapping = [#gpu.block<y>, #gpu.block<x>])
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
index a094335..7702693 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_foreach_to_gpu_spec.mlir
@@ -19,8 +19,11 @@
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %func 
-    { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.tensor.reassociative_reshape_folding
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
+  transform.iree.apply_cse %func : !transform.any_op
   transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
   %variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> (!transform.any_op)
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
@@ -33,7 +36,8 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %memref_func
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %memref_func : !transform.any_op
+  transform.iree.apply_cse %memref_func : !transform.any_op
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
index c5dc594..443d7d2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_distribution_spec.mlir
@@ -14,7 +14,8 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
index 3103b34..43f82f7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_codegen_vector_warp_execute_on_lane_0_spec.mlir
@@ -6,6 +6,9 @@
     : (!transform.any_op) -> !transform.any_op
 
   // Late canonicalizations to cleanup and pass the checks.
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.apply_patterns to %variant_op {
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
index 268bbbd..19cc283 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_promote_operands.mlir
@@ -41,9 +41,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-    transform.iree.apply_patterns %variant_op
-      { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
   }
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
index f343b18..e40cca7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_distribute_forall.mlir
@@ -53,8 +53,11 @@
 
         // Late canonicalizations to cleanup and pass the checks.
         // Needs to occur on the whole variant to perform cse on the workgroup_count region
-        transform.iree.apply_patterns %variant_op
-          { canonicalization, tiling_canonicalization, licm, cse } : (!transform.any_op) -> ()
+        transform.apply_patterns to %variant_op {
+          transform.apply_patterns.canonicalization
+        } : !transform.any_op
+        transform.iree.apply_licm %variant_op : !transform.any_op
+        transform.iree.apply_cse %variant_op : !transform.any_op
       }
     }
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_vector_to_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_vector_to_mma.mlir
index 06a50a6..6180d5f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_vector_to_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_vector_to_mma.mlir
@@ -50,13 +50,17 @@
 transform.sequence failures(propagate) {
 ^bb1(%variant_op: !transform.any_op):
   %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func { unroll_vectors_gpu_wmma } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.iree.unroll_vectors_gpu_wmma_sync
+  } : !transform.any_op
   transform.iree.vector.vector_to_mma_conversion %func { use_wmma } : (!transform.any_op) -> ()
 
   // Apply canonicalization post-hoc to trigger DCE and pass the test 
   // (i.e. all vector.contract are dead).
   // TODO: consider having the vector_to_mma_conversion do the DCE automatically.
-  transform.iree.apply_patterns %func { canonicalization } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
 }
 }
 
@@ -127,8 +131,12 @@
 transform.sequence failures(propagate) {
 ^bb1(%variant_op: !transform.any_op):
   %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func { unroll_vectors_gpu_wmma } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.iree.unroll_vectors_gpu_wmma_sync
+  } : !transform.any_op
   transform.iree.vector.vector_to_mma_conversion %func { use_wmma } : (!transform.any_op) -> ()
-  transform.iree.apply_patterns %func { canonicalization } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
 }
 }
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/CPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/CPU/Common.cpp
index 6842e4a..b84f795 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/CPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/CPU/Common.cpp
@@ -32,7 +32,6 @@
 using iree_compiler::cpu::CPUModel;
 using iree_compiler::cpu::ReductionConfig;
 using iree_compiler::cpu::ReductionStrategy;
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
 using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
 using transform::ApplyLowerContractionPatternsOp;
 using transform::ApplyLowerMultiReductionPatternsOp;
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp
index c85bc2d..ca6ac61 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp
@@ -28,8 +28,6 @@
 
 // TODO: significantly better namespacing.
 using iree_compiler::IREE::transform_dialect::ApplyBufferOptimizationsOp;
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
 using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
 using iree_compiler::IREE::transform_dialect::IREEBufferizeOp;
 using iree_compiler::IREE::transform_dialect::IREEEliminateEmptyTensorsOp;
@@ -137,13 +135,12 @@
         b.create<transform::ApplyTilingCanonicalizationPatternsOp>(loc);
         b.create<IREE::transform_dialect::ApplyFoldFillIntoPadPatternsOp>(loc);
         b.create<transform::ApplyForLoopCanonicalizationPatternsOp>(loc);
+        b.create<transform::ApplyCanonicalizationPatternsOp>(loc);
         if (populatePatternsFn) populatePatternsFn(b, loc);
       });
-  ApplyPatternsOpPatterns configuration;
-  configuration.canonicalization = true;
-  configuration.cse = true;
-  configuration.licm = true;
-  b.create<ApplyPatternsOp>(variantH, configuration);
+  b.create<IREE::transform_dialect::ApplyLoopIndependentCodeMotionOp>(variantH);
+  b.create<IREE::transform_dialect::ApplyCommonSubexpressionEliminationOp>(
+      variantH);
 }
 
 /// Dynamically selects the first non-empty handle; i.e. if (h1, h2) is:
@@ -390,9 +387,11 @@
   }
 
   auto funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
-  ApplyPatternsOpPatterns configuration;
-  configuration.bubbleExpand = true;
-  b.create<ApplyPatternsOp>(funcH, configuration);
+  b.create<transform::ApplyPatternsOp>(funcH, [](OpBuilder &b, Location loc) {
+    b.create<
+        iree_compiler::IREE::transform_dialect::ApplyBubbleExpandPatternsOp>(
+        loc);
+  });
   std::tie(result.originalFillH, result.splitFillH) =
       matchAndUnpack<2>(b, variantH, linalg::FillOp::getOperationName());
   if (hasTrailingEltwise) {
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h
index 9643a48..3018566 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.h
@@ -30,12 +30,6 @@
   return mlir::gpu::GPUBlockMappingAttr::get(ctx, mlir::gpu::Blocks::DimZ);
 }
 
-namespace IREE {
-namespace transform_dialect {
-struct ApplyPatternsOpPatterns;
-}  // namespace transform_dialect
-}  // namespace IREE
-
 struct AbstractReductionStrategy;
 
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp
index a4b8a67..0d627c3 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Common.cpp
@@ -34,8 +34,6 @@
 using namespace mlir;
 
 // TODO: significantly better namespacing.
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
 using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
 using iree_compiler::IREE::transform_dialect::MapNestedForallToGpuThreadsOp;
 using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
@@ -535,20 +533,18 @@
   b.create<transform::ApplyPatternsOp>(funcH, [](OpBuilder &b, Location loc) {
     b.create<transform::ApplyFoldMemrefAliasOpsPatternsOp>(loc);
   });
-  {
-    ApplyPatternsOpPatterns config;
-    config.extractAddressComputations = true;
-    b.create<ApplyPatternsOp>(funcH, config);
-  }
+  b.create<transform::ApplyPatternsOp>(funcH, [](OpBuilder &b, Location loc) {
+    b.create<transform::ApplyExtractAddressComputationsPatternsOp>(loc);
+  });
   iree_compiler::buildCanonicalizationAndEnablingTransforms(b, funcH);
-  {
-    ApplyPatternsOpPatterns config;
+  b.create<transform::ApplyPatternsOp>(funcH, [&](OpBuilder &b, Location loc) {
     if (strategy.useMmaSync)
-      config.unrollVectorsGpuMmaSync = true;
+      b.create<iree_compiler::IREE::transform_dialect::
+                   ApplyUnrollVectorsGpuMmaSyncPatternsOp>(loc);
     else
-      config.unrollVectorsGpuWmma = true;
-    b.create<ApplyPatternsOp>(funcH, config);
-  }
+      b.create<iree_compiler::IREE::transform_dialect::
+                   ApplyUnrollVectorsGpuWmmaSyncPatternsOp>(loc);
+  });
 
   Value forH = b.create<transform::MatchOp>(
       transform::OperationType::get(b.getContext(), "scf.for"), funcH,
@@ -654,11 +650,13 @@
 
 Value mlir::iree_compiler::gpu::buildBufferize(ImplicitLocOpBuilder &b,
                                                Value variantH) {
-  ApplyPatternsOpPatterns patterns;
-  patterns.canonicalization = true;
-  patterns.cse = true;
-  patterns.licm = true;
-  b.create<ApplyPatternsOp>(variantH, patterns);
+  b.create<transform::ApplyPatternsOp>(
+      variantH, [](OpBuilder &b, Location loc) {
+        b.create<transform::ApplyCanonicalizationPatternsOp>(loc);
+      });
+  b.create<IREE::transform_dialect::ApplyLoopIndependentCodeMotionOp>(variantH);
+  b.create<IREE::transform_dialect::ApplyCommonSubexpressionEliminationOp>(
+      variantH);
   b.create<IREEEliminateEmptyTensorsOp>(variantH);
   auto bufferizeOp = b.create<IREEBufferizeOp>(variantH, /*targetGpu=*/true);
   bufferizeOp.setTargetGpu(true);
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
index ad88077..74f2f3a 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
@@ -47,7 +47,6 @@
 using iree_compiler::gpu::kCudaWarpSize;
 using iree_compiler::gpu::MatmulStrategy;
 using iree_compiler::gpu::scaleUpByBitWidth;
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
 using iree_compiler::IREE::transform_dialect::EliminateGpuBarriersOp;
 using iree_compiler::IREE::transform_dialect::
     IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp;
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.cpp
index 01e71f7..be3cead 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/PadStrategy.cpp
@@ -43,8 +43,6 @@
 using iree_compiler::gpu::buildDistributeOnePadOrCopyWithTileSizes;
 using iree_compiler::gpu::kCudaWarpSize;
 using iree_compiler::gpu::PadStrategy;
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
 using iree_compiler::IREE::transform_dialect::
     IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp;
 using transform::MatchOp;
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/SmallReductionStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/SmallReductionStrategy.cpp
index b6d3ede..c4cd687 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/SmallReductionStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/SmallReductionStrategy.cpp
@@ -23,8 +23,6 @@
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
 // TODO: significantly better namespacing.
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
 using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
 using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
 using iree_compiler::IREE::transform_dialect::VectorWarpDistributionOp;
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/StagedReductionStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/StagedReductionStrategy.cpp
index 2ea2029..030842c 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/StagedReductionStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/StagedReductionStrategy.cpp
@@ -24,8 +24,6 @@
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
 // TODO: significantly better namespacing.
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOp;
-using iree_compiler::IREE::transform_dialect::ApplyPatternsOpPatterns;
 using iree_compiler::IREE::transform_dialect::ForallToWorkgroupOp;
 using iree_compiler::IREE::transform_dialect::ShareForallOperandsOp;
 using iree_compiler::IREE::transform_dialect::VectorToWarpExecuteOnLane0Op;
diff --git a/tests/transform_dialect/cpu/attention_codegen_spec.mlir b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
index a541e66..b84e3dc 100644
--- a/tests/transform_dialect/cpu/attention_codegen_spec.mlir
+++ b/tests/transform_dialect/cpu/attention_codegen_spec.mlir
@@ -34,14 +34,17 @@
       transform.apply_patterns.iree.fold_fill_into_pad
       transform.apply_patterns.linalg.tiling_canonicalization
       transform.apply_patterns.scf.for_loop_canonicalization
+      transform.apply_patterns.canonicalization
     } : !transform.any_op
-    transform.iree.apply_patterns %variant_op
-        { canonicalization, licm, cse } : (!transform.any_op) -> ()
+    transform.iree.apply_licm %variant_op : !transform.any_op
+    transform.iree.apply_cse %variant_op : !transform.any_op
 
     // Bufferization
     // ==========================================
     transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
-    transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
+    transform.apply_patterns to %func_3 {
+      transform.apply_patterns.linalg.erase_unnecessary_inputs
+    } : !transform.any_op
     %variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> (!transform.any_op)
     %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
     transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!transform.any_op) -> ()
@@ -52,7 +55,9 @@
     transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> ()
     %func_8 = transform.structured.hoist_redundant_vector_transfers %memref_func
     : (!transform.any_op) -> !transform.any_op
-    transform.iree.apply_patterns %func_8 { canonicalization } : (!transform.any_op) -> ()
-    transform.iree.apply_patterns %func_8 { cse } : (!transform.any_op) -> ()
+    transform.apply_patterns to %func_8 {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.iree.apply_cse %func_8 : !transform.any_op
     transform.iree.apply_buffer_optimizations %func_8 : (!transform.any_op) -> ()
 }
diff --git a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
index 61fe536..5c445b9 100644
--- a/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
+++ b/tests/transform_dialect/cpu/matmul_codegen_custom_dispatch_formation_spec.mlir
@@ -18,9 +18,9 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op 
-    { canonicalization, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_cse %variant_op : !transform.any_op
   %variant_op_3 = transform.iree.bufferize %variant_op : (!transform.any_op) -> (!transform.any_op)
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 
     : (!transform.any_op) -> !transform.any_op
@@ -28,5 +28,5 @@
   transform.iree.forall_to_workgroup %memref_func : (!transform.any_op) -> ()
 
   // CSE is needed on the workgroup_count region to pass this particular test.
-  transform.iree.apply_patterns %variant_op_3 { cse } : (!transform.any_op) -> ()
+  transform.iree.apply_cse %variant_op_3 : !transform.any_op
 }
diff --git a/tests/transform_dialect/cuda/double_mma_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/double_mma_layout_analysis_codegen_spec.mlir
index 09e3cc0..f933624 100644
--- a/tests/transform_dialect/cuda/double_mma_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/double_mma_layout_analysis_codegen_spec.mlir
@@ -37,10 +37,15 @@
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %func_3
-    { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.tensor.reassociative_reshape_folding
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
+  transform.iree.apply_cse %func_3 : !transform.any_op
   transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
-  transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.linalg.erase_unnecessary_inputs
+  } : !transform.any_op
   %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!transform.any_op) -> ()
@@ -48,7 +53,9 @@
   // Step 5. Pre-process the contract and transfer ops to put it in the right form.
   // ===========================================================================
   %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func_2 {  prepare_vector_to_mma } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_2 {
+    transform.apply_patterns.iree.prepare_vector_to_mma
+  } : !transform.any_op
 
   // Step 6. Post-bufferization vector distribution
   // ===========================================================================
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
index d9beb7f..75f63a7 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_codegen_spec.mlir
@@ -27,7 +27,9 @@
   // able to preserve the handles.
   // ===========================================================================
   %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func { bubble_expand } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.iree.bubble_expand
+  } : !transform.any_op
   %fills = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!transform.any_op) -> !transform.any_op
   %fill_2, %more_parallel_fill_2 = transform.split_handle %fills
     : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
diff --git a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
index bb2598b..31d74ad 100644
--- a/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/eltwise_reduction_eltwise_codegen_spec.mlir
@@ -29,7 +29,9 @@
   // able to preserve the handles.
   // ===========================================================================
   %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func { bubble_expand } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.iree.bubble_expand
+  } : !transform.any_op
   %fills = transform.structured.match ops{["linalg.fill"]} in %variant_op : (!transform.any_op) -> !transform.any_op
   %fill_2, %more_parallel_fill_2 = transform.split_handle %fill
     : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
diff --git a/tests/transform_dialect/cuda/mma.mlir b/tests/transform_dialect/cuda/mma.mlir
index 79cd291..92d5cf8 100644
--- a/tests/transform_dialect/cuda/mma.mlir
+++ b/tests/transform_dialect/cuda/mma.mlir
@@ -31,13 +31,17 @@
 ^bb1(%module: !transform.any_op):
   %func = transform.structured.match ops{["func.func"]} in %module
     : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func { unroll_vectors_gpu_wmma } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.iree.unroll_vectors_gpu_wmma_sync
+  } : !transform.any_op
   transform.iree.vector.vector_to_mma_conversion %func { use_wmma } : (!transform.any_op) -> ()
 
   // Apply canonicalization post-hoc to trigger DCE and pass the test 
   // (i.e. all vector.contract are dead).
   // TODO: consider having the vector_to_mma_conversion do the DCE automatically.
-  transform.iree.apply_patterns %func { canonicalization } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
 }
 
 // -----
@@ -71,11 +75,16 @@
 ^bb1(%module: !transform.any_op):
   %func = transform.structured.match ops{["func.func"]} in %module
     : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func { unroll_vectors_gpu_mma_sync } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.iree.unroll_vectors_gpu_mma_sync
+  } : !transform.any_op
   transform.iree.vector.vector_to_mma_conversion %func { use_mma_sync } : (!transform.any_op) -> ()
 
   // Apply canonicalization post-hoc to trigger DCE and pass the test 
   // (i.e. all vector.contract are dead).
   // TODO: consider having the vector_to_mma_conversion do the DCE automatically.
-  transform.iree.apply_patterns %func { canonicalization } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func {
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
 }
+
diff --git a/tests/transform_dialect/cuda/mma_elemwise_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/mma_elemwise_layout_analysis_codegen_spec.mlir
index 5d06e43..f9c93d6 100644
--- a/tests/transform_dialect/cuda/mma_elemwise_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_elemwise_layout_analysis_codegen_spec.mlir
@@ -35,10 +35,15 @@
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %func_3
-    { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.tensor.reassociative_reshape_folding
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
+  transform.iree.apply_cse %func_3 : !transform.any_op
   transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
-  transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.linalg.erase_unnecessary_inputs
+  } : !transform.any_op
   %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!transform.any_op) -> ()
diff --git a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
index 84a6ce0..c591e43 100644
--- a/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_reduction_layout_analysis_codegen_spec.mlir
@@ -36,10 +36,15 @@
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %func_3
-    { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.tensor.reassociative_reshape_folding
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
+  transform.iree.apply_cse %func_3 : !transform.any_op
   transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
-  transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.linalg.erase_unnecessary_inputs
+  } : !transform.any_op
   %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!transform.any_op) -> ()
diff --git a/tests/transform_dialect/cuda/mma_using_layout_analysis_codegen_spec.mlir b/tests/transform_dialect/cuda/mma_using_layout_analysis_codegen_spec.mlir
index c647bca..bfbe204 100644
--- a/tests/transform_dialect/cuda/mma_using_layout_analysis_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/mma_using_layout_analysis_codegen_spec.mlir
@@ -40,10 +40,15 @@
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %func_3
-    { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.tensor.reassociative_reshape_folding
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
+  transform.iree.apply_cse %func_3 : !transform.any_op
   transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
-  transform.iree.apply_patterns %func_3 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.linalg.erase_unnecessary_inputs
+  } : !transform.any_op
   %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!transform.any_op) -> ()
@@ -51,7 +56,9 @@
   // Step 5. Pre-process the contract and transfer ops to put it in the right form.
   // ===========================================================================
   %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func_2 {  prepare_vector_to_mma } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_2 {
+    transform.apply_patterns.iree.prepare_vector_to_mma
+  } : !transform.any_op
 
   // Step 6. Post-bufferization vector distribution
   // ===========================================================================
diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
index 6902576..3cb09ed 100644
--- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir
@@ -38,8 +38,8 @@
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 
   %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op 
     : (!transform.any_op) -> !transform.any_op
@@ -65,7 +65,9 @@
 
   // Step 5. Bufferize and drop HAL decriptor from memref ops.
   // ===========================================================================
-  transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.tensor.reassociative_reshape_folding
+  } : !transform.any_op
   transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
   %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> !transform.any_op
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 
@@ -105,6 +107,6 @@
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op_3
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op_3 : !transform.any_op
+  transform.iree.apply_cse %variant_op_3 : !transform.any_op
 }
diff --git a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
index 3bd414b..186fd08 100644
--- a/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_eltwise_codegen_spec.mlir
@@ -21,9 +21,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 
   // Step 2. First level of tiling + fusion parallelizes to blocks. Tile the
   // trailing elementwise the same way we want to tile the reduction.
@@ -41,9 +42,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 
   // Step 3. Second level of tiling + fusion parallelizes to threads.
   // ===========================================================================
@@ -64,9 +66,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 
   %fill_2d = transform.structured.match ops{["linalg.fill"]} filter_result_type = tensor<1x2xf32> in %variant_op
     : (!transform.any_op) -> !transform.any_op
@@ -84,9 +87,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 
   // Step 4. Rank-reduce and vectorize.
   // ===========================================================================
@@ -100,7 +104,9 @@
 
   // Step 5. Bufferize and drop HAL decriptor from memref ops.
   // ===========================================================================
-  transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.tensor.reassociative_reshape_folding
+  } : !transform.any_op
   transform.iree.eliminate_empty_tensors %variant_op: (!transform.any_op) -> ()
   %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
@@ -139,7 +145,8 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op_3
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op_3 : !transform.any_op
+  transform.iree.apply_cse %variant_op_3 : !transform.any_op
 }
diff --git a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
index 00958bb..501b724 100644
--- a/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2_codegen_spec.mlir
@@ -57,11 +57,16 @@
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %func_3
-    { fold_reassociative_reshapes, canonicalization, cse } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.tensor.reassociative_reshape_folding
+    transform.apply_patterns.canonicalization
+  } : !transform.any_op
+  transform.iree.apply_cse %func_3 : !transform.any_op
   transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
   %func_5 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func_5 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_5 {
+    transform.apply_patterns.linalg.erase_unnecessary_inputs
+  } : !transform.any_op
   %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op)
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op
   transform.iree.erase_hal_descriptor_type_from_memref %memref_func : (!transform.any_op) -> ()
@@ -92,7 +97,8 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %func_7
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %func_7 : !transform.any_op
+  transform.iree.apply_cse %func_7 : !transform.any_op
 }
diff --git a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
index 588c089..055a551 100644
--- a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
@@ -22,9 +22,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 
   // Step 2. Split the reduction to get meatier parallelism.
   // This also parallelizes to threads.
@@ -50,9 +51,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 
   // Step 3. Rank-reduce and vectorize.
   // ===========================================================================
@@ -73,9 +75,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
   transform.structured.hoist_redundant_tensor_subsets %func_3
     : (!transform.any_op) -> ()
 
@@ -87,14 +90,19 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
-  transform.iree.apply_patterns %func_3 { fold_reassociative_reshapes } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
+  transform.apply_patterns to %func_3 {
+    transform.apply_patterns.tensor.reassociative_reshape_folding
+  } : !transform.any_op 
   transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> ()
   %func_6 = transform.structured.match ops{["func.func"]} in %variant_op
     : (!transform.any_op) -> !transform.any_op
-  transform.iree.apply_patterns %func_6 { erase_unnecessary_tensor_operands } : (!transform.any_op) -> ()
+  transform.apply_patterns to %func_6 {
+    transform.apply_patterns.linalg.erase_unnecessary_inputs
+  } : !transform.any_op
   %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op
     : (!transform.any_op) -> !transform.any_op
   %memref_func = transform.structured.match ops{["func.func"]} in %variant_op_3
@@ -128,7 +136,8 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op_3
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op_3 : !transform.any_op
+  transform.iree.apply_cse %variant_op_3 : !transform.any_op
 }
diff --git a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
index a062ebf..6254ea2 100644
--- a/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/softmax_v2_codegen_spec.mlir
@@ -39,9 +39,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 
 
   // Step 2. Second level of tiling + fusion parallelizes to threads.
@@ -79,9 +80,10 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 
   // Step 3. Rank-reduce and vectorize.
   // ==================================
@@ -127,7 +129,8 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op_3
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op_3 : !transform.any_op
+  transform.iree.apply_cse %variant_op_3 : !transform.any_op
 }
diff --git a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
index 2fff8d8..12b6ef1 100644
--- a/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
+++ b/tests/transform_dialect/cuda/vecadd2d_codegen_spec_partial_tile.mlir
@@ -15,7 +15,8 @@
     transform.apply_patterns.iree.fold_fill_into_pad
     transform.apply_patterns.linalg.tiling_canonicalization
     transform.apply_patterns.scf.for_loop_canonicalization
+    transform.apply_patterns.canonicalization
   } : !transform.any_op
-  transform.iree.apply_patterns %variant_op
-    { canonicalization, licm, cse } : (!transform.any_op) -> ()
+  transform.iree.apply_licm %variant_op : !transform.any_op
+  transform.iree.apply_cse %variant_op : !transform.any_op
 }