blob: c6b01d9a6996a1dd635ebc0dbb0a33837ee9647f [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
namespace IREE = mlir::iree_compiler::IREE;
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace LinalgExt {
// Marker used as attribute name in generated Linalg rewriting transformations.
const StringLiteral LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
LinalgTransformationFilter::LinalgTransformationFilter(
ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement), matchByDefault(false) {}
LinalgTransformationFilter::LinalgTransformationFilter(
const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
Optional<StringAttr> replacement)
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement), matchByDefault(false) {
if (f)
filters.push_back(f);
}
LogicalResult
LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
Operation *op) const {
if (llvm::any_of(filters,
[&](const FilterFunction &f) { return failed(f(op)); }))
return failure();
auto attr = op->template getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker);
if (!attr) {
// 1. Has no filter case and matchDisjunction is empty.
if (matchDisjunction.empty() || matchByDefault)
return success();
// 2. Has no filter but was expecting a filter.
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << " does not have any filter from list: ";
interleaveComma(matchDisjunction, diag);
});
}
// 4. Match explicit filter.
for (auto filter : matchDisjunction)
if (attr.getValue() == filter)
return success();
// 5. Fail to match.
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << " does not have any filter from list: ";
interleaveComma(matchDisjunction, diag);
});
}
void LinalgTransformationFilter::replaceLinalgTransformationFilter(
PatternRewriter &rewriter, Operation *op) const {
if (replacement.has_value())
op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value());
else
op->removeAttr(
rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker));
}
bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const {
if (!replacement)
return false;
auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker)
.dyn_cast<StringAttr>();
return attr && attr == *replacement;
}
namespace detail {
#define GEN_PASS_REGISTRATION
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: export
} // namespace detail
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
void IREE::LinalgExt::registerPasses() {
IREE::LinalgExt::detail::registerPasses();
}