blob: 2c8f24a1a42d1647c0ede977f698f0c6cdbcbdfa [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/SMLoc.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/STLExtras.h"
namespace mlir {
namespace iree_compiler {
namespace Shape {
//===----------------------------------------------------------------------===//
// Canonicalization
//===----------------------------------------------------------------------===//
class ElideTiedGetRankedShapePattern
: public OpRewritePattern<GetRankedShapeOp> {
using OpRewritePattern::OpRewritePattern;
PatternMatchResult matchAndRewrite(GetRankedShapeOp op,
PatternRewriter &rewriter) const override {
// If the immediate predecessor is a TieShapeOp, then this op can be
// erased in favor of the input to the tie op.
if (!matchPattern(op.operand(), m_Op<TieShapeOp>())) {
return matchFailure();
}
auto tieOp = cast<TieShapeOp>(op.operand().getDefiningOp());
rewriter.replaceOp(op, tieOp.shape(), op.operand());
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// iree.tie_shape
//===----------------------------------------------------------------------===//
static ParseResult parseTieShapeOp(OpAsmParser &parser, OperationState &state) {
SmallVector<OpAsmParser::OperandType, 2> operands;
SmallVector<Type, 2> operandTypes;
if (parser.parseOperandList(operands) ||
parser.parseColonTypeList(operandTypes) ||
parser.parseOptionalAttrDict(state.attributes) ||
parser.resolveOperands(operands, operandTypes, parser.getNameLoc(),
state.operands)) {
return failure();
}
// The result type is the same as the first operand.
if (state.operands.empty()) return failure();
state.types.push_back(state.operands.front().getType());
return success();
}
static void printTieShapeOp(OpAsmPrinter &p, TieShapeOp op) {
p << op.getOperationName() << " ";
p.printOperands(op.getOperands());
p << " : ";
interleaveComma(op.getOperandTypes(), p);
p.printOptionalAttrDict(op.getOperation()->getAttrs());
}
static LogicalResult verifyTieShapeOp(TieShapeOp op) {
if (op.operand().getType() != op.result().getType()) {
return op.emitOpError("operand and result must be the same type");
}
// tie_shape currently only supports ranked tensors.
auto rankedTensorType = op.operand().getType().dyn_cast<RankedTensorType>();
auto rsType = op.shape().getType().dyn_cast<RankedShapeType>();
if (!rankedTensorType || !rsType) {
return op.emitOpError("currently only ranked tensors are supported");
}
SmallVector<int64_t, 4> rsDims;
rsType.getAllDims(rsDims);
if (!rankedTensorType.getShape().equals(rsDims)) {
return op.emitOpError("dims must match between tensor and shape");
}
return success();
}
//===----------------------------------------------------------------------===//
// iree.get_ranked_shape
//===----------------------------------------------------------------------===//
static ParseResult parseGetRankedShapeOp(OpAsmParser &parser,
OperationState &state) {
OpAsmParser::OperandType operandType;
Type resultType;
return failure(
parser.parseOperand(operandType) || parser.parseColonType(resultType) ||
parser.parseOptionalArrowTypeList(state.types) ||
parser.resolveOperand(operandType, resultType, state.operands));
}
static void printGetRankedShapeOp(OpAsmPrinter &p, GetRankedShapeOp op) {
p << op.getOperationName() << " ";
p.printOperand(op.operand());
p << " : ";
p.printType(op.operand().getType());
p << " -> ";
p.printType(op.shape().getType());
}
static LogicalResult verifyGetRankedShapeOp(GetRankedShapeOp op) {
auto tensorType = op.operand().getType().cast<TensorType>();
auto rsType = op.shape().getType().cast<RankedShapeType>();
if (tensorType.getRank() != rsType.getRank()) {
return op.emitOpError("operand and result must be of same rank");
}
SmallVector<int64_t, 4> rsDims;
rsType.getAllDims(rsDims);
if (!std::equal(rsDims.begin(), rsDims.end(),
tensorType.getShape().begin())) {
return op.emitOpError("operand tensor and result shape must be equal");
}
return success();
}
void GetRankedShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<ElideTiedGetRankedShapePattern>(context);
}
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.cpp.inc"
} // namespace Shape
} // namespace iree_compiler
} // namespace mlir