blob: 5f12d30a301fc78d3dd7d21ae3a066779c911ce6 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.h"
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Optional.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
using llvm::None;
using llvm::Optional;
using llvm::SmallVector;
using namespace mlir::iree_compiler::Shape;
namespace mlir {
namespace mhlo {
namespace {
template <typename HloOp>
Value rewriteXlaNaryElementwiseOpShape(RankedShapeType resultShape, HloOp op,
OpBuilder &builder) {
if (!op) return nullptr;
SmallVector<Value, 4> inputOperands(op.getOperands());
return buildCastInputsToResultShape(op.getLoc(), resultShape, inputOperands,
builder);
}
Value rewriteXlaDotOpShape(RankedShapeType resultRs, DotOp dotOp,
OpBuilder &builder) {
auto lhsType = dotOp.lhs().getType().dyn_cast<RankedTensorType>();
auto rhsType = dotOp.rhs().getType().dyn_cast<RankedTensorType>();
auto resultType = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
if (!lhsType || !rhsType || !resultType) return nullptr;
// Shape transfer function:
// [n] dot [n] -> scalar
// [m x k] dot [k] -> [m]
// [m x k] dot [k x n] -> [m x n]
auto lhsRank = lhsType.getRank();
auto rhsRank = rhsType.getRank();
auto resultRank = resultType.getRank();
auto loc = dotOp.getLoc();
if (lhsRank == 1 && rhsRank == 1 && resultRank == 0) {
auto scalarShape = RankedShapeType::getChecked({}, loc);
return builder.create<ConstRankedShapeOp>(dotOp.getLoc(), scalarShape);
} else if (lhsRank == 2 && rhsRank == 1 && resultRank == 1) {
SmallVector<Value, 1> dynamicDims;
if (resultRs.isDimDynamic(0)) {
auto lhsGetShape = builder.create<GetRankedShapeOp>(
loc, RankedShapeType::get(lhsType.getShape(), builder.getContext()),
dotOp.lhs());
auto getMDim =
builder.create<RankedDimOp>(loc, builder.getIndexType(), lhsGetShape,
builder.getI64IntegerAttr(0));
dynamicDims.push_back(getMDim);
}
return builder.create<MakeRankedShapeOp>(loc, resultRs, dynamicDims);
} else if (lhsRank == 2 && rhsRank == 2 && resultRank == 2) {
SmallVector<Value, 2> dynamicDims;
if (resultRs.isDimDynamic(0)) {
auto lhsGetShape = builder.create<GetRankedShapeOp>(
loc, RankedShapeType::get(lhsType.getShape(), builder.getContext()),
dotOp.lhs());
auto getMDim =
builder.create<RankedDimOp>(loc, builder.getIndexType(), lhsGetShape,
builder.getI64IntegerAttr(0));
dynamicDims.push_back(getMDim);
}
if (resultRs.isDimDynamic(1)) {
auto rhsGetShape = builder.create<GetRankedShapeOp>(
loc, RankedShapeType::get(rhsType.getShape(), builder.getContext()),
dotOp.rhs());
auto getNDim =
builder.create<RankedDimOp>(loc, builder.getIndexType(), rhsGetShape,
builder.getI64IntegerAttr(1));
dynamicDims.push_back(getNDim);
}
return builder.create<MakeRankedShapeOp>(loc, resultRs, dynamicDims);
} else {
return nullptr;
}
}
Value rewriteReduce(RankedShapeType resultShape, ReduceOp reduceOp,
OpBuilder &builder) {
Location loc = reduceOp.getLoc();
// Get a common operand shape.
Value operandShape;
SmallVector<Value, 4> operandShapes;
RankedShapeType operandRs;
for (auto operand : reduceOp.inputs()) {
auto shape = builder.create<GetRankedShapeOp>(loc, operand);
operandRs = shape.getRankedShape();
operandShapes.push_back(shape);
}
assert(!operandShapes.empty());
if (operandShapes.size() == 1) {
// Single operand.
operandShape = operandShapes.front();
} else {
// Multiple operands must be compatible.
operandShape =
builder.create<CastCompatibleShapeOp>(loc, operandRs, operandShapes);
}
// Map reduction dims onto operand dimensions.
SmallVector<bool, 4> isDimReduced;
isDimReduced.resize(operandRs.getRank());
for (const auto &apIntValue : reduceOp.dimensions().getIntValues()) {
auto intValue = apIntValue.getZExtValue();
assert(intValue < isDimReduced.size());
isDimReduced[intValue] = true;
}
// Map operand -> result dynamic dims.
assert(resultShape.getRank() ==
(operandRs.getRank() - reduceOp.dimensions().getNumElements()));
SmallVector<Value, 4> resultDims;
for (unsigned operandDimIndex = 0, e = isDimReduced.size();
operandDimIndex < e; ++operandDimIndex) {
unsigned resultDimIndex = resultDims.size();
// Skip reduced operand indices and non-dynamic result indices.
if (isDimReduced[operandDimIndex] ||
!resultShape.isDimDynamic(resultDimIndex))
continue;
resultDims.push_back(
builder.create<RankedDimOp>(loc, operandShape, operandDimIndex));
}
return builder.create<MakeRankedShapeOp>(loc, resultShape, resultDims);
}
// NOTE: This op is an HLO interloper and is just here until a corresponding
// HLO is created. As such, it is included in this file even though not HLO
// currently.
Value rewriteShapexRankedBroadcastInDim(RankedShapeType resultShape,
RankedBroadcastInDimOp bidOp,
OpBuilder &builder) {
if (!bidOp) return nullptr;
return bidOp.result_shape();
}
Value rewriteShapexIota(RankedShapeType resultShape,
iree_compiler::Shape::IotaOp iotaOp,
OpBuilder &builder) {
if (!iotaOp) return nullptr;
return iotaOp.result_shape();
}
Value rewriteTranspose(RankedShapeType resultShape, TransposeOp transposeOp,
OpBuilder &builder) {
if (!transposeOp) return nullptr;
auto loc = transposeOp.getLoc();
auto operandType =
transposeOp.operand().getType().dyn_cast<RankedTensorType>();
if (!operandType) return nullptr;
auto operandShapeValue = builder.create<GetRankedShapeOp>(
loc, RankedShapeType::get(operandType.getShape(), builder.getContext()),
transposeOp.operand());
SmallVector<int64_t, 4> perm;
for (const auto &permValue : transposeOp.permutation().getIntValues()) {
perm.push_back(permValue.getSExtValue());
}
assert(perm.size() == resultShape.getRank());
// Map the dynamic dims.
SmallVector<Value, 4> dynamicDims;
for (int64_t i = 0, e = resultShape.getRank(); i < e; ++i) {
if (!resultShape.isDimDynamic(i)) continue;
int64_t operandDim = perm[i];
auto dimValue = builder.create<RankedDimOp>(
loc, builder.getIndexType(), operandShapeValue,
builder.getI64IntegerAttr(operandDim));
dynamicDims.push_back(dimValue);
}
return builder.create<MakeRankedShapeOp>(loc, resultShape, dynamicDims);
}
// Returns a value of type `!shapex.ranked_shape` for the input value.
static Value getRankedShapeAsValue(Value v, OpBuilder &builder, Location loc) {
assert(v.getType().isa<TensorType>());
auto type = v.getType().dyn_cast<RankedTensorType>();
if (!type) {
return nullptr;
}
return builder.create<GetRankedShapeOp>(
loc, RankedShapeType::get(type.getShape(), builder.getContext()), v);
}
// Returns a value representing the extent of dimension `dim`.
static Value getExtent(Value v, int64_t dim, OpBuilder &builder, Location loc) {
return builder.create<RankedDimOp>(loc, v, dim);
}
Value rewriteDotGeneral(RankedShapeType resultShape, DotGeneralOp op,
OpBuilder &builder) {
Location loc = op.getLoc();
auto lhsShape = getRankedShapeAsValue(op.lhs(), builder, loc);
auto rhsShape = getRankedShapeAsValue(op.rhs(), builder, loc);
if (!lhsShape || !rhsShape) {
return nullptr;
}
auto getFreeDims = [&](ArrayRef<int64_t> batchDims,
ArrayRef<int64_t> contractingDims, int64_t rank) {
llvm::BitVector freeDims(rank, true);
for (int64_t dim : batchDims) {
freeDims.reset(dim);
}
for (int64_t dim : contractingDims) {
freeDims.reset(dim);
}
SmallVector<int64_t, 4> result;
for (auto bitIndex : freeDims.set_bits()) {
result.push_back(bitIndex);
}
return result;
};
auto lhsRankedShape = lhsShape.getType().cast<RankedShapeType>();
auto rhsRankedShape = rhsShape.getType().cast<RankedShapeType>();
auto dotDimensions = op.dot_dimension_numbers();
auto lhsFreeDims = getFreeDims(
llvm::to_vector<4>(
dotDimensions.lhs_batching_dimensions().getValues<int64_t>()),
llvm::to_vector<4>(
dotDimensions.lhs_contracting_dimensions().getValues<int64_t>()),
lhsRankedShape.getRank());
auto rhsFreeDims = getFreeDims(
llvm::to_vector<4>(
dotDimensions.rhs_batching_dimensions().getValues<int64_t>()),
llvm::to_vector<4>(
dotDimensions.rhs_contracting_dimensions().getValues<int64_t>()),
rhsRankedShape.getRank());
SmallVector<Value, 6> outputExtents;
for (int64_t dim :
dotDimensions.lhs_batching_dimensions().getValues<int64_t>()) {
// TODO(silvasean): Add a version of MakeRankedShapeOp that takes
// all dimensions. Having to constantly check if a dim is dynamic
// upon construction is a waste of time, more testing burden, etc.
// We can easily canonicalize to the more constrained one.
if (lhsRankedShape.isDimDynamic(dim)) {
outputExtents.push_back(getExtent(lhsShape, dim, builder, loc));
}
}
for (int64_t dim : lhsFreeDims) {
if (lhsRankedShape.isDimDynamic(dim)) {
outputExtents.push_back(getExtent(lhsShape, dim, builder, loc));
}
}
for (int64_t dim : rhsFreeDims) {
if (rhsRankedShape.isDimDynamic(dim)) {
outputExtents.push_back(getExtent(rhsShape, dim, builder, loc));
}
}
return builder.create<MakeRankedShapeOp>(loc, resultShape, outputExtents);
}
Value rewriteDynamicReshape(RankedShapeType resultShape, DynamicReshapeOp op,
OpBuilder &builder) {
return builder.create<FromExtentTensorOp>(op.getLoc(), resultShape,
op.output_shape());
}
Value rewriteSelectOp(RankedShapeType resultShape, SelectOp selectOp,
OpBuilder &builder) {
// TODO: broadcast case
if (!selectOp) return nullptr;
auto loc = selectOp.getLoc();
auto trueType = selectOp.on_true().getType().dyn_cast<RankedTensorType>();
auto falseType = selectOp.on_false().getType().dyn_cast<RankedTensorType>();
auto predType = selectOp.pred().getType().dyn_cast<RankedTensorType>();
if (!trueType || !falseType || !predType) return nullptr;
assert(trueType.getRank() == resultShape.getRank() &&
trueType.getRank() == falseType.getRank() &&
trueType.getRank() == predType.getRank());
auto operandShapeValue = builder.create<GetRankedShapeOp>(
loc, RankedShapeType::get(trueType.getShape(), builder.getContext()),
selectOp.on_true());
// Map the dynamic dims.
SmallVector<Value, 4> dynamicDims;
for (int64_t i = 0, e = resultShape.getRank(); i < e; ++i) {
if (!resultShape.isDimDynamic(i)) continue;
auto dimValue = builder.create<RankedDimOp>(loc, builder.getIndexType(),
operandShapeValue,
builder.getI64IntegerAttr(i));
dynamicDims.push_back(dimValue);
}
return builder.create<MakeRankedShapeOp>(loc, resultShape, dynamicDims);
}
Value rewriteTorchIndexSelect(RankedShapeType resultShape,
TorchIndexSelectOp torchIndexSelectOp,
OpBuilder &builder) {
if (!torchIndexSelectOp) return nullptr;
auto loc = torchIndexSelectOp.getLoc();
int64_t resultShapeRank = resultShape.getRank();
auto paramsType =
torchIndexSelectOp.input().getType().dyn_cast<RankedTensorType>();
auto indicesType =
torchIndexSelectOp.index().getType().dyn_cast<RankedTensorType>();
if (!paramsType || !indicesType) {
return nullptr;
}
auto axis = torchIndexSelectOp.dim();
auto batchDim = torchIndexSelectOp.batch_dims();
int64_t paramsRank = paramsType.getRank();
int64_t indicesRank = indicesType.getRank();
std::vector<int64_t> shape(paramsType.getShape());
int64_t axisValue = axis;
int64_t batchDimValue = batchDim;
// For neg axis values, we wrap around params,
// e.g. axis = -1 => params[:-1]
if (axisValue < 0) {
axisValue += paramsRank;
}
if (batchDimValue < 0) {
batchDimValue += indicesRank;
}
// params must be at least rank axis + 1
if (paramsRank < axisValue + 1) {
return nullptr;
}
auto paramsShapeValue = builder.create<GetRankedShapeOp>(
loc, RankedShapeType::get(paramsType.getShape(), builder.getContext()),
torchIndexSelectOp.input());
auto indicesShapeValue = builder.create<GetRankedShapeOp>(
loc, RankedShapeType::get(indicesType.getShape(), builder.getContext()),
torchIndexSelectOp.index());
SmallVector<Value, 4> dynamicDims;
#define GENERATE_RANKED_DIM_OP(value, index) \
do { \
auto dimValue = builder.create<RankedDimOp>( \
loc, builder.getIndexType(), value, builder.getI64IntegerAttr(index)); \
dynamicDims.push_back(dimValue); \
} while (0)
if (indicesRank == 0) {
// Scalar indices (output is rank(params) - 1).
if (resultShapeRank != paramsRank - 1) {
return nullptr;
}
// params.shape[:axis] + params.shape[axis+1:]
for (int64_t i = 0; i < paramsRank; ++i) {
if ((i == axisValue) || (i < axisValue && !resultShape.isDimDynamic(i)) ||
(i > axisValue && !resultShape.isDimDynamic(i - 1)))
continue;
GENERATE_RANKED_DIM_OP(paramsShapeValue, i);
}
} else if (indicesRank == 1) {
// Vector indices (output is rank(params)).
// Copy indices.shape into params.shape[axis]
if (resultShapeRank != paramsRank) {
return nullptr;
}
// params.shape[:axis] + indices.shape[batch_dims:]
// + params.shape[indicesRank-batchDim+axisValue:]
int resultShapeIndex = 0;
// params.shape[:axis]
for (int64_t i = 0; i < axisValue; ++i) {
if (!resultShape.isDimDynamic(resultShapeIndex++)) continue;
GENERATE_RANKED_DIM_OP(paramsShapeValue, i);
}
// indices.shape[:batchDim]
for (int64_t i = batchDimValue;
i < indicesRank && resultShapeIndex < resultShapeRank; ++i) {
if (!resultShape.isDimDynamic(resultShapeIndex++)) continue;
GENERATE_RANKED_DIM_OP(indicesShapeValue, i);
}
// params.shape[indicesRank-batchDim+axisValue:]
// resultShapeIndex == indicesRank-batchDim+axisValue
for (int64_t i = resultShapeIndex; i < resultShapeRank; ++i) {
if (!resultShape.isDimDynamic(resultShapeIndex++)) continue;
GENERATE_RANKED_DIM_OP(paramsShapeValue, i);
}
} else {
// params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis +
// 1:]
// The expected rank is (paramsRank-1) + (indicesRank-batchDim)
auto expectedRank = paramsRank - 1 + indicesRank - batchDimValue;
if (resultShapeRank != expectedRank) {
return nullptr;
}
int resultShapeIndex = 0;
for (int64_t i = 0; i < axisValue; ++i) {
if (!resultShape.isDimDynamic(resultShapeIndex++)) continue;
GENERATE_RANKED_DIM_OP(paramsShapeValue, i);
}
for (int64_t i = batchDimValue; i < indicesRank; ++i) {
if (!resultShape.isDimDynamic(resultShapeIndex++)) continue;
GENERATE_RANKED_DIM_OP(indicesShapeValue, i);
}
for (int64_t i = axisValue + 1;
i < paramsRank && resultShapeIndex < resultShapeRank; ++i) {
if (!resultShape.isDimDynamic(resultShapeIndex++)) continue;
GENERATE_RANKED_DIM_OP(paramsShapeValue, i);
}
}
#undef GENERATE_RANKED_DIM_OP
return builder.create<MakeRankedShapeOp>(loc, resultShape, dynamicDims);
}
} // namespace
// Creates a custom op shape builder for XLA-HLO ops that are not otherwise
// supported through traits or other declarative means.
void populateXlaHloCustomOpShapeBuilder(CustomOpShapeBuilderList &builders) {
auto &b = builders.make<CallbackCustomOpShapeBuilder>();
// NOTE: Most of these *should not* be "custom ops". They should be coming
// from declarative shape information, but that doesn't exist yet.
#define INSERT_EW_OP(OpTy) \
b.insertOpRankedShapeBuilder<OpTy>(rewriteXlaNaryElementwiseOpShape<OpTy>);
INSERT_EW_OP(AddOp);
INSERT_EW_OP(Atan2Op);
INSERT_EW_OP(DivOp);
INSERT_EW_OP(MaxOp);
INSERT_EW_OP(MinOp);
INSERT_EW_OP(MulOp);
INSERT_EW_OP(PowOp);
INSERT_EW_OP(RemOp);
INSERT_EW_OP(ShiftLeftOp);
INSERT_EW_OP(ShiftRightArithmeticOp);
INSERT_EW_OP(ShiftRightLogicalOp);
INSERT_EW_OP(SubOp);
INSERT_EW_OP(CompareOp);
INSERT_EW_OP(ClampOp);
b.insertOpRankedShapeBuilder<SelectOp>(rewriteSelectOp);
b.insertOpRankedShapeBuilder<DotOp>(rewriteXlaDotOpShape);
b.insertOpRankedShapeBuilder<RankedBroadcastInDimOp>(
rewriteShapexRankedBroadcastInDim);
b.insertOpRankedShapeBuilder<iree_compiler::Shape::IotaOp>(rewriteShapexIota);
b.insertOpRankedShapeBuilder<ReduceOp>(rewriteReduce);
b.insertOpRankedShapeBuilder<TransposeOp>(rewriteTranspose);
b.insertOpRankedShapeBuilder<mhlo::DotGeneralOp>(rewriteDotGeneral);
b.insertOpRankedShapeBuilder<mhlo::DynamicReshapeOp>(rewriteDynamicReshape);
b.insertOpRankedShapeBuilder<mhlo::TorchIndexSelectOp>(
rewriteTorchIndexSelect);
}
} // namespace mhlo
} // namespace mlir