blob: 8d955059352ce7aeb5679d52954c013f4f877af7 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implements logic for lowering the StableHLO general dot op to the dot op.
#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h"
#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace mlir::iree_compiler::stablehlo {
#define GEN_PASS_DEF_DOTGENERALTODOT
#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h.inc"
namespace {
Value transposeReshape(Value arg, Location loc,
llvm::ArrayRef<int64_t> leftDims,
llvm::ArrayRef<int64_t> rightDims,
llvm::ArrayRef<int64_t> argShape,
PatternRewriter &rewriter) {
Type elementType = getElementTypeOrSelf(arg.getType());
int64_t leftSize = 1;
for (int64_t dim : leftDims) {
leftSize = (ShapedType::isDynamic(argShape[dim]) || leftSize < 0)
? ShapedType::kDynamic
: leftSize * argShape[dim];
}
int64_t rightSize = 1;
for (int64_t dim : rightDims) {
rightSize = (ShapedType::isDynamic(argShape[dim]) || rightSize < 0)
? ShapedType::kDynamic
: rightSize * argShape[dim];
}
// Generate the transpose permutation attribute.
auto transposePermutation =
llvm::to_vector<5>(llvm::concat<const int64_t>(leftDims, rightDims));
auto transposePermutationAttr =
rewriter.getDenseI64ArrayAttr(transposePermutation);
// Compute the resulting shape.
llvm::SmallVector<int64_t, 5> transposedShape;
for (int64_t val : transposePermutation) {
transposedShape.push_back(argShape[val]);
}
// If there are only a single pair of contracting dimensions and the output
// rank is two we can skip a needless reshape.
bool noReshape = transposedShape.size() == 2 && leftDims.size() == 1 &&
rightDims.size() == 1;
// Construct transpose. If no reshape is needed, we are done.
auto transposeType = RankedTensorType::get(transposedShape, elementType);
Value transposeResult = mlir::stablehlo::TransposeOp::create(
rewriter, loc, transposeType, arg, transposePermutationAttr);
if (noReshape)
return transposeResult;
// Return the final result.
auto reshapedType = RankedTensorType::get({leftSize, rightSize}, elementType);
if (reshapedType.hasStaticShape()) {
return mlir::stablehlo::ReshapeOp::create(rewriter, loc, reshapedType,
transposeResult);
}
SmallVector<Value> reshapeDims;
auto multiplyDynamicDims = [&](llvm::ArrayRef<int64_t> dims) -> Value {
Value dynamicSize = mlir::stablehlo::GetDimensionSizeOp::create(
rewriter, loc, arg, rewriter.getI64IntegerAttr(dims.front()));
Value dynamicSizeReshaped = mlir::stablehlo::ReshapeOp::create(
rewriter, loc, RankedTensorType::get({1}, rewriter.getI32Type()),
dynamicSize);
for (auto idx : dims.drop_front()) {
Value dim = mlir::stablehlo::GetDimensionSizeOp::create(
rewriter, loc, arg, rewriter.getI64IntegerAttr(idx));
Value dimReshaped = mlir::stablehlo::ReshapeOp::create(
rewriter, loc, RankedTensorType::get({1}, rewriter.getI32Type()),
dim);
dynamicSizeReshaped = mlir::stablehlo::MulOp::create(
rewriter, loc, dynamicSizeReshaped, dimReshaped);
}
return dynamicSizeReshaped;
};
if (leftSize < 0) {
reshapeDims.push_back(multiplyDynamicDims(leftDims));
} else {
reshapeDims.push_back(mlir::stablehlo::ConstantOp::create(
rewriter, loc, rewriter.getI32TensorAttr(leftSize)));
}
if (rightSize < 0) {
reshapeDims.push_back(multiplyDynamicDims(rightDims));
} else {
reshapeDims.push_back(mlir::stablehlo::ConstantOp::create(
rewriter, loc, rewriter.getI32TensorAttr(rightSize)));
}
Value reshapeDimsTensor = mlir::stablehlo::ConcatenateOp::create(
rewriter, loc, RankedTensorType::get({2}, rewriter.getI32Type()),
reshapeDims, rewriter.getI64IntegerAttr(0));
return mlir::stablehlo::DynamicReshapeOp::create(
rewriter, loc, reshapedType, transposeResult, reshapeDimsTensor);
}
Value processDotArg(Value arg, Location loc, ArrayRef<int64_t> contractDimsAttr,
bool outerDimsFirst, PatternRewriter &rewriter) {
auto shape = llvm::cast<ShapedType>(arg.getType()).getShape();
llvm::SmallVector<bool, 5> isOuterDim;
isOuterDim.resize(shape.size(), true);
// Compute the contract dimension ordering.
llvm::SmallVector<int64_t, 5> contractDims;
for (auto dim : contractDimsAttr) {
contractDims.push_back(dim);
isOuterDim[dim] = false;
}
// Compute the outer dimension orderings.
llvm::SmallVector<int64_t, 5> outerDims;
for (const auto &it : llvm::enumerate(isOuterDim)) {
if (it.value()) {
outerDims.push_back(it.index());
}
}
if (outerDimsFirst) {
return transposeReshape(arg, loc, outerDims, contractDims, shape, rewriter);
}
return transposeReshape(arg, loc, contractDims, outerDims, shape, rewriter);
}
struct GeneralDotRemoveBatch final
: OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::stablehlo::DotGeneralOp op,
PatternRewriter &rewriter) const override {
auto lhsTy = cast<ShapedType>(op.getLhs().getType());
auto rhsTy = cast<ShapedType>(op.getRhs().getType());
auto ty = cast<ShapedType>(op.getType());
if (!ty.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "does not have static shape");
}
auto dimNumbers = op.getDotDimensionNumbers();
if (dimNumbers.getLhsBatchingDimensions().size() != 1 ||
dimNumbers.getLhsBatchingDimensions().size() != 1) {
return rewriter.notifyMatchFailure(op, "non-unary batch dimension");
}
if (dimNumbers.getLhsBatchingDimensions().front() != 0 ||
dimNumbers.getRhsBatchingDimensions().front() != 0) {
return rewriter.notifyMatchFailure(op, "not first dim on lhs/rhs");
}
if (lhsTy.getDimSize(0) != 1 || rhsTy.getDimSize(0) != 1) {
return rewriter.notifyMatchFailure(op, "not unary batch size");
}
// We no longer include the batch dimension of 1.
llvm::SmallVector<int64_t> newLhsContractingDims;
for (auto dim : dimNumbers.getLhsContractingDimensions())
newLhsContractingDims.push_back(dim - 1);
llvm::SmallVector<int64_t> newRhsContractingDims;
for (auto dim : dimNumbers.getRhsContractingDimensions())
newRhsContractingDims.push_back(dim - 1);
auto lhs = mlir::stablehlo::ReshapeOp::create(
rewriter, op.getLoc(), lhsTy.clone(lhsTy.getShape().drop_front()),
op.getLhs());
auto rhs = mlir::stablehlo::ReshapeOp::create(
rewriter, op.getLoc(), rhsTy.clone(rhsTy.getShape().drop_front()),
op.getRhs());
auto newDimNumbers = mlir::stablehlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
/*lhsBatchingDimensions=*/{},
/*rhsBatchingDimensions=*/{},
/*lhsContractingDimensions=*/
newLhsContractingDims,
/*rhsContractingDimensions=*/
newRhsContractingDims);
auto dot = mlir::stablehlo::DotGeneralOp::create(
rewriter, op.getLoc(), ty.clone(ty.getShape().drop_front()), lhs, rhs,
newDimNumbers, op.getPrecisionConfigAttr(), op.getAlgorithmAttr());
rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, ty,
dot.getResult());
return success();
}
};
struct GeneralDotConvert final
: OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
using OpRewritePattern::OpRewritePattern;
// Attempts to lower a General Dot operator to a standard Dot operator.
// General dots include batching dimensions and can have collapsing
// dimensions along any axis. Inserting correctly arrange transpose and
// reshape operators organizes the tensors and allows the General Dot to be
// replaced with the standard Dot operator.
//
// Note: This requires an empty list of batch dimensions.
LogicalResult matchAndRewrite(mlir::stablehlo::DotGeneralOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto dotNumbers = op.getDotDimensionNumbers();
if (!dotNumbers.getLhsBatchingDimensions().empty() ||
!dotNumbers.getRhsBatchingDimensions().empty()) {
return failure();
}
ArrayAttr precisionConfig;
auto opPrecisionConfig = op.getPrecisionConfig();
if (opPrecisionConfig.has_value())
precisionConfig = *opPrecisionConfig;
auto resultTy = cast<ShapedType>(op.getType());
ArrayRef<int64_t> lhsContractingDims =
dotNumbers.getLhsContractingDimensions();
ArrayRef<int64_t> rhsContractingDims =
dotNumbers.getRhsContractingDimensions();
TypedValue<RankedTensorType> lhs = op.getLhs();
TypedValue<RankedTensorType> rhs = op.getRhs();
RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
if (!lhsTy || !rhsTy)
return failure();
// The StableHLO dot operator directly supports a vector dot product
// (two vectors reduce into a scalar) as well as a matrix vector
// product (a matrix and vector reduce into a vector) without any
// need for reshaping. We handle those special cases first, before
// entering the general logic that reduces into a matrix.
if (lhsTy.hasStaticShape() && rhsTy.hasStaticShape() &&
lhsContractingDims.size() == 1 && rhsContractingDims.size() == 1) {
if (lhsTy.getRank() == 1 && rhsTy.getRank() == 1) {
// Vector-vector, reduces into scalar.
assert(lhsContractingDims[0] == 0 && rhsContractingDims[0] == 0);
ShapedType newTy = RankedTensorType::get({}, resultTy.getElementType());
rewriter.replaceOpWithNewOp<mlir::stablehlo::DotOp>(op, newTy, lhs, rhs,
precisionConfig);
return success();
}
if (lhsTy.getRank() == 2 && rhsTy.getRank() == 1 &&
lhsContractingDims[0] == 1) {
// Matrix-vector, reduces into vector.
assert(rhsContractingDims[0] == 0);
ShapedType newTy = RankedTensorType::get({lhsTy.getShape()[0]},
resultTy.getElementType());
rewriter.replaceOpWithNewOp<mlir::stablehlo::DotOp>(op, newTy, lhs, rhs,
precisionConfig);
return success();
}
if (lhsTy.getRank() == 2 && rhsTy.getRank() == 2 &&
lhsContractingDims[0] == 1 && rhsContractingDims[0] == 0) {
// Matrix-matrix, reduces into matrix. Note that for dense cases, this
// rewriting rule simply provides a shortcut for what is to follow
// (modulo optimizing the trivial transpose/reshape operations). For
// sparse cases, however, this rewriting preserves the output sparsity
// that was explicitly given for the general dot operation.
Value newDotOp = mlir::stablehlo::DotOp::create(
rewriter, loc, resultTy, lhs, rhs, precisionConfig);
if (auto enc = sparse_tensor::getSparseTensorEncoding(resultTy)) {
newDotOp.setType(RankedTensorType::get(
resultTy.getShape(), resultTy.getElementType(), enc));
}
rewriter.replaceOp(op, newDotOp);
return success();
}
}
// For any sparse situation, don't use any of the following rules, since
// transposing and reshaping is not without cost. Instead, rely on the
// default linalg lowering that follows later in the pipeline.
if (sparse_tensor::hasAnySparseOperandOrResult(op))
return failure();
// Compute the, possibly, transposed-reshaped operands.
lhs = cast<mlir::TypedValue<mlir::RankedTensorType>>(processDotArg(
lhs, loc, lhsContractingDims, /*outerDimsFirst=*/true, rewriter));
rhs = cast<mlir::TypedValue<mlir::RankedTensorType>>(processDotArg(
rhs, loc, rhsContractingDims, /*outerDimsFirst=*/false, rewriter));
// Accept only static shaped types.
auto lhsShapeType = dyn_cast_or_null<ShapedType>(lhs.getType());
auto rhsShapeType = dyn_cast_or_null<ShapedType>(rhs.getType());
if (!lhsShapeType || !rhsShapeType)
return failure();
// Generate new dot operator on expanded types.
ShapedType newTy = RankedTensorType::get(
{lhsShapeType.getShape()[0], rhsShapeType.getShape()[1]},
resultTy.getElementType());
Value newDotOp = mlir::stablehlo::DotOp::create(rewriter, loc, newTy, lhs,
rhs, precisionConfig);
if (static_cast<int64_t>(lhsContractingDims.size()) ==
lhsTy.getRank() - 1 &&
static_cast<int64_t>(rhsContractingDims.size()) ==
rhsTy.getRank() - 1) {
rewriter.replaceOp(op, newDotOp);
return success();
}
// We can avoid all the computation below if we know the static shape.
if (resultTy.hasStaticShape()) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, resultTy,
newDotOp);
return success();
}
llvm::SmallVector<int64_t> staticDims;
llvm::SmallVector<Value> dynDims;
auto getDynamicDims = [&](Value arg,
llvm::ArrayRef<int64_t> contractingDims) {
RankedTensorType ty = llvm::cast<RankedTensorType>(arg.getType());
int index = 0;
for (int64_t contractingDim : contractingDims) {
for (; index < contractingDim; ++index) {
staticDims.push_back(ty.getDimSize(index));
Value dynDim = mlir::stablehlo::GetDimensionSizeOp::create(
rewriter, loc, arg, rewriter.getI64IntegerAttr(index));
Value dynDimReshaped = mlir::stablehlo::ReshapeOp::create(
rewriter, loc, RankedTensorType::get({1}, rewriter.getI32Type()),
dynDim);
dynDims.push_back(dynDimReshaped);
}
index++;
}
for (; index < ty.getRank(); ++index) {
staticDims.push_back(ty.getDimSize(index));
Value dynDim = mlir::stablehlo::GetDimensionSizeOp::create(
rewriter, loc, arg, rewriter.getI64IntegerAttr(index));
Value dynDimReshaped = mlir::stablehlo::ReshapeOp::create(
rewriter, loc, RankedTensorType::get({1}, rewriter.getI32Type()),
dynDim);
dynDims.push_back(dynDimReshaped);
}
};
getDynamicDims(op.getLhs(), lhsContractingDims);
getDynamicDims(op.getRhs(), rhsContractingDims);
Value reshapeDimsTensor = mlir::stablehlo::ConcatenateOp::create(
rewriter, loc,
RankedTensorType::get({static_cast<int64_t>(dynDims.size())},
rewriter.getI32Type()),
dynDims, rewriter.getI64IntegerAttr(0));
Value result = mlir::stablehlo::DynamicReshapeOp::create(
rewriter, loc,
RankedTensorType::get(staticDims, resultTy.getElementType()), newDotOp,
reshapeDimsTensor);
rewriter.replaceOp(op, result);
return success();
}
};
struct DotVectorOptimization final : OpRewritePattern<mlir::stablehlo::DotOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::stablehlo::DotOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value lhs = op.getLhs();
Value rhs = op.getRhs();
ShapedType lhsTy = cast<ShapedType>(lhs.getType());
ShapedType rhsTy = cast<ShapedType>(rhs.getType());
ShapedType resultTy = cast<ShapedType>(op.getType());
llvm::SmallVector<int64_t> dotShape;
if (lhsTy.getRank() == 2 && lhsTy.getDimSize(0) == 1) {
lhs = mlir::stablehlo::ReshapeOp::create(
b, lhsTy.clone({lhsTy.getDimSize(1)}), lhs);
} else if (lhsTy.getRank() == 2) {
dotShape.push_back(lhsTy.getDimSize(0));
}
if (rhsTy.getRank() == 2 && rhsTy.getDimSize(1) == 1) {
rhs = mlir::stablehlo::ReshapeOp::create(
b, rhsTy.clone({rhsTy.getDimSize(0)}), rhs);
} else if (rhsTy.getRank() == 2) {
dotShape.push_back(rhsTy.getDimSize(1));
}
if (lhs == op.getLhs() && rhs == op.getRhs()) {
return rewriter.notifyMatchFailure(op, "no vector reform available.");
}
auto newDot = mlir::stablehlo::DotOp::create(
b, resultTy.clone(dotShape), lhs, rhs, op.getPrecisionConfigAttr());
auto resultReshape =
mlir::stablehlo::ReshapeOp::create(b, resultTy, newDot);
rewriter.replaceOp(op, resultReshape);
return success();
}
};
struct DotGeneralToDot final : impl::DotGeneralToDotBase<DotGeneralToDot> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePreprocessingDotGeneralToDotPatterns(&getContext(), &patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
void populatePreprocessingDotGeneralToDotPatterns(mlir::MLIRContext *context,
RewritePatternSet *patterns,
PatternBenefit benefit) {
patterns
->add<GeneralDotConvert, GeneralDotRemoveBatch, DotVectorOptimization>(
context, benefit);
}
} // namespace mlir::iree_compiler::stablehlo