Clean up CapturingOpMatcher and derived classes, NFC (#14022)

* Use CRTP ConcreteOpMatcher and make it forward base class methods.
* Make StructuredOpMatcher derive ConcreteOpMatcher
* Remove duplicated functionality
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index c59d659..8f98e30 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -354,11 +354,6 @@
   // Constraints on adjacent ops.
   //===-------------------------------------------------------------------===//
 
-  /// Checks that `matchers` captured all tilable ops nested in `parent` except
-  /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured.
-  static bool checkAllTilableMatched(Operation *parent, Operation *op,
-                                     ArrayRef<CapturingOpMatcher *> matchers);
-
   /// Adds a predicate checking that all ops implementing TilingInterface in the
   /// parent of the given type (e.g., a function or a module) were matched by
   /// this or nested matchers. This is useful to ensure that the matcher covered
@@ -422,6 +417,11 @@
   /// A list of additional conditions for the operation to match.
   SmallVector<PredicateFn> predicates;
 
+  /// Checks that `matchers` captured all tilable ops nested in `parent` except
+  /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured.
+  static bool checkAllTilableMatched(Operation *parent, Operation *op,
+                                     ArrayRef<CapturingOpMatcher *> matchers);
+
   /// Creates a matcher for an operation with one of the given types.
   template <typename... OpType>
   static CapturingOpMatcher create() {
@@ -448,51 +448,111 @@
 /// with debug includes, and ConcreteOpMatcher is a class template that can only
 /// reside in the header.
 void debugOutputForConcreteOpMatcherConstructor(StringRef name);
-
-/// Specialization hook that returns an op description used in ConcreteOpMatcher
-/// debug output. By default, returns the operation name. This class template
-/// can be specialized for a better name and must be specialized for interfaces
-/// that don't have an operation name.
-template <typename OpTy>
-struct DebugOpKindDescription {
-  static StringRef get() { return OpTy::getOperationName(); }
-};
 } // namespace detail
 
 /// Base class for matchers that match a specific op. Adds an initial predicate
 /// checking if the op is indeed of the specified kind.
 /// Derived classes specializing this for op interfaces MUST also define a
 /// specialization of DebugOpKindDescription.
-// TODO: use traits instead of inheritance to inject behavior and avoid
-// unintended upcasting when calling parent methods in the fluent API.
-template <typename OpTy>
+template <typename Derived, typename OpTy>
 class ConcreteOpMatcher : public CapturingOpMatcher {
 protected:
+  using Base = ConcreteOpMatcher;
+
+  static StringRef getConcreteOpDescription() {
+    return OpTy::getOperationName();
+  }
+
   /// Adds a predicate checking if the op is of the OpTy kind.
   ConcreteOpMatcher() {
     CapturingOpMatcher::addPredicate([](Operation *op) {
       detail::debugOutputForConcreteOpMatcherConstructor(
-          detail::DebugOpKindDescription<OpTy>::get());
+          Derived::getConcreteOpDescription());
       return isa<OpTy>(op);
     });
   }
 
   /// Adds a predicate for the matched operation to satisfy.
   template <typename FnTy>
-  void addPredicate(FnTy &&predicate) {
+  Derived &addPredicate(FnTy &&predicate) {
     // Dispatch to the callback.
     CapturingOpMatcher::addPredicate(
         [inner = std::move(predicate)](Operation *op) {
           return inner(cast<OpTy>(op));
         });
+    return static_cast<Derived &>(*this);
+  }
+
+public:
+  /// Adds alternative paths for predicates. In practice, this is just a
+  /// predicate that is satisfied when either the first or the second matcher is
+  /// satisfied. The alternative satisfaction is eager and short-cutting, i.e.,
+  /// the second alternative will not be processed, and therefore will not
+  /// capture values, if the first alternative succeeded.
+  Derived &alternatives(CapturingOpMatcher &first, CapturingOpMatcher &second) {
+    return static_cast<Derived &>(
+        CapturingOpMatcher::alternatives(first, second));
+  }
+
+  /// Adds a predicate checking that all ops implementing TilingInterface in the
+  /// parent of the given type (e.g., a function or a module) were matched by
+  /// this or nested matchers. This is useful to ensure that the matcher covered
+  /// the entire parent region, not just a parent of it. This predicate **must**
+  /// be added *after* all the other predicates that capture.
+  template <typename ParentTy>
+  Derived &allTilableOpsCaptured() {
+    return static_cast<Derived &>(
+        CapturingOpMatcher::allTilableOpsCaptured<ParentTy>());
+  }
+
+  //-------------------------------------------------------------------------//
+  // Predicates for operands and results.
+  //-------------------------------------------------------------------------//
+
+  /// Adds a predicate checking that the operation has exactly the given number
+  /// of operands.
+  Derived &operand(NumEqualsTo num) {
+    return static_cast<Derived &>(CapturingOpMatcher::operand(num));
+  }
+
+  /// Adds a predicate checking that the `pos`-th operand of the operation is
+  /// defined by an operation that satisfies the given matcher.
+  Derived &operand(int64_t pos, CapturingOpMatcher &nested) {
+    return static_cast<Derived &>(CapturingOpMatcher::operand(pos, nested));
+  }
+
+  /// Adds a predicate checking that the `pos`-th operand of the operation
+  /// satisfies the given value matcher.
+  Derived &operand(int64_t pos, CapturingValueMatcher &nested) {
+    return static_cast<Derived &>(CapturingOpMatcher::operand(pos, nested));
+  }
+
+  /// Adds a predicate checking that the `pos`-th operand of the operation is
+  /// defined by `arith.constant` with the value 1.0.
+  // TODO: better matching for attributes.
+  Derived &operand(int64_t pos, ConstantFloatOne c) {
+    return static_cast<Derived &>(CapturingOpMatcher::operand(pos, c));
+  }
+
+  /// Adds a predicate checking that the operation has exactly the given number
+  /// of results.
+  Derived &result(NumEqualsTo num) {
+    return static_cast<Derived &>(CapturingOpMatcher::result(num));
+  }
+
+  /// Adds a predicate checking that the `pos`-th result of the operation
+  /// satisfies the given value matcher.
+  Derived &result(int64_t pos, CapturingValueMatcher &nested) {
+    return static_cast<Derived &>(CapturingOpMatcher::result(pos, nested));
   }
 };
 
 /// Matcher for the `tensor.pad` operation.
-class TensorPadOpMatcher : public ConcreteOpMatcher<tensor::PadOp> {
+class TensorPadOpMatcher
+    : public ConcreteOpMatcher<TensorPadOpMatcher, tensor::PadOp> {
   friend class MatcherContext;
 
-  TensorPadOpMatcher() : ConcreteOpMatcher<tensor::PadOp>() {}
+  TensorPadOpMatcher() = default;
 
 public:
   /// Adds a predicate checking that the low padding sizes are exactly the given
@@ -534,14 +594,18 @@
       CapturingOpMatcher::create<OpTy...>());
 }
 
-/// Matcher for structured aka Linalg operations. Extensions must follow the
-/// same conditions as the base class.
-class StructuredOpMatcher : public CapturingOpMatcher {
+/// Matcher for structured aka Linalg operations.
+class StructuredOpMatcher
+    : public ConcreteOpMatcher<StructuredOpMatcher, linalg::LinalgOp> {
   friend class MatcherContext;
 
-  StructuredOpMatcher();
+  StructuredOpMatcher() = default;
 
 public:
+  static StringRef getConcreteOpDescription() {
+    return "linalg interface implementation";
+  }
+
   /// Creates a matcher for a structured operation with one of the given types.
   template <typename... OpType>
   static StructuredOpMatcher create() {
@@ -561,7 +625,6 @@
   //===-------------------------------------------------------------------===//
   /// Adds a predicate checking that the given rank must be greater than some
   /// constant value.
-  // TODO: Base class, derived class and proper API.
   StructuredOpMatcher &rank(NumGreaterEqualTo minRank);
   StructuredOpMatcher &rank(NumLowerEqualTo maxRank);
 
@@ -688,27 +751,6 @@
   StructuredOpMatcher &input(int64_t position, ConstantFloatZero);
 
   //===-------------------------------------------------------------------===//
-  // Constraints on adjacent ops.
-  //===-------------------------------------------------------------------===//
-
-  /// Adds a predicate checking that all ops implementing TilingInterface in the
-  /// parent of the given type (e.g., a function or a module) were matched by
-  /// this or nested matchers. This is useful to ensure that the matcher covered
-  /// the entire parent region, not just a parent of it. This predicate **must**
-  /// be added *after* all the other predicates that capture.
-  template <typename OpTy>
-  StructuredOpMatcher &allTilableOpsCaptured() {
-    SmallVector<CapturingOpMatcher *> copy;
-    copy.push_back(this);
-    getAllNested(copy);
-    addPredicate([copy = std::move(copy)](linalg::LinalgOp linalgOp) {
-      Operation *parent = linalgOp->getParentOfType<OpTy>();
-      return checkAllTilableMatched(parent, linalgOp, copy);
-    });
-    return *this;
-  }
-
-  //===-------------------------------------------------------------------===//
   // Constraints on output operands.
   //===-------------------------------------------------------------------===//
 
@@ -822,17 +864,6 @@
   StructuredOpMatcher &passThroughOp();
 
 private:
-  /// Adds a predicate for the matched operation to satisfy.
-  void addPredicate(std::function<bool(linalg::LinalgOp)> predicate) {
-    // Check that the operation implements the LinalgOp interface and dispatch
-    // to the predicate.
-    CapturingOpMatcher::addPredicate(
-        [inner = std::move(predicate)](Operation *op) {
-          auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-          return linalgOp && inner(linalgOp);
-        });
-  }
-
   /// Non-template implementations of nested predicate builders for inputs,
   /// outputs and results. Should not be called directly.
   void addInputMatcher(int64_t position,
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 4995395..56de39a 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -406,36 +406,23 @@
 }
 
 //===---------------------------------------------------------------------===//
-// StructuredOpMatcher and friends.
-//===---------------------------------------------------------------------===//
-
-transform_ext::StructuredOpMatcher::StructuredOpMatcher() {
-  addPredicate([](Operation *op) {
-    LLVM_DEBUG(DBGS() << "is a structured op");
-    return isa<linalg::LinalgOp>(op);
-  });
-}
-
-//===---------------------------------------------------------------------===//
 // Constraints on op rank and dims.
 //===---------------------------------------------------------------------===//
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::rank(NumGreaterEqualTo minRank) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "rank >= " << minRank.value);
     return linalgOp.getNumLoops() >= minRank.value;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::rank(NumLowerEqualTo maxRank) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "rank <= " << maxRank.value);
     return linalgOp.getNumLoops() <= maxRank.value;
   });
