blob: b96665ead07930ec8b871cfc55bc1abe5d34d2fe [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Utils/PatternUtils.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace Shape {
//===----------------------------------------------------------------------===//
// Canonicalization
//===----------------------------------------------------------------------===//
LogicalResult safeCastCompatibleShapePattern(
CastCompatibleShapeOp op, CastCompatibleShapeOpOperandAdaptor operands,
PatternRewriter &rewriter) {
// TODO(laurenzo): This is just eliding if everything is the same. Make
// it generic.
auto resultRs = op.result().getType().dyn_cast<RankedShapeType>();
if (resultRs) {
// Casting to a ranked shape.
for (auto operand : operands.operands()) {
auto operandType = operand.getType();
auto operandRs = operandType.dyn_cast<RankedShapeType>();
if (!operandRs || operandRs != resultRs) {
return failure();
}
}
rewriter.replaceOp(op, operands.operands()[0]);
return success();
}
return failure();
}
LogicalResult elideTiedGetRankedShapePattern(
GetRankedShapeOp op, GetRankedShapeOpOperandAdaptor operands,
PatternRewriter &rewriter) {
// If the immediate predecessor is a TieShapeOp, then this op can be
// erased in favor of the input to the tie op.
auto tieOp = dyn_cast_or_null<TieShapeOp>(operands.operand().getDefiningOp());
if (!tieOp) return failure();
rewriter.replaceOp(op, tieOp.shape());
return success();
}
LogicalResult elideDuplicateGetRankedShapePattern(
GetRankedShapeOp op, GetRankedShapeOpOperandAdaptor operands,
PatternRewriter &rewriter) {
// If the immediate predecessor is a GetRankedShapeOp, then this op can be
// erased in favor of the input to the tie op.
auto precedingGetRankedShapeOp =
dyn_cast_or_null<GetRankedShapeOp>(operands.operand().getDefiningOp());
if (!precedingGetRankedShapeOp) return failure();
rewriter.replaceOp(op, precedingGetRankedShapeOp.shape());
return success();
}
LogicalResult elideStaticGetRankedShapePattern(
GetRankedShapeOp op, GetRankedShapeOpOperandAdaptor operands,
PatternRewriter &rewriter) {
auto operandType = operands.operand().getType().dyn_cast<RankedTensorType>();
auto resultShapeType = op.shape().getType().dyn_cast<RankedShapeType>();
if (!operandType || !resultShapeType || !operandType.hasStaticShape()) {
return failure();
}
rewriter.replaceOpWithNewOp<ConstRankedShapeOp>(op, resultShapeType);
return success();
}
LogicalResult identityMakeRankedShapePattern(
MakeRankedShapeOp op, MakeRankedShapeOpOperandAdaptor operands,
PatternRewriter &rewriter) {
if (operands.dynamic_dimensions().empty()) {
// Do not match static shapes.
return failure();
}
// Detects make_ranked_shape ops whose dynamic dimensions are provided by
// ranked_dim ops that extract dimensions from an identical ranked_shape.
auto rankedShape = op.getRankedShapeType();
RankedDimOp commonRankedDimOp;
unsigned previousProvidingIndex = 0;
for (auto providingDim : operands.dynamic_dimensions()) {
auto rankedDimOp =
llvm::dyn_cast_or_null<RankedDimOp>(providingDim.getDefiningOp());
if (!rankedDimOp) return failure();
// Shapes must match and refer to a dynamic index.
unsigned providingIndex = rankedDimOp.getIndex();
if (rankedDimOp.getRankedShapeType() != rankedShape ||
!rankedShape.isDimDynamic(providingIndex)) {
return failure();
}
if (commonRankedDimOp) {
// Not first dim: verify same providing shape and indexes into next
// dynamic dim.
if (rankedDimOp.shape() != commonRankedDimOp.shape() ||
providingIndex <= previousProvidingIndex) {
return failure();
}
}
commonRankedDimOp = rankedDimOp;
previousProvidingIndex = rankedDimOp.getIndex();
}
// Fall-through: this op produces an identical shape as
// commonRankedDimOp.
assert(commonRankedDimOp &&
"dynamic ranked_shape did not find a common provider");
rewriter.replaceOp(op, commonRankedDimOp.shape());
return success();
}
LogicalResult dynamicMakeRankedShapeDimPattern(
RankedDimOp op, RankedDimOpOperandAdaptor operands,
PatternRewriter &rewriter) {
// If the immediate predecessor is a MakeRankedShapeOp, then this op can be
// erased in favor of the corresponding input to that op.
auto shapeInput = operands.shape();
auto makeRsOp =
dyn_cast_or_null<MakeRankedShapeOp>(shapeInput.getDefiningOp());
if (!makeRsOp) return failure();
RankedShapeType rsType = shapeInput.getType().cast<RankedShapeType>();
unsigned index = op.getIndex();
auto allDims = rsType.getAllDims();
assert(index < allDims.size());
if (allDims[index] >= 0) {
// Not dynamic.
return failure();
}
// Map the overall index to the dynamic dim index.
int dynamicDimIndex = 0;
for (unsigned i = 0; i < index; ++i) {
if (allDims[i] < 0) dynamicDimIndex++;
}
assert(dynamicDimIndex < makeRsOp.dynamic_dimensions().size());
rewriter.replaceOp(op, makeRsOp.dynamic_dimensions()[dynamicDimIndex]);
return success();
}
LogicalResult expandRankedShapeDimsPattern(RankedDimsOp op,
RankedDimsOpOperandAdaptor operands,
PatternRewriter &rewriter) {
auto shapeInput = operands.shape();
auto rsType = shapeInput.getType().cast<RankedShapeType>();
SmallVector<Value, 4> dims(rsType.getRank());
for (int i = 0; i < rsType.getRank(); ++i) {
dims[i] = rewriter.createOrFold<RankedDimOp>(
op.getLoc(), op.getResult(i).getType(), shapeInput, i);
}
rewriter.replaceOp(op, dims);
return success();
}
LogicalResult elideDuplicateTieShapePattern(TieShapeOp op,
TieShapeOpOperandAdaptor operands,
PatternRewriter &rewriter) {
// If the immediate predecessor is a TieShapeOp, then it can be possible
// to merge these. This can often happen when function/block tie_shape
// placeholders are inserted prior to materializing later parts of the
// computation.
auto precedingTieShapeOp =
dyn_cast_or_null<TieShapeOp>(operands.operand().getDefiningOp());
if (!precedingTieShapeOp) return failure();
if (operands.shape() != precedingTieShapeOp.shape()) {
// This can happen in intermediate states before all shape calculations
// are collapsed (i.e. the shapes may actually be equivalent but
// constructed through different branches).
return failure();
}
rewriter.replaceOp(op, precedingTieShapeOp.result());
return success();
}
//===----------------------------------------------------------------------===//
// shape.tie_shape
//===----------------------------------------------------------------------===//
void TieShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
MLIRContext *context) {
insertGreedyPattern(patterns, context, elideDuplicateTieShapePattern);
}
//===----------------------------------------------------------------------===//
// shape.cast_compatible_shape
//===----------------------------------------------------------------------===//
void CastCompatibleShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
insertGreedyPattern(patterns, context, safeCastCompatibleShapePattern);
}
//===----------------------------------------------------------------------===//
// shape.get_ranked_shape
//===----------------------------------------------------------------------===//
void GetRankedShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
insertGreedyPattern(patterns, context, elideTiedGetRankedShapePattern);
insertGreedyPattern(patterns, context, elideDuplicateGetRankedShapePattern);
insertGreedyPattern(patterns, context, elideStaticGetRankedShapePattern);
}
//===----------------------------------------------------------------------===//
// shape.make_ranked_shape
//===----------------------------------------------------------------------===//
void MakeRankedShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
insertGreedyPattern(patterns, context, identityMakeRankedShapePattern);
}
//===----------------------------------------------------------------------===//
// shape.ranked_dim
//===----------------------------------------------------------------------===//
OpFoldResult RankedDimOp::fold(ArrayRef<Attribute> operand) {
auto rsType = shape().getType().cast<RankedShapeType>();
int index = getIndex();
if (!rsType.isDimDynamic(index)) {
auto dimSize = rsType.getStaticDim(index);
return IntegerAttr::get(getType(), dimSize);
}
return {};
}
void RankedDimOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
insertGreedyPattern(patterns, context, dynamicMakeRankedShapeDimPattern);
}
//===----------------------------------------------------------------------===//
// shape.ranked_dims
//===----------------------------------------------------------------------===//
void RankedDimsOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
insertGreedyPattern(patterns, context, expandRankedShapeDimsPattern);
}
//===----------------------------------------------------------------------===//
// Standard folding and canonicalization conversion patterns.
//===----------------------------------------------------------------------===//
// Since tie_shape ops are an identity, a pattern must exist for type conversion
// to properly propagate across the operand->result edge.
struct TieShapeTypeConversionPattern : public OpConversionPattern<TieShapeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TieShapeOp srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
TieShapeOpOperandAdaptor adaptor(operands);
Type operandType = adaptor.operand().getType();
if (operandType == srcOp.getType()) {
return failure();
}
rewriter.replaceOpWithNewOp<TieShapeOp>(srcOp, operandType,
adaptor.operand(), adaptor.shape());
return success();
}
};
void populateFoldConversionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns.insert<TieShapeTypeConversionPattern>(context);
insertConversionPattern(patterns, context, dynamicMakeRankedShapeDimPattern);
insertConversionPattern(patterns, context,
elideDuplicateGetRankedShapePattern);
insertConversionPattern(patterns, context, elideDuplicateTieShapePattern);
insertConversionPattern(patterns, context, elideTiedGetRankedShapePattern);
insertConversionPattern(patterns, context, expandRankedShapeDimsPattern);
insertConversionPattern(patterns, context, identityMakeRankedShapePattern);
insertConversionPattern(patterns, context, elideStaticGetRankedShapePattern);
insertConversionPattern(patterns, context, safeCastCompatibleShapePattern);
}
} // namespace Shape
} // namespace iree_compiler
} // namespace mlir