|  | // Copyright 2022 The IREE Authors | 
|  | // | 
|  | // Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | #include "iree-dialects/Transforms/TransformMatchers.h" | 
|  |  | 
|  | #include "mlir/Analysis/SliceAnalysis.h" | 
|  | #include "mlir/Dialect/Arith/IR/Arith.h" | 
|  | #include "mlir/Dialect/Func/IR/FuncOps.h" | 
|  | #include "mlir/Dialect/Linalg/IR/Linalg.h" | 
|  | #include "mlir/Dialect/Math/IR/Math.h" | 
|  | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" | 
|  | #include "llvm/ADT/STLExtras.h" | 
|  | #include "llvm/ADT/ScopeExit.h" | 
|  | #include "llvm/Support/Debug.h" | 
|  |  | 
|  | using namespace mlir; | 
|  |  | 
|  | #define DEBUG_TYPE "transform-matchers" | 
|  | #define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " | 
|  | #define DBGSNL() llvm::dbgs() << "\n[" DEBUG_TYPE "] " | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // CapturingMatcherBase | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | void transform_ext::CapturingMatcherBase::getAllNested( | 
|  | SmallVectorImpl<CapturingOpMatcher *> &nested) { | 
|  |  | 
|  | SetVector<CapturingOpMatcher *> found; | 
|  | found.insert(nested.begin(), nested.end()); | 
|  | int64_t start = found.size(); | 
|  |  | 
|  | auto appendOne = [&found](CapturingMatcherBase &one) { | 
|  | found.insert(one.nestedCapturingMatchers.begin(), | 
|  | one.nestedCapturingMatchers.end()); | 
|  | for (CapturingValueMatcher *valueMatcher : | 
|  | one.nestedCapturingValueMatchers) { | 
|  | found.insert(valueMatcher->nestedCapturingMatchers.begin(), | 
|  | valueMatcher->nestedCapturingMatchers.end()); | 
|  | } | 
|  | }; | 
|  |  | 
|  | appendOne(*this); | 
|  | for (int64_t position = start; position < found.size(); ++position) { | 
|  | appendOne(*found[position]); | 
|  | } | 
|  |  | 
|  | llvm::append_range(nested, found.getArrayRef()); | 
|  | } | 
|  |  | 
|  | void transform_ext::CapturingMatcherBase::getAllNestedValueMatchers( | 
|  | SmallVectorImpl<CapturingValueMatcher *> &nested) { | 
|  |  | 
|  | SetVector<CapturingValueMatcher *> found; | 
|  | found.insert(nested.begin(), nested.end()); | 
|  | int64_t start = found.size(); | 
|  |  | 
|  | auto appendOne = [&found](CapturingMatcherBase &one) { | 
|  | found.insert(one.nestedCapturingValueMatchers.begin(), | 
|  | one.nestedCapturingValueMatchers.end()); | 
|  | for (CapturingOpMatcher *opMatcher : one.nestedCapturingMatchers) { | 
|  | found.insert(opMatcher->nestedCapturingValueMatchers.begin(), | 
|  | opMatcher->nestedCapturingValueMatchers.end()); | 
|  | } | 
|  | }; | 
|  |  | 
|  | appendOne(*this); | 
|  | for (int64_t position = start; position < found.size(); ++position) { | 
|  | appendOne(*found[position]); | 
|  | } | 
|  |  | 
|  | llvm::append_range(nested, found.getArrayRef()); | 
|  | } | 
|  |  | 
|  | void transform_ext::CapturingMatcherBase::resetCapture() { | 
|  | SmallVector<CapturingOpMatcher *> nested; | 
|  | getAllNested(nested); | 
|  | for (CapturingOpMatcher *matcher : nested) { | 
|  | matcher->captured = nullptr; | 
|  | } | 
|  | SmallVector<CapturingValueMatcher *> nestedValue; | 
|  | getAllNestedValueMatchers(nestedValue); | 
|  | for (CapturingValueMatcher *matcher : nestedValue) { | 
|  | matcher->captured = nullptr; | 
|  | } | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // CapturingOpMatcher | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | bool transform_ext::CapturingOpMatcher::checkAllTilableMatched( | 
|  | Operation *parent, Operation *op, | 
|  | ArrayRef<transform_ext::CapturingOpMatcher *> matchers) { | 
|  | LLVM_DEBUG(DBGS() << "all tilable ops captured"); | 
|  | int64_t numTilableOps = 0; | 
|  | if (!parent) | 
|  | return false; | 
|  | parent->walk([&](TilingInterface Op) { ++numTilableOps; }); | 
|  |  | 
|  | llvm::SmallPtrSet<Operation *, 6> matched; | 
|  | for (CapturingOpMatcher *nested : matchers) { | 
|  | if (Operation *captured = nested->getCaptured()) { | 
|  | matched.insert(captured); | 
|  | } | 
|  | } | 
|  |  | 
|  | // Don't forget to include the root matcher. | 
|  | matched.insert(op); | 
|  | return numTilableOps == matched.size(); | 
|  | } | 
|  |  | 
|  | bool transform_ext::CapturingOpMatcher::match(Operation *op) { | 
|  | auto debugRAII = | 
|  | llvm::make_scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); | 
|  | LLVM_DEBUG(DBGS() << "matching: " << *op << "\n"); | 
|  |  | 
|  | if (getCaptured()) { | 
|  | LLVM_DEBUG(DBGS() << "found an already captured op: "); | 
|  | if (getCaptured() == op) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "same\n"); | 
|  | return true; | 
|  | } else { | 
|  | LLVM_DEBUG(llvm::dbgs() << "different\n"); | 
|  | return false; | 
|  | } | 
|  | } | 
|  |  | 
|  | if (!llvm::all_of(predicates, [op](const PredicateFn &fn) { | 
|  | bool result = fn(op); | 
|  | LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); | 
|  | return result; | 
|  | })) { | 
|  | return false; | 
|  | } | 
|  |  | 
|  | captured = op; | 
|  | return true; | 
|  | } | 
|  |  | 
|  | void transform_ext::CapturingOpMatcher::debugOutputForCreate( | 
|  | ArrayRef<StringRef> opNames) { | 
|  | LLVM_DEBUG(DBGS() << "operation type is one of {"; | 
|  | llvm::interleaveComma(opNames, llvm::dbgs()); llvm::dbgs() << "}"); | 
|  | } | 
|  |  | 
|  | /// Apply the given matcher to the given object, produce debug messages. | 
|  | template <typename Matcher, typename Object = typename llvm::function_traits< | 
|  | typename Matcher::match>::template args<0>> | 
|  | static bool recursiveMatch(Matcher &matcher, Object &object, | 
|  | StringRef extraMessage = "") { | 
|  | LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "] " | 
|  | << "start recursive match (" << extraMessage | 
|  | << ") {\n"); | 
|  | bool result = matcher.match(object); | 
|  | LLVM_DEBUG(DBGS() << "} end recursive match"); | 
|  | return result; | 
|  | } | 
|  |  | 
|  | transform_ext::CapturingOpMatcher & | 
|  | transform_ext::CapturingOpMatcher::alternatives( | 
|  | transform_ext::CapturingOpMatcher &first, | 
|  | transform_ext::CapturingOpMatcher &second) { | 
|  | addPredicate([&first, &second](Operation *op) { | 
|  | LLVM_DEBUG(DBGS() << "matching alternatives\n"); | 
|  | return recursiveMatch(first, op, "alternative 1") || | 
|  | recursiveMatch(second, op, "alternative 2"); | 
|  | }); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | //---------------------------------------------------------------------------// | 
|  | // Predicates for operands and results. | 
|  | //---------------------------------------------------------------------------// | 
|  |  | 
|  | transform_ext::CapturingOpMatcher & | 
|  | transform_ext::CapturingOpMatcher::operand(transform_ext::NumEqualsTo num) { | 
|  | addPredicate([=](Operation *op) { | 
|  | LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " operands"); | 
|  | return num.value == op->getNumOperands(); | 
|  | }); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | /// If `pos` is negative, returns the number of the operand in op starting from | 
|  | /// the last. For example, -1 means the last operand, -2 means the | 
|  | /// second-to-last, etc. Returns nullopt if pos is out-of-bounds, both positive | 
|  | /// and negative. | 
|  | static std::optional<int64_t> remapNegativeOperandNumber(int64_t pos, | 
|  | Operation *op) { | 
|  | int64_t updated = pos < 0 ? op->getNumOperands() + pos : pos; | 
|  | if (updated < 0 || updated >= op->getNumOperands()) { | 
|  | LLVM_DEBUG(DBGS() << "match operand #" << pos | 
|  | << "that does not exist in the operation"); | 
|  | return std::nullopt; | 
|  | } | 
|  | return updated; | 
|  | } | 
|  |  | 
|  | transform_ext::CapturingOpMatcher & | 
|  | transform_ext::CapturingOpMatcher::operand(int64_t pos, | 
|  | CapturingOpMatcher &nested) { | 
|  | addPredicate([pos, &nested](Operation *op) { | 
|  | std::optional<int64_t> operandNo = remapNegativeOperandNumber(pos, op); | 
|  | if (!operandNo) | 
|  | return false; | 
|  | LLVM_DEBUG(DBGS() << "operand #" << pos << " is defined by an operation"); | 
|  | Operation *definingOp = op->getOperand(*operandNo).getDefiningOp(); | 
|  | if (!definingOp) | 
|  | return false; | 
|  | return recursiveMatch(nested, definingOp); | 
|  | }); | 
|  | recordNestedMatcher(nested); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | transform_ext::CapturingOpMatcher & | 
|  | transform_ext::CapturingOpMatcher::operand(int64_t pos, | 
|  | CapturingValueMatcher &nested) { | 
|  | addPredicate([pos, &nested](Operation *op) { | 
|  | std::optional<int64_t> operandNo = remapNegativeOperandNumber(pos, op); | 
|  | if (!operandNo) | 
|  | return false; | 
|  | LLVM_DEBUG(DBGS() << "operand #" << pos << " is"); | 
|  | Value operand = op->getOperand(*operandNo); | 
|  | return recursiveMatch(nested, operand); | 
|  | }); | 
|  | recordNestedMatcher(nested); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | transform_ext::CapturingOpMatcher &transform_ext::CapturingOpMatcher::operand( | 
|  | int64_t position, std::function<bool(llvm::APFloat)> floatValueFn) { | 
|  | addPredicate([position, | 
|  | floatValueFn = std::move(floatValueFn)](Operation *op) -> bool { | 
|  | std::optional<int64_t> operandNo = remapNegativeOperandNumber(position, op); | 
|  | if (!operandNo) | 
|  | return false; | 
|  |  | 
|  | LLVM_DEBUG(DBGS() << "operand #" << *operandNo | 
|  | << " is a special floating point constant"); | 
|  | auto cstOp = | 
|  | op->getOperand(*operandNo).getDefiningOp<arith::ConstantFloatOp>(); | 
|  | if (!cstOp) | 
|  | return false; | 
|  | return floatValueFn(cstOp.value()); | 
|  | }); | 
|  |  | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | transform_ext::CapturingOpMatcher & | 
|  | transform_ext::CapturingOpMatcher::operand(int64_t position, ConstantFloatOne) { | 
|  | return operand(position, | 
|  | [](llvm::APFloat value) { return value.isExactlyValue(1.0); }); | 
|  | } | 
|  |  | 
|  | transform_ext::CapturingOpMatcher & | 
|  | transform_ext::CapturingOpMatcher::result(transform_ext::NumEqualsTo num) { | 
|  | addPredicate([=](Operation *op) { | 
|  | LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " results"); | 
|  | return num.value == op->getNumResults(); | 
|  | }); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | transform_ext::CapturingOpMatcher & | 
|  | transform_ext::CapturingOpMatcher::result(int64_t pos, | 
|  | CapturingValueMatcher &nested) { | 
|  | addPredicate([pos, &nested](Operation *op) { | 
|  | int64_t updated = pos < 0 ? op->getNumResults() + pos : pos; | 
|  | if (updated < 0 || updated >= op->getNumResults()) { | 
|  | LLVM_DEBUG(DBGS() << "matching result #" << pos | 
|  | << " that does not exist in the operation"); | 
|  | return false; | 
|  | } | 
|  | LLVM_DEBUG(DBGS() << "result #" << pos << " is"); | 
|  | Value result = op->getResult(updated); | 
|  | return recursiveMatch(nested, result); | 
|  | }); | 
|  | recordNestedMatcher(nested); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // CapturingValueMatcher | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | namespace { | 
|  | struct DebugPrintValueWrapper { | 
|  | Value value; | 
|  | }; | 
|  |  | 
|  | llvm::raw_ostream &operator<<(llvm::raw_ostream &os, | 
|  | const DebugPrintValueWrapper &wrapper) { | 
|  | if (auto opResult = wrapper.value.dyn_cast<OpResult>()) { | 
|  | return os << "op result #" << opResult.getResultNumber() << " in " | 
|  | << wrapper.value; | 
|  | } | 
|  |  | 
|  | auto blockArg = wrapper.value.cast<BlockArgument>(); | 
|  | os << "block argument #" << blockArg.getArgNumber(); | 
|  | Block *parentBlock = blockArg.getParentBlock(); | 
|  | Region *parentRegion = parentBlock->getParent(); | 
|  | if (!parentRegion) { | 
|  | os << " of a detached block:\n"; | 
|  | parentBlock->print(os); | 
|  | return os; | 
|  | } | 
|  |  | 
|  | os << " of block #" | 
|  | << std::distance(parentRegion->begin(), parentBlock->getIterator()); | 
|  | Operation *parentOp = parentRegion->getParentOp(); | 
|  | if (!parentOp) { | 
|  | os << " of a detached region:\n"; | 
|  | for (Block &b : *parentRegion) | 
|  | b.print(os); | 
|  | return os; | 
|  | } | 
|  |  | 
|  | os << " in region #" << parentRegion->getRegionNumber() << " of " | 
|  | << *parentOp; | 
|  | return os; | 
|  | } | 
|  | } // namespace | 
|  |  | 
|  | bool transform_ext::CapturingValueMatcher::match(Value value) { | 
|  | auto debugRAII = | 
|  | llvm::make_scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); | 
|  | LLVM_DEBUG(DBGS() << "matching " << DebugPrintValueWrapper{value} << "\n"); | 
|  |  | 
|  | if (getCaptured()) { | 
|  | LLVM_DEBUG(DBGS() << "found an already captured value: "); | 
|  | if (getCaptured() == value) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "same\n"); | 
|  | return true; | 
|  | } else { | 
|  | LLVM_DEBUG(llvm::dbgs() << "different\n"); | 
|  | return false; | 
|  | } | 
|  | } | 
|  |  | 
|  | for (const PredicateFn &fn : predicates) { | 
|  | bool result = fn(value); | 
|  | LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); | 
|  | if (!result) | 
|  | return false; | 
|  | } | 
|  |  | 
|  | captured = value; | 
|  | return true; | 
|  | } | 
|  |  | 
|  | transform_ext::ShapedValueMatcher::ShapedValueMatcher() | 
|  | : CapturingValueMatcher() { | 
|  | addPredicate([](Value value) { | 
|  | LLVM_DEBUG(DBGS() << "value is of shaped type"); | 
|  | return value && value.getType().isa<ShapedType>(); | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::ShapedValueMatcher & | 
|  | transform_ext::ShapedValueMatcher::rank(transform_ext::CaptureRank capture) { | 
|  | addPredicate([=](Value value) { | 
|  | LLVM_DEBUG(DBGS() << "capturing shaped value rank"); | 
|  | capture.value = value.getType().cast<ShapedType>().getRank(); | 
|  | return true; | 
|  | }); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | transform_ext::ShapedValueMatcher & | 
|  | transform_ext::ShapedValueMatcher::dim(int64_t dimension, CaptureDim capture) { | 
|  | addPredicate([=](Value value) { | 
|  | LLVM_DEBUG(DBGS() << "capturing shaped value dimension " << dimension); | 
|  | capture.value = value.getType().cast<ShapedType>().getDimSize(dimension); | 
|  | return true; | 
|  | }); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | transform_ext::ShapedValueMatcher & | 
|  | transform_ext::ShapedValueMatcher::dim(AllDims tag, CaptureDims captures) { | 
|  | (void)tag; | 
|  | addPredicate([=](Value value) { | 
|  | LLVM_DEBUG(DBGS() << "capturing all shaped value dimensions"); | 
|  | ArrayRef<int64_t> shape = value.getType().cast<ShapedType>().getShape(); | 
|  | captures.value.assign(shape.begin(), shape.end()); | 
|  | return true; | 
|  | }); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | transform_ext::ShapedValueMatcher & | 
|  | transform_ext::ShapedValueMatcher::elementType(CaptureElementType captures) { | 
|  | addPredicate([=](Value value) { | 
|  | LLVM_DEBUG(DBGS() << "capturing elementType"); | 
|  | captures.value = value.getType().cast<ShapedType>().getElementType(); | 
|  | return true; | 
|  | }); | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // Constraints on op rank and dims. | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::rank(NumGreaterEqualTo minRank) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "rank >= " << minRank.value); | 
|  | return linalgOp.getNumLoops() >= minRank.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::rank(NumLowerEqualTo maxRank) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "rank <= " << maxRank.value); | 
|  | return linalgOp.getNumLoops() <= maxRank.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::rank(NumEqualsTo exactRank) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "rank == " << exactRank.value); | 
|  | return linalgOp.getNumLoops() == exactRank.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | StringRef stringifyShapeKind(transform_ext::ShapeKind kind) { | 
|  | switch (kind) { | 
|  | case transform_ext::ShapeKind::Static: | 
|  | return "static"; | 
|  | case transform_ext::ShapeKind::Dynamic: | 
|  | return "dynamic"; | 
|  | } | 
|  | llvm_unreachable("unhandled shape kind"); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::dim(SmallVector<int64_t> &&dimensions, | 
|  | ShapeKind kind) { | 
|  | return addPredicate([dimensions = std::move(dimensions), | 
|  | kind](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "dimensions ["; | 
|  | llvm::interleaveComma(dimensions, llvm::dbgs()); | 
|  | llvm::dbgs() << "] are " << stringifyShapeKind(kind)); | 
|  | SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges(); | 
|  | for (auto dimension : dimensions) { | 
|  | int64_t transformedDimension = | 
|  | dimension >= 0 ? dimension : shape.size() + dimension; | 
|  | if (transformedDimension < 0 || transformedDimension >= shape.size()) | 
|  | return false; | 
|  | if (ShapedType::isDynamic(shape[transformedDimension]) ^ | 
|  | (kind == ShapeKind::Static)) | 
|  | continue; | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "all dimensions are " << stringifyShapeKind(kind)); | 
|  | SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges(); | 
|  | return llvm::all_of(shape, [=](int64_t dimension) { | 
|  | return ShapedType::isDynamic(dimension) ^ (kind == ShapeKind::Static); | 
|  | }); | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::dim(SmallVector<int64_t> &&dimensions, | 
|  | utils::IteratorType kind) { | 
|  | return addPredicate([dimensions = std::move(dimensions), | 
|  | kind](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "dimensions ["; | 
|  | llvm::interleaveComma(dimensions, llvm::dbgs()); | 
|  | llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); | 
|  | int64_t rank = linalgOp.getNumLoops(); | 
|  | for (auto dimension : dimensions) { | 
|  | int64_t transformedDimension = | 
|  | dimension >= 0 ? dimension : rank + dimension; | 
|  | if (transformedDimension < 0 || transformedDimension >= rank) | 
|  | return false; | 
|  | utils::IteratorType iteratorKind = | 
|  | linalgOp.getIteratorTypesArray()[transformedDimension]; | 
|  | if (iteratorKind == kind) | 
|  | continue; | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) { | 
|  | return dim(AllDimsExcept({}), kind); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::dim(AllDimsExcept &&dims, | 
|  | utils::IteratorType kind) { | 
|  | return addPredicate([dimensions = std::move(dims), | 
|  | kind](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "all dimensions except ["; | 
|  | llvm::interleaveComma(dimensions.getExcluded(), llvm::dbgs()); | 
|  | llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); | 
|  | int64_t rank = linalgOp.getNumLoops(); | 
|  | llvm::SmallDenseSet<int64_t> excludedDims; | 
|  | for (int64_t dim : dimensions.getExcluded()) { | 
|  | excludedDims.insert(dim >= 0 ? dim : rank + dim); | 
|  | } | 
|  |  | 
|  | for (auto [index, type] : | 
|  | llvm::enumerate(linalgOp.getIteratorTypesArray())) { | 
|  | if (excludedDims.contains(index)) | 
|  | continue; | 
|  | if (type == kind) | 
|  | continue; | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::dim(int64_t dimension, | 
|  | DivisibleBy divisibleBy) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "dimension " << dimension << " is divisible by " | 
|  | << divisibleBy.value); | 
|  | int64_t rank = linalgOp.getNumLoops(); | 
|  | int64_t transformedDimension = | 
|  | dimension >= 0 ? dimension : rank + dimension; | 
|  | if (transformedDimension >= rank) | 
|  | return false; | 
|  |  | 
|  | int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension]; | 
|  | return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0); | 
|  | }); | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // Capture directives. | 
|  | //===---------------------------------------------------------------------===// | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::rank(CaptureRank capture) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "capture rank"); | 
|  | capture.value = linalgOp.getNumLoops(); | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::dim(int64_t dimension, CaptureDim capture) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "capture dimension"); | 
|  | int64_t rank = linalgOp.getNumLoops(); | 
|  | int64_t transformedDimension = | 
|  | dimension >= 0 ? dimension : rank + dimension; | 
|  | if (transformedDimension >= rank) | 
|  | return false; | 
|  |  | 
|  | capture.value = linalgOp.getStaticLoopRanges()[transformedDimension]; | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::dim(AllDims tag, CaptureDims captures) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "capture all dimensions"); | 
|  | captures.value = linalgOp.getStaticLoopRanges(); | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::indexingMaps( | 
|  | CaptureIndexingMaps indexingMaps) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "capture indexing maps"); | 
|  | indexingMaps.value = linalgOp.getIndexingMapsArray(); | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::contractionDims( | 
|  | CaptureContractionDims contractionDims) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "capture contraction dimensions"); | 
|  | StringRef convMessage = linalg::detail::getMatchContractionMessage( | 
|  | mlir::linalg::detail::isContractionInterfaceImpl( | 
|  | linalgOp, &contractionDims.value)); | 
|  | if (convMessage.empty()) | 
|  | return true; | 
|  | LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); | 
|  | return false; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "capture convolution dimensions"); | 
|  | StringRef convMessage = linalg::detail::getMatchConvolutionMessage( | 
|  | mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp, | 
|  | &convDims.value)); | 
|  | if (convMessage.empty()) | 
|  | return true; | 
|  | LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); | 
|  | return false; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher::StructuredOpMatcher( | 
|  | StructuredOpMatcher &A, StructuredOpMatcher &B) { | 
|  |  | 
|  | addPredicate([&A, &B](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "start recursive lhs OR match {\n"); | 
|  | { | 
|  | auto debugRAII = llvm::make_scope_exit( | 
|  | [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); | 
|  | if (A.match(linalgOp)) | 
|  | return true; | 
|  | } | 
|  | LLVM_DEBUG(DBGS() << "start recursive rhs OR match {\n"); | 
|  | { | 
|  | auto debugRAII = llvm::make_scope_exit( | 
|  | [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); | 
|  | if (B.match(linalgOp)) | 
|  | return true; | 
|  | } | 
|  | return false; | 
|  | }); | 
|  | recordNestedMatcher(A); | 
|  | recordNestedMatcher(B); | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // Constraints on input operands. | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | void transform_ext::StructuredOpMatcher::addInputMatcher( | 
|  | int64_t position, std::function<bool(Operation *)> matcher, | 
|  | OptionalMatch optional) { | 
|  | addInputMatcher( | 
|  | position, | 
|  | // No need to handle optional inside the lambda, the wrapper will do that. | 
|  | [matcher = std::move(matcher)](Value value) { | 
|  | Operation *definingOp = value.getDefiningOp(); | 
|  | return definingOp && matcher(definingOp); | 
|  | }, | 
|  | optional); | 
|  | } | 
|  |  | 
|  | void transform_ext::StructuredOpMatcher::addInputMatcher( | 
|  | int64_t position, std::function<bool(Value)> matcher, | 
|  | OptionalMatch optional) { | 
|  | addPredicate([position, optional, matcher = std::move(matcher)]( | 
|  | linalg::LinalgOp linalgOp) -> bool { | 
|  | int64_t transformedPosition = | 
|  | position >= 0 ? position : linalgOp.getNumDpsInputs() + position; | 
|  | if (transformedPosition >= linalgOp.getNumDpsInputs()) { | 
|  | LLVM_DEBUG(DBGS() << "input operand #" << position | 
|  | << " does not exist but match required"); | 
|  | return false; | 
|  | } | 
|  |  | 
|  | LLVM_DEBUG(DBGS() << "input operand #" << position | 
|  | << (optional.value ? " (optional match) " : " ") | 
|  | << "is\n"); | 
|  |  | 
|  | // We MUST run the matcher at this point, even if the match is optional, | 
|  | // to allow for capture. | 
|  | LLVM_DEBUG(DBGS() << "start recursive match {\n"); | 
|  | auto debugRAII = llvm::make_scope_exit( | 
|  | [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); | 
|  | if (matcher(linalgOp.getDpsInputOperand(transformedPosition)->get())) | 
|  | return true; | 
|  | return optional.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "all input operands have permutation maps"); | 
|  | // all_of with a lambda requires const-casting dance, so using a loop. | 
|  | for (OpOperand *operand : linalgOp.getDpsInputOperands()) { | 
|  | if (!linalgOp.getMatchingIndexingMap(operand).isPermutation()) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(AllOperands tag, | 
|  | IsProjectedPermutation) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "all input operands have projected permutation maps"); | 
|  | // all_of with a lambda requires const-casting dance, so using a loop. | 
|  | for (OpOperand *operand : linalgOp.getDpsInputOperands()) { | 
|  | if (!linalgOp.getMatchingIndexingMap(operand).isProjectedPermutation()) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | /// Helper to check if the map is an identity map with a projected dim. | 
|  | static bool isProjectedMap(AffineMap map, int64_t projectedDim) { | 
|  | if (!map.isProjectedPermutation()) | 
|  | return false; | 
|  | int64_t dimCounter = 0; | 
|  | for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { | 
|  | // Skip the project dim. | 
|  | if (dimCounter == projectedDim) | 
|  | dimCounter++; | 
|  | if (map.getDimPosition(i) != dimCounter++) { | 
|  | return false; | 
|  | } | 
|  | } | 
|  | return true; | 
|  | } | 
|  |  | 
|  | /// Helper to turn a potentially negative index to positive within the range | 
|  | /// [0, ub) and indicate whether the transformed index is in bounds. | 
|  | static bool makeValidPositiveIndex(int64_t &index, int64_t ub) { | 
|  | int64_t positiveIndex = index >= 0 ? index : ub + index; | 
|  | if (positiveIndex < 0 || ub < positiveIndex) { | 
|  | LLVM_DEBUG(DBGSNL() << "  index out of range"); | 
|  | return false; | 
|  | } | 
|  | index = positiveIndex; | 
|  | return true; | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions, | 
|  | IsProjected dim) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "operands "; | 
|  | llvm::interleaveComma(positions, llvm::dbgs()); | 
|  | llvm::dbgs() << " have a permutation maps with " << dim.value | 
|  | << " projected"); | 
|  | int64_t updatedDim = dim.value; | 
|  | if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) | 
|  | return false; | 
|  | for (int64_t position : positions) { | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) | 
|  | return false; | 
|  | OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); | 
|  | if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), updatedDim)) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "all input operands have identity maps"); | 
|  | // all_of with a lambda requires const-casting dance, so using a loop. | 
|  | for (OpOperand *operand : linalgOp.getDpsInputOperands()) { | 
|  | if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions, | 
|  | IsIdentity) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "input operands "; | 
|  | llvm::interleaveComma(positions, llvm::dbgs()); | 
|  | llvm::dbgs() << " have identity maps"); | 
|  | // all_of with a lambda requires const-casting dance, so using a loop. | 
|  | for (int64_t position : positions) { | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) | 
|  | return false; | 
|  | OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); | 
|  | if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(int64_t position, | 
|  | ElementTypeBitWidth width) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "input operand #" << position | 
|  | << " has elemental type with bit width " << width.value); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) | 
|  | return false; | 
|  | auto shapedType = linalgOp.getDpsInputOperand(updatedPosition) | 
|  | ->get() | 
|  | .getType() | 
|  | .dyn_cast<ShapedType>(); | 
|  | return shapedType && shapedType.getElementType().isIntOrFloat() && | 
|  | shapedType.getElementType().getIntOrFloatBitWidth() == width.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(int64_t position, | 
|  | CaptureElementTypeBitWidth width) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "input operand #" << position << " capture bitwidth"); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) | 
|  | return false; | 
|  | auto shapedType = linalgOp.getDpsInputOperand(updatedPosition) | 
|  | ->get() | 
|  | .getType() | 
|  | .dyn_cast<ShapedType>(); | 
|  | if (!shapedType || !shapedType.getElementType().isIntOrFloat()) | 
|  | return false; | 
|  | width.value = shapedType.getElementType().getIntOrFloatBitWidth(); | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(int64_t position, | 
|  | CaptureElementType elem) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "input operand #" << position | 
|  | << " capture element type"); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) | 
|  | return false; | 
|  | auto shapedType = linalgOp.getDpsInputOperand(updatedPosition) | 
|  | ->get() | 
|  | .getType() | 
|  | .dyn_cast<ShapedType>(); | 
|  | if (!shapedType) { | 
|  | LLVM_DEBUG(DBGSNL() << "  not a shaped type"); | 
|  | return false; | 
|  | } | 
|  | elem.value = shapedType.getElementType(); | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(NumEqualsTo num) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "number of input operands == " << num.value); | 
|  | return linalgOp.getNumDpsInputs() == num.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(int64_t position, | 
|  | ConstantFloatMinOrMinusInf) { | 
|  | return input(position, [](llvm::APFloat f) { | 
|  | return (f.isLargest() || f.isInfinity()) && f.isNegative(); | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatZero) { | 
|  | return input(position, [](llvm::APFloat f) { return f.isZero(); }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input( | 
|  | int64_t position, std::function<bool(llvm::APFloat)> floatValueFn) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "input operand #" << position | 
|  | << " is a special floating point constant"); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) | 
|  | return false; | 
|  | auto cstOp = linalgOp.getDpsInputOperand(updatedPosition) | 
|  | ->get() | 
|  | .getDefiningOp<arith::ConstantFloatOp>(); | 
|  | if (!cstOp) | 
|  | return false; | 
|  | return floatValueFn(cstOp.value()); | 
|  | }); | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // Constraints on output operands. | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | void transform_ext::StructuredOpMatcher::addOutputMatcher( | 
|  | int64_t position, std::function<bool(Operation *)> matcher, | 
|  | OptionalMatch optional) { | 
|  | addPredicate([position, optional, matcher = std::move(matcher)]( | 
|  | linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "output operand #" << position | 
|  | << (optional.value ? " (optional match) " | 
|  | : " (mandatory match) ") | 
|  | << "is produced by\n"); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) | 
|  | return false; | 
|  | Operation *definingOp = | 
|  | linalgOp.getDpsInitOperand(updatedPosition)->get().getDefiningOp(); | 
|  | if (!definingOp) | 
|  | return optional.value; | 
|  | // We MUST run the matcher at this point, even if the match is optional, | 
|  | // to allow for capture. | 
|  | LLVM_DEBUG(DBGS() << "start recursive match {\n"); | 
|  | auto debugRAII = llvm::make_scope_exit( | 
|  | [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); | 
|  | if (matcher(definingOp)) | 
|  | return true; | 
|  | return optional.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "all output operands have permutation maps"); | 
|  | for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { | 
|  | if (!linalgOp.getMatchingIndexingMap(&operand).isPermutation()) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::output(AllOperands tag, | 
|  | IsProjectedPermutation) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "all output operands have projected permutation maps"); | 
|  | for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { | 
|  | if (!linalgOp.getMatchingIndexingMap(&operand).isProjectedPermutation()) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "all output operands have a maps with projected"); | 
|  | int64_t updatedDim = dim.value; | 
|  | if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) | 
|  | return false; | 
|  | // all_of with a lambda requires const-casting dance, so using a loop. | 
|  | for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { | 
|  | if (!isProjectedMap(linalgOp.getMatchingIndexingMap(&operand), | 
|  | updatedDim)) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps"); | 
|  | for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { | 
|  | if (!linalgOp.getMatchingIndexingMap(&operand).isIdentity()) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::output(int64_t position, | 
|  | ElementTypeBitWidth width) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "output operand #" << position | 
|  | << " has elemental type with bit width " << width.value); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) | 
|  | return false; | 
|  | auto shapedType = linalgOp.getDpsInitOperand(updatedPosition) | 
|  | ->get() | 
|  | .getType() | 
|  | .dyn_cast<ShapedType>(); | 
|  | return shapedType && shapedType.getElementType().isIntOrFloat() && | 
|  | shapedType.getElementType().getIntOrFloatBitWidth() == width.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::output(int64_t position, | 
|  | CaptureElementTypeBitWidth width) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "output operand #" << position << " capture bitwidth"); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) | 
|  | return false; | 
|  | auto shapedType = linalgOp.getDpsInitOperand(updatedPosition) | 
|  | ->get() | 
|  | .getType() | 
|  | .dyn_cast<ShapedType>(); | 
|  | if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { | 
|  | LLVM_DEBUG(DBGSNL() << "  could not infer element type"); | 
|  | return false; | 
|  | } | 
|  | width.value = shapedType.getElementType().getIntOrFloatBitWidth(); | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::output(int64_t position, | 
|  | CaptureElementType elem) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "output operand #" << position | 
|  | << " capture element type"); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) | 
|  | return false; | 
|  | auto shapedType = linalgOp.getDpsInitOperand(updatedPosition) | 
|  | ->get() | 
|  | .getType() | 
|  | .dyn_cast<ShapedType>(); | 
|  | if (!shapedType) { | 
|  | LLVM_DEBUG(DBGSNL() << "  not a shaped type"); | 
|  | return false; | 
|  | } | 
|  | elem.value = shapedType.getElementType(); | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::output(int64_t position, | 
|  | SingleCombinerReduction tag) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "output operand #" << position | 
|  | << " is populated by a single-combiner reduction"); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) | 
|  | return false; | 
|  | SmallVector<Operation *> combinerOps; | 
|  | return matchReduction(linalgOp.getRegionOutputArgs(), updatedPosition, | 
|  | combinerOps) && | 
|  | llvm::hasSingleElement(combinerOps); | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::output(NumEqualsTo num) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "number of output operands == " << num.value); | 
|  | return linalgOp.getNumDpsInits() == num.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // Constraints on results. | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | void transform_ext::StructuredOpMatcher::addResultMatcher( | 
|  | int64_t position, HasAnyUse tag, std::function<bool(Operation *)> matcher, | 
|  | OptionalMatch optional) { | 
|  | addPredicate([matcher = std::move(matcher), optional, | 
|  | position](linalg::LinalgOp linalgOp) -> bool { | 
|  | LLVM_DEBUG(DBGS() << "result #" << position | 
|  | << (optional.value ? " (optional match) " | 
|  | : " (mandatory match) ") | 
|  | << "has a use\n"); | 
|  | int64_t updatedPosition = position; | 
|  | if (!makeValidPositiveIndex(updatedPosition, linalgOp->getNumResults())) | 
|  | return false; | 
|  |  | 
|  | // We MUST run the matcher at this point, even if the match is optional, | 
|  | // to allow for capture. | 
|  | LLVM_DEBUG(DBGS() << "start recursive match {\n"); | 
|  | auto debugRAII = llvm::make_scope_exit( | 
|  | [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); | 
|  | if (llvm::any_of(linalgOp->getResult(updatedPosition).getUsers(), | 
|  | [&matcher](Operation *op) { return matcher(op); })) { | 
|  | return true; | 
|  | } | 
|  | return optional.value; | 
|  | }); | 
|  | } | 
|  |  | 
|  | //===-------------------------------------------------------------------===// | 
|  | // Constraints on op region. | 
|  | //===-------------------------------------------------------------------===// | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs( | 
|  | StringRef opcode, bool commutative) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) { | 
|  | if (linalgOp.getBlock()->getOperations().size() != 2) | 
|  | return false; | 
|  | Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); | 
|  | if (innerOp->getName().getStringRef() != opcode || | 
|  | innerOp->getNumResults() != 1) | 
|  | return false; | 
|  | Operation *yieldOp = linalgOp.getBlock()->getTerminator(); | 
|  | if (yieldOp->getNumOperands() != 1) | 
|  | return false; | 
|  | if (yieldOp->getOperand(0).getDefiningOp() != innerOp) | 
|  | return false; | 
|  | if (commutative && innerOp->getNumOperands() == 2) { | 
|  | auto arg0 = dyn_cast<BlockArgument>(innerOp->getOperand(0)); | 
|  | auto arg1 = dyn_cast<BlockArgument>(innerOp->getOperand(1)); | 
|  | if (!arg0 || !arg1) | 
|  | return false; | 
|  | if (arg0.getParentBlock() != linalgOp.getBlock() || | 
|  | arg1.getParentBlock() != linalgOp.getBlock()) | 
|  | return false; | 
|  | if (!((arg0.getArgNumber() == 0 && arg1.getArgNumber() == 1) || | 
|  | (arg1.getArgNumber() == 0 && arg0.getArgNumber() == 1))) | 
|  | return false; | 
|  | } else { | 
|  | for (auto [index, operand] : llvm::enumerate(innerOp->getOperands())) { | 
|  | auto arg = dyn_cast<BlockArgument>(operand); | 
|  | if (!arg || arg.getParentBlock() != linalgOp.getBlock() || | 
|  | arg.getArgNumber() != index) | 
|  | return false; | 
|  | } | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::isFloatReciprocal() { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) { | 
|  | LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation"); | 
|  | if (linalgOp.getBlock()->getOperations().size() != 2) | 
|  | return false; | 
|  | Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); | 
|  | if (!isa<arith::DivFOp>(innerOp) || innerOp->getNumResults() != 1) | 
|  | return false; | 
|  | Operation *yieldOp = linalgOp.getBlock()->getTerminator(); | 
|  | if (yieldOp->getNumOperands() != 1) | 
|  | return false; | 
|  | if (yieldOp->getOperand(0).getDefiningOp() != innerOp) | 
|  | return false; | 
|  | auto cst = innerOp->getOperand(0).getDefiningOp<arith::ConstantFloatOp>(); | 
|  | if (!cst || cst.value().convertToDouble() != 1.0) | 
|  | return false; | 
|  | auto arg = dyn_cast<BlockArgument>(innerOp->getOperand(1)); | 
|  | if (!arg || arg.getParentBlock() != linalgOp.getBlock() || | 
|  | arg.getArgNumber() != 0) | 
|  | return false; | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::passThroughOp() { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) { | 
|  | if (linalgOp.getBlock()->getOperations().size() != 1) | 
|  | return false; | 
|  | Operation *yieldOp = linalgOp.getBlock()->getTerminator(); | 
|  | for (auto [index, operand] : llvm::enumerate(yieldOp->getOperands())) { | 
|  | auto arg = dyn_cast<BlockArgument>(operand); | 
|  | if (!arg || arg.getParentBlock() != linalgOp.getBlock() || | 
|  | arg.getArgNumber() != index) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::StructuredOpMatcher & | 
|  | transform_ext::StructuredOpMatcher::hasContractionBody( | 
|  | function_ref<bool(Operation *)> isaElemOpTy, | 
|  | function_ref<bool(Operation *)> isaReductionOpTy, StringRef elemOpName, | 
|  | StringRef reductionOpName) { | 
|  | return addPredicate([=](linalg::LinalgOp linalgOp) { | 
|  | LLVM_DEBUG(DBGS() << "op region is a " << elemOpName << "/" | 
|  | << reductionOpName << " contraction ("); | 
|  | auto scopeExitPrinter = llvm::make_scope_exit( | 
|  | [] { LLVM_DEBUG(llvm::dbgs() << " check failed)"); }); | 
|  |  | 
|  | Block *body = linalgOp.getBlock(); | 
|  | if (!llvm::hasNItems(*body, 3)) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "three-operation body"); | 
|  | return false; | 
|  | } | 
|  | if (body->getNumArguments() != 3) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "three-argument block"); | 
|  | return false; | 
|  | } | 
|  |  | 
|  | Operation *elemOp = &(*linalgOp.getBlock()->getOperations().begin()); | 
|  | Operation *reductionOp = elemOp->getNextNode(); | 
|  | Operation *yieldOp = reductionOp->getNextNode(); | 
|  | if (!isaElemOpTy(elemOp)) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "first operation is a " << elemOpName); | 
|  | return false; | 
|  | } | 
|  | if (!isaReductionOpTy(reductionOp)) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "second operation is a " << reductionOpName); | 
|  | return false; | 
|  | } | 
|  | if (yieldOp->getNumOperands() != 1) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "one value yielded"); | 
|  | return false; | 
|  | } | 
|  | if (yieldOp->getOperand(0).getDefiningOp() != reductionOp) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "yielded value produced by the second op"); | 
|  | return false; | 
|  | } | 
|  | if (elemOp->getNumOperands() != 2 || elemOp->getNumResults() != 1) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "first op has two operands and one result"); | 
|  | return false; | 
|  | } | 
|  | if (reductionOp->getNumOperands() != 2 || | 
|  | reductionOp->getNumResults() != 1) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "second op has two operands and one result"); | 
|  | return false; | 
|  | } | 
|  |  | 
|  | SmallVector<Value, 2> expectedReductionOperands = {body->getArgument(2), | 
|  | elemOp->getResult(0)}; | 
|  | if (!llvm::equal(expectedReductionOperands, reductionOp->getOperands()) && | 
|  | !llvm::equal(llvm::reverse(expectedReductionOperands), | 
|  | reductionOp->getOperands())) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "operands of the second op"); | 
|  | return false; | 
|  | } | 
|  |  | 
|  | ValueRange expectedElemOperands = body->getArguments().take_front(2); | 
|  | if (!llvm::equal(expectedElemOperands, elemOp->getOperands()) && | 
|  | !llvm::equal(llvm::reverse(expectedElemOperands), | 
|  | elemOp->getOperands())) { | 
|  | LLVM_DEBUG(llvm::dbgs() << "operands of the first op"); | 
|  | return false; | 
|  | } | 
|  |  | 
|  | scopeExitPrinter.release(); | 
|  | LLVM_DEBUG(llvm::dbgs() << "success)"); | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor( | 
|  | StringRef name) { | 
|  | LLVM_DEBUG(DBGS() << "op is a " << name << "'"); | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // TensorPadOpMatcher | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | transform_ext::TensorPadOpMatcher & | 
|  | transform_ext::TensorPadOpMatcher::low(ArrayRef<int64_t> sizes) { | 
|  | return addPredicate([=](tensor::PadOp tensorPad) { | 
|  | LLVM_DEBUG({ | 
|  | DBGS() << "low pad sizes are "; | 
|  | llvm::interleaveComma(sizes, llvm::dbgs()); | 
|  | }); | 
|  | for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedLowPad(), sizes)) { | 
|  | if (isConstantIntValue(ofr, sz)) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::TensorPadOpMatcher & | 
|  | transform_ext::TensorPadOpMatcher::low(AllDims tag, int64_t size) { | 
|  | return addPredicate([=](tensor::PadOp tensorPad) { | 
|  | LLVM_DEBUG(DBGS() << "all low pad sizes are " << size); | 
|  | return llvm::all_of(tensorPad.getMixedLowPad(), [&](OpFoldResult ofr) { | 
|  | return isConstantIntValue(ofr, size); | 
|  | }); | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::TensorPadOpMatcher & | 
|  | transform_ext::TensorPadOpMatcher::high(ArrayRef<int64_t> sizes) { | 
|  | return addPredicate([=](tensor::PadOp tensorPad) { | 
|  | LLVM_DEBUG({ | 
|  | DBGS() << "high pad sizes are "; | 
|  | llvm::interleaveComma(sizes, llvm::dbgs()); | 
|  | }); | 
|  | for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedHighPad(), sizes)) { | 
|  | if (isConstantIntValue(ofr, sz)) | 
|  | return false; | 
|  | } | 
|  | return true; | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::TensorPadOpMatcher & | 
|  | transform_ext::TensorPadOpMatcher::high(AllDims tag, int64_t size) { | 
|  | return addPredicate([=](tensor::PadOp tensorPad) { | 
|  | LLVM_DEBUG(DBGS() << "all high pad sizes are " << size); | 
|  | return llvm::all_of(tensorPad.getMixedHighPad(), [&](OpFoldResult ofr) { | 
|  | return isConstantIntValue(ofr, size); | 
|  | }); | 
|  | }); | 
|  | } | 
|  |  | 
|  | transform_ext::TensorPadOpMatcher & | 
|  | transform_ext::TensorPadOpMatcher::yieldsExternalValue() { | 
|  | return addPredicate([=](tensor::PadOp tensorPad) { | 
|  | LLVM_DEBUG(DBGS() << "pad body yields an externally-defined value"); | 
|  | Block *body = tensorPad.getBody(); | 
|  | if (!llvm::hasSingleElement(*body)) | 
|  | return false; | 
|  | return llvm::all_of(body->getTerminator()->getOperands(), | 
|  | [body](Value operand) { | 
|  | auto arg = dyn_cast<BlockArgument>(operand); | 
|  | return !arg || arg.getOwner() != body; | 
|  | }); | 
|  | }); | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // MatchCallbackResult. | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | ArrayRef<Operation *> | 
|  | transform_ext::MatchCallbackResult::getPayloadGroup(int64_t position) const { | 
|  | assert(position < payloadGroupLengths.size()); | 
|  | int64_t start = 0; | 
|  | for (int64_t i = 0; i < position; ++i) { | 
|  | start += payloadGroupLengths[i]; | 
|  | } | 
|  | return llvm::ArrayRef(payloadOperations) | 
|  | .slice(start, payloadGroupLengths[position]); | 
|  | } | 
|  |  | 
|  | //===---------------------------------------------------------------------===// | 
|  | // Case-specific matcher builders. | 
|  | //===---------------------------------------------------------------------===// | 
|  |  | 
|  | static constexpr int64_t kCudaWarpSize = 32; | 
|  |  | 
|  | void transform_ext::makeReductionMatcher( | 
|  | transform_ext::MatcherContext &matcherContext, | 
|  | transform_ext::StructuredOpMatcher *&reductionCapture, | 
|  | transform_ext::StructuredOpMatcher *&fillCapture, | 
|  | transform_ext::StructuredOpMatcher *&leadingCapture, | 
|  | transform_ext::StructuredOpMatcher *&trailingCapture, | 
|  | MatchedReductionCaptures &captures, bool mustMatchEntireFunc) { | 
|  | // The core part of the matcher is anchored on a particular reduction op. | 
|  | auto &reduction = | 
|  | m_StructuredOp(matcherContext) | 
|  | // Op has at least a parallel a reduction dimension and at | 
|  | // most 3 parallel dimensions. | 
|  | // TODO: relax once we have global collapse/expand_shape. | 
|  | // | 
|  | .rank(NumGreaterEqualTo(2)) | 
|  | .rank(NumLowerEqualTo(4)) | 
|  | .rank(CaptureRank(captures.reductionRank)) | 
|  | // Op has a single most-minor reduction. | 
|  | .dim(-1, utils::IteratorType::reduction) | 
|  | // Capture op sizes. | 
|  | .dim(AllDims(), CaptureDims(captures.reductionOpSizes)) | 
|  | // All other dimensions are parallel. | 
|  | .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) | 
|  | // Single input for now, can be arbitrary projected permutations. | 
|  | // TODO: Multiple inputs, can be arbitrary projected permutations. | 
|  | // TODO: Watch out for multiple inputs though as a reduction turns | 
|  | //       into a contraction when mixed with projected | 
|  | //       permutations. A reduction is often bandwidth bound but | 
|  | //       contraction is a different beast that is compute bound | 
|  | //       and has a very different schedule. | 
|  | // | 
|  | .input(NumEqualsTo(1)) | 
|  | .input(AllOperands(), IsProjectedPermutation()) | 
|  | // Single output supported atm. | 
|  | // TODO: Multiple outputs. | 
|  | // | 
|  | .output(NumEqualsTo(1)) | 
|  | // A reduction output must be a projected permutation, match it but we | 
|  | // could also drop this technically. | 
|  | .output(AllOperands(), IsProjectedPermutation()) | 
|  | // Only single combiner for now due to reduction warp | 
|  | // distribution. | 
|  | // TODO: relax this once reduction distribution is more powerful. | 
|  | // | 
|  | .output(0, CaptureElementTypeBitWidth( | 
|  | captures.reductionOutputElementalTypeBitWidth)) | 
|  | .output(0, SingleCombinerReduction()); | 
|  | reductionCapture = &reduction; | 
|  |  | 
|  | // Mandatory FillOp must create the unique output of the reduction. | 
|  | // TODO: Relax this, as any map, broadcast, transpose should also work. | 
|  | // | 
|  | auto &fill = m_StructuredOp<linalg::FillOp>(matcherContext); | 
|  | reduction = reduction.output(NumEqualsTo(1)).output(0, fill); | 
|  | fillCapture = &fill; | 
|  |  | 
|  | // Optional leading or trailing op can be any map, transpose, broadcast but | 
|  | // not reduce or windowing operation for now. | 
|  | // It must create the unique input for the reduction. | 
|  | // TODO: match more optional leading ops, one per input of the reduction. | 
|  | // TODO: careful about multi-output and turning into a contraction. | 
|  | // | 
|  | transform_ext::StructuredOpMatcher commonLeadingOrTrailing = | 
|  | m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | // All parallel dimensions. | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | // All inputs are any projected permutation. | 
|  | .input(AllOperands(), IsProjectedPermutation()) | 
|  | .output(AllOperands(), IsPermutation()) | 
|  | // leading and trailing may have 0, 1 or more input as long as they do | 
|  | // not come from unmatched ops. This extra constraint is taken care of | 
|  | // separately. This is also a noop but we document it. | 
|  | // TODO: Base and derived classes, atm this does not compile. | 
|  | // .input(NumGreaterEqualTo(0)) | 
|  | // Single output supported atm. | 
|  | // TODO: extend this. | 
|  | // | 
|  | .output(NumEqualsTo(1)); | 
|  | // TODO: match more optional leading ops, one per input of the reduction. | 
|  | // TODO: careful about multi-output and turning into a contraction. | 
|  | // | 
|  | auto &leading = | 
|  | m_StructuredOp(matcherContext, commonLeadingOrTrailing) | 
|  | .rank(CaptureRank(captures.maybeLeadingRank)) | 
|  | // Capture op sizes. | 
|  | .dim(AllDims(), CaptureDims(captures.leadingOpSizes)) | 
|  | // Capture output elemental type. | 
|  | .output(0, CaptureElementTypeBitWidth( | 
|  | captures.maybeLeadingOutputElementalTypeBitWidth)); | 
|  | reduction = reduction.input(0, leading, OptionalMatch()); | 
|  | leadingCapture = &leading; | 
|  |  | 
|  | // Optional trailing can be any map, transpose, broadcast but not reduce or | 
|  | // windowing operation for now. | 
|  | // It must be fed by the unique input for the reduction. | 
|  | // TODO: match more optional leading ops, one per input of the reduction. | 
|  | // TODO: careful about multi-output and turning into a contraction. | 
|  | // | 
|  | auto &trailing = | 
|  | m_StructuredOp(matcherContext, commonLeadingOrTrailing) | 
|  | .rank(CaptureRank(captures.maybeTrailingRank)) | 
|  | // Capture op sizes. | 
|  | .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) | 
|  | // Capture output elemental type. | 
|  | .output(0, CaptureElementTypeBitWidth( | 
|  | captures.maybeTrailingOutputElementalTypeBitWidth)); | 
|  | reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch()); | 
|  | if (mustMatchEntireFunc) | 
|  | reduction = reduction.allTilableOpsCaptured<func::FuncOp>(); | 
|  | trailingCapture = &trailing; | 
|  | } | 
|  |  | 
|  | void transform_ext::makeReductionMatcher(transform_ext::MatcherContext &context, | 
|  | StructuredOpMatcher *&reductionCapture, | 
|  | MatchedReductionCaptures &captures, | 
|  | bool mustMatchEntireFunc) { | 
|  | StructuredOpMatcher *fill; | 
|  | StructuredOpMatcher *leading; | 
|  | StructuredOpMatcher *trailing; | 
|  | makeReductionMatcher(context, reductionCapture, fill, leading, trailing, | 
|  | captures, mustMatchEntireFunc); | 
|  | } | 
|  |  | 
|  | void transform_ext::makeMatmulMatcher( | 
|  | transform_ext::MatcherContext &matcherContext, | 
|  | transform_ext::StructuredOpMatcher *&matmulCapture, | 
|  | transform_ext::StructuredOpMatcher *&fillCapture, | 
|  | transform_ext::StructuredOpMatcher *&trailingCapture, | 
|  | transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { | 
|  | auto &matmul = transform_ext::m_StructuredOp<linalg::MatmulOp>(matcherContext) | 
|  | // Capture op sizes. | 
|  | .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) | 
|  | // Capture input/output element types. | 
|  | .input(0, CaptureElementType(captures.lhsElementType)) | 
|  | .input(1, CaptureElementType(captures.rhsElementType)) | 
|  | .output(0, CaptureElementType(captures.outputElementType)); | 
|  | matmulCapture = &matmul; | 
|  | // Mandatory FillOp must create the unique output of the reduction. | 
|  | auto &fill = transform_ext::m_StructuredOp<linalg::FillOp>(matcherContext); | 
|  | matmul = matmul.output(transform_ext::NumEqualsTo(1)).output(0, fill); | 
|  | fillCapture = &fill; | 
|  |  | 
|  | auto &trailing = m_StructuredOp<linalg::GenericOp>(matcherContext); | 
|  | matmul = matmul.result(0, HasAnyUse(), trailing, OptionalMatch()); | 
|  | if (mustMatchEntireFunc) | 
|  | matmul = matmul.allTilableOpsCaptured<func::FuncOp>(); | 
|  | trailingCapture = &trailing; | 
|  | } | 
|  |  | 
|  | void transform_ext::makeBatchMatmulMatcher( | 
|  | transform_ext::MatcherContext &matcherContext, | 
|  | transform_ext::StructuredOpMatcher *&bmmCapture, | 
|  | transform_ext::StructuredOpMatcher *&fillCapture, | 
|  | transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { | 
|  | auto &bmm = | 
|  | transform_ext::m_StructuredOp<linalg::BatchMatmulOp, linalg::GenericOp>( | 
|  | matcherContext) | 
|  | .hasContractionBody<arith::MulFOp, arith::AddFOp>() | 
|  | .rank(NumEqualsTo(4)) | 
|  | .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) | 
|  | .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) | 
|  | .dim(-1, utils::IteratorType::reduction) | 
|  | .contractionDims(CaptureContractionDims(captures.contractionDims)) | 
|  | .input(NumEqualsTo(2)) | 
|  | .input(0, CaptureElementType(captures.lhsElementType)) | 
|  | .input(1, CaptureElementType(captures.rhsElementType)) | 
|  | .output(0, CaptureElementType(captures.outputElementType)); | 
|  | bmmCapture = &bmm; | 
|  |  | 
|  | auto &fill = transform_ext::m_StructuredOp<linalg::FillOp>(matcherContext); | 
|  | bmm = bmm.output(0, fill); | 
|  | fillCapture = &fill; | 
|  |  | 
|  | if (mustMatchEntireFunc) | 
|  | bmm = bmm.allTilableOpsCaptured<func::FuncOp>(); | 
|  | } | 
|  |  | 
|  | /// Match sum(%src, broadcast(%reduction)) | 
|  | static void | 
|  | matchSubBroadcast(transform_ext::MatcherContext &matcherContext, | 
|  | transform_ext::StructuredOpMatcher &maxReduction, | 
|  | transform_ext::CapturingValueMatcher &softmaxSourceOperand, | 
|  | transform_ext::StructuredOpMatcher *&sub) { | 
|  | using namespace transform_ext; | 
|  |  | 
|  | auto &broadcast = | 
|  | transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .passThroughOp() | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(1)) | 
|  | .input(0, IsProjected(-1)) | 
|  | .output(NumEqualsTo(1)) | 
|  | .output(AllOperands(), IsIdentity()); | 
|  | broadcast = broadcast.input(0, maxReduction); | 
|  |  | 
|  | auto &subParallel = | 
|  | transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .singleOpWithCanonicaleArgs<arith::SubFOp>() | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(2)) | 
|  | .input(0, IsIdentity()) | 
|  | .input(1, IsIdentity()) | 
|  | .output(NumEqualsTo(1)) | 
|  | .output(AllOperands(), IsIdentity()); | 
|  | subParallel = subParallel.input(0, softmaxSourceOperand); | 
|  | subParallel = subParallel.input(1, broadcast); | 
|  |  | 
|  | auto &subBroadcast = | 
|  | transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .singleOpWithCanonicaleArgs<arith::SubFOp>() | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(2)) | 
|  | .input(0, IsIdentity()) | 
|  | .input(1, IsProjected(-1)) | 
|  | .output(NumEqualsTo(1)) | 
|  | .output(AllOperands(), IsIdentity()); | 
|  | subBroadcast = subBroadcast.input(0, softmaxSourceOperand); | 
|  | subBroadcast = subBroadcast.input(1, maxReduction); | 
|  | auto &subOr = transform_ext::m_StructuredOp_Or(matcherContext, subBroadcast, | 
|  | subParallel); | 
|  | sub = &subOr; | 
|  | } | 
|  |  | 
|  | /// Match sum(%exp, broadcast(%sum)) | 
|  | static void matchdivBroadcast(transform_ext::MatcherContext &matcherContext, | 
|  | transform_ext::StructuredOpMatcher &expOperand, | 
|  | transform_ext::StructuredOpMatcher &sum, | 
|  | transform_ext::StructuredOpMatcher *&div) { | 
|  | using namespace transform_ext; | 
|  |  | 
|  | auto &broadcast = | 
|  | transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .passThroughOp() | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(1)) | 
|  | .input(0, IsProjected(-1)) | 
|  | .output(NumEqualsTo(1)) | 
|  | .output(AllOperands(), IsIdentity()); | 
|  | broadcast = broadcast.input(0, sum); | 
|  |  | 
|  | auto &divNoBroadcast = | 
|  | transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .singleOpWithCanonicaleArgs<arith::DivFOp>() | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(2)) | 
|  | .input(0, IsIdentity()) | 
|  | .input(1, IsIdentity()) | 
|  | .output(NumEqualsTo(1)) | 
|  | .output(AllOperands(), IsIdentity()); | 
|  |  | 
|  | divNoBroadcast = divNoBroadcast.input(0, expOperand); | 
|  | divNoBroadcast = divNoBroadcast.input(1, broadcast); | 
|  |  | 
|  | auto &divBroadcast = | 
|  | transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .singleOpWithCanonicaleArgs<arith::DivFOp>() | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(2)) | 
|  | .input(0, IsIdentity()) | 
|  | .input(1, IsProjected(-1)) | 
|  | .output(NumEqualsTo(1)) | 
|  | .output(AllOperands(), IsIdentity()); | 
|  |  | 
|  | divBroadcast = divBroadcast.input(0, expOperand); | 
|  | divBroadcast = divBroadcast.input(1, sum); | 
|  |  | 
|  | auto &divMerge = transform_ext::m_StructuredOp_Or( | 
|  | matcherContext, divNoBroadcast, divBroadcast); | 
|  | div = &divMerge; | 
|  | } | 
|  |  | 
|  | void transform_ext::makeSoftmaxMatcher( | 
|  | transform_ext::MatcherContext &matcherContext, | 
|  | transform_ext::StructuredOpMatcher *&maxReductionCapture, | 
|  | transform_ext::StructuredOpMatcher *&softmaxRootCapture) { | 
|  | auto &softmaxSourceOperand = m_Value(matcherContext); | 
|  |  | 
|  | auto &fillMinusInf = m_StructuredOp<linalg::FillOp>(matcherContext) | 
|  | .input(0, ConstantFloatMinOrMinusInf()); | 
|  | auto &maxReduction = | 
|  | transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .singleOpWithCanonicaleArgs<arith::MaximumFOp>(/*commutative=*/true) | 
|  | // Only handle most inner reduction for now. | 
|  | .dim(-1, utils::IteratorType::reduction) | 
|  | .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(1)) | 
|  | .input(AllOperands(), IsIdentity()) | 
|  | .output(NumEqualsTo(1)) | 
|  | .output(AllOperands(), IsProjected(-1)); | 
|  | maxReduction = maxReduction.input(0, softmaxSourceOperand); | 
|  | maxReduction = maxReduction.output(0, fillMinusInf); | 
|  | maxReductionCapture = &maxReduction; | 
|  |  | 
|  | transform_ext::StructuredOpMatcher *subOperand; | 
|  | matchSubBroadcast(matcherContext, maxReduction, softmaxSourceOperand, | 
|  | subOperand); | 
|  |  | 
|  | auto &expOperand = m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .singleOpWithCanonicaleArgs<math::ExpOp>() | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(1)) | 
|  | .input(AllOperands(), IsIdentity()) | 
|  | .output(AllOperands(), IsIdentity()) | 
|  | .output(NumEqualsTo(1)); | 
|  | expOperand = expOperand.input(0, *subOperand); | 
|  |  | 
|  | auto &fillZero = m_StructuredOp<linalg::FillOp>(matcherContext) | 
|  | .input(0, ConstantFloatZero()); | 
|  | auto &sum = | 
|  | m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .singleOpWithCanonicaleArgs<arith::AddFOp>(/*commutative=*/true) | 
|  | // Only handle most inner reduction for now. | 
|  | .dim(-1, utils::IteratorType::reduction) | 
|  | .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(1)) | 
|  | .input(AllOperands(), IsIdentity()) | 
|  | .output(AllOperands(), IsProjected(-1)) | 
|  | .output(NumEqualsTo(1)); | 
|  | sum = sum.input(0, expOperand); | 
|  | sum = sum.output(0, fillZero); | 
|  |  | 
|  | auto &rcpOperand = m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .isFloatReciprocal() | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(1)) | 
|  | .input(AllOperands(), IsIdentity()) | 
|  | .output(AllOperands(), IsIdentity()) | 
|  | .output(NumEqualsTo(1)); | 
|  | rcpOperand = rcpOperand.input(0, sum); | 
|  |  | 
|  | auto &mulOperand = | 
|  | transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | .singleOpWithCanonicaleArgs<arith::MulFOp>(/*commutative=*/true) | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | .input(NumEqualsTo(2)) | 
|  | .input(0, IsIdentity()) | 
|  | .input(1, IsProjected(-1)) | 
|  | .output(NumEqualsTo(1)) | 
|  | .output(AllOperands(), IsIdentity()); | 
|  |  | 
|  | mulOperand = mulOperand.input(0, expOperand); | 
|  | mulOperand = mulOperand.input(1, rcpOperand); | 
|  |  | 
|  | transform_ext::StructuredOpMatcher *divOperand; | 
|  | matchdivBroadcast(matcherContext, expOperand, sum, divOperand); | 
|  |  | 
|  | auto &softmaxRoot = | 
|  | transform_ext::m_StructuredOp_Or(matcherContext, mulOperand, *divOperand); | 
|  | softmaxRootCapture = &softmaxRoot; | 
|  | } | 
|  |  | 
|  | /// Matcher for convolutions. | 
|  | void transform_ext::makeConvolutionMatcher( | 
|  | transform_ext::MatcherContext &matcherContext, | 
|  | transform_ext::StructuredOpMatcher *&convolutionCapture, | 
|  | transform_ext::StructuredOpMatcher *&fillCapture, | 
|  | transform_ext::StructuredOpMatcher *&trailingCapture, | 
|  | MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { | 
|  | // The core part of the matcher is anchored on a particular convolution op. | 
|  | auto &convolution = | 
|  | m_StructuredOp<linalg::Conv2DNchwFchwOp, linalg::Conv2DNhwcHwcfOp>( | 
|  | matcherContext) | 
|  | // Capture convolution dim classifications. | 
|  | .convolutionDims(CaptureConvDims(captures.convolutionDims)) | 
|  | // Capture op sizes. | 
|  | .dim(AllDims(), CaptureDims(captures.convolutionOpSizes)) | 
|  | // Capture convolution element types. | 
|  | .input(0, CaptureElementType(captures.inputElementType)) | 
|  | .input(1, CaptureElementType(captures.filterElementType)) | 
|  | .output(0, CaptureElementType(captures.outputElementType)); | 
|  | convolutionCapture = &convolution; | 
|  |  | 
|  | // Optional FillOp to create the unique output of the convolution. | 
|  | auto &fill = m_StructuredOp<linalg::FillOp>(matcherContext) | 
|  | .output(0, CaptureElementTypeBitWidth( | 
|  | captures.maybeFillElementalTypeBitWidth)); | 
|  | convolution = | 
|  | convolution.output(NumEqualsTo(1)).output(0, fill, OptionalMatch()); | 
|  | fillCapture = &fill; | 
|  |  | 
|  | // Optional trailing op can be any map, transpose, broadcast but | 
|  | // not reduce or windowing operation for now. | 
|  | // It must create the unique input for the reduction. | 
|  | auto &trailing = | 
|  | m_StructuredOp<linalg::GenericOp>(matcherContext) | 
|  | // All parallel dimensions. | 
|  | .dim(AllDims(), utils::IteratorType::parallel) | 
|  | // All inputs are any projected permutation. | 
|  | .input(AllOperands(), IsProjectedPermutation()) | 
|  | .output(AllOperands(), IsPermutation()) | 
|  | .output(NumEqualsTo(1)) | 
|  | .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) | 
|  | // Capture output elemental type. | 
|  | .output(0, CaptureElementTypeBitWidth( | 
|  | captures.maybeTrailingOutputElementalTypeBitWidth)); | 
|  |  | 
|  | // Optional trailing can be any map, transpose, broadcast but not reduce or | 
|  | // windowing operation for now. | 
|  | convolution = convolution.result(0, HasAnyUse(), trailing, OptionalMatch()); | 
|  | if (mustMatchEntireFunc) | 
|  | convolution = convolution.allTilableOpsCaptured<func::FuncOp>(); | 
|  | trailingCapture = &trailing; | 
|  | } | 
|  |  | 
|  | void transform_ext::makeConvolutionMatcher( | 
|  | transform_ext::MatcherContext &context, | 
|  | StructuredOpMatcher *&convolutionCapture, | 
|  | MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { | 
|  | StructuredOpMatcher *fill; | 
|  | StructuredOpMatcher *trailing; | 
|  | makeConvolutionMatcher(context, convolutionCapture, fill, trailing, captures, | 
|  | mustMatchEntireFunc); | 
|  | } | 
|  |  | 
|  | void transform_ext::makePadMatcher(MatcherContext &context, | 
|  | CapturingOpMatcher *&padCapture, | 
|  | MatchedPadCaptures &captures, | 
|  | bool mustMatchEntireFunc) { | 
|  | auto &value = transform_ext::m_ShapedValue(context); | 
|  | value.rank(transform_ext::CaptureRank(captures.rank)) | 
|  | .dim(transform_ext::AllDims(), transform_ext::CaptureDims(captures.dims)) | 
|  | .elementType(CaptureElementType(captures.elementType)); | 
|  | auto &opMatcher = transform_ext::m_tensorPad(context) | 
|  | .result(0, value) | 
|  | .low(AllDims(), 0) | 
|  | .yieldsExternalValue(); | 
|  | if (mustMatchEntireFunc) | 
|  | opMatcher = opMatcher.allTilableOpsCaptured<func::FuncOp>(); | 
|  | padCapture = &opMatcher; | 
|  | } |