Clean up CapturingOpMatcher and derived classes, NFC (#14022)
* Use CRTP ConcreteOpMatcher and make it forward base class methods.
* Make StructuredOpMatcher derive ConcreteOpMatcher
* Remove duplicated functionality
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index c59d659..8f98e30 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -354,11 +354,6 @@
// Constraints on adjacent ops.
//===-------------------------------------------------------------------===//
- /// Checks that `matchers` captured all tilable ops nested in `parent` except
- /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured.
- static bool checkAllTilableMatched(Operation *parent, Operation *op,
- ArrayRef<CapturingOpMatcher *> matchers);
-
/// Adds a predicate checking that all ops implementing TilingInterface in the
/// parent of the given type (e.g., a function or a module) were matched by
/// this or nested matchers. This is useful to ensure that the matcher covered
@@ -422,6 +417,11 @@
/// A list of additional conditions for the operation to match.
SmallVector<PredicateFn> predicates;
+ /// Checks that `matchers` captured all tilable ops nested in `parent` except
+ /// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured.
+ static bool checkAllTilableMatched(Operation *parent, Operation *op,
+ ArrayRef<CapturingOpMatcher *> matchers);
+
/// Creates a matcher for an operation with one of the given types.
template <typename... OpType>
static CapturingOpMatcher create() {
@@ -448,51 +448,111 @@
/// with debug includes, and ConcreteOpMatcher is a class template that can only
/// reside in the header.
void debugOutputForConcreteOpMatcherConstructor(StringRef name);
-
-/// Specialization hook that returns an op description used in ConcreteOpMatcher
-/// debug output. By default, returns the operation name. This class template
-/// can be specialized for a better name and must be specialized for interfaces
-/// that don't have an operation name.
-template <typename OpTy>
-struct DebugOpKindDescription {
- static StringRef get() { return OpTy::getOperationName(); }
-};
} // namespace detail
/// Base class for matchers that match a specific op. Adds an initial predicate
/// checking if the op is indeed of the specified kind.
/// Derived classes specializing this for op interfaces MUST also define a
/// specialization of DebugOpKindDescription.
-// TODO: use traits instead of inheritance to inject behavior and avoid
-// unintended upcasting when calling parent methods in the fluent API.
-template <typename OpTy>
+template <typename Derived, typename OpTy>
class ConcreteOpMatcher : public CapturingOpMatcher {
protected:
+ using Base = ConcreteOpMatcher;
+
+ static StringRef getConcreteOpDescription() {
+ return OpTy::getOperationName();
+ }
+
/// Adds a predicate checking if the op is of the OpTy kind.
ConcreteOpMatcher() {
CapturingOpMatcher::addPredicate([](Operation *op) {
detail::debugOutputForConcreteOpMatcherConstructor(
- detail::DebugOpKindDescription<OpTy>::get());
+ Derived::getConcreteOpDescription());
return isa<OpTy>(op);
});
}
/// Adds a predicate for the matched operation to satisfy.
template <typename FnTy>
- void addPredicate(FnTy &&predicate) {
+ Derived &addPredicate(FnTy &&predicate) {
// Dispatch to the callback.
CapturingOpMatcher::addPredicate(
[inner = std::move(predicate)](Operation *op) {
return inner(cast<OpTy>(op));
});
+ return static_cast<Derived &>(*this);
+ }
+
+public:
+ /// Adds alternative paths for predicates. In practice, this is just a
+ /// predicate that is satisfied when either the first or the second matcher is
+ /// satisfied. The alternative satisfaction is eager and short-cutting, i.e.,
+ /// the second alternative will not be processed, and therefore will not
+ /// capture values, if the first alternative succeeded.
+ Derived &alternatives(CapturingOpMatcher &first, CapturingOpMatcher &second) {
+ return static_cast<Derived &>(
+ CapturingOpMatcher::alternatives(first, second));
+ }
+
+ /// Adds a predicate checking that all ops implementing TilingInterface in the
+ /// parent of the given type (e.g., a function or a module) were matched by
+ /// this or nested matchers. This is useful to ensure that the matcher covered
+ /// the entire parent region, not just a parent of it. This predicate **must**
+ /// be added *after* all the other predicates that capture.
+ template <typename ParentTy>
+ Derived &allTilableOpsCaptured() {
+ return static_cast<Derived &>(
+ CapturingOpMatcher::allTilableOpsCaptured<ParentTy>());
+ }
+
+ //-------------------------------------------------------------------------//
+ // Predicates for operands and results.
+ //-------------------------------------------------------------------------//
+
+ /// Adds a predicate checking that the operation has exactly the given number
+ /// of operands.
+ Derived &operand(NumEqualsTo num) {
+ return static_cast<Derived &>(CapturingOpMatcher::operand(num));
+ }
+
+ /// Adds a predicate checking that the `pos`-th operand of the operation is
+ /// defined by an operation that satisfies the given matcher.
+ Derived &operand(int64_t pos, CapturingOpMatcher &nested) {
+ return static_cast<Derived &>(CapturingOpMatcher::operand(pos, nested));
+ }
+
+ /// Adds a predicate checking that the `pos`-th operand of the operation
+ /// satisfies the given value matcher.
+ Derived &operand(int64_t pos, CapturingValueMatcher &nested) {
+ return static_cast<Derived &>(CapturingOpMatcher::operand(pos, nested));
+ }
+
+ /// Adds a predicate checking that the `pos`-th operand of the operation is
+ /// defined by `arith.constant` with the value 1.0.
+ // TODO: better matching for attributes.
+ Derived &operand(int64_t pos, ConstantFloatOne c) {
+ return static_cast<Derived &>(CapturingOpMatcher::operand(pos, c));
+ }
+
+ /// Adds a predicate checking that the operation has exactly the given number
+ /// of results.
+ Derived &result(NumEqualsTo num) {
+ return static_cast<Derived &>(CapturingOpMatcher::result(num));
+ }
+
+ /// Adds a predicate checking that the `pos`-th result of the operation
+ /// satisfies the given value matcher.
+ Derived &result(int64_t pos, CapturingValueMatcher &nested) {
+ return static_cast<Derived &>(CapturingOpMatcher::result(pos, nested));
}
};
/// Matcher for the `tensor.pad` operation.
-class TensorPadOpMatcher : public ConcreteOpMatcher<tensor::PadOp> {
+class TensorPadOpMatcher
+ : public ConcreteOpMatcher<TensorPadOpMatcher, tensor::PadOp> {
friend class MatcherContext;
- TensorPadOpMatcher() : ConcreteOpMatcher<tensor::PadOp>() {}
+ TensorPadOpMatcher() = default;
public:
/// Adds a predicate checking that the low padding sizes are exactly the given
@@ -534,14 +594,18 @@
CapturingOpMatcher::create<OpTy...>());
}
-/// Matcher for structured aka Linalg operations. Extensions must follow the
-/// same conditions as the base class.
-class StructuredOpMatcher : public CapturingOpMatcher {
+/// Matcher for structured aka Linalg operations.
+class StructuredOpMatcher
+ : public ConcreteOpMatcher<StructuredOpMatcher, linalg::LinalgOp> {
friend class MatcherContext;
- StructuredOpMatcher();
+ StructuredOpMatcher() = default;
public:
+ static StringRef getConcreteOpDescription() {
+ return "linalg interface implementation";
+ }
+
/// Creates a matcher for a structured operation with one of the given types.
template <typename... OpType>
static StructuredOpMatcher create() {
@@ -561,7 +625,6 @@
//===-------------------------------------------------------------------===//
/// Adds a predicate checking that the given rank must be greater than some
/// constant value.
- // TODO: Base class, derived class and proper API.
StructuredOpMatcher &rank(NumGreaterEqualTo minRank);
StructuredOpMatcher &rank(NumLowerEqualTo maxRank);
@@ -688,27 +751,6 @@
StructuredOpMatcher &input(int64_t position, ConstantFloatZero);
//===-------------------------------------------------------------------===//
- // Constraints on adjacent ops.
- //===-------------------------------------------------------------------===//
-
- /// Adds a predicate checking that all ops implementing TilingInterface in the
- /// parent of the given type (e.g., a function or a module) were matched by
- /// this or nested matchers. This is useful to ensure that the matcher covered
- /// the entire parent region, not just a parent of it. This predicate **must**
- /// be added *after* all the other predicates that capture.
- template <typename OpTy>
- StructuredOpMatcher &allTilableOpsCaptured() {
- SmallVector<CapturingOpMatcher *> copy;
- copy.push_back(this);
- getAllNested(copy);
- addPredicate([copy = std::move(copy)](linalg::LinalgOp linalgOp) {
- Operation *parent = linalgOp->getParentOfType<OpTy>();
- return checkAllTilableMatched(parent, linalgOp, copy);
- });
- return *this;
- }
-
- //===-------------------------------------------------------------------===//
// Constraints on output operands.
//===-------------------------------------------------------------------===//
@@ -822,17 +864,6 @@
StructuredOpMatcher &passThroughOp();
private:
- /// Adds a predicate for the matched operation to satisfy.
- void addPredicate(std::function<bool(linalg::LinalgOp)> predicate) {
- // Check that the operation implements the LinalgOp interface and dispatch
- // to the predicate.
- CapturingOpMatcher::addPredicate(
- [inner = std::move(predicate)](Operation *op) {
- auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
- return linalgOp && inner(linalgOp);
- });
- }
-
/// Non-template implementations of nested predicate builders for inputs,
/// outputs and results. Should not be called directly.
void addInputMatcher(int64_t position,
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 4995395..56de39a 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -406,36 +406,23 @@
}
//===---------------------------------------------------------------------===//
-// StructuredOpMatcher and friends.
-//===---------------------------------------------------------------------===//
-
-transform_ext::StructuredOpMatcher::StructuredOpMatcher() {
- addPredicate([](Operation *op) {
- LLVM_DEBUG(DBGS() << "is a structured op");
- return isa<linalg::LinalgOp>(op);
- });
-}
-
-//===---------------------------------------------------------------------===//
// Constraints on op rank and dims.
//===---------------------------------------------------------------------===//
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::rank(NumGreaterEqualTo minRank) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "rank >= " << minRank.value);
return linalgOp.getNumLoops() >= minRank.value;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::rank(NumLowerEqualTo maxRank) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "rank <= " << maxRank.value);
return linalgOp.getNumLoops() <= maxRank.value;
});
- return *this;
}
StringRef stringifyShapeKind(transform_ext::ShapeKind kind) {
@@ -451,8 +438,8 @@
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::dim(SmallVector<int64_t> &&dimensions,
ShapeKind kind) {
- addPredicate([dimensions = std::move(dimensions),
- kind](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([dimensions = std::move(dimensions),
+ kind](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "dimensions [";
llvm::interleaveComma(dimensions, llvm::dbgs());
llvm::dbgs() << "] are " << stringifyShapeKind(kind));
@@ -469,26 +456,24 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "all dimensions are " << stringifyShapeKind(kind));
SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges();
return llvm::all_of(shape, [=](int64_t dimension) {
return ShapedType::isDynamic(dimension) ^ (kind == ShapeKind::Static);
});
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::dim(SmallVector<int64_t> &&dimensions,
utils::IteratorType kind) {
- addPredicate([dimensions = std::move(dimensions),
- kind](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([dimensions = std::move(dimensions),
+ kind](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "dimensions [";
llvm::interleaveComma(dimensions, llvm::dbgs());
llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind));
@@ -506,7 +491,6 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) {
@@ -516,8 +500,8 @@
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::dim(AllDimsExcept &&dims,
utils::IteratorType kind) {
- addPredicate([dimensions = std::move(dims),
- kind](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([dimensions = std::move(dims),
+ kind](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "all dimensions except [";
llvm::interleaveComma(dimensions.getExcluded(), llvm::dbgs());
llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind));
@@ -537,13 +521,12 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::dim(int64_t dimension,
DivisibleBy divisibleBy) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "dimension " << dimension << " is divisible by "
<< divisibleBy.value);
int64_t rank = linalgOp.getNumLoops();
@@ -555,7 +538,6 @@
int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension];
return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0);
});
- return *this;
}
//===---------------------------------------------------------------------===//
@@ -563,17 +545,16 @@
//===---------------------------------------------------------------------===//
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::rank(CaptureRank capture) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "capture rank");
capture.value = linalgOp.getNumLoops();
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::dim(int64_t dimension, CaptureDim capture) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "capture dimension");
int64_t rank = linalgOp.getNumLoops();
int64_t transformedDimension =
@@ -584,22 +565,20 @@
capture.value = linalgOp.getStaticLoopRanges()[transformedDimension];
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::dim(AllDims tag, CaptureDims captures) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "capture all dimensions");
captures.value = linalgOp.getStaticLoopRanges();
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "capture convolution dimensions\n");
StringRef convMessage = linalg::detail::getMatchConvolutionMessage(
mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp,
@@ -610,7 +589,6 @@
<< convMessage << "\n");
return false;
});
- return *this;
}
transform_ext::StructuredOpMatcher::StructuredOpMatcher(
@@ -684,7 +662,7 @@
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "all input operands have permutation maps");
// all_of with a lambda requires const-casting dance, so using a loop.
for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
@@ -693,13 +671,12 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(AllOperands tag,
IsProjectedPermutation) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "all input operands have projected permutation maps");
// all_of with a lambda requires const-casting dance, so using a loop.
for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
@@ -708,7 +685,6 @@
}
return true;
});
- return *this;
}
/// Helper to check if the map is an identity map with a projected dim.
@@ -742,7 +718,7 @@
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions,
IsProjected dim) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "operands ";
llvm::interleaveComma(positions, llvm::dbgs());
llvm::dbgs() << " have a permutation maps with " << dim.value
@@ -760,12 +736,11 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "all input operands have identity maps");
// all_of with a lambda requires const-casting dance, so using a loop.
for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
@@ -774,13 +749,12 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions,
IsIdentity) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "input operands ";
llvm::interleaveComma(positions, llvm::dbgs());
llvm::dbgs() << " have identity maps");
@@ -795,13 +769,12 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(int64_t position,
ElementTypeBitWidth width) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "input operand #" << position
<< " has elemental type with bit width " << width.value);
int64_t updatedPosition = position;
@@ -814,13 +787,12 @@
return shapedType && shapedType.getElementType().isIntOrFloat() &&
shapedType.getElementType().getIntOrFloatBitWidth() == width.value;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(int64_t position,
CaptureElementTypeBitWidth width) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "input operand #" << position << " capture bitwidth");
int64_t updatedPosition = position;
if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs()))
@@ -834,13 +806,12 @@
width.value = shapedType.getElementType().getIntOrFloatBitWidth();
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(int64_t position,
CaptureElementType elem) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "input operand #" << position
<< " capture element type");
int64_t updatedPosition = position;
@@ -857,16 +828,14 @@
elem.value = shapedType.getElementType();
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::input(NumEqualsTo num) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "number of input operands == " << num.value);
return linalgOp.getNumDpsInputs() == num.value;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
@@ -884,7 +853,7 @@
transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input(
int64_t position, std::function<bool(llvm::APFloat)> floatValueFn) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "input operand #" << position
<< " is a special floating point constant");
int64_t updatedPosition = position;
@@ -897,7 +866,6 @@
return false;
return floatValueFn(cstOp.value());
});
- return *this;
}
//===---------------------------------------------------------------------===//
@@ -933,7 +901,7 @@
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "all output operands have permutation maps");
for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
if (!linalgOp.getMatchingIndexingMap(operand).isPermutation())
@@ -941,13 +909,12 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::output(AllOperands tag,
IsProjectedPermutation) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "all output operands have projected permutation maps");
for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
if (!linalgOp.getMatchingIndexingMap(operand).isProjectedPermutation())
@@ -955,12 +922,11 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "all output operands have a maps with projected");
int64_t updatedDim = dim.value;
if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops()))
@@ -972,12 +938,11 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps");
for (OpOperand *operand : linalgOp.getDpsInitOperands()) {
if (!linalgOp.getMatchingIndexingMap(operand).isIdentity())
@@ -985,13 +950,12 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::output(int64_t position,
ElementTypeBitWidth width) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "output operand #" << position
<< " has elemental type with bit width " << width.value);
int64_t updatedPosition = position;
@@ -1004,13 +968,12 @@
return shapedType && shapedType.getElementType().isIntOrFloat() &&
shapedType.getElementType().getIntOrFloatBitWidth() == width.value;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::output(int64_t position,
CaptureElementTypeBitWidth width) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "output operand #" << position << " capture bitwidth");
int64_t updatedPosition = position;
if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits()))
@@ -1026,13 +989,12 @@
width.value = shapedType.getElementType().getIntOrFloatBitWidth();
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::output(int64_t position,
CaptureElementType elem) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "output operand #" << position
<< " capture element type");
int64_t updatedPosition = position;
@@ -1049,13 +1011,12 @@
elem.value = shapedType.getElementType();
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::output(int64_t position,
SingleCombinerReduction tag) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "output operand #" << position
<< " is populated by a single-combiner reduction");
int64_t updatedPosition = position;
@@ -1066,16 +1027,14 @@
combinerOps) &&
llvm::hasSingleElement(combinerOps);
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::output(NumEqualsTo num) {
- addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+ return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
LLVM_DEBUG(DBGS() << "number of output operands == " << num.value);
return linalgOp.getNumDpsInits() == num.value;
});
- return *this;
}
//===---------------------------------------------------------------------===//
@@ -1115,7 +1074,7 @@
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs(
StringRef opcode, bool commutative) {
- addPredicate([=](linalg::LinalgOp linalgOp) {
+ return addPredicate([=](linalg::LinalgOp linalgOp) {
if (linalgOp.getBlock()->getOperations().size() != 2)
return false;
Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin());
@@ -1148,12 +1107,11 @@
}
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::isFloatReciprocal() {
- addPredicate([=](linalg::LinalgOp linalgOp) {
+ return addPredicate([=](linalg::LinalgOp linalgOp) {
LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation");
if (linalgOp.getBlock()->getOperations().size() != 2)
return false;
@@ -1174,12 +1132,11 @@
return false;
return true;
});
- return *this;
}
transform_ext::StructuredOpMatcher &
transform_ext::StructuredOpMatcher::passThroughOp() {
- addPredicate([=](linalg::LinalgOp linalgOp) {
+ return addPredicate([=](linalg::LinalgOp linalgOp) {
if (linalgOp.getBlock()->getOperations().size() != 1)
return false;
Operation *yieldOp = linalgOp.getBlock()->getTerminator();
@@ -1191,7 +1148,6 @@
}
return true;
});
- return *this;
}
void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor(
@@ -1205,51 +1161,47 @@
transform_ext::TensorPadOpMatcher &
transform_ext::TensorPadOpMatcher::low(ArrayRef<int64_t> sizes) {
- addPredicate([=](tensor::PadOp tensorPad) {
+ return addPredicate([=](tensor::PadOp tensorPad) {
LLVM_DEBUG({
DBGS() << "low pad sizes are ";
llvm::interleaveComma(sizes, llvm::dbgs());
});
return tensorPad.getStaticLow() == sizes;
});
- return *this;
}
transform_ext::TensorPadOpMatcher &
transform_ext::TensorPadOpMatcher::low(AllDims tag, int64_t size) {
- addPredicate([=](tensor::PadOp tensorPad) {
+ return addPredicate([=](tensor::PadOp tensorPad) {
LLVM_DEBUG(DBGS() << "all low pad sizes are " << size);
return llvm::all_of(tensorPad.getStaticLow(),
[&](int64_t v) { return v == size; });
});
- return *this;
}
transform_ext::TensorPadOpMatcher &
transform_ext::TensorPadOpMatcher::high(ArrayRef<int64_t> sizes) {
- addPredicate([=](tensor::PadOp tensorPad) {
+ return addPredicate([=](tensor::PadOp tensorPad) {
LLVM_DEBUG({
DBGS() << "high pad sizes are ";
llvm::interleaveComma(sizes, llvm::dbgs());
});
return tensorPad.getStaticHigh() == sizes;
});
- return *this;
}
transform_ext::TensorPadOpMatcher &
transform_ext::TensorPadOpMatcher::high(AllDims tag, int64_t size) {
- addPredicate([=](tensor::PadOp tensorPad) {
+ return addPredicate([=](tensor::PadOp tensorPad) {
LLVM_DEBUG(DBGS() << "all high pad sizes are " << size);
return llvm::all_of(tensorPad.getStaticHigh(),
[&](int64_t v) { return v == size; });
});
- return *this;
}
transform_ext::TensorPadOpMatcher &
transform_ext::TensorPadOpMatcher::yieldsExternalValue() {
- addPredicate([=](tensor::PadOp tensorPad) {
+ return addPredicate([=](tensor::PadOp tensorPad) {
LLVM_DEBUG(DBGS() << "pad body yields an externally-defined value");
Block *body = tensorPad.getBody();
if (!llvm::hasSingleElement(*body))
@@ -1260,7 +1212,6 @@
return !arg || arg.getOwner() != body;
});
});
- return *this;
}
//===---------------------------------------------------------------------===//
@@ -1678,9 +1629,9 @@
.dim(transform_ext::AllDims(), transform_ext::CaptureDims(captures.dims))
.elementType(CaptureElementType(captures.elementType));
auto &opMatcher = transform_ext::m_tensorPad(context)
+ .result(0, value)
.low(AllDims(), 0)
- .yieldsExternalValue()
- .result(0, value);
+ .yieldsExternalValue();
if (mustMatchEntireFunc)
opMatcher = opMatcher.allTilableOpsCaptured<func::FuncOp>();
padCapture = &opMatcher;