blob: 8f98e30d264716365265cc8b44a69483cf0eb97b [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_
#define IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_
#include <cstddef>
#include <cstdint>
#include <functional>
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringMap.h"
namespace mlir {
namespace transform_ext {
//===---------------------------------------------------------------------===//
// StructuredOpMatcher and predicates.
//===---------------------------------------------------------------------===//
class StructuredOpMatcher;
class MatcherContext;
StructuredOpMatcher &m_StructuredOp(MatcherContext &);
/// A tag indicating the shape being static or dynamic, for use with the
/// structured op matcher.
enum class ShapeKind { Static, Dynamic };
/// A placeholder indicating the structured op matcher to check the predicate
/// for all dimensions.
struct AllDims {};
/// A predicate indicating the structured op matcher to check the predicate for
/// all dimensions except the specified ones.
struct AllDimsExcept {
explicit AllDimsExcept(std::initializer_list<int64_t> range) {
llvm::append_range(exceptions, range);
}
ArrayRef<int64_t> getExcluded() const { return llvm::ArrayRef(exceptions); }
private:
SmallVector<int64_t> exceptions;
};
/// A placeholder indicating the structured op matcher to check the predicate
/// for all operands of the relevant kind.
struct AllOperands {};
/// Base class for single-value captures. Concrete captures should inherit this
/// and forward the constructor via `using Base::Base`.
template <typename T>
struct CaptureStaticValue {
using Base = CaptureStaticValue<T>;
explicit CaptureStaticValue(T &value) : value(value) {}
T &value;
};
/// Captures the (static) size of the dimension.
struct CaptureDim : public CaptureStaticValue<int64_t> {
using Base::Base;
};
/// Captures the (static) sizes of multiple dimensions.
struct CaptureDims : public CaptureStaticValue<SmallVector<int64_t>> {
using Base::Base;
};
/// Captures the convolution dimensions of the target operation.
struct CaptureConvDims
: public CaptureStaticValue<mlir::linalg::detail::ConvolutionDimensions> {
using Base::Base;
};
/// Captures the rank of the operation.
struct CaptureRank : public CaptureStaticValue<int64_t> {
using Base::Base;
};
/// Captures the bitwidth of an element type.
struct CaptureElementTypeBitWidth : public CaptureStaticValue<int64_t> {
using Base::Base;
};
/// Captures element element type.
struct CaptureElementType : public CaptureStaticValue<Type> {
using Base::Base;
};
template <typename T = Attribute>
struct CaptureAttribute : public CaptureStaticValue<T> {
static_assert(std::is_base_of_v<Attribute, T>,
"can only capture a subclass of Attribute");
using CaptureStaticValue<T>::CaptureStaticValue;
};
/// A tag indicating to look for any user of the operation's result that would
/// satisfy the predicate.
struct HasAnyUse {};
/// Base class for predicate parameters that can be described with the single
/// value. Concrete predicate parameters should inherit this and forward the
/// constructor via `using Base::Base`.
template <typename T>
struct SingleValuePredicateParam {
using Base = SingleValuePredicateParam<T>;
explicit SingleValuePredicateParam(T value) : value(value) {}
const T value;
};
/// Indicates that the dimension must be divisible by the given value.
struct DivisibleBy : public SingleValuePredicateParam<int64_t> {
using Base::Base;
};
/// Indicates that the number of entities must be equal to the given value.
struct NumEqualsTo : public SingleValuePredicateParam<size_t> {
using Base::Base;
};
/// Indicates that the number of entities must be greater than the given value.
struct NumGreaterEqualTo : public SingleValuePredicateParam<size_t> {
using Base::Base;
};
/// Indicates that the number of entities must be greater than the given value.
struct NumLowerEqualTo : public SingleValuePredicateParam<size_t> {
using Base::Base;
};
/// Indicates that the bit width of the elemental type must be equal to the give
/// value.
struct ElementTypeBitWidth : public SingleValuePredicateParam<size_t> {
using Base::Base;
};
/// Predicate tag indicating that the affine map is a permutation.
struct IsPermutation {};
/// Predicate tag indicating that the affine map is a projected permutation.
struct IsProjectedPermutation {};
/// Predicate tag indicating that the affine map is a projection of given
/// dimension.
struct IsProjected : public SingleValuePredicateParam<int64_t> {
using Base::Base;
};
/// Predicate tag indicating that the affine map is an identity.
struct IsIdentity {};
/// Predicate tag indicating that the operand is a special float constant.
struct ConstantFloatMinOrMinusInf {};
struct ConstantFloatZero {};
struct ConstantFloatOne {};
/// Indicates that the match optional. The matcher is still expected to run and
/// capture if successful. The parameter can be set to false
struct OptionalMatch : public SingleValuePredicateParam<bool> {
OptionalMatch() : Base(true) {}
explicit OptionalMatch(bool set) : Base(set) {}
};
/// Predicate tag indicating that the reduction is produced by a single combiner
/// operation.
struct SingleCombinerReduction {};
class CapturingOpMatcher;
class CapturingValueMatcher;
/// Base class for capturing matchers that can be owned by the context.
class CapturingMatcherBase {
public:
// Virtual destructor so unique pointers are deallocated correctly.
// TODO: if efficiency is a problem, consider disallowing non-trivial
// destructors for subclasses.
virtual ~CapturingMatcherBase() = default;
protected:
/// Informs the matcher that it has another, nested matcher. Derived classes
/// must call this to keep track of nested matchers for capture resetting
/// purposes.
template <typename T>
void recordNestedMatcher(T &nested) {
if constexpr (std::is_base_of_v<CapturingOpMatcher, T>)
nestedCapturingMatchers.push_back(&nested);
if constexpr (std::is_base_of_v<CapturingValueMatcher, T>)
nestedCapturingValueMatchers.push_back(&nested);
}
/// Appends all nested capturing matchers of a certain kind, excluding this
/// one, to `nested`.
void getAllNested(SmallVectorImpl<CapturingOpMatcher *> &nested);
void
getAllNestedValueMatchers(SmallVectorImpl<CapturingValueMatcher *> &nested);
/// Resets nested capturing matchers but does NOT reset the current one.
void resetCapture();
private:
/// A list of (recursively) nested capturing matchers that should be reset
/// when the current matcher is.
SmallVector<CapturingOpMatcher *, 2> nestedCapturingMatchers;
SmallVector<CapturingValueMatcher *, 2> nestedCapturingValueMatchers;
};
/// A context object holding capturing matchers, must outlive any individual
/// matcher. When matching complex subgraphs, the caller often doesn't care
/// about all intermediate nodes (operations) in the graph and shouldn't need to
/// hold matcher objects for those. These matchers can be created in this
/// context.
class MatcherContext {
public:
/// Create a new matcher of the specified type owned by this context.
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of_v<CapturingMatcherBase, T>, T> &
allocate(Args &&...args) {
// Need to call "new" explicitly as make_unique wouldn't have access to the
// private constructor when this class would.
ownedMatchers.emplace_back(
std::unique_ptr<T>(new T(std::forward<Args>(args)...)));
return *static_cast<T *>(ownedMatchers.back().get());
}
private:
/// Owning list of matchers.
// TODO: If this becomes inefficient, consider something like BumpPtrAllocator
// that derived classes can use to store their members as well.
SmallVector<std::unique_ptr<CapturingMatcherBase>> ownedMatchers;
};
/// Base class for value matchers that capture the matched value. Stores a list
/// of predicates and requires all of them to match for the value to match. Once
/// a value matched, any repeated use just verifies that equality of the value.
class CapturingValueMatcher : public CapturingMatcherBase {
friend class CapturingMatcherBase;
friend class MatcherContext;
using PredicateFn = std::function<bool(Value)>;
public:
/// Resets the captured value to null. This should be called if the same
/// pattern needs to be applied more than once as it may keep captured values
/// for optional nested predicates from the previous application.
void resetCapture() {
captured = nullptr;
CapturingMatcherBase::resetCapture();
}
/// Returns the matched value if the match was successful.
Value getCaptured() const { return captured; }
/// Matches the given value, hook for `matchPattern`.
bool match(Value value);
protected:
CapturingValueMatcher() = default;
/// Adds a predicate to the end of the predicate list for this value matcher.
template <typename Fn>
void addPredicate(Fn &&predicate) {
predicates.emplace_back(std::forward<Fn>(predicate));
}
/// The captured value.
Value captured = nullptr;
private:
/// Additional predicates to be checked on the value.
SmallVector<PredicateFn> predicates;
};
/// Creates a matcher of an arbitrary value.
inline CapturingValueMatcher &m_Value(MatcherContext &context) {
return context.allocate<CapturingValueMatcher>();
}
/// Matcher for typed values whose type implements the `ShapedType` interface.
/// Allows for matching the components of the shaped type such as rank and
/// dimensions.
class ShapedValueMatcher : public CapturingValueMatcher {
friend class MatcherContext;
ShapedValueMatcher();
public:
/// Add an always-succeeding matcher predicate capturing the rank.
ShapedValueMatcher &rank(CaptureRank capture);
/// Add an always-succeeding matcher predicate capturing the size of the
/// dimension identified by the first argument.
ShapedValueMatcher &dim(int64_t dimension, CaptureDim capture);
/// Add an always-succeeding matcher predicate capturing the sizes of all
/// dimensions in order of appearance.
ShapedValueMatcher &dim(AllDims tag, CaptureDims captures);
/// Add an always-succeeding matcher predicate capturing the element type of
/// the value.
ShapedValueMatcher &elementType(CaptureElementType captures);
};
/// Construct a new matcher of a value whose type is a `ShapedType`, owned by
/// the given context.
inline ShapedValueMatcher &m_ShapedValue(MatcherContext &context) {
return context.allocate<ShapedValueMatcher>();
}
/// Matcher for operations with additional predicates attachable through the
/// fluent, a.k.a. chainable, API. Note that public API must *not* accept
/// additional callbacks even; new predicates should be added instead when
/// necessary. Not only this decreases the depth of the callback stack and
/// increases readability, it also allows us to port the matcher to a
/// declarative format using PDL and/or Transform dialect in the future. The
/// latter will become impossible with arbitrary C++ callbacks.
class CapturingOpMatcher : public CapturingMatcherBase {
friend class CapturingMatcherBase;
friend class MatcherContext;
template <typename... OpTy>
friend CapturingOpMatcher &m_Operation(MatcherContext &matcherContext);
public:
using PredicateFn = std::function<bool(Operation *)>;
/// Matches the given operation, hook for `matchPattern`.
bool match(Operation *op);
/// Resets the captured value to null. This should be called if the same
/// pattern needs to be applied more than once as it may keep captured values
/// for optional nested predicates from the previous application.
void resetCapture() {
captured = nullptr;
CapturingMatcherBase::resetCapture();
}
/// Returns the matched operation if the match was successful.
Operation *getCaptured() const { return captured; }
/// Adds alternative paths for predicates. In practice, this is just a
/// predicate that is satisfied when either the first or the second matcher is
/// satisfied. The alternative satisfaction is eager and short-cutting, i.e.,
/// the second alternative will not be processed, and therefore will not
/// capture values, if the first alternative succeeded.
CapturingOpMatcher &alternatives(CapturingOpMatcher &first,
CapturingOpMatcher &second);
//===-------------------------------------------------------------------===//
// Constraints on adjacent ops.
//===-------------------------------------------------------------------===//
/// Adds a predicate checking that all ops implementing TilingInterface in the
/// parent of the given type (e.g., a function or a module) were matched by
/// this or nested matchers. This is useful to ensure that the matcher covered
/// the entire parent region, not just a parent of it. This predicate **must**
/// be added *after* all the other predicates that capture.
template <typename OpTy>
CapturingOpMatcher &allTilableOpsCaptured() {
SmallVector<CapturingOpMatcher *> copy;
copy.push_back(this);
getAllNested(copy);
addPredicate([copy = std::move(copy)](Operation *op) {
Operation *parent = op->getParentOfType<OpTy>();
return checkAllTilableMatched(parent, op, copy);
});
return *this;
}
//-------------------------------------------------------------------------//
// Predicates for operands and results.
//-------------------------------------------------------------------------//
/// Adds a predicate checking that the operation has exactly the given number
/// of operands.
CapturingOpMatcher &operand(NumEqualsTo num);
/// Adds a predicate checking that the `pos`-th operand of the operation is
/// defined by an operation that satisfies the given matcher.
CapturingOpMatcher &operand(int64_t pos, CapturingOpMatcher &nested);
/// Adds a predicate checking that the `pos`-th operand of the operation
/// satisfies the given value matcher.
CapturingOpMatcher &operand(int64_t pos, CapturingValueMatcher &nested);
/// Adds a predicate checking that the `pos`-th operand of the operation is
/// defined by `arith.constant` with the value 1.0.
// TODO: better matching for attributes.
CapturingOpMatcher &operand(int64_t pos, ConstantFloatOne);
/// Adds a predicate checking that the operation has exactly the given number
/// of results.
CapturingOpMatcher &result(NumEqualsTo num);
/// Adds a predicate checking that the `pos`-th result of the operation
/// satisfies the given value matcher.
CapturingOpMatcher &result(int64_t pos, CapturingValueMatcher &nested);
protected:
/// Constructs a default operation matcher accepting any operation.
CapturingOpMatcher() = default;
/// Adds a predicate for the matched operation to satisfy.
template <typename Fn>
void addPredicate(Fn &&predicate) {
predicates.emplace_back(std::forward<Fn>(predicate));
}
/// Produce the debug output for `create` method in a non-templated way.
static void debugOutputForCreate(ArrayRef<StringRef> opNames);
private:
/// A list of additional conditions for the operation to match.
SmallVector<PredicateFn> predicates;
/// Checks that `matchers` captured all tilable ops nested in `parent` except
/// for `linalgOp`. This is an implementation detail of allTilableOpsCaptured.
static bool checkAllTilableMatched(Operation *parent, Operation *op,
ArrayRef<CapturingOpMatcher *> matchers);
/// Creates a matcher for an operation with one of the given types.
template <typename... OpType>
static CapturingOpMatcher create() {
CapturingOpMatcher matcher;
matcher.addPredicate([](Operation *op) {
debugOutputForCreate(ArrayRef<StringRef>{OpType::getOperationName()...});
return isa<OpType...>(op);
});
return matcher;
}
/// Common util for constant matcher.
CapturingOpMatcher &operand(int64_t position,
std::function<bool(llvm::APFloat)> floatValueFn);
protected:
/// Matched value.
Operation *captured = nullptr;
};
namespace detail {
/// Prints the debug output from the ConcreteOpMatcher constructor. The
/// implementation must reside in the C++ file so we don't pollute the header
/// with debug includes, and ConcreteOpMatcher is a class template that can only
/// reside in the header.
void debugOutputForConcreteOpMatcherConstructor(StringRef name);
} // namespace detail
/// Base class for matchers that match a specific op. Adds an initial predicate
/// checking if the op is indeed of the specified kind.
/// Derived classes specializing this for op interfaces MUST also define a
/// specialization of DebugOpKindDescription.
template <typename Derived, typename OpTy>
class ConcreteOpMatcher : public CapturingOpMatcher {
protected:
using Base = ConcreteOpMatcher;
static StringRef getConcreteOpDescription() {
return OpTy::getOperationName();
}
/// Adds a predicate checking if the op is of the OpTy kind.
ConcreteOpMatcher() {
CapturingOpMatcher::addPredicate([](Operation *op) {
detail::debugOutputForConcreteOpMatcherConstructor(
Derived::getConcreteOpDescription());
return isa<OpTy>(op);
});
}
/// Adds a predicate for the matched operation to satisfy.
template <typename FnTy>
Derived &addPredicate(FnTy &&predicate) {
// Dispatch to the callback.
CapturingOpMatcher::addPredicate(
[inner = std::move(predicate)](Operation *op) {
return inner(cast<OpTy>(op));
});
return static_cast<Derived &>(*this);
}
public:
/// Adds alternative paths for predicates. In practice, this is just a
/// predicate that is satisfied when either the first or the second matcher is
/// satisfied. The alternative satisfaction is eager and short-cutting, i.e.,
/// the second alternative will not be processed, and therefore will not
/// capture values, if the first alternative succeeded.
Derived &alternatives(CapturingOpMatcher &first, CapturingOpMatcher &second) {
return static_cast<Derived &>(
CapturingOpMatcher::alternatives(first, second));
}
/// Adds a predicate checking that all ops implementing TilingInterface in the
/// parent of the given type (e.g., a function or a module) were matched by
/// this or nested matchers. This is useful to ensure that the matcher covered
/// the entire parent region, not just a parent of it. This predicate **must**
/// be added *after* all the other predicates that capture.
template <typename ParentTy>
Derived &allTilableOpsCaptured() {
return static_cast<Derived &>(
CapturingOpMatcher::allTilableOpsCaptured<ParentTy>());
}
//-------------------------------------------------------------------------//
// Predicates for operands and results.
//-------------------------------------------------------------------------//
/// Adds a predicate checking that the operation has exactly the given number
/// of operands.
Derived &operand(NumEqualsTo num) {
return static_cast<Derived &>(CapturingOpMatcher::operand(num));
}
/// Adds a predicate checking that the `pos`-th operand of the operation is
/// defined by an operation that satisfies the given matcher.
Derived &operand(int64_t pos, CapturingOpMatcher &nested) {
return static_cast<Derived &>(CapturingOpMatcher::operand(pos, nested));
}
/// Adds a predicate checking that the `pos`-th operand of the operation
/// satisfies the given value matcher.
Derived &operand(int64_t pos, CapturingValueMatcher &nested) {
return static_cast<Derived &>(CapturingOpMatcher::operand(pos, nested));
}
/// Adds a predicate checking that the `pos`-th operand of the operation is
/// defined by `arith.constant` with the value 1.0.
// TODO: better matching for attributes.
Derived &operand(int64_t pos, ConstantFloatOne c) {
return static_cast<Derived &>(CapturingOpMatcher::operand(pos, c));
}
/// Adds a predicate checking that the operation has exactly the given number
/// of results.
Derived &result(NumEqualsTo num) {
return static_cast<Derived &>(CapturingOpMatcher::result(num));
}
/// Adds a predicate checking that the `pos`-th result of the operation
/// satisfies the given value matcher.
Derived &result(int64_t pos, CapturingValueMatcher &nested) {
return static_cast<Derived &>(CapturingOpMatcher::result(pos, nested));
}
};
/// Matcher for the `tensor.pad` operation.
class TensorPadOpMatcher
: public ConcreteOpMatcher<TensorPadOpMatcher, tensor::PadOp> {
friend class MatcherContext;
TensorPadOpMatcher() = default;
public:
/// Adds a predicate checking that the low padding sizes are exactly the given
/// values.
TensorPadOpMatcher &low(ArrayRef<int64_t> sizes);
/// Adds a predicate checking that the low padding sizes for all dimensions
/// are exactly the same given value.
TensorPadOpMatcher &low(AllDims tag, int64_t size);
/// Adds a predicate checking that the high padding sizes for all dimensions
/// are exactly the same given value.
TensorPadOpMatcher &high(ArrayRef<int64_t> sizes);
/// Adds a predicate checking that the high padding sizes for all dimensions
/// are exactly the same given value.
TensorPadOpMatcher &high(AllDims tag, int64_t size);
/// Adds a predicate checking that the body of the pad only yields values
/// defined outside the pad region.
TensorPadOpMatcher &yieldsExternalValue();
};
inline TensorPadOpMatcher &m_tensorPad(MatcherContext &matcherContext) {
return matcherContext.allocate<TensorPadOpMatcher>();
}
/// Creates a default operation matcher in the given context that accepts any
/// operation.
inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) {
return matcherContext.allocate<CapturingOpMatcher>();
}
/// Creates an operation matcher in the given context that accepts only
/// operations of the kinds provided as template arguments.
template <typename... OpTy>
inline CapturingOpMatcher &m_Operation(MatcherContext &matcherContext) {
return matcherContext.allocate<CapturingOpMatcher>(
CapturingOpMatcher::create<OpTy...>());
}
/// Matcher for structured aka Linalg operations.
class StructuredOpMatcher
: public ConcreteOpMatcher<StructuredOpMatcher, linalg::LinalgOp> {
friend class MatcherContext;
StructuredOpMatcher() = default;
public:
static StringRef getConcreteOpDescription() {
return "linalg interface implementation";
}
/// Creates a matcher for a structured operation with one of the given types.
template <typename... OpType>
static StructuredOpMatcher create() {
StructuredOpMatcher matcher;
matcher.addPredicate([](Operation *op) {
debugOutputForCreate(ArrayRef<StringRef>{OpType::getOperationName()...});
return isa<linalg::LinalgOp>(op) && isa<OpType...>(op);
});
return matcher;
}
/// Matches a structured operation if either patterns A or B match.
StructuredOpMatcher(StructuredOpMatcher &A, StructuredOpMatcher &B);
//===-------------------------------------------------------------------===//
// Constraints on op rank and dims.
//===-------------------------------------------------------------------===//
/// Adds a predicate checking that the given rank must be greater than some
/// constant value.
StructuredOpMatcher &rank(NumGreaterEqualTo minRank);
StructuredOpMatcher &rank(NumLowerEqualTo maxRank);
/// Adds a predicate checking that the given iteration space dimension is
/// static/dynamic. The dimension index may be negative, in which case
/// dimensions are counted from the last one (i.e. Python-style), or be an
/// AllDims tag, in which case all dimensions are checked. This may be
/// eventually extended to slices and/or lists of dimensions.
StructuredOpMatcher &dim(int64_t dimension, ShapeKind kind) {
return dim(SmallVector<int64_t>{dimension}, kind);
}
StructuredOpMatcher &dim(SmallVector<int64_t> &&dimensions, ShapeKind kind);
StructuredOpMatcher &dim(AllDims tag, ShapeKind kind);
/// Adds a predicate checking that the given iteration space dimension has the
/// given iterator type, e.g., parallel or reduction. The dimension index may
/// be negative, in which case dimensions are counted from the last one
/// (i.e. Python-style), or be an AllDims tag, in which case all dimensions
/// are checked. This may be eventually extended to slices and/or lists of
/// dimensions.
StructuredOpMatcher &dim(int64_t dimension, utils::IteratorType kind) {
return dim(SmallVector<int64_t>{dimension}, kind);
}
// Ownership may get tricky here so we wrap in an explicit vector.
StructuredOpMatcher &dim(SmallVector<int64_t> &&dimensions,
utils::IteratorType kind);
StructuredOpMatcher &dim(AllDims tag, utils::IteratorType kind);
StructuredOpMatcher &dim(AllDimsExcept &&dimensions,
utils::IteratorType kind);
/// Adds a predicate checking that the given iteration space dimension is
/// statically known to be divisible by the given value. The dimension index
/// may be negative, in which case dimensions are counted from the last one
/// (i.e. Python-style).
StructuredOpMatcher &dim(int64_t dimension, DivisibleBy divisibleBy);
//===-------------------------------------------------------------------===//
// Capture directives.
//===-------------------------------------------------------------------===//
StructuredOpMatcher &rank(CaptureRank capture);
StructuredOpMatcher &dim(int64_t dimension, CaptureDim capture);
StructuredOpMatcher &dim(AllDims tag, CaptureDims captures);
StructuredOpMatcher &convolutionDims(CaptureConvDims convDims);
//===-------------------------------------------------------------------===//
// Constraints on input operands.
//===-------------------------------------------------------------------===//
/// Adds a predicate checking that the structured op has the given number of
/// inputs.
StructuredOpMatcher &input(NumEqualsTo num);
/// Adds a predicate that recursively applies other predicates to the
/// operation defining the `position`-th operand. The position may be
/// negative, in which case positions are counted from the last one
/// (i.e. Python-style). When the match is optional, the predicate check
/// succeeds as long as the `position` is in bounds. The matcher is executed
/// if there is a defining operation for the input operand.
template <typename T>
std::enable_if_t<
llvm::is_detected<::mlir::detail::has_operation_or_value_matcher_t, T,
Operation *>::value,
StructuredOpMatcher &>
input(int64_t position, T &operandMatcher,
OptionalMatch optional = OptionalMatch(false)) {
addInputMatcher(
position,
[&operandMatcher](Operation *op) { return operandMatcher.match(op); },
optional);
recordNestedMatcher(operandMatcher);
return *this;
}
template <typename T>
std::enable_if_t<
llvm::is_detected<::mlir::detail::has_operation_or_value_matcher_t, T,
Value>::value,
StructuredOpMatcher &>
input(int64_t position, T &operandMatcher,
OptionalMatch optional = OptionalMatch(false)) {
addInputMatcher(
position,
[&operandMatcher](Value v) { return operandMatcher.match(v); },
optional);
recordNestedMatcher(operandMatcher);
return *this;
}
/// Adds a predicate checking that all input operands of the structured op
/// have a permutation indexing map.
StructuredOpMatcher &input(AllOperands tag, IsPermutation);
/// Adds a predicate checking that all input operands of the structured op
/// have a projected permutation indexing map.
StructuredOpMatcher &input(AllOperands tag, IsProjectedPermutation);
/// Adds a predicate checking that all input operands of the structured op
/// are projected along the given dimension.
StructuredOpMatcher &input(SmallVector<int64_t> &&positions, IsProjected dim);
StructuredOpMatcher &input(int64_t position, IsProjected dim) {
return input(SmallVector<int64_t>{position}, dim);
}
/// Adds a predicate checking that all input operands of the structured op
/// have identity indexing map.
StructuredOpMatcher &input(AllOperands tag, IsIdentity);
StructuredOpMatcher &input(SmallVector<int64_t> &&positions, IsIdentity);
StructuredOpMatcher &input(int64_t position, IsIdentity) {
return input(SmallVector<int64_t>{position}, IsIdentity());
}
/// Adds a predicate checking that the bit width of the elemental type of the
/// structured op input at the given position is equal to the given value.
StructuredOpMatcher &input(int64_t position, ElementTypeBitWidth width);
/// Capture the elemental type bitwidth of input operand `position`.
StructuredOpMatcher &input(int64_t position,
CaptureElementTypeBitWidth width);
/// Capture the elemental type of input operand `position`.
StructuredOpMatcher &input(int64_t position, CaptureElementType elem);
/// Check if input is equal to a known constant.
// TODO: Support matching for constant ops.
StructuredOpMatcher &input(int64_t position, ConstantFloatMinOrMinusInf);
StructuredOpMatcher &input(int64_t position, ConstantFloatZero);
//===-------------------------------------------------------------------===//
// Constraints on output operands.
//===-------------------------------------------------------------------===//
/// Adds a predicate checking that the structured op has the given number of
/// outputs.
StructuredOpMatcher &output(NumEqualsTo num);
/// Adds a predicate checking that all output operands of the structured op
/// have a permutation indexing map.
StructuredOpMatcher &output(AllOperands tag, IsPermutation);
/// Adds a predicate checking that all output operands of the structured op
/// have a projected permutation indexing map.
StructuredOpMatcher &output(AllOperands tag, IsProjectedPermutation);
/// Adds a predicate checking that all output operands of the structured op
/// have a
StructuredOpMatcher &output(AllOperands tag, IsProjected dim);
/// Adds a predicate checking that all output operands of the structured op
/// have identity indexing map.
StructuredOpMatcher &output(AllOperands tag, IsIdentity);
/// Adds a predicate checking that the bit width of the elemental type of the
/// structured op output at the given position is equal to the given value.
StructuredOpMatcher &output(int64_t position, ElementTypeBitWidth width);
/// Capture the elemental type bitwidth of output operand `position`.
StructuredOpMatcher &output(int64_t position,
CaptureElementTypeBitWidth width);
/// Capture the elemental type of output operand `position`.
StructuredOpMatcher &output(int64_t position, CaptureElementType elem);
/// Adds a predicate checking that the output of the structured op is produced
/// by a reduction with a single-operation combinator (such as addf or mulf,
/// but not a compare+select pair).
StructuredOpMatcher &output(int64_t position, SingleCombinerReduction tag);
/// Adds a predicate that recursively applies other predicates to the
/// operation defining the init/out operand corresponding to `position`-th
/// output. The position may be negative, in which case positions are counted
/// from the last one (i.e. Python-style). When the match is optional, the
/// predicate check succeeds as long as the `position` is in bounds. The
/// matcher executed if there is a defining operation for the output operand.
template <typename T>
std::enable_if_t<
llvm::is_detected<::mlir::detail::has_operation_or_value_matcher_t, T,
Operation *>::value,
StructuredOpMatcher &>
output(int64_t position, T &operandMatcher,
OptionalMatch optional = OptionalMatch(false)) {
addOutputMatcher(
position,
[&operandMatcher](Operation *op) { return operandMatcher.match(op); },
optional);
recordNestedMatcher(operandMatcher);
return *this;
}
//===-------------------------------------------------------------------===//
// Constraints on results.
//===-------------------------------------------------------------------===//
/// Adds a predicate that recursively applies to users of the `position`-th
/// result of the structured op. Succeeds if any user matches the predicate.
/// When the match is optional, the predicate check succeeds as long as the
/// `position` is in bounds, after running the given matcher.
template <typename T>
std::enable_if_t<
llvm::is_detected<::mlir::detail::has_operation_or_value_matcher_t, T,
Operation *>::value,
StructuredOpMatcher &>
result(int64_t position, HasAnyUse tag, T &resultUserMatcher,
OptionalMatch optional = OptionalMatch(false)) {
addResultMatcher(
position, tag,
[&resultUserMatcher](Operation *op) {
return resultUserMatcher.match(op);
},
optional);
recordNestedMatcher(resultUserMatcher);
return *this;
}
//===-------------------------------------------------------------------===//
// Constraints on op region.
//===-------------------------------------------------------------------===//
/// Return true if the linalg op only contains a single ops and the arguments
/// of the operation match the order of the linalg operand.
/// Example:
/// linalg.generic
/// ins(%0, %1 : tensor<?x?x?xf32>, tensor<?x?xf32>)
/// outs(%2 : tensor<?x?x?xf32>) {
/// ^bb0(%arg0: f32, %arg1: f32):
/// %3 = arith.maxf %arg0, %arg1 : f32
/// linalg.yield %3 : f32
/// } -> tensor<?x?xf32>
/// If commutative is set binary operations can have their operands swapped.
template <typename OpType>
StructuredOpMatcher &singleOpWithCanonicaleArgs(bool commutative = false) {
return singleOpWithCanonicaleArgs(OpType::getOperationName(), commutative);
}
StructuredOpMatcher &singleOpWithCanonicaleArgs(StringRef opname,
bool commutative);
/// Check if the op is a linalg of with a single float reciprocal op.
StructuredOpMatcher &isFloatReciprocal();
/// Check if the op is a linalg of with a region containing only a yield op
/// using block arguments in order.
StructuredOpMatcher &passThroughOp();
private:
/// Non-template implementations of nested predicate builders for inputs,
/// outputs and results. Should not be called directly.
void addInputMatcher(int64_t position,
std::function<bool(Operation *)> matcher,
OptionalMatch optional);
void addInputMatcher(int64_t position, std::function<bool(Value)> matcher,
OptionalMatch optional);
void addOutputMatcher(int64_t position,
std::function<bool(Operation *)> matcher,
OptionalMatch optional);
void addResultMatcher(int64_t position, HasAnyUse tag,
std::function<bool(Operation *)> matcher,
OptionalMatch optional);
// Common util for constant matcher.
StructuredOpMatcher &input(int64_t position,
std::function<bool(llvm::APFloat)> floatValueFn);
};
/// Creates a matcher of an arbitrary structured op.
inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) {
return matcherContext.allocate<StructuredOpMatcher>();
}
/// Creates a matcher that is a copy of the given matcher.
inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext,
const StructuredOpMatcher &other) {
return matcherContext.allocate<StructuredOpMatcher>(other);
}
/// Creates a matcher that accepts as disjunction of the two given matchers.
inline StructuredOpMatcher &m_StructuredOp_Or(MatcherContext &matcherContext,
StructuredOpMatcher &A,
StructuredOpMatcher &B) {
return matcherContext.allocate<StructuredOpMatcher>(A, B);
}
/// Creates a matcher of a structured op with kinds provided as template
/// arguments.
template <typename... OpType>
inline StructuredOpMatcher &m_StructuredOp(MatcherContext &matcherContext) {
return matcherContext.allocate<StructuredOpMatcher>(
StructuredOpMatcher::create<OpType...>());
}
//===---------------------------------------------------------------------===//
// MatchCallback functionality.
//===---------------------------------------------------------------------===//
/// Additional results of the C++ callback usable in the `match_callback`
/// transform operation. Conceptually, a list of lists of payload operations to
/// be associated with each result handle.
class MatchCallbackResult {
public:
/// Returns the number of lists of payload operations.
int64_t getNumPayloadGroups() const { return payloadGroupLengths.size(); }
/// Returns the `position`-th list of payload operations.
ArrayRef<Operation *> getPayloadGroup(int64_t position) const;
/// Adds a new list of payload operations to the list of lists. The new list
/// must not contain null operations.
template <typename Range>
int64_t addPayloadGroup(Range operations) {
int64_t originalLength = payloadOperations.size();
assert(llvm::all_of(operations, [](Operation *op) -> bool { return op; }) &&
"null operation");
llvm::append_range(payloadOperations, operations);
payloadGroupLengths.push_back(payloadOperations.size() - originalLength);
return payloadGroupLengths.size() - 1;
}
void addPayloadGroup(ArrayRef<Operation *> operations) {
addPayloadGroup<ArrayRef<Operation *>>(operations);
}
/// Adds a new singleton list of payload operation to the list of lists if the
/// operation is non-null, adds an empty list otherwise. Useful for results of
/// optional matches.
void addPotentiallyEmptyPayloadGroup(Operation *op) {
if (!op)
addPayloadGroup(ArrayRef<Operation *>());
else
addPayloadGroup(ArrayRef<Operation *>(op));
}
private:
/// The flat list of all payload opreations. `payloadGroupLengths` can be used
/// to compute the sublist that corresponds to one nested list.
// TODO: if somebody implements such a flattened vector generically, use it.
SmallVector<Operation *> payloadOperations;
SmallVector<int64_t> payloadGroupLengths;
};
/// A transform state extension that maintains the mapping between callback
/// names as strings usable in `match_callback` and their implementations.
class MatchCallbacksRegistry : public transform::TransformState::Extension {
public:
using MatchCallbackFn = std::function<DiagnosedSilenceableFailure(
MatchCallbackResult &, Location, const transform::TransformState &,
ValueRange)>;
/// Constructs the extension.
MatchCallbacksRegistry(transform::TransformState &state)
: transform::TransformState::Extension(state) {}
/// Registers the given function as a callback with the given name. The name
/// must not be already present in the registry. The callback must be
/// convertible to MatchCallbackFn.
template <typename Fn>
void registerCallback(StringRef name, Fn &&fn) {
bool succeeded = callbacks.try_emplace(name, std::forward<Fn>(fn)).second;
(void)succeeded;
assert(succeeded && "adding a callback with a repeated name");
}
/// Returns a pointer to the implementation of the callback with the given
/// name, or null if it is not present in the registry.
const MatchCallbackFn *get(StringRef name) const {
auto iter = callbacks.find(name);
if (iter == callbacks.end())
return nullptr;
return &iter->getValue();
}
private:
llvm::StringMap<MatchCallbackFn> callbacks;
};
//===---------------------------------------------------------------------===//
// Case-specific matcher builders.
//===---------------------------------------------------------------------===//
struct MatchedReductionCaptures {
int64_t reductionRank = 0;
int64_t maybeLeadingRank = 0;
int64_t maybeTrailingRank = 0;
SmallVector<int64_t> leadingOpSizes = {};
SmallVector<int64_t> reductionOpSizes = {};
SmallVector<int64_t> trailingOpSizes = {};
int64_t reductionOutputElementalTypeBitWidth = 0;
int64_t maybeLeadingOutputElementalTypeBitWidth = 0;
int64_t maybeTrailingOutputElementalTypeBitWidth = 0;
};
struct MatchedMatmulCaptures {
Type lhsElementType, rhsElementType, outputElementType;
SmallVector<int64_t> matmulOpSizes = {};
};
/// Creates a group of matchers for:
///
/// trailing(reduction(leading(), fill()))
///
/// where trailing and leading are elementwise operations whose presence is
/// optional. Each matcher will capture the corresponding operation. If
/// `mustMatchEntireFunc` is set, the matcher additionally checks if all
/// tileable operations in the functions are captured.
void makeReductionMatcher(MatcherContext &context,
StructuredOpMatcher *&reductionCapture,
StructuredOpMatcher *&fillCapture,
StructuredOpMatcher *&leadingCapture,
StructuredOpMatcher *&trailingCapture,
MatchedReductionCaptures &captures,
bool mustMatchEntireFunc);
void makeReductionMatcher(MatcherContext &context,
StructuredOpMatcher *&reductionCapture,
MatchedReductionCaptures &captures,
bool mustMatchEntireFunc);
///
/// trailing(matmul(*, *, fill()))
///
/// where trailing and leading are elementwise operations whose presence is
/// optional. Each matcher will capture the corresponding operation. If
/// `mustMatchEntireFunc` is set, the matcher additionally checks if all
/// tileable operations in the functions are captured.
void makeMatmulMatcher(MatcherContext &matcherContext,
StructuredOpMatcher *&matmulCapture,
StructuredOpMatcher *&fillCapture,
StructuredOpMatcher *&trailingCapture,
MatchedMatmulCaptures &captures,
bool mustMatchEntireFunc);
/// Create a group of matchers for a different code sequence of operations
/// matching exactly a softmax operation.
///
/// %red = reduce_max(%0)
/// %sub = sub(%0, %red)
/// %exp = exp(%sub)
/// %sum = reduce_sum(%exp)
/// %mul = div(%exp, %%sum)
void makeSoftmaxMatcher(MatcherContext &context,
StructuredOpMatcher *&maxReductionCapture,
StructuredOpMatcher *&softmaxRootCapture);
struct MatchedConvolutionCaptures {
mlir::linalg::detail::ConvolutionDimensions convolutionDims = {};
SmallVector<int64_t> convolutionOpSizes = {};
SmallVector<int64_t> trailingOpSizes = {};
int64_t convolutionOutputElementalTypeBitWidth = 0;
int64_t maybeTrailingOutputElementalTypeBitWidth = 0;
int64_t maybeFillElementalTypeBitWidth = 0;
};
/// Creates a group of matchers for:
///
/// trailing(convolution(input, filter, fill()))
///
/// where fill is a FillOp and trailing is an elementwise operation, both of
/// which is optional. Each matcher will capture the corresponding operation. If
/// `mustMatchEntireFunc` is set, the matcher additionally checks if all
/// tileable operations in the functions are captured.
void makeConvolutionMatcher(MatcherContext &context,
StructuredOpMatcher *&convolutionCapture,
StructuredOpMatcher *&fillCapture,
StructuredOpMatcher *&trailingCapture,
MatchedConvolutionCaptures &captures,
bool mustMatchEntireFunc);
void makeConvolutionMatcher(MatcherContext &context,
StructuredOpMatcher *&convolutionCapture,
MatchedConvolutionCaptures &captures,
bool mustMatchEntireFunc);
struct MatchedPadCaptures {
int64_t rank = 0;
Type elementType;
SmallVector<int64_t> dims = {};
};
/// Create a matcher for tensor.pad(*) without leading or trailing ops atm.
/// If `mustMatchEntireFunc` is set, the matcher additionally checks if all
/// tileable operations in the functions are captured.
void makePadMatcher(MatcherContext &context, CapturingOpMatcher *&padCapture,
MatchedPadCaptures &captures, bool mustMatchEntireFunc);
/// Wraps the given matcher callback to indicate that it must capture all
/// tilable ops in the parent function. Expects the callback to accept the same
/// arguments as what is expected by MatchCallbacksRegistry::register, followed
/// by a bool.
template <typename Fn>
auto wrapAsEntireFuncMatch(Fn &&fn) {
return [fn = std::move(fn)](MatchCallbackResult &res, Location loc,
const mlir::transform::TransformState &state,
ValueRange handles) {
return fn(res, loc, state, handles, true);
};
}
/// Wraps the given matcher callback to indicate that it can match subgraphs.
/// Expects the callback to accept the same arguments as what is expected by
/// MatchCallbacksRegistry::register, followed by a bool.
template <typename Fn>
auto wrapAsPartialMatch(Fn &&fn) {
return [fn = std::move(fn)](MatchCallbackResult &res, Location loc,
const mlir::transform::TransformState &state,
ValueRange handles) {
return fn(res, loc, state, handles, false);
};
}
} // namespace transform_ext
} // namespace mlir
#endif // IREE_COMPILER_CODEGEN_COMMON_TRANSFORMEXTENSIONS_TRANSFORMMATCHERS_H_