-  return *this;
 }
 
 StringRef stringifyShapeKind(transform_ext::ShapeKind kind) {
@@ -451,8 +438,8 @@
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::dim(SmallVector<int64_t> &&dimensions,
                                         ShapeKind kind) {
-  addPredicate([dimensions = std::move(dimensions),
-                kind](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([dimensions = std::move(dimensions),
+                       kind](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "dimensions [";
                llvm::interleaveComma(dimensions, llvm::dbgs());
                llvm::dbgs() << "] are " << stringifyShapeKind(kind));
@@ -469,26 +456,24 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "all dimensions are " << stringifyShapeKind(kind));
     SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges();
     return llvm::all_of(shape, [=](int64_t dimension) {
       return ShapedType::isDynamic(dimension) ^ (kind == ShapeKind::Static);
     });
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::dim(SmallVector<int64_t> &&dimensions,
                                         utils::IteratorType kind) {
-  addPredicate([dimensions = std::move(dimensions),
-                kind](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([dimensions = std::move(dimensions),
+                       kind](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "dimensions [";
                llvm::interleaveComma(dimensions, llvm::dbgs());
                llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind));
@@ -506,7 +491,6 @@
     }
     return true;
   });
-  return *this;
 }
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) {
@@ -516,8 +500,8 @@
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::dim(AllDimsExcept &&dims,
                                         utils::IteratorType kind) {
-  addPredicate([dimensions = std::move(dims),
-                kind](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([dimensions = std::move(dims),
+                       kind](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "all dimensions except [";
                llvm::interleaveComma(dimensions.getExcluded(), llvm::dbgs());
                llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind));
@@ -537,13 +521,12 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::dim(int64_t dimension,
                                         DivisibleBy divisibleBy) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "dimension " << dimension << " is divisible by "
                       << divisibleBy.value);
     int64_t rank = linalgOp.getNumLoops();
@@ -555,7 +538,6 @@
     int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension];
     return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0);
   });
-  return *this;
 }
 
 //===---------------------------------------------------------------------===//
