| // Copyright 2022 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree-dialects/Transforms/TransformMatchers.h" |
| |
| #include "mlir/Analysis/SliceAnalysis.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/ScopeExit.h" |
| #include "llvm/Support/Debug.h" |
| |
| using namespace mlir; |
| |
| #define DEBUG_TYPE "transform-matchers" |
| #define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " |
| #define DBGSNL() llvm::dbgs() << "\n[" DEBUG_TYPE "] " |
| |
| //===---------------------------------------------------------------------===// |
| // CapturingMatcherBase |
| //===---------------------------------------------------------------------===// |
| |
| void transform_ext::CapturingMatcherBase::getAllNested( |
| SmallVectorImpl<CapturingOpMatcher *> &nested) { |
| |
| SetVector<CapturingOpMatcher *> found; |
| found.insert(nested.begin(), nested.end()); |
| int64_t start = found.size(); |
| |
| auto appendOne = [&found](CapturingMatcherBase &one) { |
| found.insert(one.nestedCapturingMatchers.begin(), |
| one.nestedCapturingMatchers.end()); |
| for (CapturingValueMatcher *valueMatcher : |
| one.nestedCapturingValueMatchers) { |
| found.insert(valueMatcher->nestedCapturingMatchers.begin(), |
| valueMatcher->nestedCapturingMatchers.end()); |
| } |
| }; |
| |
| appendOne(*this); |
| for (int64_t position = start; position < found.size(); ++position) { |
| appendOne(*found[position]); |
| } |
| |
| llvm::append_range(nested, found.getArrayRef()); |
| } |
| |
| void transform_ext::CapturingMatcherBase::getAllNestedValueMatchers( |
| SmallVectorImpl<CapturingValueMatcher *> &nested) { |
| |
| SetVector<CapturingValueMatcher *> found; |
| found.insert(nested.begin(), nested.end()); |
| int64_t start = found.size(); |
| |
| auto appendOne = [&found](CapturingMatcherBase &one) { |
| found.insert(one.nestedCapturingValueMatchers.begin(), |
| one.nestedCapturingValueMatchers.end()); |
| for (CapturingOpMatcher *opMatcher : one.nestedCapturingMatchers) { |
| found.insert(opMatcher->nestedCapturingValueMatchers.begin(), |
| opMatcher->nestedCapturingValueMatchers.end()); |
| } |
| }; |
| |
| appendOne(*this); |
| for (int64_t position = start; position < found.size(); ++position) { |
| appendOne(*found[position]); |
| } |
| |
| llvm::append_range(nested, found.getArrayRef()); |
| } |
| |
| void transform_ext::CapturingMatcherBase::resetCapture() { |
| SmallVector<CapturingOpMatcher *> nested; |
| getAllNested(nested); |
| for (CapturingOpMatcher *matcher : nested) { |
| matcher->captured = nullptr; |
| } |
| SmallVector<CapturingValueMatcher *> nestedValue; |
| getAllNestedValueMatchers(nestedValue); |
| for (CapturingValueMatcher *matcher : nestedValue) { |
| matcher->captured = nullptr; |
| } |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // CapturingOpMatcher |
| //===---------------------------------------------------------------------===// |
| |
| bool transform_ext::CapturingOpMatcher::checkAllTilableMatched( |
| Operation *parent, Operation *op, |
| ArrayRef<transform_ext::CapturingOpMatcher *> matchers) { |
| LLVM_DEBUG(DBGS() << "all tilable ops captured"); |
| int64_t numTilableOps = 0; |
| if (!parent) |
| return false; |
| parent->walk([&](TilingInterface Op) { ++numTilableOps; }); |
| |
| llvm::SmallPtrSet<Operation *, 6> matched; |
| for (CapturingOpMatcher *nested : matchers) { |
| if (Operation *captured = nested->getCaptured()) { |
| matched.insert(captured); |
| } |
| } |
| |
| // Don't forget to include the root matcher. |
| matched.insert(op); |
| return numTilableOps == matched.size(); |
| } |
| |
| bool transform_ext::CapturingOpMatcher::match(Operation *op) { |
| auto debugRAII = |
| llvm::make_scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); |
| LLVM_DEBUG(DBGS() << "matching: " << *op << "\n"); |
| |
| if (getCaptured()) { |
| LLVM_DEBUG(DBGS() << "found an already captured op: "); |
| if (getCaptured() == op) { |
| LLVM_DEBUG(llvm::dbgs() << "same\n"); |
| return true; |
| } else { |
| LLVM_DEBUG(llvm::dbgs() << "different\n"); |
| return false; |
| } |
| } |
| |
| if (!llvm::all_of(predicates, [op](const PredicateFn &fn) { |
| bool result = fn(op); |
| LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); |
| return result; |
| })) { |
| return false; |
| } |
| |
| captured = op; |
| return true; |
| } |
| |
| void transform_ext::CapturingOpMatcher::debugOutputForCreate( |
| ArrayRef<StringRef> opNames) { |
| LLVM_DEBUG(DBGS() << "operation type is one of {"; |
| llvm::interleaveComma(opNames, llvm::dbgs()); llvm::dbgs() << "}"); |
| } |
| |
| /// Apply the given matcher to the given object, produce debug messages. |
| template <typename Matcher, typename Object = typename llvm::function_traits< |
| typename Matcher::match>::template args<0>> |
| static bool recursiveMatch(Matcher &matcher, Object &object, |
| StringRef extraMessage = "") { |
| LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "] " << "start recursive match (" |
| << extraMessage << ") {\n"); |
| bool result = matcher.match(object); |
| LLVM_DEBUG(DBGS() << "} end recursive match"); |
| return result; |
| } |
| |
| transform_ext::CapturingOpMatcher & |
| transform_ext::CapturingOpMatcher::alternatives( |
| transform_ext::CapturingOpMatcher &first, |
| transform_ext::CapturingOpMatcher &second) { |
| addPredicate([&first, &second](Operation *op) { |
| LLVM_DEBUG(DBGS() << "matching alternatives\n"); |
| return recursiveMatch(first, op, "alternative 1") || |
| recursiveMatch(second, op, "alternative 2"); |
| }); |
| return *this; |
| } |
| |
| //---------------------------------------------------------------------------// |
| // Predicates for operands and results. |
| //---------------------------------------------------------------------------// |
| |
| transform_ext::CapturingOpMatcher & |
| transform_ext::CapturingOpMatcher::operand(transform_ext::NumEqualsTo num) { |
| addPredicate([=](Operation *op) { |
| LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " operands"); |
| return num.value == op->getNumOperands(); |
| }); |
| return *this; |
| } |
| |
| /// If `pos` is negative, returns the number of the operand in op starting from |
| /// the last. For example, -1 means the last operand, -2 means the |
| /// second-to-last, etc. Returns nullopt if pos is out-of-bounds, both positive |
| /// and negative. |
| static std::optional<int64_t> remapNegativeOperandNumber(int64_t pos, |
| Operation *op) { |
| int64_t updated = pos < 0 ? op->getNumOperands() + pos : pos; |
| if (updated < 0 || updated >= op->getNumOperands()) { |
| LLVM_DEBUG(DBGS() << "match operand #" << pos |
| << "that does not exist in the operation"); |
| return std::nullopt; |
| } |
| return updated; |
| } |
| |
| transform_ext::CapturingOpMatcher & |
| transform_ext::CapturingOpMatcher::operand(int64_t pos, |
| CapturingOpMatcher &nested) { |
| addPredicate([pos, &nested](Operation *op) { |
| std::optional<int64_t> operandNo = remapNegativeOperandNumber(pos, op); |
| if (!operandNo) |
| return false; |
| LLVM_DEBUG(DBGS() << "operand #" << pos << " is defined by an operation"); |
| Operation *definingOp = op->getOperand(*operandNo).getDefiningOp(); |
| if (!definingOp) |
| return false; |
| return recursiveMatch(nested, definingOp); |
| }); |
| recordNestedMatcher(nested); |
| return *this; |
| } |
| |
| transform_ext::CapturingOpMatcher & |
| transform_ext::CapturingOpMatcher::operand(int64_t pos, |
| CapturingValueMatcher &nested) { |
| addPredicate([pos, &nested](Operation *op) { |
| std::optional<int64_t> operandNo = remapNegativeOperandNumber(pos, op); |
| if (!operandNo) |
| return false; |
| LLVM_DEBUG(DBGS() << "operand #" << pos << " is"); |
| Value operand = op->getOperand(*operandNo); |
| return recursiveMatch(nested, operand); |
| }); |
| recordNestedMatcher(nested); |
| return *this; |
| } |
| |
| transform_ext::CapturingOpMatcher &transform_ext::CapturingOpMatcher::operand( |
| int64_t position, std::function<bool(llvm::APFloat)> floatValueFn) { |
| addPredicate([position, |
| floatValueFn = std::move(floatValueFn)](Operation *op) -> bool { |
| std::optional<int64_t> operandNo = remapNegativeOperandNumber(position, op); |
| if (!operandNo) |
| return false; |
| |
| LLVM_DEBUG(DBGS() << "operand #" << *operandNo |
| << " is a special floating point constant"); |
| auto cstOp = |
| op->getOperand(*operandNo).getDefiningOp<arith::ConstantFloatOp>(); |
| if (!cstOp) |
| return false; |
| return floatValueFn(cstOp.value()); |
| }); |
| |
| return *this; |
| } |
| |
| transform_ext::CapturingOpMatcher & |
| transform_ext::CapturingOpMatcher::operand(int64_t position, ConstantFloatOne) { |
| return operand(position, |
| [](llvm::APFloat value) { return value.isExactlyValue(1.0); }); |
| } |
| |
| transform_ext::CapturingOpMatcher & |
| transform_ext::CapturingOpMatcher::result(transform_ext::NumEqualsTo num) { |
| addPredicate([=](Operation *op) { |
| LLVM_DEBUG(DBGS() << "operation has exactly " << num.value << " results"); |
| return num.value == op->getNumResults(); |
| }); |
| return *this; |
| } |
| |
| transform_ext::CapturingOpMatcher & |
| transform_ext::CapturingOpMatcher::result(int64_t pos, |
| CapturingValueMatcher &nested) { |
| addPredicate([pos, &nested](Operation *op) { |
| int64_t updated = pos < 0 ? op->getNumResults() + pos : pos; |
| if (updated < 0 || updated >= op->getNumResults()) { |
| LLVM_DEBUG(DBGS() << "matching result #" << pos |
| << " that does not exist in the operation"); |
| return false; |
| } |
| LLVM_DEBUG(DBGS() << "result #" << pos << " is"); |
| Value result = op->getResult(updated); |
| return recursiveMatch(nested, result); |
| }); |
| recordNestedMatcher(nested); |
| return *this; |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // CapturingValueMatcher |
| //===---------------------------------------------------------------------===// |
| |
| namespace { |
| struct DebugPrintValueWrapper { |
| Value value; |
| }; |
| |
| llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
| const DebugPrintValueWrapper &wrapper) { |
| if (auto opResult = dyn_cast<OpResult>(wrapper.value)) { |
| return os << "op result #" << opResult.getResultNumber() << " in " |
| << wrapper.value; |
| } |
| |
| auto blockArg = cast<BlockArgument>(wrapper.value); |
| os << "block argument #" << blockArg.getArgNumber(); |
| Block *parentBlock = blockArg.getParentBlock(); |
| Region *parentRegion = parentBlock->getParent(); |
| if (!parentRegion) { |
| os << " of a detached block:\n"; |
| parentBlock->print(os); |
| return os; |
| } |
| |
| os << " of block #" |
| << std::distance(parentRegion->begin(), parentBlock->getIterator()); |
| Operation *parentOp = parentRegion->getParentOp(); |
| if (!parentOp) { |
| os << " of a detached region:\n"; |
| for (Block &b : *parentRegion) |
| b.print(os); |
| return os; |
| } |
| |
| os << " in region #" << parentRegion->getRegionNumber() << " of " |
| << *parentOp; |
| return os; |
| } |
| } // namespace |
| |
| bool transform_ext::CapturingValueMatcher::match(Value value) { |
| auto debugRAII = |
| llvm::make_scope_exit([] { LLVM_DEBUG(DBGS() << "-------\n"); }); |
| LLVM_DEBUG(DBGS() << "matching " << DebugPrintValueWrapper{value} << "\n"); |
| |
| if (getCaptured()) { |
| LLVM_DEBUG(DBGS() << "found an already captured value: "); |
| if (getCaptured() == value) { |
| LLVM_DEBUG(llvm::dbgs() << "same\n"); |
| return true; |
| } else { |
| LLVM_DEBUG(llvm::dbgs() << "different\n"); |
| return false; |
| } |
| } |
| |
| for (const PredicateFn &fn : predicates) { |
| bool result = fn(value); |
| LLVM_DEBUG(llvm::dbgs() << ": " << result << "\n"); |
| if (!result) |
| return false; |
| } |
| |
| captured = value; |
| return true; |
| } |
| |
| transform_ext::ShapedValueMatcher::ShapedValueMatcher() |
| : CapturingValueMatcher() { |
| addPredicate([](Value value) { |
| LLVM_DEBUG(DBGS() << "value is of shaped type"); |
| return value && isa<ShapedType>(value.getType()); |
| }); |
| } |
| |
| transform_ext::ShapedValueMatcher & |
| transform_ext::ShapedValueMatcher::rank(transform_ext::CaptureRank capture) { |
| addPredicate([=](Value value) { |
| LLVM_DEBUG(DBGS() << "capturing shaped value rank"); |
| capture.value = cast<ShapedType>(value.getType()).getRank(); |
| return true; |
| }); |
| return *this; |
| } |
| |
| transform_ext::ShapedValueMatcher & |
| transform_ext::ShapedValueMatcher::dim(int64_t dimension, CaptureDim capture) { |
| addPredicate([=](Value value) { |
| LLVM_DEBUG(DBGS() << "capturing shaped value dimension " << dimension); |
| capture.value = cast<ShapedType>(value.getType()).getDimSize(dimension); |
| return true; |
| }); |
| return *this; |
| } |
| |
| transform_ext::ShapedValueMatcher & |
| transform_ext::ShapedValueMatcher::dim(AllDims tag, CaptureDims captures) { |
| (void)tag; |
| addPredicate([=](Value value) { |
| LLVM_DEBUG(DBGS() << "capturing all shaped value dimensions"); |
| ArrayRef<int64_t> shape = cast<ShapedType>(value.getType()).getShape(); |
| captures.value.assign(shape.begin(), shape.end()); |
| return true; |
| }); |
| return *this; |
| } |
| |
| transform_ext::ShapedValueMatcher & |
| transform_ext::ShapedValueMatcher::elementType(CaptureElementType captures) { |
| addPredicate([=](Value value) { |
| LLVM_DEBUG(DBGS() << "capturing elementType"); |
| captures.value = cast<ShapedType>(value.getType()).getElementType(); |
| return true; |
| }); |
| return *this; |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // Constraints on op rank and dims. |
| //===---------------------------------------------------------------------===// |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::rank(NumGreaterEqualTo minRank) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "rank >= " << minRank.value); |
| return linalgOp.getNumLoops() >= minRank.value; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::rank(NumLowerEqualTo maxRank) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "rank <= " << maxRank.value); |
| return linalgOp.getNumLoops() <= maxRank.value; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::rank(NumEqualsTo exactRank) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "rank == " << exactRank.value); |
| return linalgOp.getNumLoops() == exactRank.value; |
| }); |
| } |
| |
| StringRef stringifyShapeKind(transform_ext::ShapeKind kind) { |
| switch (kind) { |
| case transform_ext::ShapeKind::Static: |
| return "static"; |
| case transform_ext::ShapeKind::Dynamic: |
| return "dynamic"; |
| } |
| llvm_unreachable("unhandled shape kind"); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::dim(SmallVector<int64_t> &&dimensions, |
| ShapeKind kind) { |
| return addPredicate([dimensions = std::move(dimensions), |
| kind](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "dimensions ["; |
| llvm::interleaveComma(dimensions, llvm::dbgs()); |
| llvm::dbgs() << "] are " << stringifyShapeKind(kind)); |
| SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges(); |
| for (auto dimension : dimensions) { |
| int64_t transformedDimension = |
| dimension >= 0 ? dimension : shape.size() + dimension; |
| if (transformedDimension < 0 || transformedDimension >= shape.size()) |
| return false; |
| if (ShapedType::isDynamic(shape[transformedDimension]) ^ |
| (kind == ShapeKind::Static)) |
| continue; |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::dim(AllDims tag, ShapeKind kind) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "all dimensions are " << stringifyShapeKind(kind)); |
| SmallVector<int64_t> shape = linalgOp.getStaticLoopRanges(); |
| return llvm::all_of(shape, [=](int64_t dimension) { |
| return ShapedType::isDynamic(dimension) ^ (kind == ShapeKind::Static); |
| }); |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::dim(SmallVector<int64_t> &&dimensions, |
| utils::IteratorType kind) { |
| return addPredicate([dimensions = std::move(dimensions), |
| kind](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "dimensions ["; |
| llvm::interleaveComma(dimensions, llvm::dbgs()); |
| llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); |
| int64_t rank = linalgOp.getNumLoops(); |
| for (auto dimension : dimensions) { |
| int64_t transformedDimension = |
| dimension >= 0 ? dimension : rank + dimension; |
| if (transformedDimension < 0 || transformedDimension >= rank) |
| return false; |
| utils::IteratorType iteratorKind = |
| linalgOp.getIteratorTypesArray()[transformedDimension]; |
| if (iteratorKind == kind) |
| continue; |
| return false; |
| } |
| return true; |
| }); |
| } |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::dim(AllDims tag, utils::IteratorType kind) { |
| return dim(AllDimsExcept({}), kind); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::dim(AllDimsExcept &&dims, |
| utils::IteratorType kind) { |
| return addPredicate([dimensions = std::move(dims), |
| kind](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "all dimensions except ["; |
| llvm::interleaveComma(dimensions.getExcluded(), llvm::dbgs()); |
| llvm::dbgs() << "] are " << utils::stringifyIteratorType(kind)); |
| int64_t rank = linalgOp.getNumLoops(); |
| llvm::SmallDenseSet<int64_t> excludedDims; |
| for (int64_t dim : dimensions.getExcluded()) { |
| excludedDims.insert(dim >= 0 ? dim : rank + dim); |
| } |
| |
| for (auto [index, type] : |
| llvm::enumerate(linalgOp.getIteratorTypesArray())) { |
| if (excludedDims.contains(index)) |
| continue; |
| if (type == kind) |
| continue; |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::dim(int64_t dimension, |
| DivisibleBy divisibleBy) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "dimension " << dimension << " is divisible by " |
| << divisibleBy.value); |
| int64_t rank = linalgOp.getNumLoops(); |
| int64_t transformedDimension = |
| dimension >= 0 ? dimension : rank + dimension; |
| if (transformedDimension >= rank) |
| return false; |
| |
| int64_t size = linalgOp.getStaticLoopRanges()[transformedDimension]; |
| return !ShapedType::isDynamic(size) && (size % divisibleBy.value == 0); |
| }); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // Capture directives. |
| //===---------------------------------------------------------------------===// |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::rank(CaptureRank capture) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "capture rank"); |
| capture.value = linalgOp.getNumLoops(); |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::dim(int64_t dimension, CaptureDim capture) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "capture dimension"); |
| int64_t rank = linalgOp.getNumLoops(); |
| int64_t transformedDimension = |
| dimension >= 0 ? dimension : rank + dimension; |
| if (transformedDimension >= rank) |
| return false; |
| |
| capture.value = linalgOp.getStaticLoopRanges()[transformedDimension]; |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::dim(AllDims tag, CaptureDims captures) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "capture all dimensions"); |
| captures.value = linalgOp.getStaticLoopRanges(); |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::indexingMaps( |
| CaptureIndexingMaps indexingMaps) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "capture indexing maps"); |
| indexingMaps.value = linalgOp.getIndexingMapsArray(); |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::contractionDims( |
| CaptureContractionDims contractionDims) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "capture contraction dimensions"); |
| StringRef convMessage = linalg::detail::getMatchContractionMessage( |
| mlir::linalg::detail::isContractionInterfaceImpl( |
| linalgOp, &contractionDims.value)); |
| if (convMessage.empty()) |
| return true; |
| LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); |
| return false; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "capture convolution dimensions"); |
| StringRef convMessage = linalg::detail::getMatchConvolutionMessage( |
| mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp, |
| &convDims.value)); |
| if (convMessage.empty()) |
| return true; |
| LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")"); |
| return false; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher::StructuredOpMatcher( |
| StructuredOpMatcher &A, StructuredOpMatcher &B) { |
| |
| addPredicate([&A, &B](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "start recursive lhs OR match {\n"); |
| { |
| auto debugRAII = llvm::make_scope_exit( |
| [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); |
| if (A.match(linalgOp)) |
| return true; |
| } |
| LLVM_DEBUG(DBGS() << "start recursive rhs OR match {\n"); |
| { |
| auto debugRAII = llvm::make_scope_exit( |
| [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); |
| if (B.match(linalgOp)) |
| return true; |
| } |
| return false; |
| }); |
| recordNestedMatcher(A); |
| recordNestedMatcher(B); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // Constraints on input operands. |
| //===---------------------------------------------------------------------===// |
| |
| void transform_ext::StructuredOpMatcher::addInputMatcher( |
| int64_t position, std::function<bool(Operation *)> matcher, |
| OptionalMatch optional) { |
| addInputMatcher( |
| position, |
| // No need to handle optional inside the lambda, the wrapper will do that. |
| [matcher = std::move(matcher)](Value value) { |
| Operation *definingOp = value.getDefiningOp(); |
| return definingOp && matcher(definingOp); |
| }, |
| optional); |
| } |
| |
| void transform_ext::StructuredOpMatcher::addInputMatcher( |
| int64_t position, std::function<bool(Value)> matcher, |
| OptionalMatch optional) { |
| addPredicate([position, optional, matcher = std::move(matcher)]( |
| linalg::LinalgOp linalgOp) -> bool { |
| int64_t transformedPosition = |
| position >= 0 ? position : linalgOp.getNumDpsInputs() + position; |
| if (transformedPosition >= linalgOp.getNumDpsInputs()) { |
| LLVM_DEBUG(DBGS() << "input operand #" << position |
| << " does not exist but match required"); |
| return false; |
| } |
| |
| LLVM_DEBUG(DBGS() << "input operand #" << position |
| << (optional.value ? " (optional match) " : " ") |
| << "is\n"); |
| |
| // We MUST run the matcher at this point, even if the match is optional, |
| // to allow for capture. |
| LLVM_DEBUG(DBGS() << "start recursive match {\n"); |
| auto debugRAII = llvm::make_scope_exit( |
| [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); |
| if (matcher(linalgOp.getDpsInputOperand(transformedPosition)->get())) |
| return true; |
| return optional.value; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(AllOperands tag, IsPermutation) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "all input operands have permutation maps"); |
| // all_of with a lambda requires const-casting dance, so using a loop. |
| for (OpOperand *operand : linalgOp.getDpsInputOperands()) { |
| if (!linalgOp.getMatchingIndexingMap(operand).isPermutation()) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(AllOperands tag, |
| IsProjectedPermutation) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "all input operands have projected permutation maps"); |
| // all_of with a lambda requires const-casting dance, so using a loop. |
| for (OpOperand *operand : linalgOp.getDpsInputOperands()) { |
| if (!linalgOp.getMatchingIndexingMap(operand).isProjectedPermutation()) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| /// Helper to check if the map is an identity map with a projected dim. |
| static bool isProjectedMap(AffineMap map, int64_t projectedDim) { |
| if (!map.isProjectedPermutation()) |
| return false; |
| int64_t dimCounter = 0; |
| for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { |
| // Skip the project dim. |
| if (dimCounter == projectedDim) |
| dimCounter++; |
| if (map.getDimPosition(i) != dimCounter++) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| /// Helper to turn a potentially negative index to positive within the range |
| /// [0, ub) and indicate whether the transformed index is in bounds. |
| static bool makeValidPositiveIndex(int64_t &index, int64_t ub) { |
| int64_t positiveIndex = index >= 0 ? index : ub + index; |
| if (positiveIndex < 0 || ub < positiveIndex) { |
| LLVM_DEBUG(DBGSNL() << " index out of range"); |
| return false; |
| } |
| index = positiveIndex; |
| return true; |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions, |
| IsProjected dim) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "operands "; |
| llvm::interleaveComma(positions, llvm::dbgs()); |
| llvm::dbgs() << " have a permutation maps with " << dim.value |
| << " projected"); |
| int64_t updatedDim = dim.value; |
| if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) |
| return false; |
| for (int64_t position : positions) { |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) |
| return false; |
| OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); |
| if (!isProjectedMap(linalgOp.getMatchingIndexingMap(operand), updatedDim)) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(AllOperands tag, IsIdentity) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "all input operands have identity maps"); |
| // all_of with a lambda requires const-casting dance, so using a loop. |
| for (OpOperand *operand : linalgOp.getDpsInputOperands()) { |
| if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(SmallVector<int64_t> &&positions, |
| IsIdentity) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "input operands "; |
| llvm::interleaveComma(positions, llvm::dbgs()); |
| llvm::dbgs() << " have identity maps"); |
| // all_of with a lambda requires const-casting dance, so using a loop. |
| for (int64_t position : positions) { |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) |
| return false; |
| OpOperand *operand = linalgOp.getDpsInputOperand(updatedPosition); |
| if (!linalgOp.getMatchingIndexingMap(operand).isIdentity()) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(int64_t position, |
| ElementTypeBitWidth width) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "input operand #" << position |
| << " has elemental type with bit width " << width.value); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) |
| return false; |
| auto shapedType = dyn_cast<ShapedType>( |
| linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); |
| return shapedType && shapedType.getElementType().isIntOrFloat() && |
| shapedType.getElementType().getIntOrFloatBitWidth() == width.value; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(int64_t position, |
| CaptureElementTypeBitWidth width) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "input operand #" << position << " capture bitwidth"); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) |
| return false; |
| auto shapedType = dyn_cast<ShapedType>( |
| linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); |
| if (!shapedType || !shapedType.getElementType().isIntOrFloat()) |
| return false; |
| width.value = shapedType.getElementType().getIntOrFloatBitWidth(); |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(int64_t position, |
| CaptureElementType elem) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "input operand #" << position |
| << " capture element type"); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) |
| return false; |
| auto shapedType = dyn_cast<ShapedType>( |
| linalgOp.getDpsInputOperand(updatedPosition)->get().getType()); |
| if (!shapedType) { |
| LLVM_DEBUG(DBGSNL() << " not a shaped type"); |
| return false; |
| } |
| elem.value = shapedType.getElementType(); |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(NumEqualsTo num) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "number of input operands == " << num.value); |
| return linalgOp.getNumDpsInputs() == num.value; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(int64_t position, |
| ConstantFloatMinOrMinusInf) { |
| return input(position, [](llvm::APFloat f) { |
| return (f.isLargest() || f.isInfinity()) && f.isNegative(); |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::input(int64_t position, ConstantFloatZero) { |
| return input(position, [](llvm::APFloat f) { return f.isZero(); }); |
| } |
| |
| transform_ext::StructuredOpMatcher &transform_ext::StructuredOpMatcher::input( |
| int64_t position, std::function<bool(llvm::APFloat)> floatValueFn) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "input operand #" << position |
| << " is a special floating point constant"); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInputs())) |
| return false; |
| auto cstOp = linalgOp.getDpsInputOperand(updatedPosition) |
| ->get() |
| .getDefiningOp<arith::ConstantFloatOp>(); |
| if (!cstOp) |
| return false; |
| return floatValueFn(cstOp.value()); |
| }); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // Constraints on output operands. |
| //===---------------------------------------------------------------------===// |
| |
| void transform_ext::StructuredOpMatcher::addOutputMatcher( |
| int64_t position, std::function<bool(Operation *)> matcher, |
| OptionalMatch optional) { |
| addPredicate([position, optional, matcher = std::move(matcher)]( |
| linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "output operand #" << position |
| << (optional.value ? " (optional match) " |
| : " (mandatory match) ") |
| << "is produced by\n"); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) |
| return false; |
| Operation *definingOp = |
| linalgOp.getDpsInitOperand(updatedPosition)->get().getDefiningOp(); |
| if (!definingOp) |
| return optional.value; |
| // We MUST run the matcher at this point, even if the match is optional, |
| // to allow for capture. |
| LLVM_DEBUG(DBGS() << "start recursive match {\n"); |
| auto debugRAII = llvm::make_scope_exit( |
| [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); |
| if (matcher(definingOp)) |
| return true; |
| return optional.value; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::output(AllOperands tag, IsPermutation) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "all output operands have permutation maps"); |
| for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { |
| if (!linalgOp.getMatchingIndexingMap(&operand).isPermutation()) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::output(AllOperands tag, |
| IsProjectedPermutation) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "all output operands have projected permutation maps"); |
| for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { |
| if (!linalgOp.getMatchingIndexingMap(&operand).isProjectedPermutation()) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::output(AllOperands tag, IsProjected dim) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "all output operands have a maps with projected"); |
| int64_t updatedDim = dim.value; |
| if (!makeValidPositiveIndex(updatedDim, linalgOp.getNumLoops())) |
| return false; |
| // all_of with a lambda requires const-casting dance, so using a loop. |
| for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { |
| if (!isProjectedMap(linalgOp.getMatchingIndexingMap(&operand), |
| updatedDim)) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::output(AllOperands tag, IsIdentity) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "all output operands have identity permutation maps"); |
| for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { |
| if (!linalgOp.getMatchingIndexingMap(&operand).isIdentity()) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::output(int64_t position, |
| ElementTypeBitWidth width) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "output operand #" << position |
| << " has elemental type with bit width " << width.value); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) |
| return false; |
| auto shapedType = dyn_cast<ShapedType>( |
| linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); |
| return shapedType && shapedType.getElementType().isIntOrFloat() && |
| shapedType.getElementType().getIntOrFloatBitWidth() == width.value; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::output(int64_t position, |
| CaptureElementTypeBitWidth width) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "output operand #" << position << " capture bitwidth"); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) |
| return false; |
| auto shapedType = dyn_cast<ShapedType>( |
| linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); |
| if (!shapedType || !shapedType.getElementType().isIntOrFloat()) { |
| LLVM_DEBUG(DBGSNL() << " could not infer element type"); |
| return false; |
| } |
| width.value = shapedType.getElementType().getIntOrFloatBitWidth(); |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::output(int64_t position, |
| CaptureElementType elem) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "output operand #" << position |
| << " capture element type"); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) |
| return false; |
| auto shapedType = dyn_cast<ShapedType>( |
| linalgOp.getDpsInitOperand(updatedPosition)->get().getType()); |
| if (!shapedType) { |
| LLVM_DEBUG(DBGSNL() << " not a shaped type"); |
| return false; |
| } |
| elem.value = shapedType.getElementType(); |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::output(int64_t position, |
| SingleCombinerReduction tag) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "output operand #" << position |
| << " is populated by a single-combiner reduction"); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp.getNumDpsInits())) |
| return false; |
| SmallVector<Operation *> combinerOps; |
| return matchReduction(linalgOp.getRegionOutputArgs(), updatedPosition, |
| combinerOps) && |
| llvm::hasSingleElement(combinerOps); |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::output(NumEqualsTo num) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "number of output operands == " << num.value); |
| return linalgOp.getNumDpsInits() == num.value; |
| }); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // Constraints on results. |
| //===---------------------------------------------------------------------===// |
| |
| void transform_ext::StructuredOpMatcher::addResultMatcher( |
| int64_t position, HasAnyUse tag, std::function<bool(Operation *)> matcher, |
| OptionalMatch optional) { |
| addPredicate([matcher = std::move(matcher), optional, |
| position](linalg::LinalgOp linalgOp) -> bool { |
| LLVM_DEBUG(DBGS() << "result #" << position |
| << (optional.value ? " (optional match) " |
| : " (mandatory match) ") |
| << "has a use\n"); |
| int64_t updatedPosition = position; |
| if (!makeValidPositiveIndex(updatedPosition, linalgOp->getNumResults())) |
| return false; |
| |
| // We MUST run the matcher at this point, even if the match is optional, |
| // to allow for capture. |
| LLVM_DEBUG(DBGS() << "start recursive match {\n"); |
| auto debugRAII = llvm::make_scope_exit( |
| [] { LLVM_DEBUG(DBGS() << "} end recursive match"); }); |
| if (llvm::any_of(linalgOp->getResult(updatedPosition).getUsers(), |
| [&matcher](Operation *op) { return matcher(op); })) { |
| return true; |
| } |
| return optional.value; |
| }); |
| } |
| |
| //===-------------------------------------------------------------------===// |
| // Constraints on op region. |
| //===-------------------------------------------------------------------===// |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::singleOpWithCanonicaleArgs( |
| StringRef opcode, bool commutative) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) { |
| if (linalgOp.getBlock()->getOperations().size() != 2) |
| return false; |
| Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); |
| if (innerOp->getName().getStringRef() != opcode || |
| innerOp->getNumResults() != 1) |
| return false; |
| Operation *yieldOp = linalgOp.getBlock()->getTerminator(); |
| if (yieldOp->getNumOperands() != 1) |
| return false; |
| if (yieldOp->getOperand(0).getDefiningOp() != innerOp) |
| return false; |
| if (commutative && innerOp->getNumOperands() == 2) { |
| auto arg0 = dyn_cast<BlockArgument>(innerOp->getOperand(0)); |
| auto arg1 = dyn_cast<BlockArgument>(innerOp->getOperand(1)); |
| if (!arg0 || !arg1) |
| return false; |
| if (arg0.getParentBlock() != linalgOp.getBlock() || |
| arg1.getParentBlock() != linalgOp.getBlock()) |
| return false; |
| if (!((arg0.getArgNumber() == 0 && arg1.getArgNumber() == 1) || |
| (arg1.getArgNumber() == 0 && arg0.getArgNumber() == 1))) |
| return false; |
| } else { |
| for (auto [index, operand] : llvm::enumerate(innerOp->getOperands())) { |
| auto arg = dyn_cast<BlockArgument>(operand); |
| if (!arg || arg.getParentBlock() != linalgOp.getBlock() || |
| arg.getArgNumber() != index) |
| return false; |
| } |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::isFloatReciprocal() { |
| return addPredicate([=](linalg::LinalgOp linalgOp) { |
| LLVM_DEBUG(DBGS() << "op region represents a reciprocal operation"); |
| if (linalgOp.getBlock()->getOperations().size() != 2) |
| return false; |
| Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); |
| if (!isa<arith::DivFOp>(innerOp) || innerOp->getNumResults() != 1) |
| return false; |
| Operation *yieldOp = linalgOp.getBlock()->getTerminator(); |
| if (yieldOp->getNumOperands() != 1) |
| return false; |
| if (yieldOp->getOperand(0).getDefiningOp() != innerOp) |
| return false; |
| auto cst = innerOp->getOperand(0).getDefiningOp<arith::ConstantFloatOp>(); |
| if (!cst || cst.value().convertToDouble() != 1.0) |
| return false; |
| auto arg = dyn_cast<BlockArgument>(innerOp->getOperand(1)); |
| if (!arg || arg.getParentBlock() != linalgOp.getBlock() || |
| arg.getArgNumber() != 0) |
| return false; |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::passThroughOp() { |
| return addPredicate([=](linalg::LinalgOp linalgOp) { |
| if (linalgOp.getBlock()->getOperations().size() != 1) |
| return false; |
| Operation *yieldOp = linalgOp.getBlock()->getTerminator(); |
| for (auto [index, operand] : llvm::enumerate(yieldOp->getOperands())) { |
| auto arg = dyn_cast<BlockArgument>(operand); |
| if (!arg || arg.getParentBlock() != linalgOp.getBlock() || |
| arg.getArgNumber() != index) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::StructuredOpMatcher & |
| transform_ext::StructuredOpMatcher::hasContractionBody( |
| function_ref<bool(Operation *)> isaElemOpTy, |
| function_ref<bool(Operation *)> isaReductionOpTy, StringRef elemOpName, |
| StringRef reductionOpName) { |
| return addPredicate([=](linalg::LinalgOp linalgOp) { |
| LLVM_DEBUG(DBGS() << "op region is a " << elemOpName << "/" |
| << reductionOpName << " contraction ("); |
| auto scopeExitPrinter = llvm::make_scope_exit( |
| [] { LLVM_DEBUG(llvm::dbgs() << " check failed)"); }); |
| |
| Block *body = linalgOp.getBlock(); |
| if (!llvm::hasNItems(*body, 3)) { |
| LLVM_DEBUG(llvm::dbgs() << "three-operation body"); |
| return false; |
| } |
| if (body->getNumArguments() != 3) { |
| LLVM_DEBUG(llvm::dbgs() << "three-argument block"); |
| return false; |
| } |
| |
| Operation *elemOp = &(*linalgOp.getBlock()->getOperations().begin()); |
| Operation *reductionOp = elemOp->getNextNode(); |
| Operation *yieldOp = reductionOp->getNextNode(); |
| if (!isaElemOpTy(elemOp)) { |
| LLVM_DEBUG(llvm::dbgs() << "first operation is a " << elemOpName); |
| return false; |
| } |
| if (!isaReductionOpTy(reductionOp)) { |
| LLVM_DEBUG(llvm::dbgs() << "second operation is a " << reductionOpName); |
| return false; |
| } |
| if (yieldOp->getNumOperands() != 1) { |
| LLVM_DEBUG(llvm::dbgs() << "one value yielded"); |
| return false; |
| } |
| if (yieldOp->getOperand(0).getDefiningOp() != reductionOp) { |
| LLVM_DEBUG(llvm::dbgs() << "yielded value produced by the second op"); |
| return false; |
| } |
| if (elemOp->getNumOperands() != 2 || elemOp->getNumResults() != 1) { |
| LLVM_DEBUG(llvm::dbgs() << "first op has two operands and one result"); |
| return false; |
| } |
| if (reductionOp->getNumOperands() != 2 || |
| reductionOp->getNumResults() != 1) { |
| LLVM_DEBUG(llvm::dbgs() << "second op has two operands and one result"); |
| return false; |
| } |
| |
| SmallVector<Value, 2> expectedReductionOperands = {body->getArgument(2), |
| elemOp->getResult(0)}; |
| if (!llvm::equal(expectedReductionOperands, reductionOp->getOperands()) && |
| !llvm::equal(llvm::reverse(expectedReductionOperands), |
| reductionOp->getOperands())) { |
| LLVM_DEBUG(llvm::dbgs() << "operands of the second op"); |
| return false; |
| } |
| |
| ValueRange expectedElemOperands = body->getArguments().take_front(2); |
| if (!llvm::equal(expectedElemOperands, elemOp->getOperands()) && |
| !llvm::equal(llvm::reverse(expectedElemOperands), |
| elemOp->getOperands())) { |
| LLVM_DEBUG(llvm::dbgs() << "operands of the first op"); |
| return false; |
| } |
| |
| scopeExitPrinter.release(); |
| LLVM_DEBUG(llvm::dbgs() << "success)"); |
| return true; |
| }); |
| } |
| |
| void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor( |
| StringRef name) { |
| LLVM_DEBUG(DBGS() << "op is a " << name << "'"); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // TensorPadOpMatcher |
| //===---------------------------------------------------------------------===// |
| |
| transform_ext::TensorPadOpMatcher & |
| transform_ext::TensorPadOpMatcher::low(ArrayRef<int64_t> sizes) { |
| return addPredicate([=](tensor::PadOp tensorPad) { |
| LLVM_DEBUG({ |
| DBGS() << "low pad sizes are "; |
| llvm::interleaveComma(sizes, llvm::dbgs()); |
| }); |
| for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedLowPad(), sizes)) { |
| if (isConstantIntValue(ofr, sz)) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::TensorPadOpMatcher & |
| transform_ext::TensorPadOpMatcher::low(AllDims tag, int64_t size) { |
| return addPredicate([=](tensor::PadOp tensorPad) { |
| LLVM_DEBUG(DBGS() << "all low pad sizes are " << size); |
| return llvm::all_of(tensorPad.getMixedLowPad(), [&](OpFoldResult ofr) { |
| return isConstantIntValue(ofr, size); |
| }); |
| }); |
| } |
| |
| transform_ext::TensorPadOpMatcher & |
| transform_ext::TensorPadOpMatcher::high(ArrayRef<int64_t> sizes) { |
| return addPredicate([=](tensor::PadOp tensorPad) { |
| LLVM_DEBUG({ |
| DBGS() << "high pad sizes are "; |
| llvm::interleaveComma(sizes, llvm::dbgs()); |
| }); |
| for (auto [ofr, sz] : llvm::zip(tensorPad.getMixedHighPad(), sizes)) { |
| if (isConstantIntValue(ofr, sz)) |
| return false; |
| } |
| return true; |
| }); |
| } |
| |
| transform_ext::TensorPadOpMatcher & |
| transform_ext::TensorPadOpMatcher::high(AllDims tag, int64_t size) { |
| return addPredicate([=](tensor::PadOp tensorPad) { |
| LLVM_DEBUG(DBGS() << "all high pad sizes are " << size); |
| return llvm::all_of(tensorPad.getMixedHighPad(), [&](OpFoldResult ofr) { |
| return isConstantIntValue(ofr, size); |
| }); |
| }); |
| } |
| |
| transform_ext::TensorPadOpMatcher & |
| transform_ext::TensorPadOpMatcher::yieldsExternalValue() { |
| return addPredicate([=](tensor::PadOp tensorPad) { |
| LLVM_DEBUG(DBGS() << "pad body yields an externally-defined value"); |
| Block *body = tensorPad.getBody(); |
| if (!llvm::hasSingleElement(*body)) |
| return false; |
| return llvm::all_of(body->getTerminator()->getOperands(), |
| [body](Value operand) { |
| auto arg = dyn_cast<BlockArgument>(operand); |
| return !arg || arg.getOwner() != body; |
| }); |
| }); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // MatchCallbackResult. |
| //===---------------------------------------------------------------------===// |
| |
| ArrayRef<Operation *> |
| transform_ext::MatchCallbackResult::getPayloadGroup(int64_t position) const { |
| assert(position < payloadGroupLengths.size()); |
| int64_t start = 0; |
| for (int64_t i = 0; i < position; ++i) { |
| start += payloadGroupLengths[i]; |
| } |
| return llvm::ArrayRef(payloadOperations) |
| .slice(start, payloadGroupLengths[position]); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // Case-specific matcher builders. |
| //===---------------------------------------------------------------------===// |
| |
| static constexpr int64_t kCudaWarpSize = 32; |
| |
| void transform_ext::makeReductionMatcher( |
| transform_ext::MatcherContext &matcherContext, |
| transform_ext::StructuredOpMatcher *&reductionCapture, |
| transform_ext::StructuredOpMatcher *&fillCapture, |
| transform_ext::StructuredOpMatcher *&leadingCapture, |
| transform_ext::StructuredOpMatcher *&trailingCapture, |
| MatchedReductionCaptures &captures, bool mustMatchEntireFunc) { |
| // The core part of the matcher is anchored on a particular reduction op. |
| auto &reduction = |
| m_StructuredOp(matcherContext) |
| // Op has at least a parallel a reduction dimension and at |
| // most 3 parallel dimensions. |
| // TODO: relax once we have global collapse/expand_shape. |
| // |
| .rank(NumGreaterEqualTo(2)) |
| .rank(NumLowerEqualTo(4)) |
| .rank(CaptureRank(captures.reductionRank)) |
| // Op has a single most-minor reduction. |
| .dim(-1, utils::IteratorType::reduction) |
| // Capture op sizes. |
| .dim(AllDims(), CaptureDims(captures.reductionOpSizes)) |
| // All other dimensions are parallel. |
| .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) |
| // Single input for now, can be arbitrary projected permutations. |
| // TODO: Multiple inputs, can be arbitrary projected permutations. |
| // TODO: Watch out for multiple inputs though as a reduction turns |
| // into a contraction when mixed with projected |
| // permutations. A reduction is often bandwidth bound but |
| // contraction is a different beast that is compute bound |
| // and has a very different schedule. |
| // |
| .input(NumEqualsTo(1)) |
| .input(AllOperands(), IsProjectedPermutation()) |
| // Single output supported atm. |
| // TODO: Multiple outputs. |
| // |
| .output(NumEqualsTo(1)) |
| // A reduction output must be a projected permutation, match it but we |
| // could also drop this technically. |
| .output(AllOperands(), IsProjectedPermutation()) |
| // Only single combiner for now due to reduction warp |
| // distribution. |
| // TODO: relax this once reduction distribution is more powerful. |
| // |
| .output(0, CaptureElementTypeBitWidth( |
| captures.reductionOutputElementalTypeBitWidth)) |
| .output(0, SingleCombinerReduction()); |
| reductionCapture = &reduction; |
| |
| // Mandatory FillOp must create the unique output of the reduction. |
| // TODO: Relax this, as any map, broadcast, transpose should also work. |
| // |
| auto &fill = m_StructuredOp<linalg::FillOp>(matcherContext); |
| reduction = reduction.output(NumEqualsTo(1)).output(0, fill); |
| fillCapture = &fill; |
| |
| // Optional leading or trailing op can be any map, transpose, broadcast but |
| // not reduce or windowing operation for now. |
| // It must create the unique input for the reduction. |
| // TODO: match more optional leading ops, one per input of the reduction. |
| // TODO: careful about multi-output and turning into a contraction. |
| // |
| transform_ext::StructuredOpMatcher commonLeadingOrTrailing = |
| m_StructuredOp<linalg::GenericOp>(matcherContext) |
| // All parallel dimensions. |
| .dim(AllDims(), utils::IteratorType::parallel) |
| // All inputs are any projected permutation. |
| .input(AllOperands(), IsProjectedPermutation()) |
| .output(AllOperands(), IsPermutation()) |
| // leading and trailing may have 0, 1 or more input as long as they do |
| // not come from unmatched ops. This extra constraint is taken care of |
| // separately. This is also a noop but we document it. |
| // TODO: Base and derived classes, atm this does not compile. |
| // .input(NumGreaterEqualTo(0)) |
| // Single output supported atm. |
| // TODO: extend this. |
| // |
| .output(NumEqualsTo(1)); |
| // TODO: match more optional leading ops, one per input of the reduction. |
| // TODO: careful about multi-output and turning into a contraction. |
| // |
| auto &leading = |
| m_StructuredOp(matcherContext, commonLeadingOrTrailing) |
| .rank(CaptureRank(captures.maybeLeadingRank)) |
| // Capture op sizes. |
| .dim(AllDims(), CaptureDims(captures.leadingOpSizes)) |
| // Capture output elemental type. |
| .output(0, CaptureElementTypeBitWidth( |
| captures.maybeLeadingOutputElementalTypeBitWidth)); |
| reduction = reduction.input(0, leading, OptionalMatch()); |
| leadingCapture = &leading; |
| |
| // Optional trailing can be any map, transpose, broadcast but not reduce or |
| // windowing operation for now. |
| // It must be fed by the unique input for the reduction. |
| // TODO: match more optional leading ops, one per input of the reduction. |
| // TODO: careful about multi-output and turning into a contraction. |
| // |
| auto &trailing = |
| m_StructuredOp(matcherContext, commonLeadingOrTrailing) |
| .rank(CaptureRank(captures.maybeTrailingRank)) |
| // Capture op sizes. |
| .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) |
| // Capture output elemental type. |
| .output(0, CaptureElementTypeBitWidth( |
| captures.maybeTrailingOutputElementalTypeBitWidth)); |
| reduction = reduction.result(0, HasAnyUse(), trailing, OptionalMatch()); |
| if (mustMatchEntireFunc) |
| reduction = reduction.allTilableOpsCaptured<mlir::FunctionOpInterface>(); |
| trailingCapture = &trailing; |
| } |
| |
| void transform_ext::makeReductionMatcher(transform_ext::MatcherContext &context, |
| StructuredOpMatcher *&reductionCapture, |
| MatchedReductionCaptures &captures, |
| bool mustMatchEntireFunc) { |
| StructuredOpMatcher *fill; |
| StructuredOpMatcher *leading; |
| StructuredOpMatcher *trailing; |
| makeReductionMatcher(context, reductionCapture, fill, leading, trailing, |
| captures, mustMatchEntireFunc); |
| } |
| |
| void transform_ext::makeMatmulMatcher( |
| transform_ext::MatcherContext &matcherContext, |
| transform_ext::StructuredOpMatcher *&matmulCapture, |
| transform_ext::StructuredOpMatcher *&fillCapture, |
| transform_ext::StructuredOpMatcher *&trailingCapture, |
| transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { |
| auto &matmul = transform_ext::m_StructuredOp<linalg::MatmulOp>(matcherContext) |
| // Capture op sizes. |
| .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) |
| // Capture input/output element types. |
| .input(0, CaptureElementType(captures.lhsElementType)) |
| .input(1, CaptureElementType(captures.rhsElementType)) |
| .output(0, CaptureElementType(captures.outputElementType)); |
| matmulCapture = &matmul; |
| // Mandatory FillOp must create the unique output of the reduction. |
| auto &fill = transform_ext::m_StructuredOp<linalg::FillOp>(matcherContext); |
| matmul = matmul.output(transform_ext::NumEqualsTo(1)).output(0, fill); |
| fillCapture = &fill; |
| |
| auto &trailing = m_StructuredOp<linalg::GenericOp>(matcherContext); |
| matmul = matmul.result(0, HasAnyUse(), trailing, OptionalMatch()); |
| if (mustMatchEntireFunc) |
| matmul = matmul.allTilableOpsCaptured<mlir::FunctionOpInterface>(); |
| trailingCapture = &trailing; |
| } |
| |
| void transform_ext::makeBatchMatmulMatcher( |
| transform_ext::MatcherContext &matcherContext, |
| transform_ext::StructuredOpMatcher *&bmmCapture, |
| transform_ext::StructuredOpMatcher *&fillCapture, |
| transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) { |
| auto &bmm = |
| transform_ext::m_StructuredOp<linalg::BatchMatmulOp, linalg::GenericOp>( |
| matcherContext) |
| .hasContractionBody<arith::MulFOp, arith::AddFOp>() |
| .rank(NumEqualsTo(4)) |
| .dim(AllDims(), CaptureDims(captures.matmulOpSizes)) |
| .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) |
| .dim(-1, utils::IteratorType::reduction) |
| .contractionDims(CaptureContractionDims(captures.contractionDims)) |
| .input(NumEqualsTo(2)) |
| .input(0, CaptureElementType(captures.lhsElementType)) |
| .input(1, CaptureElementType(captures.rhsElementType)) |
| .output(0, CaptureElementType(captures.outputElementType)); |
| bmmCapture = &bmm; |
| |
| auto &fill = transform_ext::m_StructuredOp<linalg::FillOp>(matcherContext); |
| bmm = bmm.output(0, fill); |
| fillCapture = &fill; |
| |
| if (mustMatchEntireFunc) |
| bmm = bmm.allTilableOpsCaptured<mlir::FunctionOpInterface>(); |
| } |
| |
| /// Match sum(%src, broadcast(%reduction)) |
| static void |
| matchSubBroadcast(transform_ext::MatcherContext &matcherContext, |
| transform_ext::StructuredOpMatcher &maxReduction, |
| transform_ext::CapturingValueMatcher &softmaxSourceOperand, |
| transform_ext::StructuredOpMatcher *&sub) { |
| using namespace transform_ext; |
| |
| auto &broadcast = |
| transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .passThroughOp() |
| .dim(AllDims(), utils::IteratorType::parallel) |
| .input(NumEqualsTo(1)) |
| .input(0, IsProjected(-1)) |
| .output(NumEqualsTo(1)) |
| .output(AllOperands(), IsIdentity()); |
| broadcast = broadcast.input(0, maxReduction); |
| |
| auto &subParallel = |
| transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .singleOpWithCanonicaleArgs<arith::SubFOp>() |
| .dim(AllDims(), utils::IteratorType::parallel) |
| .input(NumEqualsTo(2)) |
| .input(0, IsIdentity()) |
| .input(1, IsIdentity()) |
| .output(NumEqualsTo(1)) |
| .output(AllOperands(), IsIdentity()); |
| subParallel = subParallel.input(0, softmaxSourceOperand); |
| subParallel = subParallel.input(1, broadcast); |
| |
| auto &subBroadcast = |
| transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .singleOpWithCanonicaleArgs<arith::SubFOp>() |
| .dim(AllDims(), utils::IteratorType::parallel) |
| .input(NumEqualsTo(2)) |
| .input(0, IsIdentity()) |
| .input(1, IsProjected(-1)) |
| .output(NumEqualsTo(1)) |
| .output(AllOperands(), IsIdentity()); |
| subBroadcast = subBroadcast.input(0, softmaxSourceOperand); |
| subBroadcast = subBroadcast.input(1, maxReduction); |
| auto &subOr = transform_ext::m_StructuredOp_Or(matcherContext, subBroadcast, |
| subParallel); |
| sub = &subOr; |
| } |
| |
| /// Match sum(%exp, broadcast(%sum)) |
| static void matchdivBroadcast(transform_ext::MatcherContext &matcherContext, |
| transform_ext::StructuredOpMatcher &expOperand, |
| transform_ext::StructuredOpMatcher &sum, |
| transform_ext::StructuredOpMatcher *&div) { |
| using namespace transform_ext; |
| |
| auto &broadcast = |
| transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .passThroughOp() |
| .dim(AllDims(), utils::IteratorType::parallel) |
| .input(NumEqualsTo(1)) |
| .input(0, IsProjected(-1)) |
| .output(NumEqualsTo(1)) |
| .output(AllOperands(), IsIdentity()); |
| broadcast = broadcast.input(0, sum); |
| |
| auto &divNoBroadcast = |
| transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .singleOpWithCanonicaleArgs<arith::DivFOp>() |
| .dim(AllDims(), utils::IteratorType::parallel) |
| .input(NumEqualsTo(2)) |
| .input(0, IsIdentity()) |
| .input(1, IsIdentity()) |
| .output(NumEqualsTo(1)) |
| .output(AllOperands(), IsIdentity()); |
| |
| divNoBroadcast = divNoBroadcast.input(0, expOperand); |
| divNoBroadcast = divNoBroadcast.input(1, broadcast); |
| |
| auto &divBroadcast = |
| transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .singleOpWithCanonicaleArgs<arith::DivFOp>() |
| .dim(AllDims(), utils::IteratorType::parallel) |
| .input(NumEqualsTo(2)) |
| .input(0, IsIdentity()) |
| .input(1, IsProjected(-1)) |
| .output(NumEqualsTo(1)) |
| .output(AllOperands(), IsIdentity()); |
| |
| divBroadcast = divBroadcast.input(0, expOperand); |
| divBroadcast = divBroadcast.input(1, sum); |
| |
| auto &divMerge = transform_ext::m_StructuredOp_Or( |
| matcherContext, divNoBroadcast, divBroadcast); |
| div = &divMerge; |
| } |
| |
| void transform_ext::makeSoftmaxMatcher( |
| transform_ext::MatcherContext &matcherContext, |
| transform_ext::StructuredOpMatcher *&maxReductionCapture, |
| transform_ext::StructuredOpMatcher *&softmaxRootCapture) { |
| auto &softmaxSourceOperand = m_Value(matcherContext); |
| |
| auto &fillMinusInf = m_StructuredOp<linalg::FillOp>(matcherContext) |
| .input(0, ConstantFloatMinOrMinusInf()); |
| auto &maxReduction = |
| transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .singleOpWithCanonicaleArgs<arith::MaximumFOp>(/*commutative=*/true) |
| // Only handle most inner reduction for now. |
| .dim(-1, utils::IteratorType::reduction) |
| .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) |
| .input(NumEqualsTo(1)) |
| .input(AllOperands(), IsIdentity()) |
| .output(NumEqualsTo(1)) |
| .output(AllOperands(), IsProjected(-1)); |
| maxReduction = maxReduction.input(0, softmaxSourceOperand); |
| maxReduction = maxReduction.output(0, fillMinusInf); |
| maxReductionCapture = &maxReduction; |
| |
| transform_ext::StructuredOpMatcher *subOperand; |
| matchSubBroadcast(matcherContext, maxReduction, softmaxSourceOperand, |
| subOperand); |
| |
| auto &expOperand = m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .singleOpWithCanonicaleArgs<math::ExpOp>() |
| .dim(AllDims(), utils::IteratorType::parallel) |
| .input(NumEqualsTo(1)) |
| .input(AllOperands(), IsIdentity()) |
| .output(AllOperands(), IsIdentity()) |
| .output(NumEqualsTo(1)); |
| expOperand = expOperand.input(0, *subOperand); |
| |
| auto &fillZero = m_StructuredOp<linalg::FillOp>(matcherContext) |
| .input(0, ConstantFloatZero()); |
| auto &sum = |
| m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .singleOpWithCanonicaleArgs<arith::AddFOp>(/*commutative=*/true) |
| // Only handle most inner reduction for now. |
| .dim(-1, utils::IteratorType::reduction) |
| .dim(AllDimsExcept({-1}), utils::IteratorType::parallel) |
| .input(NumEqualsTo(1)) |
| .input(AllOperands(), IsIdentity()) |
| .output(AllOperands(), IsProjected(-1)) |
| .output(NumEqualsTo(1)); |
| sum = sum.input(0, expOperand); |
| sum = sum.output(0, fillZero); |
| |
| auto &rcpOperand = m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .isFloatReciprocal() |
| .dim(AllDims(), utils::IteratorType::parallel) |
| .input(NumEqualsTo(1)) |
| .input(AllOperands(), IsIdentity()) |
| .output(AllOperands(), IsIdentity()) |
| .output(NumEqualsTo(1)); |
| rcpOperand = rcpOperand.input(0, sum); |
| |
| auto &mulOperand = |
| transform_ext::m_StructuredOp<linalg::GenericOp>(matcherContext) |
| .singleOpWithCanonicaleArgs<arith::MulFOp>(/*commutative=*/true) |
| .dim(AllDims(), utils::IteratorType::parallel) |
| .input(NumEqualsTo(2)) |
| .input(0, IsIdentity()) |
| .input(1, IsProjected(-1)) |
| .output(NumEqualsTo(1)) |
| .output(AllOperands(), IsIdentity()); |
| |
| mulOperand = mulOperand.input(0, expOperand); |
| mulOperand = mulOperand.input(1, rcpOperand); |
| |
| transform_ext::StructuredOpMatcher *divOperand; |
| matchdivBroadcast(matcherContext, expOperand, sum, divOperand); |
| |
| auto &softmaxRoot = |
| transform_ext::m_StructuredOp_Or(matcherContext, mulOperand, *divOperand); |
| softmaxRootCapture = &softmaxRoot; |
| } |
| |
| /// Matcher for convolutions. |
| void transform_ext::makeConvolutionMatcher( |
| transform_ext::MatcherContext &matcherContext, |
| transform_ext::StructuredOpMatcher *&convolutionCapture, |
| transform_ext::StructuredOpMatcher *&fillCapture, |
| transform_ext::StructuredOpMatcher *&trailingCapture, |
| MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { |
| // The core part of the matcher is anchored on a particular convolution op. |
| auto &convolution = |
| m_StructuredOp<linalg::Conv2DNchwFchwOp, linalg::Conv2DNhwcHwcfOp>( |
| matcherContext) |
| // Capture convolution dim classifications. |
| .convolutionDims(CaptureConvDims(captures.convolutionDims)) |
| // Capture op sizes. |
| .dim(AllDims(), CaptureDims(captures.convolutionOpSizes)) |
| // Capture convolution element types. |
| .input(0, CaptureElementType(captures.inputElementType)) |
| .input(1, CaptureElementType(captures.filterElementType)) |
| .output(0, CaptureElementType(captures.outputElementType)); |
| convolutionCapture = &convolution; |
| |
| // Optional FillOp to create the unique output of the convolution. |
| auto &fill = m_StructuredOp<linalg::FillOp>(matcherContext) |
| .output(0, CaptureElementTypeBitWidth( |
| captures.maybeFillElementalTypeBitWidth)); |
| convolution = |
| convolution.output(NumEqualsTo(1)).output(0, fill, OptionalMatch()); |
| fillCapture = &fill; |
| |
| // Optional trailing op can be any map, transpose, broadcast but |
| // not reduce or windowing operation for now. |
| // It must create the unique input for the reduction. |
| auto &trailing = |
| m_StructuredOp<linalg::GenericOp>(matcherContext) |
| // All parallel dimensions. |
| .dim(AllDims(), utils::IteratorType::parallel) |
| // All inputs are any projected permutation. |
| .input(AllOperands(), IsProjectedPermutation()) |
| .output(AllOperands(), IsPermutation()) |
| .output(NumEqualsTo(1)) |
| .dim(AllDims(), CaptureDims(captures.trailingOpSizes)) |
| // Capture output elemental type. |
| .output(0, CaptureElementTypeBitWidth( |
| captures.maybeTrailingOutputElementalTypeBitWidth)); |
| |
| // Optional trailing can be any map, transpose, broadcast but not reduce or |
| // windowing operation for now. |
| convolution = convolution.result(0, HasAnyUse(), trailing, OptionalMatch()); |
| if (mustMatchEntireFunc) |
| convolution = |
| convolution.allTilableOpsCaptured<mlir::FunctionOpInterface>(); |
| trailingCapture = &trailing; |
| } |
| |
| void transform_ext::makeConvolutionMatcher( |
| transform_ext::MatcherContext &context, |
| StructuredOpMatcher *&convolutionCapture, |
| MatchedConvolutionCaptures &captures, bool mustMatchEntireFunc) { |
| StructuredOpMatcher *fill; |
| StructuredOpMatcher *trailing; |
| makeConvolutionMatcher(context, convolutionCapture, fill, trailing, captures, |
| mustMatchEntireFunc); |
| } |
| |
| void transform_ext::makePadMatcher(MatcherContext &context, |
| CapturingOpMatcher *&padCapture, |
| MatchedPadCaptures &captures, |
| bool mustMatchEntireFunc) { |
| auto &value = transform_ext::m_ShapedValue(context); |
| value.rank(transform_ext::CaptureRank(captures.rank)) |
| .dim(transform_ext::AllDims(), transform_ext::CaptureDims(captures.dims)) |
| .elementType(CaptureElementType(captures.elementType)); |
| auto &opMatcher = transform_ext::m_tensorPad(context) |
| .result(0, value) |
| .low(AllDims(), 0) |
| .yieldsExternalValue(); |
| if (mustMatchEntireFunc) |
| opMatcher = opMatcher.allTilableOpsCaptured<mlir::FunctionOpInterface>(); |
| padCapture = &opMatcher; |
| } |