blob: c599c40017bd6c5e869d56baf37b21ea40cd0993 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Patterns that have a direct lowering from TF to Linalg and IREE. For these,
// we use a high benefit and lower them directly. Some of these are temporary
// while additional work lands upstream.
#include "iree_tf_compiler/TF/Passes.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/PatternMatch.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace TFOps = mlir::TF;
namespace mlir {
namespace iree_integrations {
namespace TF {
namespace {
static PatternBenefit OVERRIDE_BENEFIT = 1000;
struct ConvertExplicitSqueezePattern
: public OpRewritePattern<TFOps::SqueezeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TFOps::SqueezeOp op,
PatternRewriter &rewriter) const override {
RankedTensorType inputType =
op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType resultType = op.getType().dyn_cast<RankedTensorType>();
if (!resultType) {
// This will happen if shape inference could not determine a rank,
// which we do not support.
return rewriter.notifyMatchFailure(op, "not ranked result");
}
auto reassociationIndices =
mlir::getReassociationIndicesForReshape(inputType, resultType);
if (!reassociationIndices) {
return rewriter.notifyMatchFailure(
op, "could not compute reassociation indices");
}
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, resultType, op.input(), *reassociationIndices);
return success();
}
};
// Converts a tf.expand_dims op with a constant dim directly to a linalg
// expanding reshape.
struct ConvertConstExpandDimsPattern
: public OpRewritePattern<TFOps::ExpandDimsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TFOps::ExpandDimsOp op,
PatternRewriter &rewriter) const override {
RankedTensorType inputType =
op.input().getType().dyn_cast<RankedTensorType>();
RankedTensorType resultType = op.getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return rewriter.notifyMatchFailure(op, "not ranked");
}
DenseIntElementsAttr dimAttr;
if (!matchPattern(op.dim(), m_Constant(&dimAttr))) {
return rewriter.notifyMatchFailure(op, "not constant dim");
}
int expandDim = (*dimAttr.value_begin<APInt>()).getSExtValue();
auto dims = llvm::to_vector<6>(resultType.getShape());
if (expandDim < 0) {
expandDim += dims.size();
if (expandDim < 0) {
return rewriter.notifyMatchFailure(op, "illegal insertion dim");
}
}
if (expandDim >= dims.size()) {
return rewriter.notifyMatchFailure(op, "illegal insertion dim");
}
dims[expandDim] = 1;
RankedTensorType expandedType =
RankedTensorType::get(dims, resultType.getElementType());
if (expandedType != resultType) {
return rewriter.notifyMatchFailure(
op, "inferred expanded type not equal to result type");
}
auto reassociationIndices =
mlir::getReassociationIndicesForReshape(inputType, expandedType);
if (!reassociationIndices) {
return rewriter.notifyMatchFailure(
op, "could not compute reassociation indices");
}
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, expandedType, op.input(), *reassociationIndices);
return success();
}
};
} // namespace
void populateDirectLoweringPatterns(MLIRContext *context,
RewritePatternSet &patterns) {
patterns.insert<ConvertConstExpandDimsPattern>(context, OVERRIDE_BENEFIT);
patterns.insert<ConvertExplicitSqueezePattern>(context, OVERRIDE_BENEFIT);
}
} // namespace TF
} // namespace iree_integrations
} // namespace mlir