blob: e8d94774862b7c13ae298db9fd586854f851fbf5 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "PDL.h"
#include "iree-dialects/Transforms/Functional.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/ADT/ScopeExit.h"
namespace mlir {
namespace linalg {
/// Return ops that match any of the patterns.
static SmallVector<Operation *>
getMatchingOps(Operation *parent, const FrozenRewritePatternSet &patterns) {
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
// TODO: The C++ functional API needs better interoperability with PDL.
return functional::applyForEachIn(
parent,
[&](Operation *op, PatternRewriter &rewriter) -> FailureOr<Operation *> {
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
return op;
return failure();
});
}
/// Hook for PDL driver to check if an operation (`value`) is directly nested in
/// a function with the name provided as constant parameter.
/// TODO: PDL needs user-defined "questions".
static LogicalResult nestedInFunc(PDLValue value, ArrayAttr constantParams,
PatternRewriter &rewriter) {
auto *operation = value.cast<Operation *>();
auto func = operation->getParentOfType<FuncOp>();
assert(constantParams.size() == 1 &&
"expected a constant param with function name");
auto functionSymbol = constantParams[0].dyn_cast<SymbolRefAttr>();
assert(functionSymbol && "expected a function name");
if (!func)
return rewriter.notifyMatchFailure(operation, "not nested in a function");
return success(functionSymbol.getLeafReference() == func.getName());
}
/// PDL rewrite hook that does nothing.
static void noOpRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
PatternRewriter &rewriter, PDLResultList &results) {
assert(args.size() == 1 && "expected one argument");
#ifndef NDEBUG
args.front().cast<Operation *>()->setAttr("iree_linalg_transform.matched",
rewriter.getUnitAttr());
#endif
}
/// Construct a BlockAndValueMapping from `linalgOp` to `genericLinalgModelOp`.
/// Walk both ops and check whether all subops are the same.
static LogicalResult haveIdenticalBodiesImpl(LinalgOp linalgOp,
LinalgOp genericLinalgModelOp) {
BlockAndValueMapping bvm;
bvm.map(linalgOp.getBlock()->getArguments(),
genericLinalgModelOp.getBlock()->getArguments());
SmallVector<Operation *> linalgBodyOps;
linalgOp.getBlock()->walk(
[&](Operation *op) { linalgBodyOps.push_back(op); });
unsigned idx = 0;
WalkResult res = genericLinalgModelOp.getBlock()->walk([&](Operation *op) {
Operation *linalgSubOp = linalgBodyOps[idx++];
if (op->getName() != linalgSubOp->getName())
return WalkResult::interrupt();
if (op->getAttrs() != linalgSubOp->getAttrs())
return WalkResult::interrupt();
for (auto it : llvm::zip(op->getOperands(), linalgSubOp->getOperands()))
if (std::get<0>(it) != bvm.lookupOrNull(std::get<1>(it)))
return WalkResult::interrupt();
bvm.map(linalgSubOp->getResults(), op->getResults());
return WalkResult::advance();
});
return success(!res.wasInterrupted());
}
/// Dispatch body equivalence check depending on case.
static LogicalResult haveEquivalentBodies(LinalgOp linalgOp,
LinalgOp genericLinalgModelOp,
PatternRewriter &rewriter) {
if (succeeded(haveIdenticalBodiesImpl(linalgOp, genericLinalgModelOp)))
return success();
// TODO: haveEquivalentBodiesImpl, see e.g.
// https://gist.github.com/nicolasvasilache/39e89e18c46e02335c16db6ec20a07e3
return failure();
}
/// Succeed when `linalgOp` and `linalgModelOp` are deemed equivalent.
static LogicalResult isEquivalentToOpImpl(LinalgOp linalgOp,
LinalgOp linalgModelOp,
PatternRewriter &rewriter) {
// If basic properties do not match, return failure.
if (linalgOp.inputs() != linalgModelOp.inputs() ||
linalgOp.outputs() != linalgModelOp.outputs() ||
linalgOp.indexing_maps() != linalgModelOp.indexing_maps() ||
linalgOp.iterator_types() != linalgModelOp.iterator_types())
return failure();
// Build the block and go perform a body comparison.
{
// createBlock moves the insertion point, scope it in an RAII block.
OpBuilder::InsertionGuard guard(rewriter);
Region &r = linalgModelOp->getRegion(0);
Block *bodyBlock = rewriter.createBlock(
&r, r.end(), linalgOp.getBlock()->getArgumentTypes(),
llvm::to_vector<4>(
llvm::map_range(linalgOp.getBlock()->getArguments(),
[](Value v) { return v.getLoc(); })));
ImplicitLocOpBuilder b(linalgModelOp.getLoc(), rewriter);
auto regionBuilder = linalgModelOp.getRegionBuilder();
llvm::ArrayRef<mlir::NamedAttribute> attrs = {};
regionBuilder(b, *bodyBlock, attrs);
}
return haveEquivalentBodies(linalgOp, linalgModelOp, rewriter);
}
/// Check whether the unique Operation* `op` stored in `value` (assumed) is
/// equivalent to the unique StringRefAttr operation name passed in
/// `constantParams`.
/// Equivalence is achieved when either:
/// 1. `op` has the name stored in `constantParams`.
/// 2. `op` and `constantParams` are both linalg ops and their structured
/// interfaces as well as their bodies are equivalent.
/// Structured interfaces equivalence is a simple attribute level check.
/// Body equivalence is more involved and currently limited:
/// a. the current impl constructs an instance of the op whose name is
/// specified in `constantParams` and checks for exact body equality.
/// b. a more advanced version would "subtract" the bodies and fold, cse
/// and canonicalize to fixed point. If the result is "all zeros",
/// then the bodies would be equivalent (really isomorphic).
/// 3. other cases TBD (e.g. vector.generic when available).
static LogicalResult isEquivalentToOp(PDLValue value, ArrayAttr constantParams,
PatternRewriter &rewriter) {
auto *maybeOp = value.dyn_cast<Operation *>();
if (!maybeOp)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
Operation *op = maybeOp;
ArrayRef<Attribute> attrs = constantParams.getValue();
if (attrs.size() != 1)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
auto modelOpNameAttr = attrs.front().dyn_cast<StringAttr>();
if (!modelOpNameAttr)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
auto modelOpName = modelOpNameAttr.strref();
// 1. If op has name `modelOpName`, the match is trivial.
if (op->getName().getStringRef() == modelOpName)
return success();
// 2. Linalg vs Linalg.
// Create op from `constantParams`.
OperationState modelOpState(op->getLoc(), modelOpName, op->getOperands(),
op->getResultTypes(), op->getAttrs());
modelOpState.addRegion();
Operation *modelOp = rewriter.createOperation(modelOpState);
auto g1 = llvm::make_scope_exit([&]() { rewriter.eraseOp(modelOp); });
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
LinalgOp linalgModelOp = dyn_cast<LinalgOp>(modelOp);
if (linalgOp && linalgModelOp)
return isEquivalentToOpImpl(linalgOp, linalgModelOp, rewriter);
// 3. TBD
return failure();
}
/// Assume that:
/// 1. `value` is an operands range
/// 2. `constantParams` contains a DictAttr with `operand_number`, `dim` and
/// `divisor` IntegerAttr entries.
/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is a
/// multiple of `divisor`.
/// Note: 0 is the convention to express "do not tile", it is considered to
/// divide everything.
static LogicalResult isDimMultipleOf(PDLValue value, ArrayAttr constantParams,
PatternRewriter &rewriter) {
auto maybeOperands = value.dyn_cast<ValueRange>();
if (!maybeOperands)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
auto operands = *maybeOperands;
auto dict = constantParams.begin()->dyn_cast<DictionaryAttr>();
if (!dict)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
int64_t dim;
auto dimAttr = dict.getAs<IntegerAttr>("dim");
if (!dimAttr)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
dim = dimAttr.getInt();
int64_t divisor;
auto divisorAttr = dict.getAs<IntegerAttr>("divisor");
if (!divisorAttr)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
divisor = divisorAttr.getInt();
int64_t operandNumber;
auto operandNumberAttr = dict.getAs<IntegerAttr>("operand_number");
if (!operandNumberAttr)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
operandNumber = operandNumberAttr.getInt();
ShapedType shapedType;
if (static_cast<int64_t>(operands.size()) > operandNumber)
shapedType = operands[operandNumber].getType().dyn_cast<ShapedType>();
if (!shapedType || shapedType.getRank() <= dim)
return failure();
return success(divisor == 0 || (shapedType.getShape()[dim] > 0 &&
shapedType.getShape()[dim] % divisor == 0));
}
/// Assume that:
/// 1. `value` is an operands range
/// 2. `constantParams` contains a DictAttr with `operand_number` and `dim`
/// IntegerAttr entries.
/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is
/// dynamic.
static LogicalResult isDimStatic(PDLValue value, ArrayAttr constantParams,
PatternRewriter &rewriter) {
auto maybeOperands = value.dyn_cast<ValueRange>();
if (!maybeOperands)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
auto operands = *maybeOperands;
auto dict = constantParams.begin()->dyn_cast<DictionaryAttr>();
if (!dict)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
int64_t dim;
auto dimAttr = dict.getAs<IntegerAttr>("dim");
if (!dimAttr)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
dim = dimAttr.getInt();
int64_t operandNumber;
auto operandNumberAttr = dict.getAs<IntegerAttr>("operand_number");
if (!operandNumberAttr)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
operandNumber = operandNumberAttr.getInt();
ShapedType shapedType;
if (static_cast<int64_t>(operands.size()) > operandNumber)
shapedType = operands[operandNumber].getType().dyn_cast<ShapedType>();
return success(shapedType && !shapedType.isDynamicDim(dim));
}
/// Assume that:
/// 1. `value` is an operands range
/// 2. `constantParams` contains a DictAttr with `operand_number` and `dim`
/// IntegerAttr entries.
/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is
/// dynamic.
static LogicalResult isDimDynamic(PDLValue value, ArrayAttr constantParams,
PatternRewriter &rewriter) {
auto maybeOperands = value.dyn_cast<ValueRange>();
if (!maybeOperands)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
auto operands = *maybeOperands;
auto dict = constantParams.begin()->dyn_cast<DictionaryAttr>();
if (!dict)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
int64_t dim;
auto dimAttr = dict.getAs<IntegerAttr>("dim");
if (!dimAttr)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
dim = dimAttr.getInt();
int64_t operandNumber;
auto operandNumberAttr = dict.getAs<IntegerAttr>("operand_number");
if (!operandNumberAttr)
return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
operandNumber = operandNumberAttr.getInt();
ShapedType shapedType;
if (static_cast<int64_t>(operands.size()) > operandNumber)
shapedType = operands[operandNumber].getType().dyn_cast<ShapedType>();
return success(shapedType && shapedType.isDynamicDim(dim));
}
FailureOr<SmallVector<Operation *>> findMatchingOps(transform::MatchOp matchOp,
SymbolRefAttr pattern,
Operation *containerOp) {
auto symbolTableOp = matchOp->getParentWithTrait<OpTrait::SymbolTable>();
if (!symbolTableOp)
return {symbolTableOp->emitError("no parent op with a SymbolTable")};
auto patternOp = dyn_cast_or_null<pdl::PatternOp>(
SymbolTable::lookupSymbolIn(symbolTableOp, pattern));
if (!patternOp)
return {symbolTableOp->emitError("could not find a pattern named: ")
<< pattern};
// Clone the pattern operation into the temporary module used by the driver
// as it might be referenced multiple times.
OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
OpBuilder::atBlockBegin(pdlModuleOp->getBody()).clone(*patternOp);
// Build the PDL module.
PDLPatternModule pdlModule(std::move(pdlModuleOp));
pdlModule.registerConstraintFunction("nestedInFunc", nestedInFunc);
pdlModule.registerConstraintFunction("isDimDynamic", isDimDynamic);
pdlModule.registerConstraintFunction("isDimMultipleOf", isDimMultipleOf);
pdlModule.registerConstraintFunction("isDimStatic", isDimStatic);
pdlModule.registerConstraintFunction("isEquivalentToOp", isEquivalentToOp);
pdlModule.registerRewriteFunction("iree_linalg_transform.apply",
noOpRewriter);
RewritePatternSet patterns(std::move(pdlModule));
return getMatchingOps(containerOp, std::move(patterns));
}
} // namespace linalg
} // namespace mlir