@@ -563,17 +545,16 @@
 //===---------------------------------------------------------------------===//
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::rank(CaptureRank capture) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "capture rank");
     capture.value = linalgOp.getNumLoops();
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::dim(int64_t dimension, CaptureDim capture) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "capture dimension");
     int64_t rank = linalgOp.getNumLoops();
     int64_t transformedDimension =
@@ -584,22 +565,20 @@
     capture.value = linalgOp.getStaticLoopRanges()[transformedDimension];
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::dim(AllDims tag, CaptureDims captures) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "capture all dimensions");
     captures.value = linalgOp.getStaticLoopRanges();
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "capture convolution dimensions\n");
     StringRef convMessage = linalg::detail::getMatchConvolutionMessage(
         mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp,
@@ -610,7 +589,6 @@
                       << convMessage << "\n");
     return false;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher::StructuredOpMatcher(
@@ -684,7 +662,7 @@
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "all input operands have permutation maps");
     // all_of with a lambda requires const-casting dance, so using a loop.
     for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
@@ -693,13 +671,12 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(AllOperands tag,
                                           IsProjectedPermutation) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "all input operands have projected permutation maps");
     // all_of with a lambda requires const-casting dance, so using a loop.
     for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
@@ -708,7 +685,6 @@
     }
     return true;
   });
