blob: 59d3251441c09e2e19da1b82fb26b1ec050bf0cd [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Utils/PatternUtils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace Shape {
//===----------------------------------------------------------------------===//
// Canonicalization
//===----------------------------------------------------------------------===//
static LogicalResult identityMakeRankedShapePattern(
MakeRankedShapeOp op, MakeRankedShapeOp::Adaptor operands,
PatternRewriter &rewriter) {
if (operands.dynamic_dimensions().empty()) {
// Do not match static shapes.
return failure();
}
// Detects make_ranked_shape ops whose dynamic dimensions are provided by
// ranked_dim ops that extract dimensions from an identical ranked_shape.
auto rankedShape = op.getRankedShapeType();
RankedDimOp commonRankedDimOp;
unsigned previousProvidingIndex = 0;
for (auto providingDim : operands.dynamic_dimensions()) {
auto rankedDimOp =
llvm::dyn_cast_or_null<RankedDimOp>(providingDim.getDefiningOp());
if (!rankedDimOp) return failure();
// Shapes must match and refer to a dynamic index.
unsigned providingIndex = rankedDimOp.getIndex();
if (rankedDimOp.getRankedShapeType() != rankedShape ||
!rankedShape.isDimDynamic(providingIndex)) {
return failure();
}
if (commonRankedDimOp) {
// Not first dim: verify same providing shape and indexes into next
// dynamic dim.
if (rankedDimOp.shape() != commonRankedDimOp.shape() ||
providingIndex <= previousProvidingIndex) {
return failure();
}
}
commonRankedDimOp = rankedDimOp;
previousProvidingIndex = rankedDimOp.getIndex();
}
// Fall-through: this op produces an identical shape as
// commonRankedDimOp.
assert(commonRankedDimOp &&
"dynamic ranked_shape did not find a common provider");
rewriter.replaceOp(op, commonRankedDimOp.shape());
return success();
}
// TODO(silvasean): Better handling of "erase unused ops for legality".
// Currently, the way that we legalize !shapex.ranked_shape into individual SSA
// values per dimension is to iteratively reduce other ops to
// shapex.ranked_dim/shapex.ranked_dims and shapex.make_ranked_shape and then
// have patterns that know how to resolve the
// shapex.ranked_dim/shapex.ranked_dims to scalar values by looking through the
// shapex.make_ranked_shape ops, with the eventual goal of not having any uses
// of the shapex.make_ranked_shape op itself, instead the main computation flow
// using the individual SSA values. This naturally produces a lot of unused
// shapex.make_ranked_shape ops which we need to delete for legality reasons.
// This pattern allows conversions to erase those ops.
static LogicalResult eraseUnusedMakeRankedShapeOp(
MakeRankedShapeOp op, MakeRankedShapeOp::Adaptor operands,
PatternRewriter &rewriter) {
if (!op.getResult().use_empty())
return rewriter.notifyMatchFailure(op, "op has uses");
rewriter.eraseOp(op);
return success();
}
static LogicalResult dynamicMakeRankedShapeDimPattern(
RankedDimOp op, RankedDimOp::Adaptor operands, PatternRewriter &rewriter) {
// If the immediate predecessor is a MakeRankedShapeOp, then this op can be
// erased in favor of the corresponding input to that op.
auto shapeInput = operands.shape();
auto makeRsOp =
dyn_cast_or_null<MakeRankedShapeOp>(shapeInput.getDefiningOp());
if (!makeRsOp) return failure();
RankedShapeType rsType = shapeInput.getType().cast<RankedShapeType>();
unsigned index = op.getIndex();
auto allDims = rsType.getAllDims();
assert(index < allDims.size());
if (allDims[index] >= 0) {
// Not dynamic.
return failure();
}
// Map the overall index to the dynamic dim index.
int dynamicDimIndex = 0;
for (unsigned i = 0; i < index; ++i) {
if (allDims[i] < 0) dynamicDimIndex++;
}
assert(dynamicDimIndex < makeRsOp.dynamic_dimensions().size());
rewriter.replaceOp(op, makeRsOp.dynamic_dimensions()[dynamicDimIndex]);
return success();
}
static LogicalResult elideDuplicateTieShapePattern(TieShapeOp op,
TieShapeOp::Adaptor operands,
PatternRewriter &rewriter) {
// If the immediate predecessor is a TieShapeOp, then it can be possible
// to merge these. This can often happen when function/block tie_shape
// placeholders are inserted prior to materializing later parts of the
// computation.
auto precedingTieShapeOp =
dyn_cast_or_null<TieShapeOp>(operands.operand().getDefiningOp());
if (!precedingTieShapeOp) return failure();
if (operands.shape() != precedingTieShapeOp.shape()) {
// This can happen in intermediate states before all shape calculations
// are collapsed (i.e. the shapes may actually be equivalent but
// constructed through different branches).
return failure();
}
rewriter.replaceOp(op, precedingTieShapeOp.result());
return success();
}
// Removes tie_shape ops when the operand is produced by a shape-aware op.
static LogicalResult elideShapeCarryingOperandTieShapePattern(
TieShapeOp op, TieShapeOp::Adaptor operands, PatternRewriter &rewriter) {
auto definingOp = operands.operand().getDefiningOp();
if (!definingOp) return failure();
if (isa<TieShapeOp>(definingOp)) {
return failure(); // ignore tie-shape handled above
} else if (isa<ShapeCarryingInterface>(definingOp)) {
rewriter.replaceOp(op, operands.operand());
return success();
} else {
return failure();
}
}
// Reroutes uses of tie_shape ops by ops that are shape-aware or dim ops.
static LogicalResult elideTieShapeUsagePattern(TieShapeOp op,
TieShapeOp::Adaptor operands,
PatternRewriter &rewriter) {
bool didAnything = false;
for (auto &use : llvm::make_early_inc_range(op.result().getUses())) {
if (auto carryingOp = dyn_cast<ShapeCarryingInterface>(use.getOwner())) {
carryingOp->setOperand(use.getOperandNumber(), operands.operand());
didAnything = true;
} else if (auto dimOp = dyn_cast<tensor::DimOp>(use.getOwner())) {
auto index = dimOp.getConstantIndex();
if (index.hasValue()) {
rewriter.replaceOpWithNewOp<RankedDimOp>(dimOp, op.shape(),
index.getValue());
didAnything = true;
}
}
}
return didAnything ? success() : failure();
}
//===----------------------------------------------------------------------===//
// shapex.tie_shape
//===----------------------------------------------------------------------===//
void TieShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
MLIRContext *context) {
insertGreedyPattern(patterns, context, elideDuplicateTieShapePattern);
insertGreedyPattern(patterns, context,
elideShapeCarryingOperandTieShapePattern);
insertGreedyPattern(patterns, context, elideTieShapeUsagePattern);
}
//===----------------------------------------------------------------------===//
// shapex.make_ranked_shape
//===----------------------------------------------------------------------===//
void MakeRankedShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
insertGreedyPattern(patterns, context, identityMakeRankedShapePattern);
}
//===----------------------------------------------------------------------===//
// shapex.ranked_dim
//===----------------------------------------------------------------------===//
OpFoldResult RankedDimOp::fold(ArrayRef<Attribute> operand) {
auto rsType = shape().getType().cast<RankedShapeType>();
int index = getIndex();
if (!rsType.isDimDynamic(index)) {
auto dimSize = rsType.getStaticDim(index);
return IntegerAttr::get(getType(), dimSize);
}
return {};
}
void RankedDimOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
insertGreedyPattern(patterns, context, dynamicMakeRankedShapeDimPattern);
}
//===----------------------------------------------------------------------===//
// Standard folding and canonicalization conversion patterns.
//===----------------------------------------------------------------------===//
// Since tie_shape ops are an identity, a pattern must exist for type conversion
// to properly propagate across the operand->result edge.
struct TieShapeTypeConversionPattern : public OpConversionPattern<TieShapeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TieShapeOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type operandType = adaptor.operand().getType();
if (operandType == srcOp.getType()) {
return failure();
}
rewriter.replaceOpWithNewOp<TieShapeOp>(srcOp, operandType,
adaptor.operand(), adaptor.shape());
return success();
}
};
void populateFoldConversionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns.insert<TieShapeTypeConversionPattern>(context);
insertConversionPattern(patterns, context, eraseUnusedMakeRankedShapeOp);
insertConversionPattern(patterns, context, dynamicMakeRankedShapeDimPattern);
insertConversionPattern(patterns, context, elideDuplicateTieShapePattern);
insertConversionPattern(patterns, context,
elideShapeCarryingOperandTieShapePattern);
insertConversionPattern(patterns, context, elideTieShapeUsagePattern);
insertConversionPattern(patterns, context, identityMakeRankedShapePattern);
}
} // namespace Shape
} // namespace iree_compiler
} // namespace mlir