blob: bbc8ae31307262e077a48642af544d14be1242ab [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include <utility>
namespace mlir {
namespace linalg {
// Pure C++ functional patterns requires some type plumbing.
namespace detail {
template <typename OpT>
struct ConvertOrForward {
static OpT to(LinalgOp op) { return cast<OpT>(op.getOperation()); }
static LinalgOp from(OpT op) { return cast<LinalgOp>(op.getOperation()); }
};
template <>
struct ConvertOrForward<LinalgOp> {
static LinalgOp to(LinalgOp op) { return op; }
static LinalgOp from(LinalgOp op) { return op; }
};
} // namespace detail
/// Wrap a call to a Linalg pattern where the input is a `LinalgOp` and the
/// output is a `LinalgOp`.
template <typename FunctionalLinalgPattern, typename... Args>
auto callLinalgPattern(Args &&... args) {
FunctionalLinalgPattern pattern(std::forward<Args>(args)...);
using Traits = llvm::function_traits<decltype(
&FunctionalLinalgPattern::returningMatchAndRewrite)>;
using OpT = typename Traits::template arg_t<0>;
return
[pattern = std::move(pattern)](
LinalgOp linalgOp, PatternRewriter &rewriter) -> FailureOr<LinalgOp> {
OpT op = detail::ConvertOrForward<OpT>::to(linalgOp);
auto result = pattern.returningMatchAndRewrite(op, rewriter);
if (failed(result))
return failure();
return detail::ConvertOrForward<decltype(*result)>::from(*result);
};
}
} // namespace linalg
} // namespace mlir