blob: abe9c60c7f7e487e3bf9ad4b8c7280e449ecafc1 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/SMLoc.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace iree_compiler {
namespace Shape {
//===----------------------------------------------------------------------===//
// shapex.tie_shape
//===----------------------------------------------------------------------===//
static LogicalResult verifyTieShapeOp(TieShapeOp op) {
// Validate shapedType and ranked_shape_type conservatively in this
// case (tie_shape supports arbitrary operand() but we constrain it if
// it is specific enough.
auto shapedType = op.operand().getType().dyn_cast<ShapedType>();
auto rsType = op.shape().getType().dyn_cast<RankedShapeType>();
if (shapedType && shapedType.hasRank() && rsType) {
for (auto it : llvm::zip(shapedType.getShape(), rsType.getAllDims())) {
if ((std::get<0>(it) != -1 && std::get<1>(it) != -1) &&
std::get<0>(it) != std::get<1>(it)) {
return op.emitOpError("dims must match between tensor and shape");
}
}
}
return success();
}
Value TieShapeOp::getViewSource() { return operand(); }
//===----------------------------------------------------------------------===//
// shapex.cast_compatible_shape
//===----------------------------------------------------------------------===//
static LogicalResult verifyCastCompatibleShapeOp(CastCompatibleShapeOp op) {
if (op.operands().empty()) {
return op.emitOpError() << "Must have at least one operand";
}
auto resultRs = op.result().getType().dyn_cast<RankedShapeType>();
if (resultRs) {
// TODO(laurenzo): Expand this to check true compatibility instead of
// just equality.
// Casting to a ranked shape.
for (auto operandType : op.getOperandTypes()) {
auto operandRs = operandType.dyn_cast<RankedShapeType>();
if (!operandRs || operandRs != resultRs) {
return op.emitOpError()
<< "Incompatible static shape cast: " << operandRs << " -> "
<< resultRs;
}
}
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// shapex.get_ranked_shape
//===----------------------------------------------------------------------===//
void GetRankedShapeOp::build(OpBuilder &builder, OperationState &result,
Value operand) {
auto rankedOperandType = operand.getType().dyn_cast<RankedTensorType>();
if (rankedOperandType) {
result.types.push_back(RankedShapeType::get(rankedOperandType.getShape(),
builder.getContext()));
}
result.addOperands(operand);
}
static LogicalResult verifyGetRankedShapeOp(GetRankedShapeOp op) {
auto tensorType = op.operand().getType().cast<TensorType>();
auto rsType = op.shape().getType().cast<RankedShapeType>();
if (tensorType.getRank() != rsType.getRank()) {
return op.emitOpError("operand and result must be of same rank");
}
auto rsDims = rsType.getAllDims();
if (!std::equal(rsDims.begin(), rsDims.end(),
tensorType.getShape().begin())) {
return op.emitOpError("operand tensor and result shape must be equal");
}
return success();
}
//===----------------------------------------------------------------------===//
// shapex.const_ranked_shape
//===----------------------------------------------------------------------===//
void ConstRankedShapeOp::build(OpBuilder &builder, OperationState &result,
Type type) {
assert(type.cast<RankedShapeType>().isFullyStatic());
result.types.push_back(type);
}
static LogicalResult verifyConstRankedShapeOp(ConstRankedShapeOp op) {
auto rsType = op.result().getType().dyn_cast<RankedShapeType>();
if (!rsType || !rsType.isFullyStatic()) {
return op.emitOpError("must be a fully static ranked_shape");
}
return success();
}
void ConstRankedShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
auto rankedShape = result().getType().cast<RankedShapeType>();
SmallString<32> buffer;
llvm::raw_svector_ostream os(buffer);
os << "rs";
interleave(
rankedShape.getAllDims(), os, [&](int64_t dim) { os << dim; }, "_");
setNameFn(getResult(), os.str());
}
//===----------------------------------------------------------------------===//
// shapex.make_ranked_shape
//===----------------------------------------------------------------------===//
static LogicalResult verifyMakeRankedShapeOp(MakeRankedShapeOp op) {
if (op.getRankedShapeType().getNumDynamicDims() != op.getNumOperands()) {
return op.emitError()
<< "number of dynamic dims doesn't match number of operands";
}
return success();
}
//===----------------------------------------------------------------------===//
// shapex.ranked_dim
//===----------------------------------------------------------------------===//
void RankedDimOp::build(OpBuilder &builder, OperationState &result,
Type dimType, Value shape, int index) {
result.addOperands(shape);
result.addAttribute("index",
builder.getIntegerAttr(builder.getIndexType(), index));
result.addTypes(dimType);
}
void RankedDimOp::build(OpBuilder &builder, OperationState &result, Value shape,
int index) {
RankedDimOp::build(builder, result, builder.getIndexType(), shape, index);
}
ParseResult parseRankedDimOp(OpAsmParser &parser, OperationState &state) {
OpAsmParser::OperandType operand;
Type operandType;
IntegerAttr indexAttr;
Type indexType = parser.getBuilder().getIndexType();
SmallVector<Type, 1> resultTypes;
if (parser.parseOperand(operand) || parser.parseLSquare() ||
parser.parseAttribute(indexAttr, indexType, "index", state.attributes) ||
parser.parseRSquare() || parser.parseColonType(operandType) ||
parser.parseArrowTypeList(resultTypes) || resultTypes.empty() ||
parser.resolveOperand(operand, operandType, state.operands)) {
return failure();
}
auto rsType = operandType.dyn_cast<RankedShapeType>();
if (!rsType) {
return parser.emitError(parser.getNameLoc());
}
state.types.push_back(resultTypes[0]);
return success();
}
static void printRankedDimOp(OpAsmPrinter &p, RankedDimOp op) {
p << op.getOperationName() << " ";
p.printOperand(op.shape());
p << "[" << op.getIndex() << "]";
p << " : ";
p.printType(op.shape().getType());
p << " -> ";
p.printType(op.getType());
}
static LogicalResult verifyRankedDimOp(RankedDimOp op) {
auto rsType = op.shape().getType().dyn_cast<RankedShapeType>();
auto index = static_cast<int64_t>(op.getIndex());
if (index < 0 || index >= rsType.getRank()) {
return op.emitOpError() << "index out of bounds of shape";
}
return success();
}
//===----------------------------------------------------------------------===//
// shapex.ranked_dims
//===----------------------------------------------------------------------===//
void RankedDimsOp::build(OpBuilder &builder, OperationState &result,
Type dimType, Value shape) {
result.addOperands(shape);
auto rankedShapeType = shape.getType().cast<RankedShapeType>();
for (int i = 0; i < rankedShapeType.getRank(); ++i) {
result.types.push_back(dimType);
}
}
void RankedDimsOp::build(OpBuilder &builder, OperationState &result,
Value shape) {
RankedDimsOp::build(builder, result, builder.getIndexType(), shape);
}
//===----------------------------------------------------------------------===//
// shapex.gather_extents
//===----------------------------------------------------------------------===//
// Helper for accessing attributes for inferReturnTypes callback.
// That helper gives the attributes as an `ArrayRef<NamedAttribute>` which isn't
// the nicest form.
template <typename Attr>
static Attr getRequiredAttr(DictionaryAttr attributes, StringRef name) {
return attributes.get(name).cast<Attr>();
}
/*static*/ SmallVector<int64_t, 6> GatherExtentsOp::getConcatenatedExtents(
ValueRange values) {
SmallVector<int64_t, 6> ret;
for (auto type : values.getTypes()) {
auto rankedShape = type.cast<RankedShapeType>();
ret.append(rankedShape.getAllDims().begin(),
rankedShape.getAllDims().end());
}
return ret;
}
static LogicalResult verifyGatherExtentsOp(GatherExtentsOp op) {
int64_t totalExtents = 0;
for (Type type : op.shapes().getTypes()) {
totalExtents += type.cast<RankedShapeType>().getRank();
}
for (int64_t index : op.indices().getValues<int64_t>()) {
if (index >= totalExtents) {
return op.emitError() << "index " << index
<< " exceeds total number of extents of operands ("
<< totalExtents << ")";
}
}
return success();
}
LogicalResult GatherExtentsOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// We can't infer the DimType of the result if there are no operands.
// If a user requires this, then they should manually specify the return type.
// We could in theory use an index type here (the default).
assert(!operands.empty() && "inferring return type for empty operands");
auto indices = getRequiredAttr<DenseIntElementsAttr>(attributes, "indices")
.getValues<int64_t>();
auto inputExtents = getConcatenatedExtents(operands);
SmallVector<int64_t, 6> resultExtents;
for (auto index : indices) {
resultExtents.push_back(inputExtents[index]);
}
inferredReturnTypes.push_back(RankedShapeType::get(resultExtents, context));
return success();
}
//===----------------------------------------------------------------------===//
// shapex.to_extent_tensor
//===----------------------------------------------------------------------===//
LogicalResult ToExtentTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto inputType = operands[0].getType().cast<RankedShapeType>();
inferredReturnTypes.push_back(
RankedTensorType::get({inputType.getRank()}, IndexType::get(context)));
return success();
}
//===----------------------------------------------------------------------===//
// shapex.from_extent_tensor
//===----------------------------------------------------------------------===//
static bool isValidTensorOfExtents(RankedTensorType type) {
// If the tensor of extents is not static shapes, that would imply that the
// tensor whose shape it is describing is unranked.
return type.getRank() == 1 && type.hasStaticShape();
}
LogicalResult FromExtentTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto inputType = operands[0].getType().dyn_cast<RankedTensorType>();
if (!inputType || !isValidTensorOfExtents(inputType)) {
return emitOptionalError(location, "Invalid input type, ",
operands[0].getType(),
", for from_extent_tensor op");
}
SmallVector<int64_t, 6> extents(inputType.getDimSize(0),
static_cast<int64_t>(-1));
inferredReturnTypes.push_back(RankedShapeType::get(extents, context));
return success();
}
bool FromExtentTensorOp::isCompatibleReturnTypes(ArrayRef<Type> lhs,
ArrayRef<Type> rhs) {
auto lhsRs = lhs[0].cast<RankedShapeType>();
auto rhsRs = rhs[0].cast<RankedShapeType>();
return succeeded(
verifyCompatibleShape(lhsRs.getAllDims(), rhsRs.getAllDims()));
}
} // namespace Shape
} // namespace iree_compiler
} // namespace mlir
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.cpp.inc"