-  return *this;
 }
 
 /// Helper to check if the map is an identity map with a projected dim.
@@ -742,7 +718,7 @@
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions,
                                           IsProjected dim) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "operands ";
                llvm::interleaveComma(positions, llvm::dbgs());
                llvm::dbgs() << " have a permutation maps with " << dim.value
@@ -760,12 +736,11 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "all input operands have identity maps");
     // all_of with a lambda requires const-casting dance, so using a loop.
     for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
@@ -774,13 +749,12 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions,
                                           IsIdentity) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "input operands ";
                llvm::interleaveComma(positions, llvm::dbgs());
                llvm::dbgs() << " have identity maps");
@@ -795,13 +769,12 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(int64_t position,
                                           ElementTypeBitWidth width) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "input operand #" << position
                       << " has elemental type with bit width " << width.value);
     int64_t updatedPosition = position;
@@ -814,13 +787,12 @@
     return shapedType && shapedType.getElementType().isIntOrFloat() &&
            shapedType.getElementType().getIntOrFloatBitWidth() == width.value;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(int64_t position,
                                           CaptureElementTypeBitWidth width) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "input operand #" << position << " capture bitwidth");
     int64_t updatedPosition = position;
     if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs()))
@@ -834,13 +806,12 @@
     width.value = shapedType.getElementType().getIntOrFloatBitWidth();
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(int64_t position,
                                           CaptureElementType elem) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "input operand #" << position
                       << " capture element type");
     int64_t updatedPosition = position;
@@ -857,16 +828,14 @@
     elem.value = shapedType.getElementType();
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::input(NumEqualsTo num) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "number of input operands == " << num.value);
     return linalgOp.getNumDpsInputs() == num.value;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
@@ -884,7 +853,7 @@
 
 transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input(
     int64_t position, std::function<bool(llvm::APFloat)> floatValueFn) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "input operand #" << position
                       << " is a special floating point constant");
     int64_t updatedPosition = position;
@@ -897,7 +866,6 @@
       return false;
     return floatValueFn(cstOp.value());
   });
-  return *this;
 }
 
 //===---------------------------------------------------------------------===//
@@ -933,7 +901,7 @@
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "all output operands have permutation maps");
     for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
       if (!linalgOp.getMatchingIndexingMap(operand).isPermutation())
@@ -941,13 +909,12 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(AllOperands tag,
                                            IsProjectedPermutation) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "all output operands have projected permutation maps");
     for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
       if (!linalgOp.getMatchingIndexingMap(operand).isProjectedPermutation())
@@ -955,12 +922,11 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "all output operands have a maps with projected");
     int64_t updatedDim = dim.value;
     if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops()))
@@ -972,12 +938,11 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps");
     for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
       if (!linalgOp.getMatchingIndexingMap(operand).isIdentity())
@@ -985,13 +950,12 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(int64_t position,
                                            ElementTypeBitWidth width) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "output operand #" << position
                       << " has elemental type with bit width " << width.value);
     int64_t updatedPosition = position;
@@ -1004,13 +968,12 @@
     return shapedType && shapedType.getElementType().isIntOrFloat() &&
            shapedType.getElementType().getIntOrFloatBitWidth() == width.value;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(int64_t position,
                                            CaptureElementTypeBitWidth width) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "output operand #" << position << " capture bitwidth");
     int64_t updatedPosition = position;
     if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits()))
@@ -1026,13 +989,12 @@
     width.value = shapedType.getElementType().getIntOrFloatBitWidth();
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(int64_t position,
                                            CaptureElementType elem) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "output operand #" << position
                       << " capture element type");
     int64_t updatedPosition = position;
@@ -1049,13 +1011,12 @@
     elem.value = shapedType.getElementType();
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(int64_t position,
                                            SingleCombinerReduction tag) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "output operand #" << position
                       << " is populated by a single-combiner reduction");
     int64_t updatedPosition = position;
@@ -1066,16 +1027,14 @@
                           combinerOps) &&
            llvm::hasSingleElement(combinerOps);
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::output(NumEqualsTo num) {
-  addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
     LLVM_DEBUG(DBGS() << "number of output operands == " << num.value);
     return linalgOp.getNumDpsInits() == num.value;
   });
