| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Implements logic for lowering CHLO ops to StableHLO and Shape dialect ops, |
| // taking care of CHLO's broadcasting semantics |
| |
| #include "compiler/plugins/input/StableHLO/Conversion/Passes.h" |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h" |
| #include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "mlir/Dialect/Complex/IR/Complex.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/Shape/IR/Shape.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/ImplicitLocOpBuilder.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "stablehlo/dialect/BroadcastUtils.h" |
| #include "stablehlo/dialect/ChloOps.h" |
| #include "stablehlo/dialect/StablehloOps.h" |
| |
| namespace mlir::iree_compiler::stablehlo { |
| |
| #define GEN_PASS_DEF_LEGALIZECHLO |
| #include "compiler/plugins/input/StableHLO/Conversion/Passes.h.inc" |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Helpers. |
| //===----------------------------------------------------------------------===// |
| |
| template <typename FromOpTy, typename ToOpTy> |
| struct HloNaryElementwiseAdaptor { |
| static ToOpTy createOp(FromOpTy fromOp, Type resultType, |
| ValueRange broadcastedOperands, OpBuilder &builder) { |
| return builder.create<ToOpTy>(fromOp.getLoc(), resultType, |
| broadcastedOperands); |
| } |
| }; |
| |
| static std::optional<mlir::stablehlo::ComparisonDirection> |
| toStableHloComparisonDirection(mlir::chlo::ComparisonDirection value) { |
| switch (value) { |
| case mlir::chlo::ComparisonDirection::EQ: |
| return mlir::stablehlo::ComparisonDirection::EQ; |
| case mlir::chlo::ComparisonDirection::NE: |
| return mlir::stablehlo::ComparisonDirection::NE; |
| case mlir::chlo::ComparisonDirection::GE: |
| return mlir::stablehlo::ComparisonDirection::GE; |
| case mlir::chlo::ComparisonDirection::GT: |
| return mlir::stablehlo::ComparisonDirection::GT; |
| case mlir::chlo::ComparisonDirection::LE: |
| return mlir::stablehlo::ComparisonDirection::LE; |
| case mlir::chlo::ComparisonDirection::LT: |
| return mlir::stablehlo::ComparisonDirection::LT; |
| } |
| return {}; |
| } |
| |
| static std::optional<mlir::stablehlo::ComparisonType> |
| toStableHloComparisonType(mlir::chlo::ComparisonType value) { |
| switch (value) { |
| case mlir::chlo::ComparisonType::NOTYPE: |
| return mlir::stablehlo::ComparisonType::NOTYPE; |
| case mlir::chlo::ComparisonType::FLOAT: |
| return mlir::stablehlo::ComparisonType::FLOAT; |
| case mlir::chlo::ComparisonType::TOTALORDER: |
| return mlir::stablehlo::ComparisonType::TOTALORDER; |
| case mlir::chlo::ComparisonType::SIGNED: |
| return mlir::stablehlo::ComparisonType::SIGNED; |
| case mlir::chlo::ComparisonType::UNSIGNED: |
| return mlir::stablehlo::ComparisonType::UNSIGNED; |
| } |
| return {}; |
| } |
| |
| struct HloCompareAdaptor { |
| static mlir::stablehlo::CompareOp |
| createOp(mlir::chlo::BroadcastCompareOp fromOp, Type resultType, |
| ValueRange broadcastedOperands, OpBuilder &builder) { |
| auto chloDirection = fromOp.getComparisonDirection(); |
| auto hloDirection = toStableHloComparisonDirection(chloDirection); |
| if (!hloDirection) |
| return nullptr; |
| auto chloType = |
| fromOp.getCompareType().value_or(mlir::chlo::ComparisonType::NOTYPE); |
| auto hloType = toStableHloComparisonType(chloType); |
| if (!hloType) |
| return nullptr; |
| auto hloTypeAttr = fromOp.getCompareType() |
| ? mlir::stablehlo::ComparisonTypeAttr::get( |
| builder.getContext(), *hloType) |
| : nullptr; |
| return builder.create<mlir::stablehlo::CompareOp>( |
| fromOp.getLoc(), resultType, broadcastedOperands[0], |
| broadcastedOperands[1], *hloDirection, hloTypeAttr); |
| } |
| }; |
| |
| // Populate a pattern for each Broadcasting Chlo op. This requires the pattern |
| // to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values. |
| template <template <typename, typename, typename> typename Pattern, |
| typename... ConstructorArgs> |
| static void populateForBroadcastingBinaryOp(MLIRContext *context, |
| RewritePatternSet *patterns, |
| ConstructorArgs &&...args) { |
| #define POPULATE_BCAST(ChloOp, HloOp) \ |
| patterns \ |
| ->add<Pattern<ChloOp, HloOp, HloNaryElementwiseAdaptor<ChloOp, HloOp>>>( \ |
| context, args...); |
| |
| POPULATE_BCAST(mlir::chlo::BroadcastAddOp, mlir::stablehlo::AddOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastAndOp, mlir::stablehlo::AndOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastAtan2Op, mlir::stablehlo::Atan2Op); |
| POPULATE_BCAST(mlir::chlo::BroadcastComplexOp, mlir::stablehlo::ComplexOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastDivOp, mlir::stablehlo::DivOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastMaxOp, mlir::stablehlo::MaxOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastMinOp, mlir::stablehlo::MinOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastMulOp, mlir::stablehlo::MulOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastNextAfterOp, mlir::chlo::NextAfterOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastOrOp, mlir::stablehlo::OrOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastPolygammaOp, mlir::chlo::PolygammaOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastPowOp, mlir::stablehlo::PowOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastRemOp, mlir::stablehlo::RemOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastShiftLeftOp, |
| mlir::stablehlo::ShiftLeftOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastShiftRightArithmeticOp, |
| mlir::stablehlo::ShiftRightArithmeticOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastShiftRightLogicalOp, |
| mlir::stablehlo::ShiftRightLogicalOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastSubOp, mlir::stablehlo::SubtractOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastXorOp, mlir::stablehlo::XorOp); |
| POPULATE_BCAST(mlir::chlo::BroadcastZetaOp, mlir::chlo::ZetaOp); |
| |
| #undef POPULATE_BCAST |
| |
| // Broadcasting ops requiring special construction. |
| patterns->add<Pattern<mlir::chlo::BroadcastCompareOp, |
| mlir::stablehlo::CompareOp, HloCompareAdaptor>>( |
| context, args...); |
| } |
| |
| template <typename T> |
| static Value getConstantLike(OpBuilder &b, Location loc, T constant, |
| Value val) { |
| Type ty = getElementTypeOrSelf(val.getType()); |
| auto getAttr = [&]() -> Attribute { |
| if (isa<IntegerType>(ty)) |
| return b.getIntegerAttr(ty, constant); |
| if (isa<FloatType>(ty)) |
| return b.getFloatAttr(ty, constant); |
| if (auto complexTy = dyn_cast<ComplexType>(ty)) { |
| return complex::NumberAttr::get(complexTy, constant, 0); |
| } |
| llvm_unreachable("unhandled element type"); |
| }; |
| return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()), |
| val); |
| } |
| |
| static Value getConstantLike(OpBuilder &b, Location loc, |
| const APFloat &constant, Value val) { |
| Type ty = getElementTypeOrSelf(val.getType()); |
| return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant), |
| val); |
| } |
| |
| static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, |
| Value val) { |
| auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType())); |
| return getConstantLike( |
| b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val); |
| } |
| |
| static Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val, |
| bool negative) { |
| auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType())); |
| return getConstantLike( |
| b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Broadcasting Patterns. |
| //===----------------------------------------------------------------------===// |
| |
| // Converts binary ops that statically are determined to not broadcast directly |
| // to the corresponding stablehlo non-broadcasting op. |
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> |
| struct ConvertTrivialNonBroadcastBinaryOp final |
| : OpConversionPattern<ChloOpTy> { |
| using OpConversionPattern<ChloOpTy>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(ChloOpTy op, typename ChloOpTy::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Only rewrite for statically determinable non-broadcasting cases. |
| auto lhsType = dyn_cast<RankedTensorType>(adaptor.getLhs().getType()); |
| auto rhsType = dyn_cast<RankedTensorType>(adaptor.getRhs().getType()); |
| if (!lhsType || !rhsType) |
| return failure(); |
| |
| // Requires rank broadcast. |
| if (lhsType.getRank() != rhsType.getRank()) |
| return failure(); |
| |
| // Any dynamic dimension may require broadcasting and requires more |
| // analysis. |
| if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) { |
| return failure(); |
| } |
| |
| if (!llvm::equal(lhsType.getShape(), rhsType.getShape())) { |
| return failure(); |
| } |
| |
| rewriter.replaceOp( |
| op, ValueRange{Adaptor::createOp(op, op.getResult().getType(), |
| adaptor.getOperands(), rewriter)}); |
| return success(); |
| } |
| }; |
| |
| // Converts a binary op with ranked broadcasting operands to explicitly |
| // broadcast and invoke the corresponding stablehlo non-broadcasting op. |
| // Note that dynamic broadcasting supported by this pattern is only valid for |
| // "numpy" broadcasting semantics as defined here: |
| // https://docs.scipy.org/doc/numpy/reference/ufuncs.html |
| // Specifically, this includes the following cases: |
| // - Same rank broadcast (operands have the same static rank). |
| // - Different-rank broadcast, either without a broadcast_dims attribute or |
| // with the broadcast_dims attribute set to map to a prefix padding. |
| // - Legal combinations of degenerate (1-dim) implicit broadcasting. |
| // The restriction on broadcast_dims derives from the definition of the |
| // `shape.broadcast` op, which only supports prefix-padding. |
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> |
| struct ConvertRankedDynamicBroadcastBinaryOp final |
| : OpConversionPattern<ChloOpTy> { |
| using OpConversionPattern<ChloOpTy>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(ChloOpTy op, typename ChloOpTy::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Only support ranked operands. |
| Value lhs = adaptor.getLhs(); |
| Value rhs = adaptor.getRhs(); |
| auto lhsType = dyn_cast<RankedTensorType>(lhs.getType()); |
| auto rhsType = dyn_cast<RankedTensorType>(rhs.getType()); |
| auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType()); |
| if (!lhsType || !rhsType || !resultType) |
| return failure(); |
| |
| // Check for "numpy"-style rank broadcast. |
| auto broadcastDimensions = op.getBroadcastDimensions(); |
| if (broadcastDimensions && !mlir::hlo::isLegalNumpyRankedBroadcast( |
| lhs, rhs, *broadcastDimensions)) { |
| // Note: It is unclear whether the general specification of explicit |
| // broadcast_dimensions on binary ops is a feature we want to carry |
| // forward. While it can technically be implemented for ranked-dynamic, |
| // it is incompatible with unranked inputs. If this warning is emitted |
| // in real programs, it is an indication that the feature should be |
| // implemented versus just falling back on the more standard definition |
| // of numpy-like prefix-padding. |
| op.emitWarning() << "unsupported non prefix-padded dynamic rank " |
| << "broadcast_dimensions = " << *broadcastDimensions; |
| return failure(); |
| } |
| |
| // Compute result shape. |
| Location loc = op.getLoc(); |
| |
| // Insert a constraint on the shapes being broadcastable and insert all |
| // future code into an assuming block reliant on the constraint. |
| Value lhsShape = rewriter.create<shape::ShapeOfOp>(loc, lhs); |
| Value rhsShape = rewriter.create<shape::ShapeOfOp>(loc, rhs); |
| auto broadcastableCstr = |
| rewriter.create<shape::CstrBroadcastableOp>(loc, lhsShape, rhsShape); |
| auto assumingOp = rewriter.create<shape::AssumingOp>( |
| loc, ArrayRef<Type>{resultType}, broadcastableCstr.getResult()); |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.createBlock(&assumingOp.getDoRegion()); |
| |
| int64_t resultRank = std::max(lhsType.getRank(), rhsType.getRank()); |
| Value resultExtents = |
| hlo::computeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, |
| rewriter); |
| |
| // Note that we unconditionally emit DynamicBroadcastInDim ops and let |
| // downstream canonicalizations fold them away if possible. This is |
| // because, in the dynamic case, there are many corner cases regarding |
| // when it is safe to omit, and some of them require analysis to prove |
| // properly. |
| auto lhsBroadcastDimensions = llvm::to_vector( |
| llvm::seq<int64_t>(resultRank - lhsType.getRank(), resultRank)); |
| Value broadcastedLhs = |
| rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(resultType.getShape(), |
| lhsType.getElementType()), |
| lhs, resultExtents, |
| rewriter.getDenseI64ArrayAttr(lhsBroadcastDimensions)); |
| auto rhsBroadcastDimensions = llvm::to_vector( |
| llvm::seq<int64_t>(resultRank - rhsType.getRank(), resultRank)); |
| Value broadcastedRhs = |
| rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(resultType.getShape(), |
| rhsType.getElementType()), |
| rhs, resultExtents, |
| rewriter.getDenseI64ArrayAttr(rhsBroadcastDimensions)); |
| |
| // And generate the final non-broadcasted binary op. |
| Value finalResult = Adaptor::createOp( |
| op, resultType, {broadcastedLhs, broadcastedRhs}, rewriter); |
| rewriter.create<shape::AssumingYieldOp>(loc, finalResult); |
| rewriter.replaceOp(op, {assumingOp.getResult(0)}); |
| return success(); |
| } |
| }; |
| |
| struct ConvertConstantLikeOp final |
| : OpConversionPattern<mlir::chlo::ConstantLikeOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::ConstantLikeOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto resultTy = cast<ShapedType>(op.getType()); |
| |
| // Unranked uses are not supported. |
| if (!resultTy.hasRank()) |
| return failure(); |
| |
| // Lower to HLO constant if statically shaped. |
| if (resultTy.hasStaticShape()) { |
| auto complexAttr = dyn_cast<mlir::complex::NumberAttr>(op.getValue()); |
| auto attr = DenseElementsAttr::get(resultTy, complexAttr ? complexAttr |
| : op.getValue()); |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, attr); |
| return success(); |
| } |
| |
| // Lower to broadcasted constant. |
| Location loc = op.getLoc(); |
| Value constant = |
| rewriter.create<mlir::stablehlo::ConstantOp>(loc, op.getValue()); |
| Value shape = rewriter.create<shape::ShapeOfOp>(loc, adaptor.getOperand()); |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::DynamicBroadcastInDimOp>( |
| op, resultTy, constant, shape, rewriter.getDenseI64ArrayAttr({})); |
| return success(); |
| } |
| }; |
| |
| struct ConvertSelectOp final |
| : OpConversionPattern<mlir::chlo::BroadcastSelectOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::BroadcastSelectOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Only support ranked operands. |
| Value pred = adaptor.getPred(); |
| Value onTrue = adaptor.getOnTrue(); |
| Value onFalse = adaptor.getOnFalse(); |
| auto predType = dyn_cast<RankedTensorType>(pred.getType()); |
| auto onTrueType = dyn_cast<RankedTensorType>(onTrue.getType()); |
| auto onFalseType = dyn_cast<RankedTensorType>(onFalse.getType()); |
| auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType()); |
| if (!predType || !onTrueType || !onFalseType || !resultType) { |
| return failure(); |
| } |
| |
| Location loc = op.getLoc(); |
| Value predShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred); |
| Value onTrueShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onTrue); |
| Value onFalseShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onFalse); |
| int64_t resultRank = std::max( |
| {predType.getRank(), onTrueType.getRank(), onFalseType.getRank()}); |
| |
| Value broadcastableCstr = rewriter.createOrFold<shape::CstrBroadcastableOp>( |
| loc, ValueRange{predShape, onTrueShape, onFalseShape}); |
| auto assumingOp = rewriter.create<shape::AssumingOp>( |
| loc, ArrayRef<Type>{resultType}, broadcastableCstr); |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.createBlock(&assumingOp.getDoRegion()); |
| |
| Value resultExtents = rewriter.createOrFold<shape::BroadcastOp>( |
| loc, shape::getExtentTensorType(op.getContext()), |
| ValueRange{predShape, onTrueShape, onFalseShape}, |
| /*error=*/nullptr); |
| auto shapeType = |
| RankedTensorType::get({resultRank}, rewriter.getIndexType()); |
| resultExtents = |
| rewriter.createOrFold<tensor::CastOp>(loc, shapeType, resultExtents); |
| |
| Value broadcastedPred = pred; |
| // Pred has an implicit broadcast for scalars, so use that when convenient. |
| if (predType.getRank() > 0) { |
| auto predBroadcastDimensions = llvm::to_vector( |
| llvm::seq<int64_t>(resultRank - predType.getRank(), resultRank)); |
| broadcastedPred = |
| rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(resultType.getShape(), |
| predType.getElementType()), |
| pred, resultExtents, |
| rewriter.getDenseI64ArrayAttr(predBroadcastDimensions)); |
| } |
| auto onTrueBroadcastDimensions = llvm::to_vector( |
| llvm::seq<int64_t>(resultRank - onTrueType.getRank(), resultRank)); |
| Value broadcastedOnTrue = |
| rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(resultType.getShape(), |
| onTrueType.getElementType()), |
| onTrue, resultExtents, |
| rewriter.getDenseI64ArrayAttr(onTrueBroadcastDimensions)); |
| auto onFalseBroadcastDimensions = llvm::to_vector( |
| llvm::seq<int64_t>(resultRank - onFalseType.getRank(), resultRank)); |
| Value broadcastedOnFalse = |
| rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>( |
| loc, |
| RankedTensorType::get(resultType.getShape(), |
| onFalseType.getElementType()), |
| onFalse, resultExtents, |
| rewriter.getDenseI64ArrayAttr(onFalseBroadcastDimensions)); |
| |
| // And generate the final non-broadcasted ternary op. |
| Value finalResult = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, resultType, broadcastedPred, broadcastedOnTrue, |
| broadcastedOnFalse); |
| rewriter.create<shape::AssumingYieldOp>(loc, finalResult); |
| rewriter.replaceOp(op, {assumingOp.getResult(0)}); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Decomposition Patterns. |
| //===----------------------------------------------------------------------===// |
| |
| struct ConvertConstantOp final : OpConversionPattern<mlir::chlo::ConstantOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::ConstantOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, op.getValue()); |
| return success(); |
| } |
| }; |
| |
| template <typename FTy> |
| static Value |
| materializeChebyshevPolynomialApproximation(ConversionPatternRewriter &rewriter, |
| Location loc, Value x, |
| ArrayRef<FTy> coefficients) { |
| Value b0 = getConstantLike(rewriter, loc, 0.0, x); |
| Value b1 = getConstantLike(rewriter, loc, 0.0, x); |
| Value b2 = getConstantLike(rewriter, loc, 0.0, x); |
| for (FTy c : coefficients) { |
| b2 = b1; |
| b1 = b0; |
| b0 = rewriter.create<mlir::stablehlo::MulOp>(loc, x.getType(), x, b1); |
| b0 = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x.getType(), b0, b2); |
| b0 = rewriter.create<mlir::stablehlo::AddOp>( |
| loc, x.getType(), b0, getConstantLike(rewriter, loc, c, x)); |
| } |
| Value result = |
| rewriter.create<mlir::stablehlo::SubtractOp>(loc, x.getType(), b0, b2); |
| result = rewriter.create<mlir::stablehlo::MulOp>( |
| loc, x.getType(), result, getConstantLike(rewriter, loc, 0.5, x)); |
| return result; |
| } |
| |
| template <typename FTy> |
| static Value materializeBesselI1eApproximation( |
| ConversionPatternRewriter &rewriter, Location loc, Value x, |
| ArrayRef<FTy> kI1eCoeffsA, ArrayRef<FTy> kI1eCoeffsB) { |
| Value z = rewriter.create<mlir::stablehlo::AbsOp>(loc, x); |
| Value half = getConstantLike(rewriter, loc, 0.5, x); |
| Value two = getConstantLike(rewriter, loc, 2.0, x); |
| Value thirtyTwo = getConstantLike(rewriter, loc, 32.0, x); |
| Value eight = getConstantLike(rewriter, loc, 8.0, x); |
| |
| Value tmp = rewriter.create<mlir::stablehlo::MulOp>(loc, half, z); |
| tmp = rewriter.create<mlir::stablehlo::SubtractOp>(loc, tmp, two); |
| |
| Value xLe8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp, |
| kI1eCoeffsA); |
| xLe8 = rewriter.create<mlir::stablehlo::MulOp>(loc, z, xLe8); |
| |
| tmp = rewriter.create<mlir::stablehlo::DivOp>(loc, thirtyTwo, z); |
| tmp = rewriter.create<mlir::stablehlo::SubtractOp>(loc, tmp, two); |
| Value xGt8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp, |
| kI1eCoeffsB); |
| xGt8 = rewriter.create<mlir::stablehlo::DivOp>( |
| loc, xGt8, rewriter.create<mlir::stablehlo::SqrtOp>(loc, z)); |
| |
| Value isLe8 = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, z, eight, mlir::stablehlo::ComparisonDirection::LE); |
| |
| Value select = |
| rewriter.create<mlir::stablehlo::SelectOp>(loc, isLe8, xLe8, xGt8); |
| return rewriter.create<mlir::stablehlo::MulOp>( |
| loc, rewriter.create<mlir::stablehlo::SignOp>(loc, x), select); |
| } |
| |
| Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF32() && |
| "expect f32 element type"); |
| const float kI1eCoeffsA[] = { |
| 9.38153738649577178388E-9f, -4.44505912879632808065E-8f, |
| 2.00329475355213526229E-7f, -8.56872026469545474066E-7f, |
| 3.47025130813767847674E-6f, -1.32731636560394358279E-5f, |
| 4.78156510755005422638E-5f, -1.61760815825896745588E-4f, |
| 5.12285956168575772895E-4f, -1.51357245063125314899E-3f, |
| 4.15642294431288815669E-3f, -1.05640848946261981558E-2f, |
| 2.47264490306265168283E-2f, -5.29459812080949914269E-2f, |
| 1.02643658689847095384E-1f, -1.76416518357834055153E-1f, |
| 2.52587186443633654823E-1f}; |
| |
| const float kI1eCoeffsB[] = { |
| -3.83538038596423702205E-9f, -2.63146884688951950684E-8f, |
| -2.51223623787020892529E-7f, -3.88256480887769039346E-6f, |
| -1.10588938762623716291E-4f, -9.76109749136146840777E-3f, |
| 7.78576235018280120474E-1f}; |
| |
| return materializeBesselI1eApproximation<float>(rewriter, loc, x, kI1eCoeffsA, |
| kI1eCoeffsB); |
| } |
| |
| static Value |
| materializeBesselI1eApproximationF64(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF64() && |
| "expect f64 element type"); |
| |
| const double kI1eCoeffsA[] = { |
| 2.77791411276104639959E-18, -2.11142121435816608115E-17, |
| 1.55363195773620046921E-16, -1.10559694773538630805E-15, |
| 7.60068429473540693410E-15, -5.04218550472791168711E-14, |
| 3.22379336594557470981E-13, -1.98397439776494371520E-12, |
| 1.17361862988909016308E-11, -6.66348972350202774223E-11, |
| 3.62559028155211703701E-10, -1.88724975172282928790E-9, |
| 9.38153738649577178388E-9, -4.44505912879632808065E-8, |
| 2.00329475355213526229E-7, -8.56872026469545474066E-7, |
| 3.47025130813767847674E-6, -1.32731636560394358279E-5, |
| 4.78156510755005422638E-5, -1.61760815825896745588E-4, |
| 5.12285956168575772895E-4, -1.51357245063125314899E-3, |
| 4.15642294431288815669E-3, -1.05640848946261981558E-2, |
| 2.47264490306265168283E-2, -5.29459812080949914269E-2, |
| 1.02643658689847095384E-1, -1.76416518357834055153E-1, |
| 2.52587186443633654823E-1}; |
| |
| const double kI1eCoeffsB[] = { |
| 7.51729631084210481353E-18, 4.41434832307170791151E-18, |
| -4.65030536848935832153E-17, -3.20952592199342395980E-17, |
| 2.96262899764595013876E-16, 3.30820231092092828324E-16, |
| -1.88035477551078244854E-15, -3.81440307243700780478E-15, |
| 1.04202769841288027642E-14, 4.27244001671195135429E-14, |
| -2.10154184277266431302E-14, -4.08355111109219731823E-13, |
| -7.19855177624590851209E-13, 2.03562854414708950722E-12, |
| 1.41258074366137813316E-11, 3.25260358301548823856E-11, |
| -1.89749581235054123450E-11, -5.58974346219658380687E-10, |
| -3.83538038596423702205E-9, -2.63146884688951950684E-8, |
| -2.51223623787020892529E-7, -3.88256480887769039346E-6, |
| -1.10588938762623716291E-4, -9.76109749136146840777E-3, |
| 7.78576235018280120474E-1}; |
| |
| return materializeBesselI1eApproximation<double>(rewriter, loc, x, |
| kI1eCoeffsA, kI1eCoeffsB); |
| } |
| |
| static Value materializeWithUpcast(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args, |
| FloatType minPrecisionTy, |
| Value callback(ConversionPatternRewriter &, |
| Location, ValueRange)) { |
| Type originalTy = getElementTypeOrSelf(args.front().getType()); |
| auto floatOriginalTy = dyn_cast<FloatType>(originalTy); |
| bool needsUpcast = |
| floatOriginalTy && floatOriginalTy.getWidth() < minPrecisionTy.getWidth(); |
| |
| // Upcast arguments if necessary. |
| llvm::SmallVector<Value, 2> castedArgs; |
| if (needsUpcast) { |
| for (Value a : args) { |
| castedArgs.push_back( |
| rewriter.create<mlir::stablehlo::ConvertOp>(loc, a, minPrecisionTy)); |
| } |
| args = castedArgs; |
| } |
| |
| Value result = callback(rewriter, loc, args); |
| |
| // Cast back if necessary. |
| if (needsUpcast) { |
| result = |
| rewriter.create<mlir::stablehlo::ConvertOp>(loc, result, originalTy); |
| } |
| |
| return result; |
| } |
| |
| struct ConvertBesselI1eOp final : OpConversionPattern<mlir::chlo::BesselI1eOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::BesselI1eOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Value x = adaptor.getOperand(); |
| Type ty = cast<ShapedType>(x.getType()).getElementType(); |
| |
| // For now, we support only f64, f32, f16 and bf16. |
| // See https://www.tensorflow.org/api_docs/python/tf/math/bessel_i1e |
| if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) { |
| return failure(); |
| } |
| |
| if (ty.isF64()) { |
| rewriter.replaceOp( |
| op, materializeBesselI1eApproximationF64(rewriter, loc, x)); |
| return success(); |
| } |
| |
| rewriter.replaceOp( |
| op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), |
| rewriter.getF32Type(), |
| &materializeBesselI1eApproximationF32)); |
| return success(); |
| } |
| }; |
| |
| template <typename FTy> |
| static Value |
| materializePolynomialApproximation(ConversionPatternRewriter &rewriter, |
| Location loc, Value x, |
| ArrayRef<FTy> coefficients) { |
| if (coefficients.empty()) |
| return getConstantLike(rewriter, loc, 0.0, x); |
| |
| Value poly = getConstantLike(rewriter, loc, coefficients[0], x); |
| for (size_t i = 1, e = coefficients.size(); i < e; ++i) { |
| poly = rewriter.create<mlir::stablehlo::MulOp>(loc, x.getType(), poly, x); |
| poly = rewriter.create<mlir::stablehlo::AddOp>( |
| loc, x.getType(), poly, |
| getConstantLike(rewriter, loc, coefficients[i], x)); |
| } |
| return poly; |
| } |
| |
| // Precondition is |x| >= 1. Use erf approximation, otherwise. |
| // |
| // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an |
| // argument and derive the final approximation for all |x| >= 1. |
| // This implementation is based on Cephes. |
| static Value materializeErfcApproximationF64ForMagnituteGeOne( |
| ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF64() && |
| "expect f64 element type"); |
| const double kMaxlog = 7.09782712893383996843E2; |
| const double kErfcPCoefficients[] = { |
| 2.46196981473530512524E-10, 5.64189564831068821977E-1, |
| 7.46321056442269912687E0, 4.86371970985681366614E1, |
| 1.96520832956077098242E2, 5.26445194995477358631E2, |
| 9.34528527171957607540E2, 1.02755188689515710272E3, |
| 5.57535335369399327526E2}; |
| const double kErfcQCoefficients[] = { |
| 1.00000000000000000000E0, 1.32281951154744992508E1, |
| 8.67072140885989742329E1, 3.54937778887819891062E2, |
| 9.75708501743205489753E2, 1.82390916687909736289E3, |
| 2.24633760818710981792E3, 1.65666309194161350182E3, |
| 5.57535340817727675546E2}; |
| const double kErfcRCoefficients[] = { |
| 5.64189583547755073984E-1, 1.27536670759978104416E0, |
| 5.01905042251180477414E0, 6.16021097993053585195E0, |
| 7.40974269950448939160E0, 2.97886665372100240670E0}; |
| const double kErfcSCoefficients[] = { |
| 1.00000000000000000000E0, 2.26052863220117276590E0, |
| 9.39603524938001434673E0, 1.20489539808096656605E1, |
| 1.70814450747565897222E1, 9.60896809063285878198E0, |
| 3.36907645100081516050E0}; |
| |
| // Let z = -x^2. |
| Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x); |
| Value z = rewriter.create<mlir::stablehlo::NegOp>(loc, xSq); |
| |
| // Materialize polynomial approximation for x in [1, 8) as |
| // erfc(x) = exp(z) P(|x|) / Q(|x|). |
| Value expZ = rewriter.create<mlir::stablehlo::ExpOp>(loc, z); |
| Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x); |
| Value polP = materializePolynomialApproximation( |
| rewriter, loc, absX, llvm::ArrayRef(kErfcPCoefficients)); |
| Value expZMulPolyP = rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, polP); |
| Value polQ = materializePolynomialApproximation( |
| rewriter, loc, absX, llvm::ArrayRef(kErfcQCoefficients)); |
| Value erfcApprox18 = |
| rewriter.create<mlir::stablehlo::DivOp>(loc, expZMulPolyP, polQ); |
| |
| // Materialize polynomial approximation for x in >= 8 as |
| // erfc(x) exp(z) R(|x|) / S(|x|). |
| Value polR = materializePolynomialApproximation( |
| rewriter, loc, absX, llvm::ArrayRef(kErfcRCoefficients)); |
| Value expZMulPolyR = rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, polR); |
| Value polS = materializePolynomialApproximation( |
| rewriter, loc, absX, llvm::ArrayRef(kErfcSCoefficients)); |
| Value erfcApprox8Inf = |
| rewriter.create<mlir::stablehlo::DivOp>(loc, expZMulPolyR, polS); |
| |
| // Combine polynomial approximations for x >= 1. |
| Value eight = getConstantLike(rewriter, loc, 8.0, x); |
| Value absXLt8 = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, absX, eight, mlir::stablehlo::ComparisonDirection::LT); |
| Value erfcApprox = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, absXLt8, erfcApprox18, erfcApprox8Inf); |
| |
| // Clamp to prevent overflow and materialize approximation for large x as |
| // erfc(x) = 0. |
| Value zLtNegMaxlog = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, z, getConstantLike(rewriter, loc, -kMaxlog, x), |
| mlir::stablehlo::ComparisonDirection::LT); |
| Value zero = getConstantLike(rewriter, loc, 0.0, x); |
| Value erfcApproxClamped = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, zLtNegMaxlog, zero, erfcApprox); |
| |
| // Derive approximation for x <= -1 as |
| // erfc(x) = 2 - erfc(-x). |
| // Reuse previously materialized approximations all of which take |x| as their |
| // argument. |
| Value xLtZero = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, zero, mlir::stablehlo::ComparisonDirection::LT); |
| Value two = getConstantLike(rewriter, loc, 2.0, x); |
| Value twoSubErfcApproxClamped = |
| rewriter.create<mlir::stablehlo::SubtractOp>(loc, two, erfcApproxClamped); |
| return rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, xLtZero, twoSubErfcApproxClamped, erfcApproxClamped); |
| } |
| |
| // Precondition is |x| <= 1. Use erfc approximation, otherwise. |
| // This implementation is based on Cephes. |
| static Value materializeErfApproximationF64ForMagnituteLeOne( |
| ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF64() && |
| "expect f64 element type"); |
| const double kErfTCoefficients[] = { |
| 9.60497373987051638749E0, 9.00260197203842689217E1, |
| 2.23200534594684319226E3, 7.00332514112805075473E3, |
| 5.55923013010394962768E4}; |
| const double kErfUCoefficients[] = { |
| 1.00000000000000000000E0, 3.35617141647503099647E1, |
| 5.21357949780152679795E2, 4.59432382970980127987E3, |
| 2.26290000613890934246E4, 4.92673942608635921086E4}; |
| |
| // Materialize polynomial approximation for |x| <= 1 as |
| // erf(x) = x T(x^2) / U(x^2). |
| Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x); |
| Value polyT = materializePolynomialApproximation( |
| rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients)); |
| Value xMulPolyT = rewriter.create<mlir::stablehlo::MulOp>(loc, x, polyT); |
| Value polyU = materializePolynomialApproximation( |
| rewriter, loc, xSq, llvm::ArrayRef(kErfUCoefficients)); |
| return rewriter.create<mlir::stablehlo::DivOp>(loc, xMulPolyT, polyU); |
| } |
| |
| // This implementation is based on Cephes. |
| static Value materializeErfApproximationF64(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF64() && |
| "expect f64 element type"); |
| |
| // Rely on erf approximation for |x| < 1 |
| // erf(x) = erf_approx(x) |
| Value erfApprox = |
| materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x); |
| |
| // Rely on erfc approximation for |x| >= 1 and materialize erf as |
| // erf(x) = 1 - erfc_approx(x) |
| Value one = getConstantLike(rewriter, loc, 1.0, x); |
| Value erfcApprox = |
| materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x); |
| Value erfcBasedApprox = |
| rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfcApprox); |
| |
| // Materialize approximation selection based on argument. |
| Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x); |
| Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, absX, one, mlir::stablehlo::ComparisonDirection::LT); |
| return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne, erfApprox, |
| erfcBasedApprox); |
| } |
| |
| static Value |
| materializeErfcApproximationF64(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF64() && |
| "expect f64 element type"); |
| |
| // Rely on erfc approximation for |x| >= 1 |
| // erfc(x) = erfc_approx(x) |
| Value erfcApprox = |
| materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x); |
| |
| // Rely on erf approximation for |x| < 1 and materialize erfc as |
| // erfc(x) = 1 - erf_approx(x) |
| Value one = getConstantLike(rewriter, loc, 1.0, x); |
| Value erfApprox = |
| materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x); |
| Value erfBasedApprox = |
| rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfApprox); |
| |
| // Materialize approximation selection based on argument. |
| Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x); |
| Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, absX, one, mlir::stablehlo::ComparisonDirection::LT); |
| return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne, |
| erfBasedApprox, erfcApprox); |
| } |
| |
| // Precondition is |x| >= 1. Use erf approximation, otherwise. |
| // |
| // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an |
| // argument and derive the final approximation for all |x| >= 1. |
| // This implementation is based on Cephes. |
| static Value materializeErfcApproximationF32ForMagnitudeGeOne( |
| ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF32() && |
| "expect f32 element type"); |
| const double kMaxlog = 88.72283905206835; |
| const float kErfcPCoefficients[] = { |
| +2.326819970068386E-2f, -1.387039388740657E-1f, +3.687424674597105E-1f, |
| -5.824733027278666E-1f, +6.210004621745983E-1f, -4.944515323274145E-1f, |
| +3.404879937665872E-1f, -2.741127028184656E-1f, +5.638259427386472E-1f, |
| }; |
| const float kErfcRCoefficients[] = { |
| -1.047766399936249E+1f, +1.297719955372516E+1f, -7.495518717768503E+0f, |
| +2.921019019210786E+0f, -1.015265279202700E+0f, +4.218463358204948E-1f, |
| -2.820767439740514E-1f, +5.641895067754075E-1f, |
| }; |
| |
| // Let z = -x^2. |
| Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x); |
| Value z = rewriter.create<mlir::stablehlo::NegOp>(loc, xSq); |
| |
| // Materialize polynomial approximation for x >= 1 as |
| // erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2) |
| // erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2 |
| Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x); |
| Value one = getConstantLike(rewriter, loc, 1.0, x); |
| Value reciprocalXSq = rewriter.create<mlir::stablehlo::DivOp>(loc, one, xSq); |
| Value expZ = rewriter.create<mlir::stablehlo::ExpOp>(loc, z); |
| Value oneDivAbsX = rewriter.create<mlir::stablehlo::DivOp>(loc, one, absX); |
| Value expZMulOneDivAbsX = |
| rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, oneDivAbsX); |
| Value two = getConstantLike(rewriter, loc, 2.0, x); |
| Value absXLtTwo = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, absX, two, mlir::stablehlo::ComparisonDirection::LT); |
| Value polP = materializePolynomialApproximation( |
| rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcPCoefficients)); |
| Value polR = materializePolynomialApproximation( |
| rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcRCoefficients)); |
| Value poly = |
| rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtTwo, polP, polR); |
| Value erfcApprox = |
| rewriter.create<mlir::stablehlo::MulOp>(loc, expZMulOneDivAbsX, poly); |
| |
| // Clamp to prevent overflow and materialize approximation for large x as |
| // erfc(x) = 0. |
| Value zLtNeqMaxlog = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, z, getConstantLike(rewriter, loc, -kMaxlog, x), |
| mlir::stablehlo::ComparisonDirection::LT); |
| Value zero = getConstantLike(rewriter, loc, 0.0, x); |
| Value erfcApproxClamped = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, zLtNeqMaxlog, zero, erfcApprox); |
| |
| // Derive approximation for x <= -1 as |
| // erfc(x) = 2 - erfc(-x). |
| // Reuse previously materialized approximations all of which take |x| as their |
| // argument. |
| Value xLtZero = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, zero, mlir::stablehlo::ComparisonDirection::LT); |
| Value twoSubErfcApprox = |
| rewriter.create<mlir::stablehlo::SubtractOp>(loc, two, erfcApproxClamped); |
| return rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, xLtZero, twoSubErfcApprox, erfcApproxClamped); |
| } |
| |
| // Precondition is |x| <= 1. Use erfc approximation, otherwise. |
| // This implementation is based on Cephes. |
| static Value materializeErfApproximationF32ForMagnitudeLeOne( |
| ConversionPatternRewriter &rewriter, Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF32() && |
| "expect f32 element type"); |
| const float kErfTCoefficients[] = { |
| +7.853861353153693E-5f, -8.010193625184903E-4f, +5.188327685732524E-3f, |
| -2.685381193529856E-2f, +1.128358514861418E-1f, -3.761262582423300E-1f, |
| +1.128379165726710E+0f, |
| }; |
| |
| // Materialize polynomial approximation for |x| <= 1 as |
| // erf(x) = x T(x^2). |
| Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x); |
| Value polyT = materializePolynomialApproximation( |
| rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients)); |
| return rewriter.create<mlir::stablehlo::MulOp>(loc, x, polyT); |
| } |
| |
| // This is the same approximation as used in Eigen. |
| static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF32() && |
| "expect f32 element type"); |
| const float kAlpha[] = { |
| -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, |
| -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, |
| -1.60960333262415e-02f, |
| }; |
| const float kBeta[] = { |
| -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f, |
| -7.37332916720468e-03f, -1.42647390514189e-02f, |
| }; |
| |
| // Clamp argument between -4 and 4. |
| Value lb = getConstantLike(rewriter, loc, -4.0, x); |
| Value ub = getConstantLike(rewriter, loc, 4.0, x); |
| x = rewriter.create<mlir::stablehlo::ClampOp>(loc, x.getType(), lb, x, ub); |
| Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x); |
| |
| // Materialize polynomial approximation for x in [-4, 4] as |
| // erf(x) = x * Alpha(x^2) / Beta(x^2). |
| Value alphaPoly = materializePolynomialApproximation(rewriter, loc, xSq, |
| llvm::ArrayRef(kAlpha)); |
| Value betaPoly = materializePolynomialApproximation(rewriter, loc, xSq, |
| llvm::ArrayRef(kBeta)); |
| Value xMulAlphaPoly = |
| rewriter.create<mlir::stablehlo::MulOp>(loc, x, alphaPoly); |
| Value erf = |
| rewriter.create<mlir::stablehlo::DivOp>(loc, xMulAlphaPoly, betaPoly); |
| Value lbErf = getConstantLike(rewriter, loc, -1.0, x); |
| Value ubErf = getConstantLike(rewriter, loc, 1.0, x); |
| return rewriter.create<mlir::stablehlo::ClampOp>(loc, erf.getType(), lbErf, |
| erf, ubErf); |
| } |
| |
| static Value |
| materializeErfcApproximationF32(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| Value x = args.front(); |
| assert(cast<ShapedType>(x.getType()).getElementType().isF32() && |
| "expect f32 element type"); |
| |
| // Rely on erfc approximation for |x| >= 1 |
| // erfc(x) = erfc_approx(x) |
| Value erfcApprox = |
| materializeErfcApproximationF32ForMagnitudeGeOne(rewriter, loc, x); |
| |
| // Rely on erf approximation for |x| < 1 and materialize erfc as |
| // erfc(x) = 1 - erf_approx(x) |
| Value one = getConstantLike(rewriter, loc, 1.0, x); |
| Value erfApprox = |
| materializeErfApproximationF32ForMagnitudeLeOne(rewriter, loc, x); |
| Value erfBasedApprox = |
| rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfApprox); |
| |
| // Materialize approximation selection based on argument. |
| Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x); |
| Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, absX, one, mlir::stablehlo::ComparisonDirection::LT); |
| return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne, |
| erfBasedApprox, erfcApprox); |
| } |
| |
| struct ConvertErfOp final : OpConversionPattern<mlir::chlo::ErfOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::ErfOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Value x = adaptor.getOperand(); |
| Type ty = cast<ShapedType>(x.getType()).getElementType(); |
| |
| // For now, we support only f64, f32, f16 and bf16. |
| if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) { |
| return failure(); |
| } |
| |
| if (ty.isF64()) { |
| rewriter.replaceOp(op, materializeErfApproximationF64(rewriter, loc, x)); |
| return success(); |
| } |
| |
| rewriter.replaceOp( |
| op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), |
| rewriter.getF32Type(), |
| &materializeErfApproximationF32)); |
| return success(); |
| } |
| }; |
| |
| struct ConvertErfcOp final : OpConversionPattern<mlir::chlo::ErfcOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::ErfcOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Value x = adaptor.getOperand(); |
| Type ty = cast<ShapedType>(x.getType()).getElementType(); |
| |
| // For now, we support only f64, f32, f16 and bf16. |
| if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) { |
| return failure(); |
| } |
| |
| if (ty.isF64()) { |
| rewriter.replaceOp(op, materializeErfcApproximationF64(rewriter, loc, x)); |
| return success(); |
| } |
| |
| rewriter.replaceOp( |
| op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), |
| rewriter.getF32Type(), |
| &materializeErfcApproximationF32)); |
| return success(); |
| } |
| }; |
| |
| static Value erfInv32(ConversionPatternRewriter &b, Location loc, |
| ValueRange args) { |
| constexpr int kDegree = 9; |
| constexpr std::array<float, 9> wLessThan5Constants = { |
| 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, |
| -4.39150654e-06f, 0.00021858087f, -0.00125372503f, |
| -0.00417768164f, 0.246640727f, 1.50140941f}; |
| constexpr std::array<float, 9> wGreaterThan5Constants = { |
| -0.000200214257f, 0.000100950558f, 0.00134934322f, |
| -0.00367342844f, 0.00573950773f, -0.0076224613f, |
| 0.00943887047f, 1.00167406f, 2.83297682f}; |
| |
| Value x = args[0]; |
| // Compute logarithm of (1+arg) using log1p(arg) which is more precise than |
| // log(1+arg) when arg is close to zero. For more details, see |
| // https://en.cppreference.com/w/cpp/numeric/math/log1p |
| Value minusXSquared = b.create<mlir::stablehlo::MulOp>( |
| loc, x, b.create<mlir::stablehlo::NegOp>(loc, x)); |
| Value w = b.create<mlir::stablehlo::NegOp>( |
| loc, b.create<mlir::stablehlo::Log1pOp>(loc, minusXSquared)); |
| |
| Value lt = b.create<mlir::stablehlo::CompareOp>( |
| loc, w, getConstantLike(b, loc, 5.0, x), |
| mlir::stablehlo::ComparisonDirection::LT); |
| auto coefficient = [&](int i) { |
| return b.create<mlir::stablehlo::SelectOp>( |
| loc, lt, getConstantLike(b, loc, wLessThan5Constants[i], x), |
| getConstantLike(b, loc, wGreaterThan5Constants[i], x)); |
| }; |
| w = b.create<mlir::stablehlo::SelectOp>( |
| loc, lt, |
| b.create<mlir::stablehlo::SubtractOp>(loc, w, |
| getConstantLike(b, loc, 2.5, x)), |
| b.create<mlir::stablehlo::SubtractOp>( |
| loc, b.create<mlir::stablehlo::SqrtOp>(loc, w), |
| getConstantLike(b, loc, 3.0, x))); |
| Value p = coefficient(0); |
| for (int i = 1; i < kDegree; ++i) { |
| p = b.create<mlir::stablehlo::AddOp>( |
| loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)); |
| } |
| |
| // Result modulo edge cases. |
| Value result = b.create<mlir::stablehlo::MulOp>(loc, p, x); |
| |
| // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is |
| // indeterminate, and can give nan or -/+inf.) |
| return b.create<mlir::stablehlo::SelectOp>( |
| loc, |
| b.create<mlir::stablehlo::CompareOp>( |
| loc, b.create<mlir::stablehlo::AbsOp>(loc, x), |
| getConstantLike(b, loc, 1, x), |
| mlir::stablehlo::ComparisonDirection::EQ), |
| b.create<mlir::stablehlo::MulOp>( |
| loc, x, getConstantLikeInfValue(b, loc, x, false)), |
| result); |
| } |
| |
| static Value erfInv64(ConversionPatternRewriter &b, Location loc, |
| ValueRange args) { |
| constexpr std::array<double, 23> wLessThan625Constants = { |
| -3.6444120640178196996e-21, -1.685059138182016589e-19, |
| 1.2858480715256400167e-18, 1.115787767802518096e-17, |
| -1.333171662854620906e-16, 2.0972767875968561637e-17, |
| 6.6376381343583238325e-15, -4.0545662729752068639e-14, |
| -8.1519341976054721522e-14, 2.6335093153082322977e-12, |
| -1.2975133253453532498e-11, -5.4154120542946279317e-11, |
| 1.051212273321532285e-09, -4.1126339803469836976e-09, |
| -2.9070369957882005086e-08, 4.2347877827932403518e-07, |
| -1.3654692000834678645e-06, -1.3882523362786468719e-05, |
| 0.0001867342080340571352, -0.00074070253416626697512, |
| -0.0060336708714301490533, 0.24015818242558961693, |
| 1.6536545626831027356}; |
| constexpr std::array<double, 19> wLessThan16Constants = { |
| 2.2137376921775787049e-09, 9.0756561938885390979e-08, |
| -2.7517406297064545428e-07, 1.8239629214389227755e-08, |
| 1.5027403968909827627e-06, -4.013867526981545969e-06, |
| 2.9234449089955446044e-06, 1.2475304481671778723e-05, |
| -4.7318229009055733981e-05, 6.8284851459573175448e-05, |
| 2.4031110387097893999e-05, -0.0003550375203628474796, |
| 0.00095328937973738049703, -0.0016882755560235047313, |
| 0.0024914420961078508066, -0.0037512085075692412107, |
| 0.005370914553590063617, 1.0052589676941592334, |
| 3.0838856104922207635, |
| }; |
| constexpr std::array<double, 17> wGreaterThan16Constants = { |
| -2.7109920616438573243e-11, -2.5556418169965252055e-10, |
| 1.5076572693500548083e-09, -3.7894654401267369937e-09, |
| 7.6157012080783393804e-09, -1.4960026627149240478e-08, |
| 2.9147953450901080826e-08, -6.7711997758452339498e-08, |
| 2.2900482228026654717e-07, -9.9298272942317002539e-07, |
| 4.5260625972231537039e-06, -1.9681778105531670567e-05, |
| 7.5995277030017761139e-05, -0.00021503011930044477347, |
| -0.00013871931833623122026, 1.0103004648645343977, |
| 4.8499064014085844221, |
| }; |
| |
| Value x = args[0]; |
| // Compute logarithm of (1+arg) using log1p(arg) which is more precise than |
| // log(1+arg) when arg is close to zero. For more details, see |
| // https://en.cppreference.com/w/cpp/numeric/math/log1p |
| Value minusXSquared = b.create<mlir::stablehlo::MulOp>( |
| loc, x, b.create<mlir::stablehlo::NegOp>(loc, x)); |
| Value w = b.create<mlir::stablehlo::NegOp>( |
| loc, b.create<mlir::stablehlo::Log1pOp>(loc, minusXSquared)); |
| |
| Value lt625 = b.create<mlir::stablehlo::CompareOp>( |
| loc, w, getConstantLike(b, loc, 6.25, x), |
| mlir::stablehlo::ComparisonDirection::LT); |
| Value lt16 = b.create<mlir::stablehlo::CompareOp>( |
| loc, w, getConstantLike(b, loc, 16, x), |
| mlir::stablehlo::ComparisonDirection::LT); |
| |
| auto coefficient = [&](int i) { |
| Value c = getConstantLike(b, loc, wLessThan625Constants[i], x); |
| if (i < 19) { |
| c = b.create<mlir::stablehlo::SelectOp>( |
| loc, lt625, c, getConstantLike(b, loc, wLessThan16Constants[i], x)); |
| } |
| if (i < 17) { |
| c = b.create<mlir::stablehlo::SelectOp>( |
| loc, lt16, c, getConstantLike(b, loc, wGreaterThan16Constants[i], x)); |
| } |
| return c; |
| }; |
| |
| Value sqrtW = b.create<mlir::stablehlo::SqrtOp>(loc, w); |
| Value wMinus3125 = b.create<mlir::stablehlo::SubtractOp>( |
| loc, w, getConstantLike(b, loc, 3.125, x)); |
| Value select2 = b.create<mlir::stablehlo::SelectOp>( |
| loc, lt16, getConstantLike(b, loc, 3.25, w), |
| getConstantLike(b, loc, 5.0, w)); |
| Value select2Result = |
| b.create<mlir::stablehlo::SubtractOp>(loc, sqrtW, select2); |
| w = b.create<mlir::stablehlo::SelectOp>(loc, lt625, wMinus3125, |
| select2Result); |
| |
| Value p = coefficient(0); |
| for (int i = 1; i < 17; ++i) { |
| p = b.create<mlir::stablehlo::AddOp>( |
| loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)); |
| } |
| for (int i = 17; i < 19; ++i) { |
| p = b.create<mlir::stablehlo::SelectOp>( |
| loc, lt16, |
| b.create<mlir::stablehlo::AddOp>( |
| loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)), |
| p); |
| } |
| for (int i = 19; i < 23; ++i) { |
| p = b.create<mlir::stablehlo::SelectOp>( |
| loc, lt625, |
| b.create<mlir::stablehlo::AddOp>( |
| loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)), |
| p); |
| } |
| |
| // Result modulo edge cases. |
| Value result = b.create<mlir::stablehlo::MulOp>(loc, p, x); |
| |
| // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is |
| // indeterminate, and can give nan or -/+inf.) |
| return b.create<mlir::stablehlo::SelectOp>( |
| loc, |
| b.create<mlir::stablehlo::CompareOp>( |
| loc, b.create<mlir::stablehlo::AbsOp>(loc, x), |
| getConstantLike(b, loc, 1, x), |
| mlir::stablehlo::ComparisonDirection::EQ), |
| b.create<mlir::stablehlo::MulOp>( |
| loc, x, getConstantLikeInfValue(b, loc, x, false)), |
| result); |
| } |
| |
| struct ConvertErfInvOp final : OpConversionPattern<mlir::chlo::ErfInvOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::ErfInvOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| if (op.getResult().getType().getElementType().isF64()) { |
| rewriter.replaceOp(op, erfInv64(rewriter, loc, adaptor.getOperands())); |
| return success(); |
| } |
| FloatType minPrecisionTy = rewriter.getF32Type(); |
| rewriter.replaceOp(op, materializeWithUpcast(rewriter, loc, |
| adaptor.getOperands(), |
| minPrecisionTy, &erfInv32)); |
| return success(); |
| } |
| }; |
| |
| // Coefficients for the Lanczos approximation of the gamma function. The |
| // coefficients are uniquely determined by the choice of g and n (kLanczosGamma |
| // and kLanczosCoefficients.size() + 1). The coefficients below correspond to |
| // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and |
| // [7, 9] seemed to be the least sensitive to the quality of the log function. |
| // In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 |
| // for a particularly inaccurate log function. |
| constexpr double kLanczosGamma = 7; // aka g |
| constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; |
| constexpr std::array<double, 8> kLanczosCoefficients = { |
| 676.520368121885098567009190444019, -1259.13921672240287047156078755283, |
| 771.3234287776530788486528258894, -176.61502916214059906584551354, |
| 12.507343278686904814458936853, -0.13857109526572011689554707, |
| 9.984369578019570859563e-6, 1.50563273514931155834e-7}; |
| |
| // Compute the Lgamma function using Lanczos' approximation from "A Precision |
| // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis |
| // series B. Vol. 1: |
| // lgamma(z + 1) = (log(2) + log(pi)) / 2 |
| // + (z + 1/2) * log(t(z)) |
| // - t(z) + log(a(z)) |
| // with t(z) = z + kLanczosGamma + 1/2 |
| // a(z) = kBaseLanczosCoeff |
| // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) |
| static Value materializeLgamma(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| // If the input is less than 0.5 use Euler's reflection formula. |
| // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) |
| // Let z be |
| // z = -x if x < 1/2 |
| // z = x - 1 otheriwse |
| Value x = args.front(); |
| Value half = getConstantLike(rewriter, loc, 0.5, x); |
| Value needToReflect = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, half, mlir::stablehlo::ComparisonDirection::LT); |
| Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x); |
| Value one = getConstantLike(rewriter, loc, 1, x); |
| Value xSubOne = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, one); |
| Value z = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, negX, |
| xSubOne); |
| |
| // Materialize |
| // a(z) = kBaseLanczosCoeff |
| // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) |
| Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); |
| for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { |
| Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); |
| Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x); |
| Value quotient = rewriter.create<mlir::stablehlo::DivOp>( |
| loc, coeff, |
| rewriter.create<mlir::stablehlo::AddOp>(loc, z, oneBasedIndex)); |
| a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, quotient); |
| } |
| |
| // To improve accuracy on platforms with less-precise log implementations, |
| // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the |
| // device. |
| // Materialize as |
| // log(t) = log(kLanczosGamma + 1/2 + z) |
| // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). |
| Value lanczosPlusHalf = |
| getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); |
| Value t = rewriter.create<mlir::stablehlo::AddOp>(loc, lanczosPlusHalf, z); |
| Value logTerm = |
| getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); |
| Value log1pTerm = rewriter.create<mlir::stablehlo::Log1pOp>( |
| loc, rewriter.create<mlir::stablehlo::DivOp>(loc, z, lanczosPlusHalf)); |
| Value logT = rewriter.create<mlir::stablehlo::AddOp>(loc, logTerm, log1pTerm); |
| |
| // Note that t(z) may be large and we need to be careful not to overflow to |
| // infinity in the relevant term |
| // r = (z + 1/2) * log(t(z)) - t(z). |
| // Therefore, we compute this as |
| // r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)). |
| Value tDivLogT = rewriter.create<mlir::stablehlo::DivOp>(loc, t, logT); |
| Value sum = rewriter.create<mlir::stablehlo::SubtractOp>( |
| loc, rewriter.create<mlir::stablehlo::AddOp>(loc, z, half), tDivLogT); |
| Value r = rewriter.create<mlir::stablehlo::MulOp>(loc, sum, logT); |
| |
| // Compute the final result (modulo reflection) as |
| // lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)). |
| Value logA = rewriter.create<mlir::stablehlo::LogOp>(loc, a); |
| Value lgamma = rewriter.create<mlir::stablehlo::AddOp>( |
| loc, |
| rewriter.create<mlir::stablehlo::AddOp>( |
| loc, |
| getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x), |
| r), |
| logA); |
| |
| // Compute the reflected value for x < 0.5 as |
| // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). |
| // |
| // The abs is needed because lgamma is the log of the absolute value of the |
| // gamma function. |
| // |
| // We have to be careful when computing the final term above. gamma(x) goes |
| // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x) |
| // term. The slope is large, so precision is particularly important. |
| // |
| // Because abs(sin(pi * x)) has period of 1 we can equivalently use |
| // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is |
| // more numerically accurate: It doesn't overflow to inf like pi * x would and |
| // if x is an integer it evaluates to exactly 0 which is important because we |
| // then take the log of this value, and log(0) is inf. |
| // |
| // We don't have a frac(x) primitive in HLO and computing it is tricky, but |
| // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our |
| // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). |
| // |
| // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close |
| // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain |
| // [0, 1] is symmetric across the line Y=0.5. |
| // |
| |
| // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of |
| // pi * abs_frac for values of abs_frac close to 1. |
| Value abs = rewriter.create<mlir::stablehlo::AbsOp>(loc, x); |
| Value absFrac = rewriter.create<mlir::stablehlo::SubtractOp>( |
| loc, abs, rewriter.create<mlir::stablehlo::FloorOp>(loc, abs)); |
| Value reduceAbsFrac = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, half, absFrac, mlir::stablehlo::ComparisonDirection::LT); |
| absFrac = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, reduceAbsFrac, |
| rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, absFrac), absFrac); |
| |
| // Materialize reflection. |
| Value reflectionDenom = rewriter.create<mlir::stablehlo::LogOp>( |
| loc, |
| rewriter.create<mlir::stablehlo::SineOp>( |
| loc, rewriter.create<mlir::stablehlo::MulOp>( |
| loc, getConstantLike(rewriter, loc, M_PI, x), absFrac))); |
| Value lgammaReflection = rewriter.create<mlir::stablehlo::SubtractOp>( |
| loc, |
| rewriter.create<mlir::stablehlo::SubtractOp>( |
| loc, getConstantLike(rewriter, loc, std::log(M_PI), x), |
| reflectionDenom), |
| lgamma); |
| |
| // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, |
| // then it "wins" and the result is +/-inf. |
| Value finiteReflectionDenom = |
| rewriter.create<mlir::stablehlo::IsFiniteOp>(loc, reflectionDenom); |
| Value negReflectionDenom = |
| rewriter.create<mlir::stablehlo::NegOp>(loc, reflectionDenom); |
| lgammaReflection = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom); |
| |
| // Select whether or not to rely on the reflection. |
| lgamma = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, |
| lgammaReflection, lgamma); |
| |
| // Materialize +/-inf behavior as |
| // lgamma(+/-inf) = +inf. |
| Value xIsInf = rewriter.create<chlo::IsInfOp>(loc, x); |
| return rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, xIsInf, |
| getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), lgamma); |
| } |
| |
| // Express `cosh` as |
| // cosh(x) = (e^x + e^-x) / 2 |
| // = e^(x + log(1/2)) + e^(-x + log(1/2)) |
| // |
| // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not. |
| // |
| // This incorrectly overflows to inf for two f32 input values, namely |
| // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The |
| // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so |
| // we deem this acceptable. |
| static Value materializeCoshApproximation(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange operands) { |
| mlir::chlo::CoshOp::Adaptor transformed(operands); |
| Value x = transformed.getOperand(); |
| |
| Value logOneHalf = rewriter.create<mlir::stablehlo::LogOp>( |
| loc, getConstantLike(rewriter, loc, 0.5, x)); |
| Value expAdd = rewriter.create<mlir::stablehlo::ExpOp>( |
| loc, rewriter.create<mlir::stablehlo::AddOp>(loc, x, logOneHalf)); |
| Value expSub = rewriter.create<mlir::stablehlo::ExpOp>( |
| loc, rewriter.create<mlir::stablehlo::SubtractOp>(loc, logOneHalf, x)); |
| return rewriter.create<mlir::stablehlo::AddOp>(loc, expAdd, expSub); |
| } |
| |
| struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::CoshOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOp( |
| op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), |
| rewriter.getF32Type(), |
| &materializeCoshApproximation)); |
| return success(); |
| } |
| }; |
| |
| // Compute the Digamma function using Lanczos' approximation from "A Precision |
| // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis |
| // series B. Vol. 1: |
| // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z) |
| // with t(z) = z + kLanczosGamma + 1/2 |
| // a(z) = kBaseLanczosCoeff |
| // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) |
| // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) |
| static Value materializeDigamma(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| // If the input is less than 0.5 use Euler's reflection formula. |
| // digamma(x) = digamma(1 - x) - pi * cot(pi * x) |
| // Let z be |
| // z = -x if x < 1/2 |
| // z = x - 1 otheriwse |
| Value x = args.front(); |
| Value half = getConstantLike(rewriter, loc, 0.5, x); |
| Value needToReflect = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, half, mlir::stablehlo::ComparisonDirection::LT); |
| Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x); |
| Value one = getConstantLike(rewriter, loc, 1, x); |
| Value xSubOne = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, one); |
| Value z = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, negX, |
| xSubOne); |
| |
| // Materialize |
| // a(z) = kBaseLanczosCoeff |
| // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) |
| // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) |
| Value zero = getConstantLike(rewriter, loc, 0.0, x); |
| Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); |
| Value aPrime = zero; |
| for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { |
| Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); |
| Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x); |
| Value zTerm = |
| rewriter.create<mlir::stablehlo::AddOp>(loc, z, oneBasedIndex); |
| aPrime = rewriter.create<mlir::stablehlo::SubtractOp>( |
| loc, aPrime, |
| rewriter.create<mlir::stablehlo::DivOp>( |
| loc, coeff, |
| rewriter.create<mlir::stablehlo::MulOp>(loc, zTerm, zTerm))); |
| a = rewriter.create<mlir::stablehlo::AddOp>( |
| loc, a, rewriter.create<mlir::stablehlo::DivOp>(loc, coeff, zTerm)); |
| } |
| |
| // To improve accuracy on platforms with less-precise log implementations, |
| // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the |
| // device. |
| // Materialize as |
| // log(t) = log(kLanczosGamma + 1/2 + z) |
| // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). |
| Value lanczosPlusHalf = |
| getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); |
| Value t = rewriter.create<mlir::stablehlo::AddOp>(loc, lanczosPlusHalf, z); |
| Value logTerm = |
| getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); |
| Value log1pTerm = rewriter.create<mlir::stablehlo::Log1pOp>( |
| loc, rewriter.create<mlir::stablehlo::DivOp>(loc, z, lanczosPlusHalf)); |
| Value logT = rewriter.create<mlir::stablehlo::AddOp>(loc, logTerm, log1pTerm); |
| |
| // Materialize the final result (modulo reflection) as |
| // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z). |
| Value aPrimeDivA = rewriter.create<mlir::stablehlo::DivOp>(loc, aPrime, a); |
| Value lanczosGammaDivT = rewriter.create<mlir::stablehlo::DivOp>( |
| loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t); |
| Value digamma = rewriter.create<mlir::stablehlo::SubtractOp>( |
| loc, rewriter.create<mlir::stablehlo::AddOp>(loc, logT, aPrimeDivA), |
| lanczosGammaDivT); |
| |
| // We need to be careful how we compute cot(pi * input) below: For |
| // near-integral arguments, pi * input can lose precision. |
| // |
| // Input is already known to be less than 0.5 (otherwise we don't have to |
| // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to |
| // increase precision of pi * x and the resulting cotangent. |
| Value reducedX = rewriter.create<mlir::stablehlo::AddOp>( |
| loc, x, |
| rewriter.create<mlir::stablehlo::AbsOp>( |
| loc, rewriter.create<mlir::stablehlo::FloorOp>( |
| loc, rewriter.create<mlir::stablehlo::AddOp>( |
| loc, x, getConstantLike(rewriter, loc, 0.5, x))))); |
| |
| // Materialize reflection for inputs less than 0.5 as |
| // digamma(x) = digamma(1 - x) - pi * cot(pi * x) |
| // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x) |
| Value pi = getConstantLike(rewriter, loc, M_PI, x); |
| Value piMulReducedX = |
| rewriter.create<mlir::stablehlo::MulOp>(loc, pi, reducedX); |
| Value cos = rewriter.create<mlir::stablehlo::CosineOp>(loc, piMulReducedX); |
| Value sin = rewriter.create<mlir::stablehlo::SineOp>(loc, piMulReducedX); |
| Value reflection = rewriter.create<mlir::stablehlo::SubtractOp>( |
| loc, digamma, |
| rewriter.create<mlir::stablehlo::DivOp>( |
| loc, rewriter.create<mlir::stablehlo::MulOp>(loc, pi, cos), sin)); |
| |
| // Select whether or not to rely on the reflection. |
| digamma = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, |
| reflection, digamma); |
| |
| // Digamma has poles at negative integers and zero; return nan for those. |
| Value isLeZero = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, zero, mlir::stablehlo::ComparisonDirection::LE); |
| Value isInt = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x), |
| mlir::stablehlo::ComparisonDirection::EQ); |
| Value isPole = rewriter.create<mlir::stablehlo::AndOp>(loc, isLeZero, isInt); |
| return rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, isPole, |
| getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(), |
| x), |
| digamma); |
| } |
| |
| static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc, |
| Value val) { |
| auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType())); |
| return getConstantLike( |
| b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val); |
| } |
| |
| static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc, |
| ValueRange args) { |
| // Code should match StableHLO's materializeZeta |
| assert(args.size() == 2); |
| Value x = args[0]; |
| Value q = args[1]; |
| static const std::array<double, 12> kZetaCoeffs{ |
| -7.1661652561756670113e18, |
| 1.8152105401943546773e17, |
| -4.5979787224074726105e15, |
| 1.1646782814350067249e14, |
| -2.950130727918164224e12, |
| 7.47242496e10, |
| -1.8924375803183791606e9, |
| 47900160.0, |
| -1209600.0, |
| 30240.0, |
| -720.0, |
| 12.0, |
| }; |
| |
| // For speed we'll always use 9 iterations for the initial series estimate, |
| // and a 12 term expansion for the Euler-Maclaurin formula. |
| Value a = q; |
| Value zero = getConstantLike(rewriter, loc, 0.0, a); |
| Value negPower = zero; |
| Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x); |
| Value initialSum = rewriter.create<mlir::stablehlo::PowOp>(loc, q, negX); |
| Value one = getConstantLike(rewriter, loc, 1.0, a); |
| for (int i = 0; i < 9; ++i) { |
| a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, one); |
| negPower = rewriter.create<mlir::stablehlo::PowOp>(loc, a, negX); |
| initialSum = |
| rewriter.create<mlir::stablehlo::AddOp>(loc, initialSum, negPower); |
| } |
| |
| a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, one); |
| negPower = rewriter.create<mlir::stablehlo::PowOp>(loc, a, negX); |
| Value oneLikeX = getConstantLike(rewriter, loc, 1.0, x); |
| Value xMinusOne = |
| rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, oneLikeX); |
| Value negPowerMulA = |
| rewriter.create<mlir::stablehlo::MulOp>(loc, negPower, a); |
| Value negPowerMulADivXMinusOne = |
| rewriter.create<mlir::stablehlo::DivOp>(loc, negPowerMulA, xMinusOne); |
| Value s = rewriter.create<mlir::stablehlo::AddOp>(loc, initialSum, |
| negPowerMulADivXMinusOne); |
| Value aInverseSquare = rewriter.create<mlir::stablehlo::DivOp>( |
| loc, one, rewriter.create<mlir::stablehlo::MulOp>(loc, a, a)); |
| |
| Value hornerSum = zero; |
| Value factor = one; |
| // Use Horner's rule for this. |
| // Note this differs from Cephes which does a 'naive' polynomial evaluation. |
| // Using Horner's rule allows to avoid some NaN's and Infs from happening, |
| // resulting in more numerically stable code. |
| for (int i = 0; i < 11; ++i) { |
| Value factorLhs = rewriter.create<mlir::stablehlo::AddOp>( |
| loc, x, getConstantLike(rewriter, loc, 22 - 2 * i, x)); |
| Value factorRhs = rewriter.create<mlir::stablehlo::AddOp>( |
| loc, x, getConstantLike(rewriter, loc, 21 - 2 * i, x)); |
| factor = rewriter.create<mlir::stablehlo::MulOp>(loc, factorLhs, factorRhs); |
| hornerSum = rewriter.create<mlir::stablehlo::MulOp>( |
| loc, factor, |
| rewriter.create<mlir::stablehlo::MulOp>( |
| loc, aInverseSquare, |
| rewriter.create<mlir::stablehlo::AddOp>( |
| loc, hornerSum, |
| getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a)))); |
| } |
| Value zeroPointFiveLikeNegPower = |
| getConstantLike(rewriter, loc, .5, negPower); |
| Value xDivA = rewriter.create<mlir::stablehlo::DivOp>(loc, x, a); |
| s = rewriter.create<mlir::stablehlo::AddOp>( |
| loc, s, |
| rewriter.create<mlir::stablehlo::MulOp>( |
| loc, negPower, |
| rewriter.create<mlir::stablehlo::AddOp>( |
| loc, zeroPointFiveLikeNegPower, |
| rewriter.create<mlir::stablehlo::MulOp>( |
| loc, xDivA, |
| rewriter.create<mlir::stablehlo::AddOp>( |
| loc, |
| getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], a), |
| hornerSum))))); |
| |
| // Use the initial zeta sum without the correction term coming |
| // from Euler-Maclaurin if it is accurate enough. |
| Value absNegPower = rewriter.create<mlir::stablehlo::AbsOp>(loc, negPower); |
| Value absInitialSum = |
| rewriter.create<mlir::stablehlo::AbsOp>(loc, initialSum); |
| Value output = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, |
| rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, absNegPower, |
| rewriter.create<mlir::stablehlo::MulOp>( |
| loc, absInitialSum, |
| getConstantLikeSmallestFiniteValue(rewriter, loc, a)), |
| mlir::stablehlo::ComparisonDirection::LT), |
| initialSum, s); |
| |
| // Function is not defined for x < 1. |
| Value nan = getConstantLike(rewriter, loc, |
| std::numeric_limits<double>::quiet_NaN(), x); |
| output = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, |
| rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, oneLikeX, mlir::stablehlo::ComparisonDirection::LT), |
| nan, output); |
| |
| // For q <= 0, x must be an integer. |
| Value qLeZero = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, q, zero, mlir::stablehlo::ComparisonDirection::LE); |
| Value xNotInt = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x), |
| mlir::stablehlo::ComparisonDirection::NE); |
| Value xDomainError = |
| rewriter.create<mlir::stablehlo::AndOp>(loc, qLeZero, xNotInt); |
| output = rewriter.create<mlir::stablehlo::SelectOp>(loc, xDomainError, nan, |
| output); |
| |
| // For all integer q <= 0, zeta has a pole. The limit is only defined as |
| // +inf if x is and even integer. |
| Value inf = getConstantLike(rewriter, loc, |
| std::numeric_limits<double>::infinity(), x); |
| Value qIsInt = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, q, rewriter.create<mlir::stablehlo::FloorOp>(loc, q), |
| mlir::stablehlo::ComparisonDirection::EQ); |
| Value atPole = rewriter.create<mlir::stablehlo::AndOp>(loc, qLeZero, qIsInt); |
| Value two = getConstantLike(rewriter, loc, 2.0, x); |
| Value xIsInt = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x), |
| mlir::stablehlo::ComparisonDirection::EQ); |
| Value xIsEven = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, rewriter.create<mlir::stablehlo::RemOp>(loc, x, two), zero, |
| mlir::stablehlo::ComparisonDirection::EQ); |
| Value xIsEvenInt = |
| rewriter.create<mlir::stablehlo::AndOp>(loc, xIsInt, xIsEven); |
| output = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, atPole, |
| rewriter.create<mlir::stablehlo::SelectOp>(loc, xIsEvenInt, inf, nan), |
| output); |
| |
| // For x = 1, this is the harmonic series and diverges. |
| output = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, |
| rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, x, one, mlir::stablehlo::ComparisonDirection::EQ), |
| inf, output); |
| |
| return output; |
| } |
| |
| static Value materializePolygamma(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange args) { |
| mlir::chlo::PolygammaOp::Adaptor transformed(args); |
| Value n = transformed.getN(); |
| Value x = transformed.getX(); |
| |
| // Handle integer n > 0. |
| Value one = getConstantLike(rewriter, loc, 1.0, x); |
| Value two = getConstantLike(rewriter, loc, 2.0, x); |
| Value sign = rewriter.create<mlir::stablehlo::SubtractOp>( |
| loc, |
| rewriter.create<mlir::stablehlo::MulOp>( |
| loc, two, rewriter.create<mlir::stablehlo::RemOp>(loc, n, two)), |
| one); |
| Value nPlusOne = rewriter.create<mlir::stablehlo::AddOp>(loc, n, one); |
| Value expLgammaNp1 = rewriter.create<mlir::stablehlo::ExpOp>( |
| loc, rewriter.create<chlo::LgammaOp>(loc, nPlusOne)); |
| Value zeta = rewriter.create<chlo::ZetaOp>(loc, nPlusOne, x); |
| Value result = rewriter.create<mlir::stablehlo::MulOp>( |
| loc, rewriter.create<mlir::stablehlo::MulOp>(loc, sign, expLgammaNp1), |
| zeta); |
| |
| // Handle n = 0. |
| Value zero = getConstantLike(rewriter, loc, 0.0, x); |
| Value nEqZero = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, n, zero, mlir::stablehlo::ComparisonDirection::EQ); |
| result = rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, nEqZero, rewriter.create<chlo::DigammaOp>(loc, x), result); |
| |
| // Check that n is a natural number. Return nan, otherwise. |
| Value nonInt = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, n, rewriter.create<mlir::stablehlo::FloorOp>(loc, n), |
| mlir::stablehlo::ComparisonDirection::NE); |
| Value negative = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, n, zero, mlir::stablehlo::ComparisonDirection::LT); |
| Value nonNatural = |
| rewriter.create<mlir::stablehlo::OrOp>(loc, nonInt, negative); |
| return rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, nonNatural, |
| getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(), |
| x), |
| result); |
| } |
| |
| struct ConvertLgammaOp final : OpConversionPattern<mlir::chlo::LgammaOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::LgammaOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| FloatType minPrecisionTy = rewriter.getF32Type(); |
| rewriter.replaceOp( |
| op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), |
| minPrecisionTy, &materializeLgamma)); |
| return success(); |
| } |
| }; |
| |
| struct ConvertDigammaOp final : OpConversionPattern<mlir::chlo::DigammaOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::DigammaOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| FloatType minPrecisionTy = rewriter.getF32Type(); |
| rewriter.replaceOp( |
| op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), |
| minPrecisionTy, &materializeDigamma)); |
| return success(); |
| } |
| }; |
| |
| static Value materializeNextAfter(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange operands) { |
| mlir::chlo::NextAfterOp::Adaptor transformed(operands); |
| Value x = transformed.getX(); |
| Value y = transformed.getY(); |
| auto resultTy = cast<ShapedType>(x.getType()); |
| auto bitwidth = resultTy.getElementType().getIntOrFloatBitWidth(); |
| mlir::ImplicitLocOpBuilder b(loc, rewriter); |
| Type intTy = resultTy.clone(b.getIntegerType(bitwidth)); |
| auto xAsInt = b.create<mlir::stablehlo::BitcastConvertOp>(intTy, x); |
| auto yAsInt = b.create<mlir::stablehlo::BitcastConvertOp>(intTy, y); |
| |
| // The result is NaN if either "x" or "y" are NaN. |
| auto xIsNan = b.create<mlir::stablehlo::CompareOp>( |
| x, x, mlir::stablehlo::ComparisonDirection::NE); |
| auto yIsNan = b.create<mlir::stablehlo::CompareOp>( |
| y, y, mlir::stablehlo::ComparisonDirection::NE); |
| auto nanInput = b.create<mlir::stablehlo::OrOp>(xIsNan, yIsNan); |
| auto resultForNan = getConstantLike( |
| rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x); |
| auto resultForNanAsInt = |
| b.create<mlir::stablehlo::BitcastConvertOp>(intTy, resultForNan); |
| |
| // The sign bit is the MSB. |
| const int64_t signBit = int64_t{1} << (bitwidth - 1); |
| // Discard the sign bit to make the result non-negative. |
| Value signMask = getConstantLike(rewriter, loc, signBit, xAsInt); |
| Value negatedSignMask = getConstantLike(rewriter, loc, ~signBit, xAsInt); |
| auto xAbs = b.create<mlir::stablehlo::AndOp>(xAsInt, negatedSignMask); |
| auto yAbs = b.create<mlir::stablehlo::AndOp>(yAsInt, negatedSignMask); |
| |
| // When both "x" and "y" are equal, the result is "y". |
| auto xAndYAreEqual = b.create<mlir::stablehlo::CompareOp>( |
| x, y, mlir::stablehlo::ComparisonDirection::EQ); |
| auto resultForEqual = yAsInt; |
| |
| // When both "x" and "y" are 0, the result is "y". This is a separate case |
| // from above because "x" and "y" might have a different sign. |
| Value zero = getConstantLike(rewriter, loc, 0, xAsInt); |
| auto xIsZero = b.create<mlir::stablehlo::CompareOp>( |
| xAbs, zero, mlir::stablehlo::ComparisonDirection::EQ); |
| auto yIsZero = b.create<mlir::stablehlo::CompareOp>( |
| yAbs, zero, mlir::stablehlo::ComparisonDirection::EQ); |
| auto resultForBothZero = yAsInt; |
| |
| auto xSign = b.create<mlir::stablehlo::AndOp>(xAsInt, signMask); |
| auto ySign = b.create<mlir::stablehlo::AndOp>(yAsInt, signMask); |
| |
| // If from == 0 && to != 0, we need to return the smallest subnormal number |
| // signed like "to". |
| Value one = getConstantLike(rewriter, loc, 1, xAsInt); |
| auto resultForXZeroYNonZero = b.create<mlir::stablehlo::OrOp>(ySign, one); |
| |
| // If the sign of "x" and "y" disagree: |
| // - we need to make the magnitude of "from" smaller so that it is closer to |
| // zero. |
| // |
| // Otherwise the signs agree: |
| // - "x" with a magnitude larger than "y" means we need to make the magnitude |
| // smaller. |
| // - "x" with a magnitude smaller than "y" means we need to make the magnitude |
| // larger. |
| auto signsDisagree = b.create<mlir::stablehlo::CompareOp>( |
| xSign, ySign, mlir::stablehlo::ComparisonDirection::NE); |
| auto xMagnitudeLargerThanY = b.create<mlir::stablehlo::CompareOp>( |
| xAbs, yAbs, mlir::stablehlo::ComparisonDirection::GT); |
| auto resultHasSmallerMagnitude = |
| b.create<mlir::stablehlo::OrOp>(xMagnitudeLargerThanY, signsDisagree); |
| auto minusOne = getConstantLike(rewriter, loc, -1, xAsInt); |
| auto magnitudeAdjustment = b.create<mlir::stablehlo::SelectOp>( |
| resultHasSmallerMagnitude, minusOne, one); |
| Value result = b.create<mlir::stablehlo::AddOp>(xAsInt, magnitudeAdjustment); |
| // Handle from == +-0. |
| result = b.create<mlir::stablehlo::SelectOp>( |
| xIsZero, |
| b.create<mlir::stablehlo::SelectOp>(yIsZero, resultForBothZero, |
| resultForXZeroYNonZero), |
| result); |
| // Handle from == to. |
| result = b.create<mlir::stablehlo::SelectOp>(xAndYAreEqual, resultForEqual, |
| result); |
| // Handle isnan(x) || isnan(y). |
| result = |
| b.create<mlir::stablehlo::SelectOp>(nanInput, resultForNanAsInt, result); |
| |
| // Cast back to the original type. |
| return b.create<mlir::stablehlo::BitcastConvertOp>(resultTy, result); |
| } |
| |
| struct ConvertNextAfterOp final : OpConversionPattern<mlir::chlo::NextAfterOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::NextAfterOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOp( |
| op, materializeNextAfter(rewriter, op.getLoc(), adaptor.getOperands())); |
| return success(); |
| } |
| }; |
| |
| struct ConvertPolygammaOp final : OpConversionPattern<mlir::chlo::PolygammaOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::PolygammaOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| FloatType minPrecisionTy = rewriter.getF32Type(); |
| rewriter.replaceOp( |
| op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), |
| minPrecisionTy, materializePolygamma)); |
| return success(); |
| } |
| }; |
| |
| // Sinh(x) = (e^x - e^-x) / 2 |
| // = e^(x + log(1/2)) - e^(-x + log(1/2)). |
| // |
| // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not |
| // inf. |
| // |
| // This incorrectly overflows to +/-inf for two f32 input values, namely |
| // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The |
| // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so |
| // we deem this acceptable. |
| static Value |
| materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange operands) { |
| mlir::chlo::SinhOp::Adaptor transformed(operands); |
| Value x = transformed.getOperand(); |
| |
| Value logOneHalf = rewriter.create<mlir::stablehlo::LogOp>( |
| loc, getConstantLike(rewriter, loc, 0.5, x)); |
| Value expAdd = rewriter.create<mlir::stablehlo::ExpOp>( |
| loc, rewriter.create<mlir::stablehlo::AddOp>(loc, x, logOneHalf)); |
| Value expSub = rewriter.create<mlir::stablehlo::ExpOp>( |
| loc, rewriter.create<mlir::stablehlo::SubtractOp>(loc, logOneHalf, x)); |
| return rewriter.create<mlir::stablehlo::SubtractOp>(loc, expAdd, expSub); |
| } |
| |
| // Express `sinh` as |
| // sinh(x) = (e^x - e^-x) / 2 if |x| < 1 |
| // = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. |
| static Value materializeSinhApproximation(ConversionPatternRewriter &rewriter, |
| Location loc, ValueRange operands) { |
| Value largeSinhResult = |
| materializeSinhApproximationForLargeX(rewriter, loc, operands); |
| |
| mlir::chlo::SinhOp::Adaptor transformed(operands); |
| Value x = transformed.getOperand(); |
| |
| // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in |
| // 0. |
| // Rewrite this to avoid that. We use expm1(x) because that preserves the |
| // first order term of the taylor series of e^x. |
| // (e^(x) - e^(-x)) / 2. = |
| // (e^(x) - 1 + 1 - e^(-x)) / 2. |
| // (expm1(x) + (e^(x) - 1) / e^x) / 2. |
| // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2. |
| Value expm1 = rewriter.create<mlir::stablehlo::Expm1Op>(loc, x); |
| Value one = getConstantLike(rewriter, loc, 1.0, x); |
| Value oneHalf = getConstantLike(rewriter, loc, 0.5, x); |
| Value expm1PlusOne = rewriter.create<mlir::stablehlo::AddOp>(loc, expm1, one); |
| Value ratio = |
| rewriter.create<mlir::stablehlo::DivOp>(loc, expm1, expm1PlusOne); |
| Value sum = rewriter.create<mlir::stablehlo::AddOp>(loc, expm1, ratio); |
| Value smallSinhResult = |
| rewriter.create<mlir::stablehlo::MulOp>(loc, oneHalf, sum); |
| |
| Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x); |
| Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>( |
| loc, absX, one, mlir::stablehlo::ComparisonDirection::LT); |
| return rewriter.create<mlir::stablehlo::SelectOp>( |
| loc, absXLtOne, smallSinhResult, largeSinhResult); |
| } |
| |
| struct ConvertSinhOp final : OpConversionPattern<mlir::chlo::SinhOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::SinhOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Value x = adaptor.getOperand(); |
| if (isa<ComplexType>(cast<ShapedType>(x.getType()).getElementType())) { |
| rewriter.replaceOp(op, materializeSinhApproximationForLargeX( |
| rewriter, op.getLoc(), adaptor.getOperands())); |
| return success(); |
| } |
| rewriter.replaceOp( |
| op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), |
| rewriter.getF32Type(), |
| &materializeSinhApproximation)); |
| return success(); |
| } |
| }; |
| |
| // Converts chlo.top_k to HLO iota, sort, and slice ops. |
| // |
| // chlo.top_k sorts along last dimension of the input tensor and then returns |
| // the top K components' values and indices. This is translated into a few |
| // ops in HLO: first generating an integer sequence for the indices, |
| // then sort both the original input tensor and the indices together, and |
| // at last slice out the top K components. |
| // |
| // For example, for the following IR: |
| // |
| // %0:2 = "chlo.top_k"(%input, k=8): tensor<16x16xf32> -> |
| // (tensor<16x8xf32>, tensor<16x8xi32>) |
| // |
| // We will get: |
| // |
| // %1 = "hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32> |
| // %2 = "hlo.sort"(%input, %1) ({ |
| // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, |
| // %arg3: tensor<i32>, %arg4: tensor<i32>): |
| // %7 = "hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ... |
| // "hlo.return"(%7) : (tensor<i1>) -> () |
| // }) {dimension = 1 : i64, is_stable = true} : ... |
| // %3 = "hlo.get_tuple_element"(%2) {index = 0 : i32} : ... |
| // %4 = "hlo.get_tuple_element"(%2) {index = 1 : i32} : ... |
| // %5 = "hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>, |
| // start_indices dense<0> : tensor<2xi64>, |
| // strides = dense<1> : tensor<2xi64>} : |
| // (tensor<16x16xf32>) -> tensor<16x8xf32> |
| // %6 = "hlo.slice"(%4) ... |
| // |
| // TODO(b/284078162): Decide what to do with this pattern given that we now |
| // have mlir::stablehlo::TopKOp. No action needed for now given that |
| // mlir::stablehlo::TopKOp is currently categorized as |
| // `hasPrivateFeaturesNotInStablehlo`. |
| struct ConvertTopKOp final : OpConversionPattern<mlir::chlo::TopKOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::TopKOp op, OpAdaptor /*adaptor*/, |
| ConversionPatternRewriter &rewriter) const override { |
| auto operandType = dyn_cast<RankedTensorType>(op.getOperand().getType()); |
| if (!operandType) |
| return failure(); |
| int64_t operandRank = operandType.getRank(); |
| int64_t lastDimIndex = operandRank - 1; |
| int64_t lastDimSize = operandType.getDimSize(lastDimIndex); |
| int64_t lastDimResultSize = |
| mlir::hlo::isDynamicDimSize(lastDimSize) |
| ? static_cast<int64_t>(op.getK()) |
| : std::min(static_cast<int64_t>(op.getK()), lastDimSize); |
| int64_t isDynamic = !operandType.hasStaticShape(); |
| auto i32Type = rewriter.getIntegerType(32); |
| Value opShapeValue, resultShapeValue; |
| if (isDynamic) { |
| SmallVector<Value> sizesI32x1; |
| for (auto i = 0; i < operandType.getRank(); ++i) { |
| auto sizeI32 = rewriter.create<mlir::stablehlo::GetDimensionSizeOp>( |
| op.getLoc(), op.getOperand(), i); |
| auto sizeI32x1 = rewriter.create<mlir::stablehlo::ReshapeOp>( |
| op.getLoc(), RankedTensorType::get({1}, i32Type), sizeI32); |
| sizesI32x1.push_back(sizeI32x1); |
| } |
| opShapeValue = rewriter.create<mlir::stablehlo::ConcatenateOp>( |
| op.getLoc(), sizesI32x1, |
| /*dimension=*/0); |
| auto lastDimI32 = rewriter.create<mlir::stablehlo::ConstantOp>( |
| op.getLoc(), |
| rewriter.getI32IntegerAttr(static_cast<int32_t>(lastDimResultSize))); |
| auto lastDimI32x1 = rewriter.create<mlir::stablehlo::ReshapeOp>( |
| op.getLoc(), RankedTensorType::get({1}, i32Type), lastDimI32); |
| sizesI32x1.back() = lastDimI32x1; |
| resultShapeValue = rewriter.create<mlir::stablehlo::ConcatenateOp>( |
| op.getLoc(), sizesI32x1, |
| /*dimension=*/0); |
| } |
| |
| // Create an Iota op for indices. |
| Type iotaType = RankedTensorType::get(operandType.getShape(), i32Type); |
| Value iotaOp; |
| if (isDynamic) { |
| iotaOp = rewriter.create<mlir::stablehlo::DynamicIotaOp>( |
| op.getLoc(), iotaType, opShapeValue, |
| rewriter.getI64IntegerAttr(lastDimIndex)); |
| } else { |
| iotaOp = rewriter.create<mlir::stablehlo::IotaOp>( |
| op.getLoc(), iotaType, rewriter.getI64IntegerAttr(lastDimIndex)); |
| } |
| |
| // Create the sort op. It takes two inputs, one for the original input, the |
| // other for the indices. Use TOTALORDER comparison type instead of the |
| // default comparison if the element type is of type float. |
| Type elementType = operandType.getElementType(); |
| mlir::stablehlo::SortOp sortOp = |
| createSortOp(&rewriter, op.getLoc(), {op.getOperand(), iotaOp}, |
| {elementType, i32Type}, lastDimIndex, |
| /*isStable=*/true, |
| /*direction=*/mlir::stablehlo::ComparisonDirection::GT); |
| |
| // Get the sorted input and index tuple element. |
| Value tupleFirstElement = sortOp.getResult(0); |
| Value tupleSecondElement = sortOp.getResult(1); |
| |
| SmallVector<int64_t> beginIndices(operandRank, 0); |
| auto endIndices = llvm::to_vector(operandType.getShape()); |
| endIndices.back() = lastDimResultSize; |
| SmallVector<int64_t> strides(operandRank, 1); |
| |
| // Get the slice for the top K elements. |
| auto indicesTy = RankedTensorType::get(operandRank, rewriter.getI64Type()); |
| Value values, indices; |
| if (isDynamic) { |
| Value startIndices = rewriter.create<mlir::stablehlo::ConstantOp>( |
| op.getLoc(), DenseIntElementsAttr::get(indicesTy, beginIndices)); |
| Value lastIndices = rewriter.create<mlir::stablehlo::ConvertOp>( |
| op.getLoc(), resultShapeValue, rewriter.getI64Type()); |
| Value stridesOp = rewriter.create<mlir::stablehlo::ConstantOp>( |
| op.getLoc(), DenseIntElementsAttr::get(indicesTy, strides)); |
| |
| SmallVector<int64_t> resultShape = |
| llvm::to_vector(operandType.getShape()); |
| resultShape.back() = lastDimResultSize; |
| RankedTensorType resultType = RankedTensorType::get( |
| resultShape, elementType, operandType.getEncoding()); |
| RankedTensorType indexResultType = |
| RankedTensorType::get(resultShape, i32Type); |
| |
| values = rewriter.create<mlir::stablehlo::RealDynamicSliceOp>( |
| op.getLoc(), resultType, tupleFirstElement, startIndices, lastIndices, |
| stridesOp); |
| indices = rewriter.create<mlir::stablehlo::RealDynamicSliceOp>( |
| op.getLoc(), indexResultType, tupleSecondElement, startIndices, |
| lastIndices, stridesOp); |
| } else { |
| values = rewriter.create<mlir::stablehlo::SliceOp>( |
| op.getLoc(), tupleFirstElement, |
| rewriter.getDenseI64ArrayAttr(beginIndices), |
| rewriter.getDenseI64ArrayAttr(endIndices), |
| rewriter.getDenseI64ArrayAttr(strides)); |
| indices = rewriter.create<mlir::stablehlo::SliceOp>( |
| op.getLoc(), tupleSecondElement, |
| rewriter.getDenseI64ArrayAttr(beginIndices), |
| rewriter.getDenseI64ArrayAttr(endIndices), |
| rewriter.getDenseI64ArrayAttr(strides)); |
| } |
| |
| rewriter.replaceOp(op, {values, indices}); |
| return success(); |
| } |
| }; |
| |
| struct ConvertZetaOp final : OpConversionPattern<mlir::chlo::ZetaOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::chlo::ZetaOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| FloatType minPrecisionTy = rewriter.getF32Type(); |
| rewriter.replaceOp( |
| op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(), |
| minPrecisionTy, &materializeZeta)); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Pass Definition. |
| //===----------------------------------------------------------------------===// |
| |
| struct LegalizeChlo final : impl::LegalizeChloBase<LegalizeChlo> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<mlir::scf::SCFDialect, mlir::shape::ShapeDialect, |
| mlir::stablehlo::StablehloDialect, |
| mlir::tensor::TensorDialect>(); |
| } |
| |
| void runOnOperation() override { |
| MLIRContext *ctx = &getContext(); |
| { |
| ConversionTarget conversionTarget(getContext()); |
| RewritePatternSet conversionPatterns(ctx); |
| conversionTarget.addIllegalDialect<chlo::ChloDialect>(); |
| conversionTarget.addLegalDialect< |
| mlir::stablehlo::StablehloDialect, mlir::arith::ArithDialect, |
| mlir::shape::ShapeDialect, mlir::scf::SCFDialect, |
| mlir::tensor::TensorDialect>(); |
| |
| populateLegalizeChloPatterns(ctx, &conversionPatterns); |
| if (failed(applyPartialConversion(getOperation(), conversionTarget, |
| std::move(conversionPatterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| { |
| // Add canonicalization patterns to simplify produced ops from other |
| // dialects. |
| RewritePatternSet patterns(ctx); |
| populateCanonicalizationPatterns(ctx, &patterns); |
| mlir::shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx); |
| mlir::shape::ShapeOfOp::getCanonicalizationPatterns(patterns, ctx); |
| mlir::shape::BroadcastOp::getCanonicalizationPatterns(patterns, ctx); |
| mlir::shape::CstrBroadcastableOp::getCanonicalizationPatterns(patterns, |
| ctx); |
| mlir::tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); |
| if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| } |
| }; |
| } // namespace |
| |
| namespace { |
| #include "compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.h.inc" |
| } // end anonymous namespace |
| |
| namespace { |
| static void populateBroadcastingPatterns(MLIRContext *context, |
| RewritePatternSet *patterns) { |
| // Instantiate conversion templates for conforming binary elementwise ops |
| // that do not have different dtypes between operands and results and do |
| // not have special attributes that need to be preserved. |
| populateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>( |
| context, patterns, 10); |
| populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>( |
| context, patterns, 5); |
| patterns->add<ConvertConstantLikeOp, ConvertSelectOp>(context); |
| } |
| |
| static void populateDecompositionPatterns(MLIRContext *context, |
| RewritePatternSet *patterns) { |
| populateWithGenerated(*patterns); |
| patterns->add<ConvertConstantOp, ConvertBesselI1eOp, ConvertCoshOp, |
| ConvertDigammaOp, ConvertErfOp, ConvertErfcOp, ConvertErfInvOp, |
| ConvertLgammaOp, ConvertNextAfterOp, ConvertPolygammaOp, |
| ConvertSinhOp, ConvertTopKOp, ConvertZetaOp>(context); |
| } |
| } // namespace |
| |
| void populateLegalizeChloPatterns(MLIRContext *context, |
| RewritePatternSet *patterns) { |
| populateBroadcastingPatterns(context, patterns); |
| populateDecompositionPatterns(context, patterns); |
| } |
| } // namespace mlir::iree_compiler::stablehlo |