| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Implements logic for lowering the StableHLO general dot op to the dot op. |
| |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h" |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/ImplicitLocOpBuilder.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "stablehlo/dialect/StablehloOps.h" |
| |
| namespace mlir::iree_compiler::stablehlo { |
| |
| #define GEN_PASS_DEF_DOTGENERALTODOT |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h.inc" |
| |
| namespace { |
| Value transposeReshape(Value arg, Location loc, |
| llvm::ArrayRef<int64_t> leftDims, |
| llvm::ArrayRef<int64_t> rightDims, |
| llvm::ArrayRef<int64_t> argShape, |
| PatternRewriter &rewriter) { |
| Type elementType = getElementTypeOrSelf(arg.getType()); |
| |
| int64_t leftSize = 1; |
| for (int64_t dim : leftDims) { |
| leftSize = (ShapedType::isDynamic(argShape[dim]) || leftSize < 0) |
| ? ShapedType::kDynamic |
| : leftSize * argShape[dim]; |
| } |
| |
| int64_t rightSize = 1; |
| for (int64_t dim : rightDims) { |
| rightSize = (ShapedType::isDynamic(argShape[dim]) || rightSize < 0) |
| ? ShapedType::kDynamic |
| : rightSize * argShape[dim]; |
| } |
| |
| // Generate the transpose permutation attribute. |
| auto transposePermutation = |
| llvm::to_vector<5>(llvm::concat<const int64_t>(leftDims, rightDims)); |
| |
| auto transposePermutationAttr = |
| rewriter.getDenseI64ArrayAttr(transposePermutation); |
| |
| // Compute the resulting shape. |
| llvm::SmallVector<int64_t, 5> transposedShape; |
| for (int64_t val : transposePermutation) { |
| transposedShape.push_back(argShape[val]); |
| } |
| |
| // If there are only a single pair of contracting dimensions and the output |
| // rank is two we can skip a needless reshape. |
| bool noReshape = transposedShape.size() == 2 && leftDims.size() == 1 && |
| rightDims.size() == 1; |
| |
| // Construct transpose. If no reshape is needed, we are done. |
| auto transposeType = RankedTensorType::get(transposedShape, elementType); |
| Value transposeResult = mlir::stablehlo::TransposeOp::create( |
| rewriter, loc, transposeType, arg, transposePermutationAttr); |
| if (noReshape) |
| return transposeResult; |
| |
| // Return the final result. |
| auto reshapedType = RankedTensorType::get({leftSize, rightSize}, elementType); |
| |
| if (reshapedType.hasStaticShape()) { |
| return mlir::stablehlo::ReshapeOp::create(rewriter, loc, reshapedType, |
| transposeResult); |
| } |
| |
| SmallVector<Value> reshapeDims; |
| auto multiplyDynamicDims = [&](llvm::ArrayRef<int64_t> dims) -> Value { |
| Value dynamicSize = mlir::stablehlo::GetDimensionSizeOp::create( |
| rewriter, loc, arg, rewriter.getI64IntegerAttr(dims.front())); |
| Value dynamicSizeReshaped = mlir::stablehlo::ReshapeOp::create( |
| rewriter, loc, RankedTensorType::get({1}, rewriter.getI32Type()), |
| dynamicSize); |
| for (auto idx : dims.drop_front()) { |
| Value dim = mlir::stablehlo::GetDimensionSizeOp::create( |
| rewriter, loc, arg, rewriter.getI64IntegerAttr(idx)); |
| Value dimReshaped = mlir::stablehlo::ReshapeOp::create( |
| rewriter, loc, RankedTensorType::get({1}, rewriter.getI32Type()), |
| dim); |
| dynamicSizeReshaped = mlir::stablehlo::MulOp::create( |
| rewriter, loc, dynamicSizeReshaped, dimReshaped); |
| } |
| return dynamicSizeReshaped; |
| }; |
| |
| if (leftSize < 0) { |
| reshapeDims.push_back(multiplyDynamicDims(leftDims)); |
| } else { |
| reshapeDims.push_back(mlir::stablehlo::ConstantOp::create( |
| rewriter, loc, rewriter.getI32TensorAttr(leftSize))); |
| } |
| |
| if (rightSize < 0) { |
| reshapeDims.push_back(multiplyDynamicDims(rightDims)); |
| } else { |
| reshapeDims.push_back(mlir::stablehlo::ConstantOp::create( |
| rewriter, loc, rewriter.getI32TensorAttr(rightSize))); |
| } |
| |
| Value reshapeDimsTensor = mlir::stablehlo::ConcatenateOp::create( |
| rewriter, loc, RankedTensorType::get({2}, rewriter.getI32Type()), |
| reshapeDims, rewriter.getI64IntegerAttr(0)); |
| return mlir::stablehlo::DynamicReshapeOp::create( |
| rewriter, loc, reshapedType, transposeResult, reshapeDimsTensor); |
| } |
| |
| Value processDotArg(Value arg, Location loc, ArrayRef<int64_t> contractDimsAttr, |
| bool outerDimsFirst, PatternRewriter &rewriter) { |
| auto shape = llvm::cast<ShapedType>(arg.getType()).getShape(); |
| |
| llvm::SmallVector<bool, 5> isOuterDim; |
| isOuterDim.resize(shape.size(), true); |
| |
| // Compute the contract dimension ordering. |
| llvm::SmallVector<int64_t, 5> contractDims; |
| for (auto dim : contractDimsAttr) { |
| contractDims.push_back(dim); |
| isOuterDim[dim] = false; |
| } |
| |
| // Compute the outer dimension orderings. |
| llvm::SmallVector<int64_t, 5> outerDims; |
| for (const auto &it : llvm::enumerate(isOuterDim)) { |
| if (it.value()) { |
| outerDims.push_back(it.index()); |
| } |
| } |
| |
| if (outerDimsFirst) { |
| return transposeReshape(arg, loc, outerDims, contractDims, shape, rewriter); |
| } |
| |
| return transposeReshape(arg, loc, contractDims, outerDims, shape, rewriter); |
| } |
| |
| struct GeneralDotRemoveBatch final |
| : OpRewritePattern<mlir::stablehlo::DotGeneralOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::DotGeneralOp op, |
| PatternRewriter &rewriter) const override { |
| auto lhsTy = cast<ShapedType>(op.getLhs().getType()); |
| auto rhsTy = cast<ShapedType>(op.getRhs().getType()); |
| auto ty = cast<ShapedType>(op.getType()); |
| |
| if (!ty.hasStaticShape()) { |
| return rewriter.notifyMatchFailure(op, "does not have static shape"); |
| } |
| |
| auto dimNumbers = op.getDotDimensionNumbers(); |
| if (dimNumbers.getLhsBatchingDimensions().size() != 1 || |
| dimNumbers.getLhsBatchingDimensions().size() != 1) { |
| return rewriter.notifyMatchFailure(op, "non-unary batch dimension"); |
| } |
| |
| if (dimNumbers.getLhsBatchingDimensions().front() != 0 || |
| dimNumbers.getRhsBatchingDimensions().front() != 0) { |
| return rewriter.notifyMatchFailure(op, "not first dim on lhs/rhs"); |
| } |
| |
| if (lhsTy.getDimSize(0) != 1 || rhsTy.getDimSize(0) != 1) { |
| return rewriter.notifyMatchFailure(op, "not unary batch size"); |
| } |
| |
| // We no longer include the batch dimension of 1. |
| llvm::SmallVector<int64_t> newLhsContractingDims; |
| for (auto dim : dimNumbers.getLhsContractingDimensions()) |
| newLhsContractingDims.push_back(dim - 1); |
| |
| llvm::SmallVector<int64_t> newRhsContractingDims; |
| for (auto dim : dimNumbers.getRhsContractingDimensions()) |
| newRhsContractingDims.push_back(dim - 1); |
| |
| auto lhs = mlir::stablehlo::ReshapeOp::create( |
| rewriter, op.getLoc(), lhsTy.clone(lhsTy.getShape().drop_front()), |
| op.getLhs()); |
| |
| auto rhs = mlir::stablehlo::ReshapeOp::create( |
| rewriter, op.getLoc(), rhsTy.clone(rhsTy.getShape().drop_front()), |
| op.getRhs()); |
| |
| auto newDimNumbers = mlir::stablehlo::DotDimensionNumbersAttr::get( |
| rewriter.getContext(), |
| /*lhsBatchingDimensions=*/{}, |
| /*rhsBatchingDimensions=*/{}, |
| /*lhsContractingDimensions=*/ |
| newLhsContractingDims, |
| /*rhsContractingDimensions=*/ |
| newRhsContractingDims); |
| |
| auto dot = mlir::stablehlo::DotGeneralOp::create( |
| rewriter, op.getLoc(), ty.clone(ty.getShape().drop_front()), lhs, rhs, |
| newDimNumbers, op.getPrecisionConfigAttr(), op.getAlgorithmAttr()); |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, ty, |
| dot.getResult()); |
| return success(); |
| } |
| }; |
| |
| struct GeneralDotConvert final |
| : OpRewritePattern<mlir::stablehlo::DotGeneralOp> { |
| using OpRewritePattern::OpRewritePattern; |
| // Attempts to lower a General Dot operator to a standard Dot operator. |
| // General dots include batching dimensions and can have collapsing |
| // dimensions along any axis. Inserting correctly arrange transpose and |
| // reshape operators organizes the tensors and allows the General Dot to be |
| // replaced with the standard Dot operator. |
| // |
| // Note: This requires an empty list of batch dimensions. |
| LogicalResult matchAndRewrite(mlir::stablehlo::DotGeneralOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| |
| auto dotNumbers = op.getDotDimensionNumbers(); |
| if (!dotNumbers.getLhsBatchingDimensions().empty() || |
| !dotNumbers.getRhsBatchingDimensions().empty()) { |
| return failure(); |
| } |
| |
| ArrayAttr precisionConfig; |
| auto opPrecisionConfig = op.getPrecisionConfig(); |
| if (opPrecisionConfig.has_value()) |
| precisionConfig = *opPrecisionConfig; |
| |
| auto resultTy = cast<ShapedType>(op.getType()); |
| |
| ArrayRef<int64_t> lhsContractingDims = |
| dotNumbers.getLhsContractingDimensions(); |
| ArrayRef<int64_t> rhsContractingDims = |
| dotNumbers.getRhsContractingDimensions(); |
| |
| TypedValue<RankedTensorType> lhs = op.getLhs(); |
| TypedValue<RankedTensorType> rhs = op.getRhs(); |
| |
| RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType()); |
| RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType()); |
| if (!lhsTy || !rhsTy) |
| return failure(); |
| |
| // The StableHLO dot operator directly supports a vector dot product |
| // (two vectors reduce into a scalar) as well as a matrix vector |
| // product (a matrix and vector reduce into a vector) without any |
| // need for reshaping. We handle those special cases first, before |
| // entering the general logic that reduces into a matrix. |
| if (lhsTy.hasStaticShape() && rhsTy.hasStaticShape() && |
| lhsContractingDims.size() == 1 && rhsContractingDims.size() == 1) { |
| if (lhsTy.getRank() == 1 && rhsTy.getRank() == 1) { |
| // Vector-vector, reduces into scalar. |
| assert(lhsContractingDims[0] == 0 && rhsContractingDims[0] == 0); |
| ShapedType newTy = RankedTensorType::get({}, resultTy.getElementType()); |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::DotOp>(op, newTy, lhs, rhs, |
| precisionConfig); |
| return success(); |
| } |
| if (lhsTy.getRank() == 2 && rhsTy.getRank() == 1 && |
| lhsContractingDims[0] == 1) { |
| // Matrix-vector, reduces into vector. |
| assert(rhsContractingDims[0] == 0); |
| ShapedType newTy = RankedTensorType::get({lhsTy.getShape()[0]}, |
| resultTy.getElementType()); |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::DotOp>(op, newTy, lhs, rhs, |
| precisionConfig); |
| return success(); |
| } |
| if (lhsTy.getRank() == 2 && rhsTy.getRank() == 2 && |
| lhsContractingDims[0] == 1 && rhsContractingDims[0] == 0) { |
| // Matrix-matrix, reduces into matrix. Note that for dense cases, this |
| // rewriting rule simply provides a shortcut for what is to follow |
| // (modulo optimizing the trivial transpose/reshape operations). For |
| // sparse cases, however, this rewriting preserves the output sparsity |
| // that was explicitly given for the general dot operation. |
| Value newDotOp = mlir::stablehlo::DotOp::create( |
| rewriter, loc, resultTy, lhs, rhs, precisionConfig); |
| if (auto enc = sparse_tensor::getSparseTensorEncoding(resultTy)) { |
| newDotOp.setType(RankedTensorType::get( |
| resultTy.getShape(), resultTy.getElementType(), enc)); |
| } |
| rewriter.replaceOp(op, newDotOp); |
| return success(); |
| } |
| } |
| |
| // For any sparse situation, don't use any of the following rules, since |
| // transposing and reshaping is not without cost. Instead, rely on the |
| // default linalg lowering that follows later in the pipeline. |
| if (sparse_tensor::hasAnySparseOperandOrResult(op)) |
| return failure(); |
| |
| // Compute the, possibly, transposed-reshaped operands. |
| lhs = cast<mlir::TypedValue<mlir::RankedTensorType>>(processDotArg( |
| lhs, loc, lhsContractingDims, /*outerDimsFirst=*/true, rewriter)); |
| rhs = cast<mlir::TypedValue<mlir::RankedTensorType>>(processDotArg( |
| rhs, loc, rhsContractingDims, /*outerDimsFirst=*/false, rewriter)); |
| |
| // Accept only static shaped types. |
| auto lhsShapeType = dyn_cast_or_null<ShapedType>(lhs.getType()); |
| auto rhsShapeType = dyn_cast_or_null<ShapedType>(rhs.getType()); |
| if (!lhsShapeType || !rhsShapeType) |
| return failure(); |
| |
| // Generate new dot operator on expanded types. |
| ShapedType newTy = RankedTensorType::get( |
| {lhsShapeType.getShape()[0], rhsShapeType.getShape()[1]}, |
| resultTy.getElementType()); |
| Value newDotOp = mlir::stablehlo::DotOp::create(rewriter, loc, newTy, lhs, |
| rhs, precisionConfig); |
| if (static_cast<int64_t>(lhsContractingDims.size()) == |
| lhsTy.getRank() - 1 && |
| static_cast<int64_t>(rhsContractingDims.size()) == |
| rhsTy.getRank() - 1) { |
| rewriter.replaceOp(op, newDotOp); |
| return success(); |
| } |
| |
| // We can avoid all the computation below if we know the static shape. |
| if (resultTy.hasStaticShape()) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, resultTy, |
| newDotOp); |
| return success(); |
| } |
| |
| llvm::SmallVector<int64_t> staticDims; |
| llvm::SmallVector<Value> dynDims; |
| |
| auto getDynamicDims = [&](Value arg, |
| llvm::ArrayRef<int64_t> contractingDims) { |
| RankedTensorType ty = llvm::cast<RankedTensorType>(arg.getType()); |
| int index = 0; |
| for (int64_t contractingDim : contractingDims) { |
| for (; index < contractingDim; ++index) { |
| staticDims.push_back(ty.getDimSize(index)); |
| Value dynDim = mlir::stablehlo::GetDimensionSizeOp::create( |
| rewriter, loc, arg, rewriter.getI64IntegerAttr(index)); |
| Value dynDimReshaped = mlir::stablehlo::ReshapeOp::create( |
| rewriter, loc, RankedTensorType::get({1}, rewriter.getI32Type()), |
| dynDim); |
| dynDims.push_back(dynDimReshaped); |
| } |
| index++; |
| } |
| |
| for (; index < ty.getRank(); ++index) { |
| staticDims.push_back(ty.getDimSize(index)); |
| Value dynDim = mlir::stablehlo::GetDimensionSizeOp::create( |
| rewriter, loc, arg, rewriter.getI64IntegerAttr(index)); |
| Value dynDimReshaped = mlir::stablehlo::ReshapeOp::create( |
| rewriter, loc, RankedTensorType::get({1}, rewriter.getI32Type()), |
| dynDim); |
| dynDims.push_back(dynDimReshaped); |
| } |
| }; |
| |
| getDynamicDims(op.getLhs(), lhsContractingDims); |
| getDynamicDims(op.getRhs(), rhsContractingDims); |
| |
| Value reshapeDimsTensor = mlir::stablehlo::ConcatenateOp::create( |
| rewriter, loc, |
| RankedTensorType::get({static_cast<int64_t>(dynDims.size())}, |
| rewriter.getI32Type()), |
| dynDims, rewriter.getI64IntegerAttr(0)); |
| |
| Value result = mlir::stablehlo::DynamicReshapeOp::create( |
| rewriter, loc, |
| RankedTensorType::get(staticDims, resultTy.getElementType()), newDotOp, |
| reshapeDimsTensor); |
| |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| struct DotVectorOptimization final : OpRewritePattern<mlir::stablehlo::DotOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(mlir::stablehlo::DotOp op, |
| PatternRewriter &rewriter) const override { |
| ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| Value lhs = op.getLhs(); |
| Value rhs = op.getRhs(); |
| |
| ShapedType lhsTy = cast<ShapedType>(lhs.getType()); |
| ShapedType rhsTy = cast<ShapedType>(rhs.getType()); |
| ShapedType resultTy = cast<ShapedType>(op.getType()); |
| |
| llvm::SmallVector<int64_t> dotShape; |
| if (lhsTy.getRank() == 2 && lhsTy.getDimSize(0) == 1) { |
| lhs = mlir::stablehlo::ReshapeOp::create( |
| b, lhsTy.clone({lhsTy.getDimSize(1)}), lhs); |
| } else if (lhsTy.getRank() == 2) { |
| dotShape.push_back(lhsTy.getDimSize(0)); |
| } |
| |
| if (rhsTy.getRank() == 2 && rhsTy.getDimSize(1) == 1) { |
| rhs = mlir::stablehlo::ReshapeOp::create( |
| b, rhsTy.clone({rhsTy.getDimSize(0)}), rhs); |
| } else if (rhsTy.getRank() == 2) { |
| dotShape.push_back(rhsTy.getDimSize(1)); |
| } |
| |
| if (lhs == op.getLhs() && rhs == op.getRhs()) { |
| return rewriter.notifyMatchFailure(op, "no vector reform available."); |
| } |
| |
| auto newDot = mlir::stablehlo::DotOp::create( |
| b, resultTy.clone(dotShape), lhs, rhs, op.getPrecisionConfigAttr()); |
| auto resultReshape = |
| mlir::stablehlo::ReshapeOp::create(b, resultTy, newDot); |
| |
| rewriter.replaceOp(op, resultReshape); |
| return success(); |
| } |
| }; |
| |
| struct DotGeneralToDot final : impl::DotGeneralToDotBase<DotGeneralToDot> { |
| |
| void runOnOperation() override { |
| RewritePatternSet patterns(&getContext()); |
| populatePreprocessingDotGeneralToDotPatterns(&getContext(), &patterns); |
| if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| void populatePreprocessingDotGeneralToDotPatterns(mlir::MLIRContext *context, |
| RewritePatternSet *patterns, |
| PatternBenefit benefit) { |
| patterns |
| ->add<GeneralDotConvert, GeneralDotRemoveBatch, DotVectorOptimization>( |
| context, benefit); |
| } |
| |
| } // namespace mlir::iree_compiler::stablehlo |