-  return *this;
 }
 
 //===---------------------------------------------------------------------===//
@@ -1115,7 +1074,7 @@
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs(
     StringRef opcode, bool commutative) {
-  addPredicate([=](linalg::LinalgOp linalgOp) {
+  return addPredicate([=](linalg::LinalgOp linalgOp) {
     if (linalgOp.getBlock()->getOperations().size() != 2)
       return false;
     Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin());
@@ -1148,12 +1107,11 @@
     }
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::isFloatReciprocal() {
-  addPredicate([=](linalg::LinalgOp linalgOp) {
+  return addPredicate([=](linalg::LinalgOp linalgOp) {
     LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation");
     if (linalgOp.getBlock()->getOperations().size() != 2)
       return false;
@@ -1174,12 +1132,11 @@
       return false;
     return true;
   });
-  return *this;
 }
 
 transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::passThroughOp() {
-  addPredicate([=](linalg::LinalgOp linalgOp) {
+  return addPredicate([=](linalg::LinalgOp linalgOp) {
     if (linalgOp.getBlock()->getOperations().size() != 1)
       return false;
     Operation *yieldOp = linalgOp.getBlock()->getTerminator();
@@ -1191,7 +1148,6 @@
     }
     return true;
   });
-  return *this;
 }
 
 void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor(
@@ -1205,51 +1161,47 @@
 
 transform_ext::TensorPadOpMatcher &
 transform_ext::TensorPadOpMatcher::low(ArrayRef<int64_t> sizes) {
-  addPredicate([=](tensor::PadOp tensorPad) {
+  return addPredicate([=](tensor::PadOp tensorPad) {
     LLVM_DEBUG({
       DBGS() << "low pad sizes are ";
       llvm::interleaveComma(sizes, llvm::dbgs());
     });
     return tensorPad.getStaticLow() == sizes;
   });
-  return *this;
 }
 
 transform_ext::TensorPadOpMatcher &
 transform_ext::TensorPadOpMatcher::low(AllDims tag, int64_t size) {
-  addPredicate([=](tensor::PadOp tensorPad) {
+  return addPredicate([=](tensor::PadOp tensorPad) {
     LLVM_DEBUG(DBGS() << "all low pad sizes are " << size);
     return llvm::all_of(tensorPad.getStaticLow(),
                         [&](int64_t v) { return v == size; });
   });
-  return *this;
 }
 
 transform_ext::TensorPadOpMatcher &
 transform_ext::TensorPadOpMatcher::high(ArrayRef<int64_t> sizes) {
-  addPredicate([=](tensor::PadOp tensorPad) {
+  return addPredicate([=](tensor::PadOp tensorPad) {
     LLVM_DEBUG({
       DBGS() << "high pad sizes are ";
       llvm::interleaveComma(sizes, llvm::dbgs());
     });
     return tensorPad.getStaticHigh() == sizes;
   });
-  return *this;
 }
 
 transform_ext::TensorPadOpMatcher &
 transform_ext::TensorPadOpMatcher::high(AllDims tag, int64_t size) {
-  addPredicate([=](tensor::PadOp tensorPad) {
+  return addPredicate([=](tensor::PadOp tensorPad) {
     LLVM_DEBUG(DBGS() << "all high pad sizes are " << size);
     return llvm::all_of(tensorPad.getStaticHigh(),
                         [&](int64_t v) { return v == size; });
   });
-  return *this;
 }
 
 transform_ext::TensorPadOpMatcher &
 transform_ext::TensorPadOpMatcher::yieldsExternalValue() {
-  addPredicate([=](tensor::PadOp tensorPad) {
+  return addPredicate([=](tensor::PadOp tensorPad) {
     LLVM_DEBUG(DBGS() << "pad body yields an externally-defined value");
     Block *body = tensorPad.getBody();
     if (!llvm::hasSingleElement(*body))
@@ -1260,7 +1212,6 @@
                           return !arg || arg.getOwner() != body;
                         });
   });
-  return *this;
 }
 
 //===---------------------------------------------------------------------===//
@@ -1678,9 +1629,9 @@
       .dim(transform_ext::AllDims(), transform_ext::CaptureDims(captures.dims))
       .elementType(CaptureElementType(captures.elementType));
   auto &opMatcher = transform_ext::m_tensorPad(context)
+                        .result(0, value)
                         .low(AllDims(), 0)
-                        .yieldsExternalValue()
-                        .result(0, value);
+                        .yieldsExternalValue();
   if (mustMatchEntireFunc)
     opMatcher = opMatcher.allTilableOpsCaptured<func::FuncOp>();
   padCapture = &opMatcher;