// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// Implements logic for lowering CHLO ops to StableHLO and Shape dialect ops,
// taking care of CHLO's broadcasting semantics

#include "compiler/plugins/input/StableHLO/Conversion/Passes.h"
#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h"
#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/BroadcastUtils.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {

#define GEN_PASS_DEF_LEGALIZECHLO
#include "compiler/plugins/input/StableHLO/Conversion/Passes.h.inc"

namespace {

//===----------------------------------------------------------------------===//
// Helpers.
//===----------------------------------------------------------------------===//

template <typename FromOpTy, typename ToOpTy>
struct HloNaryElementwiseAdaptor {
  static ToOpTy createOp(FromOpTy fromOp, Type resultType,
                         ValueRange broadcastedOperands, OpBuilder &builder) {
    return builder.create<ToOpTy>(fromOp.getLoc(), resultType,
                                  broadcastedOperands);
  }
};

static std::optional<mlir::stablehlo::ComparisonDirection>
toStableHloComparisonDirection(mlir::chlo::ComparisonDirection value) {
  switch (value) {
  case mlir::chlo::ComparisonDirection::EQ:
    return mlir::stablehlo::ComparisonDirection::EQ;
  case mlir::chlo::ComparisonDirection::NE:
    return mlir::stablehlo::ComparisonDirection::NE;
  case mlir::chlo::ComparisonDirection::GE:
    return mlir::stablehlo::ComparisonDirection::GE;
  case mlir::chlo::ComparisonDirection::GT:
    return mlir::stablehlo::ComparisonDirection::GT;
  case mlir::chlo::ComparisonDirection::LE:
    return mlir::stablehlo::ComparisonDirection::LE;
  case mlir::chlo::ComparisonDirection::LT:
    return mlir::stablehlo::ComparisonDirection::LT;
  }
  return {};
}

static std::optional<mlir::stablehlo::ComparisonType>
toStableHloComparisonType(mlir::chlo::ComparisonType value) {
  switch (value) {
  case mlir::chlo::ComparisonType::NOTYPE:
    return mlir::stablehlo::ComparisonType::NOTYPE;
  case mlir::chlo::ComparisonType::FLOAT:
    return mlir::stablehlo::ComparisonType::FLOAT;
  case mlir::chlo::ComparisonType::TOTALORDER:
    return mlir::stablehlo::ComparisonType::TOTALORDER;
  case mlir::chlo::ComparisonType::SIGNED:
    return mlir::stablehlo::ComparisonType::SIGNED;
  case mlir::chlo::ComparisonType::UNSIGNED:
    return mlir::stablehlo::ComparisonType::UNSIGNED;
  }
  return {};
}

struct HloCompareAdaptor {
  static mlir::stablehlo::CompareOp
  createOp(mlir::chlo::BroadcastCompareOp fromOp, Type resultType,
           ValueRange broadcastedOperands, OpBuilder &builder) {
    auto chloDirection = fromOp.getComparisonDirection();
    auto hloDirection = toStableHloComparisonDirection(chloDirection);
    if (!hloDirection)
      return nullptr;
    auto chloType =
        fromOp.getCompareType().value_or(mlir::chlo::ComparisonType::NOTYPE);
    auto hloType = toStableHloComparisonType(chloType);
    if (!hloType)
      return nullptr;
    auto hloTypeAttr = fromOp.getCompareType()
                           ? mlir::stablehlo::ComparisonTypeAttr::get(
                                 builder.getContext(), *hloType)
                           : nullptr;
    return builder.create<mlir::stablehlo::CompareOp>(
        fromOp.getLoc(), resultType, broadcastedOperands[0],
        broadcastedOperands[1], *hloDirection, hloTypeAttr);
  }
};

// Populate a pattern for each Broadcasting Chlo op. This requires the pattern
// to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values.
template <template <typename, typename, typename> typename Pattern,
          typename... ConstructorArgs>
static void populateForBroadcastingBinaryOp(MLIRContext *context,
                                            RewritePatternSet *patterns,
                                            ConstructorArgs &&...args) {
#define POPULATE_BCAST(ChloOp, HloOp)                                          \
  patterns                                                                     \
      ->add<Pattern<ChloOp, HloOp, HloNaryElementwiseAdaptor<ChloOp, HloOp>>>( \
          context, args...);

  POPULATE_BCAST(mlir::chlo::BroadcastAddOp, mlir::stablehlo::AddOp);
  POPULATE_BCAST(mlir::chlo::BroadcastAndOp, mlir::stablehlo::AndOp);
  POPULATE_BCAST(mlir::chlo::BroadcastAtan2Op, mlir::stablehlo::Atan2Op);
  POPULATE_BCAST(mlir::chlo::BroadcastComplexOp, mlir::stablehlo::ComplexOp);
  POPULATE_BCAST(mlir::chlo::BroadcastDivOp, mlir::stablehlo::DivOp);
  POPULATE_BCAST(mlir::chlo::BroadcastMaxOp, mlir::stablehlo::MaxOp);
  POPULATE_BCAST(mlir::chlo::BroadcastMinOp, mlir::stablehlo::MinOp);
  POPULATE_BCAST(mlir::chlo::BroadcastMulOp, mlir::stablehlo::MulOp);
  POPULATE_BCAST(mlir::chlo::BroadcastNextAfterOp, mlir::chlo::NextAfterOp);
  POPULATE_BCAST(mlir::chlo::BroadcastOrOp, mlir::stablehlo::OrOp);
  POPULATE_BCAST(mlir::chlo::BroadcastPolygammaOp, mlir::chlo::PolygammaOp);
  POPULATE_BCAST(mlir::chlo::BroadcastPowOp, mlir::stablehlo::PowOp);
  POPULATE_BCAST(mlir::chlo::BroadcastRemOp, mlir::stablehlo::RemOp);
  POPULATE_BCAST(mlir::chlo::BroadcastShiftLeftOp,
                 mlir::stablehlo::ShiftLeftOp);
  POPULATE_BCAST(mlir::chlo::BroadcastShiftRightArithmeticOp,
                 mlir::stablehlo::ShiftRightArithmeticOp);
  POPULATE_BCAST(mlir::chlo::BroadcastShiftRightLogicalOp,
                 mlir::stablehlo::ShiftRightLogicalOp);
  POPULATE_BCAST(mlir::chlo::BroadcastSubOp, mlir::stablehlo::SubtractOp);
  POPULATE_BCAST(mlir::chlo::BroadcastXorOp, mlir::stablehlo::XorOp);
  POPULATE_BCAST(mlir::chlo::BroadcastZetaOp, mlir::chlo::ZetaOp);

#undef POPULATE_BCAST

  // Broadcasting ops requiring special construction.
  patterns->add<Pattern<mlir::chlo::BroadcastCompareOp,
                        mlir::stablehlo::CompareOp, HloCompareAdaptor>>(
      context, args...);
}

template <typename T>
static Value getConstantLike(OpBuilder &b, Location loc, T constant,
                             Value val) {
  Type ty = getElementTypeOrSelf(val.getType());
  auto getAttr = [&]() -> Attribute {
    if (isa<IntegerType>(ty))
      return b.getIntegerAttr(ty, constant);
    if (isa<FloatType>(ty))
      return b.getFloatAttr(ty, constant);
    if (auto complexTy = dyn_cast<ComplexType>(ty)) {
      return complex::NumberAttr::get(complexTy, constant, 0);
    }
    llvm_unreachable("unhandled element type");
  };
  return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
                                              val);
}

static Value getConstantLike(OpBuilder &b, Location loc,
                             const APFloat &constant, Value val) {
  Type ty = getElementTypeOrSelf(val.getType());
  return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
                                              val);
}

static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc,
                                           Value val) {
  auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
  return getConstantLike(
      b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
}

static Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val,
                                     bool negative) {
  auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
  return getConstantLike(
      b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
}

//===----------------------------------------------------------------------===//
// Broadcasting Patterns.
//===----------------------------------------------------------------------===//

// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding stablehlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertTrivialNonBroadcastBinaryOp final
    : OpConversionPattern<ChloOpTy> {
  using OpConversionPattern<ChloOpTy>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Only rewrite for statically determinable non-broadcasting cases.
    auto lhsType = dyn_cast<RankedTensorType>(adaptor.getLhs().getType());
    auto rhsType = dyn_cast<RankedTensorType>(adaptor.getRhs().getType());
    if (!lhsType || !rhsType)
      return failure();

    // Requires rank broadcast.
    if (lhsType.getRank() != rhsType.getRank())
      return failure();

    // Any dynamic dimension may require broadcasting and requires more
    // analysis.
    if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) {
      return failure();
    }

    if (!llvm::equal(lhsType.getShape(), rhsType.getShape())) {
      return failure();
    }

    rewriter.replaceOp(
        op, ValueRange{Adaptor::createOp(op, op.getResult().getType(),
                                         adaptor.getOperands(), rewriter)});
    return success();
  }
};

// Converts a binary op with ranked broadcasting operands to explicitly
// broadcast and invoke the corresponding stablehlo non-broadcasting op.
// Note that dynamic broadcasting supported by this pattern is only valid for
// "numpy" broadcasting semantics as defined here:
//   https://docs.scipy.org/doc/numpy/reference/ufuncs.html
// Specifically, this includes the following cases:
//   - Same rank broadcast (operands have the same static rank).
//   - Different-rank broadcast, either without a broadcast_dims attribute or
//     with the broadcast_dims attribute set to map to a prefix padding.
//   - Legal combinations of degenerate (1-dim) implicit broadcasting.
// The restriction on broadcast_dims derives from the definition of the
// `shape.broadcast` op, which only supports prefix-padding.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertRankedDynamicBroadcastBinaryOp final
    : OpConversionPattern<ChloOpTy> {
  using OpConversionPattern<ChloOpTy>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Only support ranked operands.
    Value lhs = adaptor.getLhs();
    Value rhs = adaptor.getRhs();
    auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
    auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
    auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
    if (!lhsType || !rhsType || !resultType)
      return failure();

    // Check for "numpy"-style rank broadcast.
    auto broadcastDimensions = op.getBroadcastDimensions();
    if (broadcastDimensions && !mlir::hlo::isLegalNumpyRankedBroadcast(
                                   lhs, rhs, *broadcastDimensions)) {
      // Note: It is unclear whether the general specification of explicit
      // broadcast_dimensions on binary ops is a feature we want to carry
      // forward. While it can technically be implemented for ranked-dynamic,
      // it is incompatible with unranked inputs. If this warning is emitted
      // in real programs, it is an indication that the feature should be
      // implemented versus just falling back on the more standard definition
      // of numpy-like prefix-padding.
      op.emitWarning() << "unsupported non prefix-padded dynamic rank "
                       << "broadcast_dimensions = " << *broadcastDimensions;
      return failure();
    }

    // Compute result shape.
    Location loc = op.getLoc();

    // Insert a constraint on the shapes being broadcastable and insert all
    // future code into an assuming block reliant on the constraint.
    Value lhsShape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
    Value rhsShape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
    auto broadcastableCstr =
        rewriter.create<shape::CstrBroadcastableOp>(loc, lhsShape, rhsShape);
    auto assumingOp = rewriter.create<shape::AssumingOp>(
        loc, ArrayRef<Type>{resultType}, broadcastableCstr.getResult());

    OpBuilder::InsertionGuard guard(rewriter);
    rewriter.createBlock(&assumingOp.getDoRegion());

    int64_t resultRank = std::max(lhsType.getRank(), rhsType.getRank());
    Value resultExtents =
        hlo::computeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
                                                               rewriter);

    // Note that we unconditionally emit DynamicBroadcastInDim ops and let
    // downstream canonicalizations fold them away if possible. This is
    // because, in the dynamic case, there are many corner cases regarding
    // when it is safe to omit, and some of them require analysis to prove
    // properly.
    auto lhsBroadcastDimensions = llvm::to_vector(
        llvm::seq<int64_t>(resultRank - lhsType.getRank(), resultRank));
    Value broadcastedLhs =
        rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
            loc,
            RankedTensorType::get(resultType.getShape(),
                                  lhsType.getElementType()),
            lhs, resultExtents,
            rewriter.getDenseI64ArrayAttr(lhsBroadcastDimensions));
    auto rhsBroadcastDimensions = llvm::to_vector(
        llvm::seq<int64_t>(resultRank - rhsType.getRank(), resultRank));
    Value broadcastedRhs =
        rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
            loc,
            RankedTensorType::get(resultType.getShape(),
                                  rhsType.getElementType()),
            rhs, resultExtents,
            rewriter.getDenseI64ArrayAttr(rhsBroadcastDimensions));

    // And generate the final non-broadcasted binary op.
    Value finalResult = Adaptor::createOp(
        op, resultType, {broadcastedLhs, broadcastedRhs}, rewriter);
    rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
    rewriter.replaceOp(op, {assumingOp.getResult(0)});
    return success();
  }
};

struct ConvertConstantLikeOp final
    : OpConversionPattern<mlir::chlo::ConstantLikeOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::ConstantLikeOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto resultTy = cast<ShapedType>(op.getType());

    // Unranked uses are not supported.
    if (!resultTy.hasRank())
      return failure();

    // Lower to HLO constant if statically shaped.
    if (resultTy.hasStaticShape()) {
      auto complexAttr = dyn_cast<mlir::complex::NumberAttr>(op.getValue());
      auto attr = DenseElementsAttr::get(resultTy, complexAttr ? complexAttr
                                                               : op.getValue());
      rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, attr);
      return success();
    }

    // Lower to broadcasted constant.
    Location loc = op.getLoc();
    Value constant =
        rewriter.create<mlir::stablehlo::ConstantOp>(loc, op.getValue());
    Value shape = rewriter.create<shape::ShapeOfOp>(loc, adaptor.getOperand());
    rewriter.replaceOpWithNewOp<mlir::stablehlo::DynamicBroadcastInDimOp>(
        op, resultTy, constant, shape, rewriter.getDenseI64ArrayAttr({}));
    return success();
  }
};

struct ConvertSelectOp final
    : OpConversionPattern<mlir::chlo::BroadcastSelectOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::BroadcastSelectOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Only support ranked operands.
    Value pred = adaptor.getPred();
    Value onTrue = adaptor.getOnTrue();
    Value onFalse = adaptor.getOnFalse();
    auto predType = dyn_cast<RankedTensorType>(pred.getType());
    auto onTrueType = dyn_cast<RankedTensorType>(onTrue.getType());
    auto onFalseType = dyn_cast<RankedTensorType>(onFalse.getType());
    auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
    if (!predType || !onTrueType || !onFalseType || !resultType) {
      return failure();
    }

    Location loc = op.getLoc();
    Value predShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
    Value onTrueShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onTrue);
    Value onFalseShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onFalse);
    int64_t resultRank = std::max(
        {predType.getRank(), onTrueType.getRank(), onFalseType.getRank()});

    Value broadcastableCstr = rewriter.createOrFold<shape::CstrBroadcastableOp>(
        loc, ValueRange{predShape, onTrueShape, onFalseShape});
    auto assumingOp = rewriter.create<shape::AssumingOp>(
        loc, ArrayRef<Type>{resultType}, broadcastableCstr);

    OpBuilder::InsertionGuard guard(rewriter);
    rewriter.createBlock(&assumingOp.getDoRegion());

    Value resultExtents = rewriter.createOrFold<shape::BroadcastOp>(
        loc, shape::getExtentTensorType(op.getContext()),
        ValueRange{predShape, onTrueShape, onFalseShape},
        /*error=*/nullptr);
    auto shapeType =
        RankedTensorType::get({resultRank}, rewriter.getIndexType());
    resultExtents =
        rewriter.createOrFold<tensor::CastOp>(loc, shapeType, resultExtents);

    Value broadcastedPred = pred;
    // Pred has an implicit broadcast for scalars, so use that when convenient.
    if (predType.getRank() > 0) {
      auto predBroadcastDimensions = llvm::to_vector(
          llvm::seq<int64_t>(resultRank - predType.getRank(), resultRank));
      broadcastedPred =
          rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
              loc,
              RankedTensorType::get(resultType.getShape(),
                                    predType.getElementType()),
              pred, resultExtents,
              rewriter.getDenseI64ArrayAttr(predBroadcastDimensions));
    }
    auto onTrueBroadcastDimensions = llvm::to_vector(
        llvm::seq<int64_t>(resultRank - onTrueType.getRank(), resultRank));
    Value broadcastedOnTrue =
        rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
            loc,
            RankedTensorType::get(resultType.getShape(),
                                  onTrueType.getElementType()),
            onTrue, resultExtents,
            rewriter.getDenseI64ArrayAttr(onTrueBroadcastDimensions));
    auto onFalseBroadcastDimensions = llvm::to_vector(
        llvm::seq<int64_t>(resultRank - onFalseType.getRank(), resultRank));
    Value broadcastedOnFalse =
        rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
            loc,
            RankedTensorType::get(resultType.getShape(),
                                  onFalseType.getElementType()),
            onFalse, resultExtents,
            rewriter.getDenseI64ArrayAttr(onFalseBroadcastDimensions));

    // And generate the final non-broadcasted ternary op.
    Value finalResult = rewriter.create<mlir::stablehlo::SelectOp>(
        loc, resultType, broadcastedPred, broadcastedOnTrue,
        broadcastedOnFalse);
    rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
    rewriter.replaceOp(op, {assumingOp.getResult(0)});
    return success();
  }
};

struct ConvertDynamicReshapeOp final
    : OpRewritePattern<mlir::chlo::DynamicReshapeOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(mlir::chlo::DynamicReshapeOp op,
                                PatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    TypedValue<TensorType> tensor = op.getOperand();
    TypedValue<RankedTensorType> shape = op.getOutputShape();

    auto shapeTy = cast<ShapedType>(shape.getType());
    auto resultTy = cast<ShapedType>(op.getType());

    Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
    Value numEls = rewriter.create<shape::NumElementsOp>(loc, inputShape);
    Value cstr =
        rewriter.create<mlir::stablehlo::CstrReshapableOp>(loc, numEls, shape);
    rewriter.replaceOpWithNewOp<shape::AssumingOp>(
        op, cstr, [&](OpBuilder &b, Location l) {
          Value computedShape =
              b.create<mlir::stablehlo::ComputeReshapeShapeOp>(l, shapeTy,
                                                               numEls, shape);
          SmallVector<Value> result;
          result.push_back(b.create<mlir::stablehlo::DynamicReshapeOp>(
              l, resultTy, tensor, computedShape));
          return result;
        });

    return success();
  }
};

//===----------------------------------------------------------------------===//
// Decomposition Patterns.
//===----------------------------------------------------------------------===//

struct ConvertConstantOp final : OpConversionPattern<mlir::chlo::ConstantOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::ConstantOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, op.getValue());
    return success();
  }
};

template <typename FTy>
static Value
materializeChebyshevPolynomialApproximation(ConversionPatternRewriter &rewriter,
                                            Location loc, Value x,
                                            ArrayRef<FTy> coefficients) {
  Value b0 = getConstantLike(rewriter, loc, 0.0, x);
  Value b1 = getConstantLike(rewriter, loc, 0.0, x);
  Value b2 = getConstantLike(rewriter, loc, 0.0, x);
  for (FTy c : coefficients) {
    b2 = b1;
    b1 = b0;
    b0 = rewriter.create<mlir::stablehlo::MulOp>(loc, x.getType(), x, b1);
    b0 = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x.getType(), b0, b2);
    b0 = rewriter.create<mlir::stablehlo::AddOp>(
        loc, x.getType(), b0, getConstantLike(rewriter, loc, c, x));
  }
  Value result =
      rewriter.create<mlir::stablehlo::SubtractOp>(loc, x.getType(), b0, b2);
  result = rewriter.create<mlir::stablehlo::MulOp>(
      loc, x.getType(), result, getConstantLike(rewriter, loc, 0.5, x));
  return result;
}

template <typename FTy>
static Value materializeBesselI1eApproximation(
    ConversionPatternRewriter &rewriter, Location loc, Value x,
    ArrayRef<FTy> kI1eCoeffsA, ArrayRef<FTy> kI1eCoeffsB) {
  Value z = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
  Value half = getConstantLike(rewriter, loc, 0.5, x);
  Value two = getConstantLike(rewriter, loc, 2.0, x);
  Value thirtyTwo = getConstantLike(rewriter, loc, 32.0, x);
  Value eight = getConstantLike(rewriter, loc, 8.0, x);

  Value tmp = rewriter.create<mlir::stablehlo::MulOp>(loc, half, z);
  tmp = rewriter.create<mlir::stablehlo::SubtractOp>(loc, tmp, two);

  Value xLe8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
                                                           kI1eCoeffsA);
  xLe8 = rewriter.create<mlir::stablehlo::MulOp>(loc, z, xLe8);

  tmp = rewriter.create<mlir::stablehlo::DivOp>(loc, thirtyTwo, z);
  tmp = rewriter.create<mlir::stablehlo::SubtractOp>(loc, tmp, two);
  Value xGt8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
                                                           kI1eCoeffsB);
  xGt8 = rewriter.create<mlir::stablehlo::DivOp>(
      loc, xGt8, rewriter.create<mlir::stablehlo::SqrtOp>(loc, z));

  Value isLe8 = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, z, eight, mlir::stablehlo::ComparisonDirection::LE);

  Value select =
      rewriter.create<mlir::stablehlo::SelectOp>(loc, isLe8, xLe8, xGt8);
  return rewriter.create<mlir::stablehlo::MulOp>(
      loc, rewriter.create<mlir::stablehlo::SignOp>(loc, x), select);
}

Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
                                           Location loc, ValueRange args) {
  Value x = args.front();
  assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
         "expect f32 element type");
  const float kI1eCoeffsA[] = {
      9.38153738649577178388E-9f, -4.44505912879632808065E-8f,
      2.00329475355213526229E-7f, -8.56872026469545474066E-7f,
      3.47025130813767847674E-6f, -1.32731636560394358279E-5f,
      4.78156510755005422638E-5f, -1.61760815825896745588E-4f,
      5.12285956168575772895E-4f, -1.51357245063125314899E-3f,
      4.15642294431288815669E-3f, -1.05640848946261981558E-2f,
      2.47264490306265168283E-2f, -5.29459812080949914269E-2f,
      1.02643658689847095384E-1f, -1.76416518357834055153E-1f,
      2.52587186443633654823E-1f};

  const float kI1eCoeffsB[] = {
      -3.83538038596423702205E-9f, -2.63146884688951950684E-8f,
      -2.51223623787020892529E-7f, -3.88256480887769039346E-6f,
      -1.10588938762623716291E-4f, -9.76109749136146840777E-3f,
      7.78576235018280120474E-1f};

  return materializeBesselI1eApproximation<float>(rewriter, loc, x, kI1eCoeffsA,
                                                  kI1eCoeffsB);
}

static Value
materializeBesselI1eApproximationF64(ConversionPatternRewriter &rewriter,
                                     Location loc, ValueRange args) {
  Value x = args.front();
  assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
         "expect f64 element type");

  const double kI1eCoeffsA[] = {
      2.77791411276104639959E-18, -2.11142121435816608115E-17,
      1.55363195773620046921E-16, -1.10559694773538630805E-15,
      7.60068429473540693410E-15, -5.04218550472791168711E-14,
      3.22379336594557470981E-13, -1.98397439776494371520E-12,
      1.17361862988909016308E-11, -6.66348972350202774223E-11,
      3.62559028155211703701E-10, -1.88724975172282928790E-9,
      9.38153738649577178388E-9,  -4.44505912879632808065E-8,
      2.00329475355213526229E-7,  -8.56872026469545474066E-7,
      3.47025130813767847674E-6,  -1.32731636560394358279E-5,
      4.78156510755005422638E-5,  -1.61760815825896745588E-4,
      5.12285956168575772895E-4,  -1.51357245063125314899E-3,
      4.15642294431288815669E-3,  -1.05640848946261981558E-2,
      2.47264490306265168283E-2,  -5.29459812080949914269E-2,
      1.02643658689847095384E-1,  -1.76416518357834055153E-1,
      2.52587186443633654823E-1};

  const double kI1eCoeffsB[] = {
      7.51729631084210481353E-18,  4.41434832307170791151E-18,
      -4.65030536848935832153E-17, -3.20952592199342395980E-17,
      2.96262899764595013876E-16,  3.30820231092092828324E-16,
      -1.88035477551078244854E-15, -3.81440307243700780478E-15,
      1.04202769841288027642E-14,  4.27244001671195135429E-14,
      -2.10154184277266431302E-14, -4.08355111109219731823E-13,
      -7.19855177624590851209E-13, 2.03562854414708950722E-12,
      1.41258074366137813316E-11,  3.25260358301548823856E-11,
      -1.89749581235054123450E-11, -5.58974346219658380687E-10,
      -3.83538038596423702205E-9,  -2.63146884688951950684E-8,
      -2.51223623787020892529E-7,  -3.88256480887769039346E-6,
      -1.10588938762623716291E-4,  -9.76109749136146840777E-3,
      7.78576235018280120474E-1};

  return materializeBesselI1eApproximation<double>(rewriter, loc, x,
                                                   kI1eCoeffsA, kI1eCoeffsB);
}

static Value materializeWithUpcast(ConversionPatternRewriter &rewriter,
                                   Location loc, ValueRange args,
                                   FloatType minPrecisionTy,
                                   Value callback(ConversionPatternRewriter &,
                                                  Location, ValueRange)) {
  Type originalTy = getElementTypeOrSelf(args.front().getType());
  auto floatOriginalTy = dyn_cast<FloatType>(originalTy);
  bool needsUpcast =
      floatOriginalTy && floatOriginalTy.getWidth() < minPrecisionTy.getWidth();

  // Upcast arguments if necessary.
  llvm::SmallVector<Value, 2> castedArgs;
  if (needsUpcast) {
    for (Value a : args) {
      castedArgs.push_back(
          rewriter.create<mlir::stablehlo::ConvertOp>(loc, a, minPrecisionTy));
    }
    args = castedArgs;
  }

  Value result = callback(rewriter, loc, args);

  // Cast back if necessary.
  if (needsUpcast) {
    result =
        rewriter.create<mlir::stablehlo::ConvertOp>(loc, result, originalTy);
  }

  return result;
}

struct ConvertBesselI1eOp final : OpConversionPattern<mlir::chlo::BesselI1eOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::BesselI1eOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    Value x = adaptor.getOperand();
    Type ty = cast<ShapedType>(x.getType()).getElementType();

    // For now, we support only f64, f32, f16 and bf16.
    // See https://www.tensorflow.org/api_docs/python/tf/math/bessel_i1e
    if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) {
      return failure();
    }

    if (ty.isF64()) {
      rewriter.replaceOp(
          op, materializeBesselI1eApproximationF64(rewriter, loc, x));
      return success();
    }

    rewriter.replaceOp(
        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
                                  rewriter.getF32Type(),
                                  &materializeBesselI1eApproximationF32));
    return success();
  }
};

template <typename FTy>
static Value
materializePolynomialApproximation(ConversionPatternRewriter &rewriter,
                                   Location loc, Value x,
                                   ArrayRef<FTy> coefficients) {
  if (coefficients.empty())
    return getConstantLike(rewriter, loc, 0.0, x);

  Value poly = getConstantLike(rewriter, loc, coefficients[0], x);
  for (size_t i = 1, e = coefficients.size(); i < e; ++i) {
    poly = rewriter.create<mlir::stablehlo::MulOp>(loc, x.getType(), poly, x);
    poly = rewriter.create<mlir::stablehlo::AddOp>(
        loc, x.getType(), poly,
        getConstantLike(rewriter, loc, coefficients[i], x));
  }
  return poly;
}

// Precondition is |x| >= 1. Use erf approximation, otherwise.
//
// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
static Value materializeErfcApproximationF64ForMagnituteGeOne(
    ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
  Value x = args.front();
  assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
         "expect f64 element type");
  const double kMaxlog = 7.09782712893383996843E2;
  const double kErfcPCoefficients[] = {
      2.46196981473530512524E-10, 5.64189564831068821977E-1,
      7.46321056442269912687E0,   4.86371970985681366614E1,
      1.96520832956077098242E2,   5.26445194995477358631E2,
      9.34528527171957607540E2,   1.02755188689515710272E3,
      5.57535335369399327526E2};
  const double kErfcQCoefficients[] = {
      1.00000000000000000000E0, 1.32281951154744992508E1,
      8.67072140885989742329E1, 3.54937778887819891062E2,
      9.75708501743205489753E2, 1.82390916687909736289E3,
      2.24633760818710981792E3, 1.65666309194161350182E3,
      5.57535340817727675546E2};
  const double kErfcRCoefficients[] = {
      5.64189583547755073984E-1, 1.27536670759978104416E0,
      5.01905042251180477414E0,  6.16021097993053585195E0,
      7.40974269950448939160E0,  2.97886665372100240670E0};
  const double kErfcSCoefficients[] = {
      1.00000000000000000000E0, 2.26052863220117276590E0,
      9.39603524938001434673E0, 1.20489539808096656605E1,
      1.70814450747565897222E1, 9.60896809063285878198E0,
      3.36907645100081516050E0};

  // Let z = -x^2.
  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
  Value z = rewriter.create<mlir::stablehlo::NegOp>(loc, xSq);

  // Materialize polynomial approximation for x in [1, 8) as
  //   erfc(x) = exp(z) P(|x|) / Q(|x|).
  Value expZ = rewriter.create<mlir::stablehlo::ExpOp>(loc, z);
  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
  Value polP = materializePolynomialApproximation(
      rewriter, loc, absX, llvm::ArrayRef(kErfcPCoefficients));
  Value expZMulPolyP = rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, polP);
  Value polQ = materializePolynomialApproximation(
      rewriter, loc, absX, llvm::ArrayRef(kErfcQCoefficients));
  Value erfcApprox18 =
      rewriter.create<mlir::stablehlo::DivOp>(loc, expZMulPolyP, polQ);

  // Materialize polynomial approximation for x in >= 8 as
  //   erfc(x) exp(z) R(|x|) / S(|x|).
  Value polR = materializePolynomialApproximation(
      rewriter, loc, absX, llvm::ArrayRef(kErfcRCoefficients));
  Value expZMulPolyR = rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, polR);
  Value polS = materializePolynomialApproximation(
      rewriter, loc, absX, llvm::ArrayRef(kErfcSCoefficients));
  Value erfcApprox8Inf =
      rewriter.create<mlir::stablehlo::DivOp>(loc, expZMulPolyR, polS);

  // Combine polynomial approximations for x >= 1.
  Value eight = getConstantLike(rewriter, loc, 8.0, x);
  Value absXLt8 = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, absX, eight, mlir::stablehlo::ComparisonDirection::LT);
  Value erfcApprox = rewriter.create<mlir::stablehlo::SelectOp>(
      loc, absXLt8, erfcApprox18, erfcApprox8Inf);

  // Clamp to prevent overflow and materialize approximation for large x as
  //   erfc(x) = 0.
  Value zLtNegMaxlog = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, z, getConstantLike(rewriter, loc, -kMaxlog, x),
      mlir::stablehlo::ComparisonDirection::LT);
  Value zero = getConstantLike(rewriter, loc, 0.0, x);
  Value erfcApproxClamped = rewriter.create<mlir::stablehlo::SelectOp>(
      loc, zLtNegMaxlog, zero, erfcApprox);

  // Derive approximation for x <= -1 as
  //   erfc(x) = 2 - erfc(-x).
  // Reuse previously materialized approximations all of which take |x| as their
  // argument.
  Value xLtZero = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, x, zero, mlir::stablehlo::ComparisonDirection::LT);
  Value two = getConstantLike(rewriter, loc, 2.0, x);
  Value twoSubErfcApproxClamped =
      rewriter.create<mlir::stablehlo::SubtractOp>(loc, two, erfcApproxClamped);
  return rewriter.create<mlir::stablehlo::SelectOp>(
      loc, xLtZero, twoSubErfcApproxClamped, erfcApproxClamped);
}

// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
static Value materializeErfApproximationF64ForMagnituteLeOne(
    ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
  Value x = args.front();
  assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
         "expect f64 element type");
  const double kErfTCoefficients[] = {
      9.60497373987051638749E0, 9.00260197203842689217E1,
      2.23200534594684319226E3, 7.00332514112805075473E3,
      5.55923013010394962768E4};
  const double kErfUCoefficients[] = {
      1.00000000000000000000E0, 3.35617141647503099647E1,
      5.21357949780152679795E2, 4.59432382970980127987E3,
      2.26290000613890934246E4, 4.92673942608635921086E4};

  // Materialize polynomial approximation for |x| <= 1 as
  //   erf(x) = x T(x^2) / U(x^2).
  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
  Value polyT = materializePolynomialApproximation(
      rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients));
  Value xMulPolyT = rewriter.create<mlir::stablehlo::MulOp>(loc, x, polyT);
  Value polyU = materializePolynomialApproximation(
      rewriter, loc, xSq, llvm::ArrayRef(kErfUCoefficients));
  return rewriter.create<mlir::stablehlo::DivOp>(loc, xMulPolyT, polyU);
}

// This implementation is based on Cephes.
static Value materializeErfApproximationF64(ConversionPatternRewriter &rewriter,
                                            Location loc, ValueRange args) {
  Value x = args.front();
  assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
         "expect f64 element type");

  // Rely on erf approximation for |x| < 1
  //   erf(x) = erf_approx(x)
  Value erfApprox =
      materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);

  // Rely on erfc approximation for |x| >= 1 and materialize erf as
  //   erf(x) = 1 - erfc_approx(x)
  Value one = getConstantLike(rewriter, loc, 1.0, x);
  Value erfcApprox =
      materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
  Value erfcBasedApprox =
      rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfcApprox);

  // Materialize approximation selection based on argument.
  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
  Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
  return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne, erfApprox,
                                                    erfcBasedApprox);
}

static Value
materializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
                                Location loc, ValueRange args) {
  Value x = args.front();
  assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
         "expect f64 element type");

  // Rely on erfc approximation for |x| >= 1
  //   erfc(x) = erfc_approx(x)
  Value erfcApprox =
      materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);

  // Rely on erf approximation for |x| < 1 and materialize erfc as
  //   erfc(x) = 1 - erf_approx(x)
  Value one = getConstantLike(rewriter, loc, 1.0, x);
  Value erfApprox =
      materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
  Value erfBasedApprox =
      rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfApprox);

  // Materialize approximation selection based on argument.
  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
  Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
  return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne,
                                                    erfBasedApprox, erfcApprox);
}

// Precondition is |x| >= 1. Use erf approximation, otherwise.
//
// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
static Value materializeErfcApproximationF32ForMagnitudeGeOne(
    ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
  Value x = args.front();
  assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
         "expect f32 element type");
  const double kMaxlog = 88.72283905206835;
  const float kErfcPCoefficients[] = {
      +2.326819970068386E-2f, -1.387039388740657E-1f, +3.687424674597105E-1f,
      -5.824733027278666E-1f, +6.210004621745983E-1f, -4.944515323274145E-1f,
      +3.404879937665872E-1f, -2.741127028184656E-1f, +5.638259427386472E-1f,
  };
  const float kErfcRCoefficients[] = {
      -1.047766399936249E+1f, +1.297719955372516E+1f, -7.495518717768503E+0f,
      +2.921019019210786E+0f, -1.015265279202700E+0f, +4.218463358204948E-1f,
      -2.820767439740514E-1f, +5.641895067754075E-1f,
  };

  // Let z = -x^2.
  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
  Value z = rewriter.create<mlir::stablehlo::NegOp>(loc, xSq);

  // Materialize polynomial approximation for x >= 1 as
  //   erfc(x) = exp(z) 1/x P(1/x^2)   if x in [1, 2)
  //   erfc(x) = exp(z) 1/x R(1/x^2)   if x >= 2
  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
  Value one = getConstantLike(rewriter, loc, 1.0, x);
  Value reciprocalXSq = rewriter.create<mlir::stablehlo::DivOp>(loc, one, xSq);
  Value expZ = rewriter.create<mlir::stablehlo::ExpOp>(loc, z);
  Value oneDivAbsX = rewriter.create<mlir::stablehlo::DivOp>(loc, one, absX);
  Value expZMulOneDivAbsX =
      rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, oneDivAbsX);
  Value two = getConstantLike(rewriter, loc, 2.0, x);
  Value absXLtTwo = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, absX, two, mlir::stablehlo::ComparisonDirection::LT);
  Value polP = materializePolynomialApproximation(
      rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcPCoefficients));
  Value polR = materializePolynomialApproximation(
      rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcRCoefficients));
  Value poly =
      rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtTwo, polP, polR);
  Value erfcApprox =
      rewriter.create<mlir::stablehlo::MulOp>(loc, expZMulOneDivAbsX, poly);

  // Clamp to prevent overflow and materialize approximation for large x as
  //   erfc(x) = 0.
  Value zLtNeqMaxlog = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, z, getConstantLike(rewriter, loc, -kMaxlog, x),
      mlir::stablehlo::ComparisonDirection::LT);
  Value zero = getConstantLike(rewriter, loc, 0.0, x);
  Value erfcApproxClamped = rewriter.create<mlir::stablehlo::SelectOp>(
      loc, zLtNeqMaxlog, zero, erfcApprox);

  // Derive approximation for x <= -1 as
  //   erfc(x) = 2 - erfc(-x).
  // Reuse previously materialized approximations all of which take |x| as their
  // argument.
  Value xLtZero = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, x, zero, mlir::stablehlo::ComparisonDirection::LT);
  Value twoSubErfcApprox =
      rewriter.create<mlir::stablehlo::SubtractOp>(loc, two, erfcApproxClamped);
  return rewriter.create<mlir::stablehlo::SelectOp>(
      loc, xLtZero, twoSubErfcApprox, erfcApproxClamped);
}

// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
static Value materializeErfApproximationF32ForMagnitudeLeOne(
    ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
  Value x = args.front();
  assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
         "expect f32 element type");
  const float kErfTCoefficients[] = {
      +7.853861353153693E-5f, -8.010193625184903E-4f, +5.188327685732524E-3f,
      -2.685381193529856E-2f, +1.128358514861418E-1f, -3.761262582423300E-1f,
      +1.128379165726710E+0f,
  };

  // Materialize polynomial approximation for |x| <= 1 as
  //   erf(x) = x T(x^2).
  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
  Value polyT = materializePolynomialApproximation(
      rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients));
  return rewriter.create<mlir::stablehlo::MulOp>(loc, x, polyT);
}

// This is the same approximation as used in Eigen.
static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
                                            Location loc, ValueRange args) {
  Value x = args.front();
  assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
         "expect f32 element type");
  const float kAlpha[] = {
      -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f,
      -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
      -1.60960333262415e-02f,
  };
  const float kBeta[] = {
      -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
      -7.37332916720468e-03f, -1.42647390514189e-02f,
  };

  // Clamp argument between -4 and 4.
  Value lb = getConstantLike(rewriter, loc, -4.0, x);
  Value ub = getConstantLike(rewriter, loc, 4.0, x);
  x = rewriter.create<mlir::stablehlo::ClampOp>(loc, x.getType(), lb, x, ub);
  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);

  // Materialize polynomial approximation for x in [-4, 4] as
  //   erf(x) = x * Alpha(x^2) / Beta(x^2).
  Value alphaPoly = materializePolynomialApproximation(rewriter, loc, xSq,
                                                       llvm::ArrayRef(kAlpha));
  Value betaPoly = materializePolynomialApproximation(rewriter, loc, xSq,
                                                      llvm::ArrayRef(kBeta));
  Value xMulAlphaPoly =
      rewriter.create<mlir::stablehlo::MulOp>(loc, x, alphaPoly);
  Value erf =
      rewriter.create<mlir::stablehlo::DivOp>(loc, xMulAlphaPoly, betaPoly);
  Value lbErf = getConstantLike(rewriter, loc, -1.0, x);
  Value ubErf = getConstantLike(rewriter, loc, 1.0, x);
  return rewriter.create<mlir::stablehlo::ClampOp>(loc, erf.getType(), lbErf,
                                                   erf, ubErf);
}

static Value
materializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
                                Location loc, ValueRange args) {
  Value x = args.front();
  assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
         "expect f32 element type");

  // Rely on erfc approximation for |x| >= 1
  //   erfc(x) = erfc_approx(x)
  Value erfcApprox =
      materializeErfcApproximationF32ForMagnitudeGeOne(rewriter, loc, x);

  // Rely on erf approximation for |x| < 1 and materialize erfc as
  //   erfc(x) = 1 - erf_approx(x)
  Value one = getConstantLike(rewriter, loc, 1.0, x);
  Value erfApprox =
      materializeErfApproximationF32ForMagnitudeLeOne(rewriter, loc, x);
  Value erfBasedApprox =
      rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfApprox);

  // Materialize approximation selection based on argument.
  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
  Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
  return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne,
                                                    erfBasedApprox, erfcApprox);
}

struct ConvertErfOp final : OpConversionPattern<mlir::chlo::ErfOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::ErfOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    Value x = adaptor.getOperand();
    Type ty = cast<ShapedType>(x.getType()).getElementType();

    // For now, we support only f64, f32, f16 and bf16.
    if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) {
      return failure();
    }

    if (ty.isF64()) {
      rewriter.replaceOp(op, materializeErfApproximationF64(rewriter, loc, x));
      return success();
    }

    rewriter.replaceOp(
        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
                                  rewriter.getF32Type(),
                                  &materializeErfApproximationF32));
    return success();
  }
};

struct ConvertErfcOp final : OpConversionPattern<mlir::chlo::ErfcOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::ErfcOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    Value x = adaptor.getOperand();
    Type ty = cast<ShapedType>(x.getType()).getElementType();

    // For now, we support only f64, f32, f16 and bf16.
    if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) {
      return failure();
    }

    if (ty.isF64()) {
      rewriter.replaceOp(op, materializeErfcApproximationF64(rewriter, loc, x));
      return success();
    }

    rewriter.replaceOp(
        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
                                  rewriter.getF32Type(),
                                  &materializeErfcApproximationF32));
    return success();
  }
};

static Value erfInv32(ConversionPatternRewriter &b, Location loc,
                      ValueRange args) {
  constexpr int kDegree = 9;
  constexpr std::array<float, 9> wLessThan5Constants = {
      2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
      -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
      -0.00417768164f,  0.246640727f,    1.50140941f};
  constexpr std::array<float, 9> wGreaterThan5Constants = {
      -0.000200214257f, 0.000100950558f, 0.00134934322f,
      -0.00367342844f,  0.00573950773f,  -0.0076224613f,
      0.00943887047f,   1.00167406f,     2.83297682f};

  Value x = args[0];
  // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
  // log(1+arg) when arg is close to zero. For more details, see
  // https://en.cppreference.com/w/cpp/numeric/math/log1p
  Value minusXSquared = b.create<mlir::stablehlo::MulOp>(
      loc, x, b.create<mlir::stablehlo::NegOp>(loc, x));
  Value w = b.create<mlir::stablehlo::NegOp>(
      loc, b.create<mlir::stablehlo::Log1pOp>(loc, minusXSquared));

  Value lt = b.create<mlir::stablehlo::CompareOp>(
      loc, w, getConstantLike(b, loc, 5.0, x),
      mlir::stablehlo::ComparisonDirection::LT);
  auto coefficient = [&](int i) {
    return b.create<mlir::stablehlo::SelectOp>(
        loc, lt, getConstantLike(b, loc, wLessThan5Constants[i], x),
        getConstantLike(b, loc, wGreaterThan5Constants[i], x));
  };
  w = b.create<mlir::stablehlo::SelectOp>(
      loc, lt,
      b.create<mlir::stablehlo::SubtractOp>(loc, w,
                                            getConstantLike(b, loc, 2.5, x)),
      b.create<mlir::stablehlo::SubtractOp>(
          loc, b.create<mlir::stablehlo::SqrtOp>(loc, w),
          getConstantLike(b, loc, 3.0, x)));
  Value p = coefficient(0);
  for (int i = 1; i < kDegree; ++i) {
    p = b.create<mlir::stablehlo::AddOp>(
        loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w));
  }

  // Result modulo edge cases.
  Value result = b.create<mlir::stablehlo::MulOp>(loc, p, x);

  // Handle edge cases, namely erfinv(+/-1) = +/-inf.  (The above computation is
  // indeterminate, and can give nan or -/+inf.)
  return b.create<mlir::stablehlo::SelectOp>(
      loc,
      b.create<mlir::stablehlo::CompareOp>(
          loc, b.create<mlir::stablehlo::AbsOp>(loc, x),
          getConstantLike(b, loc, 1, x),
          mlir::stablehlo::ComparisonDirection::EQ),
      b.create<mlir::stablehlo::MulOp>(
          loc, x, getConstantLikeInfValue(b, loc, x, false)),
      result);
}

static Value erfInv64(ConversionPatternRewriter &b, Location loc,
                      ValueRange args) {
  constexpr std::array<double, 23> wLessThan625Constants = {
      -3.6444120640178196996e-21, -1.685059138182016589e-19,
      1.2858480715256400167e-18,  1.115787767802518096e-17,
      -1.333171662854620906e-16,  2.0972767875968561637e-17,
      6.6376381343583238325e-15,  -4.0545662729752068639e-14,
      -8.1519341976054721522e-14, 2.6335093153082322977e-12,
      -1.2975133253453532498e-11, -5.4154120542946279317e-11,
      1.051212273321532285e-09,   -4.1126339803469836976e-09,
      -2.9070369957882005086e-08, 4.2347877827932403518e-07,
      -1.3654692000834678645e-06, -1.3882523362786468719e-05,
      0.0001867342080340571352,   -0.00074070253416626697512,
      -0.0060336708714301490533,  0.24015818242558961693,
      1.6536545626831027356};
  constexpr std::array<double, 19> wLessThan16Constants = {
      2.2137376921775787049e-09,  9.0756561938885390979e-08,
      -2.7517406297064545428e-07, 1.8239629214389227755e-08,
      1.5027403968909827627e-06,  -4.013867526981545969e-06,
      2.9234449089955446044e-06,  1.2475304481671778723e-05,
      -4.7318229009055733981e-05, 6.8284851459573175448e-05,
      2.4031110387097893999e-05,  -0.0003550375203628474796,
      0.00095328937973738049703,  -0.0016882755560235047313,
      0.0024914420961078508066,   -0.0037512085075692412107,
      0.005370914553590063617,    1.0052589676941592334,
      3.0838856104922207635,
  };
  constexpr std::array<double, 17> wGreaterThan16Constants = {
      -2.7109920616438573243e-11, -2.5556418169965252055e-10,
      1.5076572693500548083e-09,  -3.7894654401267369937e-09,
      7.6157012080783393804e-09,  -1.4960026627149240478e-08,
      2.9147953450901080826e-08,  -6.7711997758452339498e-08,
      2.2900482228026654717e-07,  -9.9298272942317002539e-07,
      4.5260625972231537039e-06,  -1.9681778105531670567e-05,
      7.5995277030017761139e-05,  -0.00021503011930044477347,
      -0.00013871931833623122026, 1.0103004648645343977,
      4.8499064014085844221,
  };

  Value x = args[0];
  // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
  // log(1+arg) when arg is close to zero. For more details, see
  // https://en.cppreference.com/w/cpp/numeric/math/log1p
  Value minusXSquared = b.create<mlir::stablehlo::MulOp>(
      loc, x, b.create<mlir::stablehlo::NegOp>(loc, x));
  Value w = b.create<mlir::stablehlo::NegOp>(
      loc, b.create<mlir::stablehlo::Log1pOp>(loc, minusXSquared));

  Value lt625 = b.create<mlir::stablehlo::CompareOp>(
      loc, w, getConstantLike(b, loc, 6.25, x),
      mlir::stablehlo::ComparisonDirection::LT);
  Value lt16 = b.create<mlir::stablehlo::CompareOp>(
      loc, w, getConstantLike(b, loc, 16, x),
      mlir::stablehlo::ComparisonDirection::LT);

  auto coefficient = [&](int i) {
    Value c = getConstantLike(b, loc, wLessThan625Constants[i], x);
    if (i < 19) {
      c = b.create<mlir::stablehlo::SelectOp>(
          loc, lt625, c, getConstantLike(b, loc, wLessThan16Constants[i], x));
    }
    if (i < 17) {
      c = b.create<mlir::stablehlo::SelectOp>(
          loc, lt16, c, getConstantLike(b, loc, wGreaterThan16Constants[i], x));
    }
    return c;
  };

  Value sqrtW = b.create<mlir::stablehlo::SqrtOp>(loc, w);
  Value wMinus3125 = b.create<mlir::stablehlo::SubtractOp>(
      loc, w, getConstantLike(b, loc, 3.125, x));
  Value select2 = b.create<mlir::stablehlo::SelectOp>(
      loc, lt16, getConstantLike(b, loc, 3.25, w),
      getConstantLike(b, loc, 5.0, w));
  Value select2Result =
      b.create<mlir::stablehlo::SubtractOp>(loc, sqrtW, select2);
  w = b.create<mlir::stablehlo::SelectOp>(loc, lt625, wMinus3125,
                                          select2Result);

  Value p = coefficient(0);
  for (int i = 1; i < 17; ++i) {
    p = b.create<mlir::stablehlo::AddOp>(
        loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w));
  }
  for (int i = 17; i < 19; ++i) {
    p = b.create<mlir::stablehlo::SelectOp>(
        loc, lt16,
        b.create<mlir::stablehlo::AddOp>(
            loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)),
        p);
  }
  for (int i = 19; i < 23; ++i) {
    p = b.create<mlir::stablehlo::SelectOp>(
        loc, lt625,
        b.create<mlir::stablehlo::AddOp>(
            loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)),
        p);
  }

  // Result modulo edge cases.
  Value result = b.create<mlir::stablehlo::MulOp>(loc, p, x);

  // Handle edge cases, namely erfinv(+/-1) = +/-inf.  (The above computation is
  // indeterminate, and can give nan or -/+inf.)
  return b.create<mlir::stablehlo::SelectOp>(
      loc,
      b.create<mlir::stablehlo::CompareOp>(
          loc, b.create<mlir::stablehlo::AbsOp>(loc, x),
          getConstantLike(b, loc, 1, x),
          mlir::stablehlo::ComparisonDirection::EQ),
      b.create<mlir::stablehlo::MulOp>(
          loc, x, getConstantLikeInfValue(b, loc, x, false)),
      result);
}

struct ConvertErfInvOp final : OpConversionPattern<mlir::chlo::ErfInvOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::ErfInvOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    if (op.getResult().getType().getElementType().isF64()) {
      rewriter.replaceOp(op, erfInv64(rewriter, loc, adaptor.getOperands()));
      return success();
    }
    FloatType minPrecisionTy = rewriter.getF32Type();
    rewriter.replaceOp(op, materializeWithUpcast(rewriter, loc,
                                                 adaptor.getOperands(),
                                                 minPrecisionTy, &erfInv32));
    return success();
  }
};

// Coefficients for the Lanczos approximation of the gamma function. The
// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
// and kLanczosCoefficients.size() + 1). The coefficients below correspond to
// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
// [7, 9] seemed to be the least sensitive to the quality of the log function.
// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
// for a particularly inaccurate log function.
constexpr double kLanczosGamma = 7; // aka g
constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
constexpr std::array<double, 8> kLanczosCoefficients = {
    676.520368121885098567009190444019, -1259.13921672240287047156078755283,
    771.3234287776530788486528258894,   -176.61502916214059906584551354,
    12.507343278686904814458936853,     -0.13857109526572011689554707,
    9.984369578019570859563e-6,         1.50563273514931155834e-7};

// Compute the Lgamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
//   lgamma(z + 1) = (log(2) + log(pi)) / 2
//                     + (z + 1/2) * log(t(z))
//                     - t(z) + log(a(z))
//   with   t(z) = z + kLanczosGamma + 1/2
//          a(z) = kBaseLanczosCoeff
//                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
static Value materializeLgamma(ConversionPatternRewriter &rewriter,
                               Location loc, ValueRange args) {
  // If the input is less than 0.5 use Euler's reflection formula.
  //   gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
  // Let z be
  //   z = -x      if x < 1/2
  //   z = x - 1   otheriwse
  Value x = args.front();
  Value half = getConstantLike(rewriter, loc, 0.5, x);
  Value needToReflect = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, x, half, mlir::stablehlo::ComparisonDirection::LT);
  Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
  Value one = getConstantLike(rewriter, loc, 1, x);
  Value xSubOne = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, one);
  Value z = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, negX,
                                                       xSubOne);

  // Materialize
  //   a(z) = kBaseLanczosCoeff
  //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
  Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
  for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
    Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
    Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
    Value quotient = rewriter.create<mlir::stablehlo::DivOp>(
        loc, coeff,
        rewriter.create<mlir::stablehlo::AddOp>(loc, z, oneBasedIndex));
    a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, quotient);
  }

  // To improve accuracy on platforms with less-precise log implementations,
  // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
  // device.
  // Materialize as
  //   log(t) = log(kLanczosGamma + 1/2 + z)
  //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
  Value lanczosPlusHalf =
      getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
  Value t = rewriter.create<mlir::stablehlo::AddOp>(loc, lanczosPlusHalf, z);
  Value logTerm =
      getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
  Value log1pTerm = rewriter.create<mlir::stablehlo::Log1pOp>(
      loc, rewriter.create<mlir::stablehlo::DivOp>(loc, z, lanczosPlusHalf));
  Value logT = rewriter.create<mlir::stablehlo::AddOp>(loc, logTerm, log1pTerm);

  // Note that t(z) may be large and we need to be careful not to overflow to
  // infinity in the relevant term
  //   r = (z + 1/2) * log(t(z)) - t(z).
  // Therefore, we compute this as
  //   r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
  Value tDivLogT = rewriter.create<mlir::stablehlo::DivOp>(loc, t, logT);
  Value sum = rewriter.create<mlir::stablehlo::SubtractOp>(
      loc, rewriter.create<mlir::stablehlo::AddOp>(loc, z, half), tDivLogT);
  Value r = rewriter.create<mlir::stablehlo::MulOp>(loc, sum, logT);

  // Compute the final result (modulo reflection) as
  //   lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
  Value logA = rewriter.create<mlir::stablehlo::LogOp>(loc, a);
  Value lgamma = rewriter.create<mlir::stablehlo::AddOp>(
      loc,
      rewriter.create<mlir::stablehlo::AddOp>(
          loc,
          getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
          r),
      logA);

  // Compute the reflected value for x < 0.5 as
  //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
  //
  // The abs is needed because lgamma is the log of the absolute value of the
  // gamma function.
  //
  // We have to be careful when computing the final term above. gamma(x) goes
  // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
  // term. The slope is large, so precision is particularly important.
  //
  // Because abs(sin(pi * x)) has period of 1 we can equivalently use
  // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
  // more numerically accurate: It doesn't overflow to inf like pi * x would and
  // if x is an integer it evaluates to exactly 0 which is important because we
  // then take the log of this value, and log(0) is inf.
  //
  // We don't have a frac(x) primitive in HLO and computing it is tricky, but
  // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
  // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
  //
  // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
  // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
  // [0, 1] is symmetric across the line Y=0.5.
  //

  // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
  // pi * abs_frac for values of abs_frac close to 1.
  Value abs = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
  Value absFrac = rewriter.create<mlir::stablehlo::SubtractOp>(
      loc, abs, rewriter.create<mlir::stablehlo::FloorOp>(loc, abs));
  Value reduceAbsFrac = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, half, absFrac, mlir::stablehlo::ComparisonDirection::LT);
  absFrac = rewriter.create<mlir::stablehlo::SelectOp>(
      loc, reduceAbsFrac,
      rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, absFrac), absFrac);

  // Materialize reflection.
  Value reflectionDenom = rewriter.create<mlir::stablehlo::LogOp>(
      loc,
      rewriter.create<mlir::stablehlo::SineOp>(
          loc, rewriter.create<mlir::stablehlo::MulOp>(
                   loc, getConstantLike(rewriter, loc, M_PI, x), absFrac)));
  Value lgammaReflection = rewriter.create<mlir::stablehlo::SubtractOp>(
      loc,
      rewriter.create<mlir::stablehlo::SubtractOp>(
          loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
          reflectionDenom),
      lgamma);

  // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
  // then it "wins" and the result is +/-inf.
  Value finiteReflectionDenom =
      rewriter.create<mlir::stablehlo::IsFiniteOp>(loc, reflectionDenom);
  Value negReflectionDenom =
      rewriter.create<mlir::stablehlo::NegOp>(loc, reflectionDenom);
  lgammaReflection = rewriter.create<mlir::stablehlo::SelectOp>(
      loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom);

  // Select whether or not to rely on the reflection.
  lgamma = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect,
                                                      lgammaReflection, lgamma);

  // Materialize +/-inf behavior as
  //   lgamma(+/-inf) = +inf.
  Value xIsInf = rewriter.create<chlo::IsInfOp>(loc, x);
  return rewriter.create<mlir::stablehlo::SelectOp>(
      loc, xIsInf,
      getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), lgamma);
}

// Express `cosh` as
//   cosh(x) = (e^x + e^-x) / 2
//           = e^(x + log(1/2)) + e^(-x + log(1/2))
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
//
// This incorrectly overflows to inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2).  The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
static Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
                                          Location loc, ValueRange operands) {
  mlir::chlo::CoshOp::Adaptor transformed(operands);
  Value x = transformed.getOperand();

  Value logOneHalf = rewriter.create<mlir::stablehlo::LogOp>(
      loc, getConstantLike(rewriter, loc, 0.5, x));
  Value expAdd = rewriter.create<mlir::stablehlo::ExpOp>(
      loc, rewriter.create<mlir::stablehlo::AddOp>(loc, x, logOneHalf));
  Value expSub = rewriter.create<mlir::stablehlo::ExpOp>(
      loc, rewriter.create<mlir::stablehlo::SubtractOp>(loc, logOneHalf, x));
  return rewriter.create<mlir::stablehlo::AddOp>(loc, expAdd, expSub);
}

struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::CoshOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    rewriter.replaceOp(
        op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
                                  rewriter.getF32Type(),
                                  &materializeCoshApproximation));
    return success();
  }
};

// Compute the Digamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
//   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
//   with   t(z) = z + kLanczosGamma + 1/2
//          a(z) = kBaseLanczosCoeff
//                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
//          a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
static Value materializeDigamma(ConversionPatternRewriter &rewriter,
                                Location loc, ValueRange args) {
  // If the input is less than 0.5 use Euler's reflection formula.
  //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
  // Let z be
  //   z = -x      if x < 1/2
  //   z = x - 1   otheriwse
  Value x = args.front();
  Value half = getConstantLike(rewriter, loc, 0.5, x);
  Value needToReflect = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, x, half, mlir::stablehlo::ComparisonDirection::LT);
  Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
  Value one = getConstantLike(rewriter, loc, 1, x);
  Value xSubOne = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, one);
  Value z = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, negX,
                                                       xSubOne);

  // Materialize
  //   a(z) = kBaseLanczosCoeff
  //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
  //   a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
  Value zero = getConstantLike(rewriter, loc, 0.0, x);
  Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
  Value aPrime = zero;
  for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
    Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
    Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
    Value zTerm =
        rewriter.create<mlir::stablehlo::AddOp>(loc, z, oneBasedIndex);
    aPrime = rewriter.create<mlir::stablehlo::SubtractOp>(
        loc, aPrime,
        rewriter.create<mlir::stablehlo::DivOp>(
            loc, coeff,
            rewriter.create<mlir::stablehlo::MulOp>(loc, zTerm, zTerm)));
    a = rewriter.create<mlir::stablehlo::AddOp>(
        loc, a, rewriter.create<mlir::stablehlo::DivOp>(loc, coeff, zTerm));
  }

  // To improve accuracy on platforms with less-precise log implementations,
  // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
  // device.
  // Materialize as
  //   log(t) = log(kLanczosGamma + 1/2 + z)
  //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
  Value lanczosPlusHalf =
      getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
  Value t = rewriter.create<mlir::stablehlo::AddOp>(loc, lanczosPlusHalf, z);
  Value logTerm =
      getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
  Value log1pTerm = rewriter.create<mlir::stablehlo::Log1pOp>(
      loc, rewriter.create<mlir::stablehlo::DivOp>(loc, z, lanczosPlusHalf));
  Value logT = rewriter.create<mlir::stablehlo::AddOp>(loc, logTerm, log1pTerm);

  // Materialize the final result (modulo reflection) as
  //   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
  Value aPrimeDivA = rewriter.create<mlir::stablehlo::DivOp>(loc, aPrime, a);
  Value lanczosGammaDivT = rewriter.create<mlir::stablehlo::DivOp>(
      loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
  Value digamma = rewriter.create<mlir::stablehlo::SubtractOp>(
      loc, rewriter.create<mlir::stablehlo::AddOp>(loc, logT, aPrimeDivA),
      lanczosGammaDivT);

  // We need to be careful how we compute cot(pi * input) below: For
  // near-integral arguments, pi * input can lose precision.
  //
  // Input is already known to be less than 0.5 (otherwise we don't have to
  // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
  // increase precision of pi * x and the resulting cotangent.
  Value reducedX = rewriter.create<mlir::stablehlo::AddOp>(
      loc, x,
      rewriter.create<mlir::stablehlo::AbsOp>(
          loc, rewriter.create<mlir::stablehlo::FloorOp>(
                   loc, rewriter.create<mlir::stablehlo::AddOp>(
                            loc, x, getConstantLike(rewriter, loc, 0.5, x)))));

  // Materialize reflection for inputs less than 0.5 as
  //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
  //              = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
  Value pi = getConstantLike(rewriter, loc, M_PI, x);
  Value piMulReducedX =
      rewriter.create<mlir::stablehlo::MulOp>(loc, pi, reducedX);
  Value cos = rewriter.create<mlir::stablehlo::CosineOp>(loc, piMulReducedX);
  Value sin = rewriter.create<mlir::stablehlo::SineOp>(loc, piMulReducedX);
  Value reflection = rewriter.create<mlir::stablehlo::SubtractOp>(
      loc, digamma,
      rewriter.create<mlir::stablehlo::DivOp>(
          loc, rewriter.create<mlir::stablehlo::MulOp>(loc, pi, cos), sin));

  // Select whether or not to rely on the reflection.
  digamma = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect,
                                                       reflection, digamma);

  // Digamma has poles at negative integers and zero; return nan for those.
  Value isLeZero = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, x, zero, mlir::stablehlo::ComparisonDirection::LE);
  Value isInt = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x),
      mlir::stablehlo::ComparisonDirection::EQ);
  Value isPole = rewriter.create<mlir::stablehlo::AndOp>(loc, isLeZero, isInt);
  return rewriter.create<mlir::stablehlo::SelectOp>(
      loc, isPole,
      getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
                      x),
      digamma);
}

static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc,
                                                Value val) {
  auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
  return getConstantLike(
      b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
}

static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
                             ValueRange args) {
  // Code should match StableHLO's materializeZeta
  assert(args.size() == 2);
  Value x = args[0];
  Value q = args[1];
  static const std::array<double, 12> kZetaCoeffs{
      -7.1661652561756670113e18,
      1.8152105401943546773e17,
      -4.5979787224074726105e15,
      1.1646782814350067249e14,
      -2.950130727918164224e12,
      7.47242496e10,
      -1.8924375803183791606e9,
      47900160.0,
      -1209600.0,
      30240.0,
      -720.0,
      12.0,
  };

  // For speed we'll always use 9 iterations for the initial series estimate,
  // and a 12 term expansion for the Euler-Maclaurin formula.
  Value a = q;
  Value zero = getConstantLike(rewriter, loc, 0.0, a);
  Value negPower = zero;
  Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
  Value initialSum = rewriter.create<mlir::stablehlo::PowOp>(loc, q, negX);
  Value one = getConstantLike(rewriter, loc, 1.0, a);
  for (int i = 0; i < 9; ++i) {
    a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, one);
    negPower = rewriter.create<mlir::stablehlo::PowOp>(loc, a, negX);
    initialSum =
        rewriter.create<mlir::stablehlo::AddOp>(loc, initialSum, negPower);
  }

  a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, one);
  negPower = rewriter.create<mlir::stablehlo::PowOp>(loc, a, negX);
  Value oneLikeX = getConstantLike(rewriter, loc, 1.0, x);
  Value xMinusOne =
      rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, oneLikeX);
  Value negPowerMulA =
      rewriter.create<mlir::stablehlo::MulOp>(loc, negPower, a);
  Value negPowerMulADivXMinusOne =
      rewriter.create<mlir::stablehlo::DivOp>(loc, negPowerMulA, xMinusOne);
  Value s = rewriter.create<mlir::stablehlo::AddOp>(loc, initialSum,
                                                    negPowerMulADivXMinusOne);
  Value aInverseSquare = rewriter.create<mlir::stablehlo::DivOp>(
      loc, one, rewriter.create<mlir::stablehlo::MulOp>(loc, a, a));

  Value hornerSum = zero;
  Value factor = one;
  // Use Horner's rule for this.
  // Note this differs from Cephes which does a 'naive' polynomial evaluation.
  // Using Horner's rule allows to avoid some NaN's and Infs from happening,
  // resulting in more numerically stable code.
  for (int i = 0; i < 11; ++i) {
    Value factorLhs = rewriter.create<mlir::stablehlo::AddOp>(
        loc, x, getConstantLike(rewriter, loc, 22 - 2 * i, x));
    Value factorRhs = rewriter.create<mlir::stablehlo::AddOp>(
        loc, x, getConstantLike(rewriter, loc, 21 - 2 * i, x));
    factor = rewriter.create<mlir::stablehlo::MulOp>(loc, factorLhs, factorRhs);
    hornerSum = rewriter.create<mlir::stablehlo::MulOp>(
        loc, factor,
        rewriter.create<mlir::stablehlo::MulOp>(
            loc, aInverseSquare,
            rewriter.create<mlir::stablehlo::AddOp>(
                loc, hornerSum,
                getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
  }
  Value zeroPointFiveLikeNegPower =
      getConstantLike(rewriter, loc, .5, negPower);
  Value xDivA = rewriter.create<mlir::stablehlo::DivOp>(loc, x, a);
  s = rewriter.create<mlir::stablehlo::AddOp>(
      loc, s,
      rewriter.create<mlir::stablehlo::MulOp>(
          loc, negPower,
          rewriter.create<mlir::stablehlo::AddOp>(
              loc, zeroPointFiveLikeNegPower,
              rewriter.create<mlir::stablehlo::MulOp>(
                  loc, xDivA,
                  rewriter.create<mlir::stablehlo::AddOp>(
                      loc,
                      getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], a),
                      hornerSum)))));

  // Use the initial zeta sum without the correction term coming
  // from Euler-Maclaurin if it is accurate enough.
  Value absNegPower = rewriter.create<mlir::stablehlo::AbsOp>(loc, negPower);
  Value absInitialSum =
      rewriter.create<mlir::stablehlo::AbsOp>(loc, initialSum);
  Value output = rewriter.create<mlir::stablehlo::SelectOp>(
      loc,
      rewriter.create<mlir::stablehlo::CompareOp>(
          loc, absNegPower,
          rewriter.create<mlir::stablehlo::MulOp>(
              loc, absInitialSum,
              getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
          mlir::stablehlo::ComparisonDirection::LT),
      initialSum, s);

  // Function is not defined for x < 1.
  Value nan = getConstantLike(rewriter, loc,
                              std::numeric_limits<double>::quiet_NaN(), x);
  output = rewriter.create<mlir::stablehlo::SelectOp>(
      loc,
      rewriter.create<mlir::stablehlo::CompareOp>(
          loc, x, oneLikeX, mlir::stablehlo::ComparisonDirection::LT),
      nan, output);

  // For q <= 0, x must be an integer.
  Value qLeZero = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, q, zero, mlir::stablehlo::ComparisonDirection::LE);
  Value xNotInt = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x),
      mlir::stablehlo::ComparisonDirection::NE);
  Value xDomainError =
      rewriter.create<mlir::stablehlo::AndOp>(loc, qLeZero, xNotInt);
  output = rewriter.create<mlir::stablehlo::SelectOp>(loc, xDomainError, nan,
                                                      output);

  // For all integer q <= 0, zeta has a pole. The limit is only defined as
  // +inf if x is and even integer.
  Value inf = getConstantLike(rewriter, loc,
                              std::numeric_limits<double>::infinity(), x);
  Value qIsInt = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, q, rewriter.create<mlir::stablehlo::FloorOp>(loc, q),
      mlir::stablehlo::ComparisonDirection::EQ);
  Value atPole = rewriter.create<mlir::stablehlo::AndOp>(loc, qLeZero, qIsInt);
  Value two = getConstantLike(rewriter, loc, 2.0, x);
  Value xIsInt = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x),
      mlir::stablehlo::ComparisonDirection::EQ);
  Value xIsEven = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, rewriter.create<mlir::stablehlo::RemOp>(loc, x, two), zero,
      mlir::stablehlo::ComparisonDirection::EQ);
  Value xIsEvenInt =
      rewriter.create<mlir::stablehlo::AndOp>(loc, xIsInt, xIsEven);
  output = rewriter.create<mlir::stablehlo::SelectOp>(
      loc, atPole,
      rewriter.create<mlir::stablehlo::SelectOp>(loc, xIsEvenInt, inf, nan),
      output);

  // For x = 1, this is the harmonic series and diverges.
  output = rewriter.create<mlir::stablehlo::SelectOp>(
      loc,
      rewriter.create<mlir::stablehlo::CompareOp>(
          loc, x, one, mlir::stablehlo::ComparisonDirection::EQ),
      inf, output);

  return output;
}

static Value materializePolygamma(ConversionPatternRewriter &rewriter,
                                  Location loc, ValueRange args) {
  mlir::chlo::PolygammaOp::Adaptor transformed(args);
  Value n = transformed.getN();
  Value x = transformed.getX();

  // Handle integer n > 0.
  Value one = getConstantLike(rewriter, loc, 1.0, x);
  Value two = getConstantLike(rewriter, loc, 2.0, x);
  Value sign = rewriter.create<mlir::stablehlo::SubtractOp>(
      loc,
      rewriter.create<mlir::stablehlo::MulOp>(
          loc, two, rewriter.create<mlir::stablehlo::RemOp>(loc, n, two)),
      one);
  Value nPlusOne = rewriter.create<mlir::stablehlo::AddOp>(loc, n, one);
  Value expLgammaNp1 = rewriter.create<mlir::stablehlo::ExpOp>(
      loc, rewriter.create<chlo::LgammaOp>(loc, nPlusOne));
  Value zeta = rewriter.create<chlo::ZetaOp>(loc, nPlusOne, x);
  Value result = rewriter.create<mlir::stablehlo::MulOp>(
      loc, rewriter.create<mlir::stablehlo::MulOp>(loc, sign, expLgammaNp1),
      zeta);

  // Handle n = 0.
  Value zero = getConstantLike(rewriter, loc, 0.0, x);
  Value nEqZero = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, n, zero, mlir::stablehlo::ComparisonDirection::EQ);
  result = rewriter.create<mlir::stablehlo::SelectOp>(
      loc, nEqZero, rewriter.create<chlo::DigammaOp>(loc, x), result);

  // Check that n is a natural number. Return nan, otherwise.
  Value nonInt = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, n, rewriter.create<mlir::stablehlo::FloorOp>(loc, n),
      mlir::stablehlo::ComparisonDirection::NE);
  Value negative = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, n, zero, mlir::stablehlo::ComparisonDirection::LT);
  Value nonNatural =
      rewriter.create<mlir::stablehlo::OrOp>(loc, nonInt, negative);
  return rewriter.create<mlir::stablehlo::SelectOp>(
      loc, nonNatural,
      getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
                      x),
      result);
}

struct ConvertLgammaOp final : OpConversionPattern<mlir::chlo::LgammaOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::LgammaOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    FloatType minPrecisionTy = rewriter.getF32Type();
    rewriter.replaceOp(
        op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
                                  minPrecisionTy, &materializeLgamma));
    return success();
  }
};

struct ConvertDigammaOp final : OpConversionPattern<mlir::chlo::DigammaOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::DigammaOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    FloatType minPrecisionTy = rewriter.getF32Type();
    rewriter.replaceOp(
        op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
                                  minPrecisionTy, &materializeDigamma));
    return success();
  }
};

static Value materializeNextAfter(ConversionPatternRewriter &rewriter,
                                  Location loc, ValueRange operands) {
  mlir::chlo::NextAfterOp::Adaptor transformed(operands);
  Value x = transformed.getX();
  Value y = transformed.getY();
  auto resultTy = cast<ShapedType>(x.getType());
  auto bitwidth = resultTy.getElementType().getIntOrFloatBitWidth();
  mlir::ImplicitLocOpBuilder b(loc, rewriter);
  Type intTy = resultTy.clone(b.getIntegerType(bitwidth));
  auto xAsInt = b.create<mlir::stablehlo::BitcastConvertOp>(intTy, x);
  auto yAsInt = b.create<mlir::stablehlo::BitcastConvertOp>(intTy, y);

  // The result is NaN if either "x" or "y" are NaN.
  auto xIsNan = b.create<mlir::stablehlo::CompareOp>(
      x, x, mlir::stablehlo::ComparisonDirection::NE);
  auto yIsNan = b.create<mlir::stablehlo::CompareOp>(
      y, y, mlir::stablehlo::ComparisonDirection::NE);
  auto nanInput = b.create<mlir::stablehlo::OrOp>(xIsNan, yIsNan);
  auto resultForNan = getConstantLike(
      rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
  auto resultForNanAsInt =
      b.create<mlir::stablehlo::BitcastConvertOp>(intTy, resultForNan);

  // The sign bit is the MSB.
  const int64_t signBit = int64_t{1} << (bitwidth - 1);
  // Discard the sign bit to make the result non-negative.
  Value signMask = getConstantLike(rewriter, loc, signBit, xAsInt);
  Value negatedSignMask = getConstantLike(rewriter, loc, ~signBit, xAsInt);
  auto xAbs = b.create<mlir::stablehlo::AndOp>(xAsInt, negatedSignMask);
  auto yAbs = b.create<mlir::stablehlo::AndOp>(yAsInt, negatedSignMask);

  // When both "x" and "y" are equal, the result is "y".
  auto xAndYAreEqual = b.create<mlir::stablehlo::CompareOp>(
      x, y, mlir::stablehlo::ComparisonDirection::EQ);
  auto resultForEqual = yAsInt;

  // When both "x" and "y" are 0, the result is "y". This is a separate case
  // from above because "x" and "y" might have a different sign.
  Value zero = getConstantLike(rewriter, loc, 0, xAsInt);
  auto xIsZero = b.create<mlir::stablehlo::CompareOp>(
      xAbs, zero, mlir::stablehlo::ComparisonDirection::EQ);
  auto yIsZero = b.create<mlir::stablehlo::CompareOp>(
      yAbs, zero, mlir::stablehlo::ComparisonDirection::EQ);
  auto resultForBothZero = yAsInt;

  auto xSign = b.create<mlir::stablehlo::AndOp>(xAsInt, signMask);
  auto ySign = b.create<mlir::stablehlo::AndOp>(yAsInt, signMask);

  // If from == 0 && to != 0, we need to return the smallest subnormal number
  // signed like "to".
  Value one = getConstantLike(rewriter, loc, 1, xAsInt);
  auto resultForXZeroYNonZero = b.create<mlir::stablehlo::OrOp>(ySign, one);

  // If the sign of "x" and "y" disagree:
  // - we need to make the magnitude of "from" smaller so that it is closer to
  //   zero.
  //
  // Otherwise the signs agree:
  // - "x" with a magnitude larger than "y" means we need to make the magnitude
  //   smaller.
  // - "x" with a magnitude smaller than "y" means we need to make the magnitude
  //   larger.
  auto signsDisagree = b.create<mlir::stablehlo::CompareOp>(
      xSign, ySign, mlir::stablehlo::ComparisonDirection::NE);
  auto xMagnitudeLargerThanY = b.create<mlir::stablehlo::CompareOp>(
      xAbs, yAbs, mlir::stablehlo::ComparisonDirection::GT);
  auto resultHasSmallerMagnitude =
      b.create<mlir::stablehlo::OrOp>(xMagnitudeLargerThanY, signsDisagree);
  auto minusOne = getConstantLike(rewriter, loc, -1, xAsInt);
  auto magnitudeAdjustment = b.create<mlir::stablehlo::SelectOp>(
      resultHasSmallerMagnitude, minusOne, one);
  Value result = b.create<mlir::stablehlo::AddOp>(xAsInt, magnitudeAdjustment);
  // Handle from == +-0.
  result = b.create<mlir::stablehlo::SelectOp>(
      xIsZero,
      b.create<mlir::stablehlo::SelectOp>(yIsZero, resultForBothZero,
                                          resultForXZeroYNonZero),
      result);
  // Handle from == to.
  result = b.create<mlir::stablehlo::SelectOp>(xAndYAreEqual, resultForEqual,
                                               result);
  // Handle isnan(x) || isnan(y).
  result =
      b.create<mlir::stablehlo::SelectOp>(nanInput, resultForNanAsInt, result);

  // Cast back to the original type.
  return b.create<mlir::stablehlo::BitcastConvertOp>(resultTy, result);
}

struct ConvertNextAfterOp final : OpConversionPattern<mlir::chlo::NextAfterOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::NextAfterOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    rewriter.replaceOp(
        op, materializeNextAfter(rewriter, op.getLoc(), adaptor.getOperands()));
    return success();
  }
};

struct ConvertPolygammaOp final : OpConversionPattern<mlir::chlo::PolygammaOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::PolygammaOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    FloatType minPrecisionTy = rewriter.getF32Type();
    rewriter.replaceOp(
        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
                                  minPrecisionTy, materializePolygamma));
    return success();
  }
};

// Sinh(x) = (e^x - e^-x) / 2
//         = e^(x + log(1/2)) - e^(-x + log(1/2)).
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
// inf.
//
// This incorrectly overflows to +/-inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2).  The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
static Value
materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
                                      Location loc, ValueRange operands) {
  mlir::chlo::SinhOp::Adaptor transformed(operands);
  Value x = transformed.getOperand();

  Value logOneHalf = rewriter.create<mlir::stablehlo::LogOp>(
      loc, getConstantLike(rewriter, loc, 0.5, x));
  Value expAdd = rewriter.create<mlir::stablehlo::ExpOp>(
      loc, rewriter.create<mlir::stablehlo::AddOp>(loc, x, logOneHalf));
  Value expSub = rewriter.create<mlir::stablehlo::ExpOp>(
      loc, rewriter.create<mlir::stablehlo::SubtractOp>(loc, logOneHalf, x));
  return rewriter.create<mlir::stablehlo::SubtractOp>(loc, expAdd, expSub);
}

// Express `sinh` as
//   sinh(x) = (e^x - e^-x) / 2                     if |x| < 1
//           = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
static Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
                                          Location loc, ValueRange operands) {
  Value largeSinhResult =
      materializeSinhApproximationForLargeX(rewriter, loc, operands);

  mlir::chlo::SinhOp::Adaptor transformed(operands);
  Value x = transformed.getOperand();

  // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
  // 0.
  // Rewrite this to avoid that. We use expm1(x) because that preserves the
  // first order term of the taylor series of e^x.
  // (e^(x) - e^(-x)) / 2. =
  // (e^(x) - 1 + 1 - e^(-x)) / 2.
  // (expm1(x) + (e^(x) - 1) / e^x) / 2.
  // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
  Value expm1 = rewriter.create<mlir::stablehlo::Expm1Op>(loc, x);
  Value one = getConstantLike(rewriter, loc, 1.0, x);
  Value oneHalf = getConstantLike(rewriter, loc, 0.5, x);
  Value expm1PlusOne = rewriter.create<mlir::stablehlo::AddOp>(loc, expm1, one);
  Value ratio =
      rewriter.create<mlir::stablehlo::DivOp>(loc, expm1, expm1PlusOne);
  Value sum = rewriter.create<mlir::stablehlo::AddOp>(loc, expm1, ratio);
  Value smallSinhResult =
      rewriter.create<mlir::stablehlo::MulOp>(loc, oneHalf, sum);

  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
  Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
      loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
  return rewriter.create<mlir::stablehlo::SelectOp>(
      loc, absXLtOne, smallSinhResult, largeSinhResult);
}

struct ConvertSinhOp final : OpConversionPattern<mlir::chlo::SinhOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::SinhOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Value x = adaptor.getOperand();
    if (cast<ShapedType>(x.getType()).getElementType().isa<ComplexType>()) {
      rewriter.replaceOp(op, materializeSinhApproximationForLargeX(
                                 rewriter, op.getLoc(), adaptor.getOperands()));
      return success();
    }
    rewriter.replaceOp(
        op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
                                  rewriter.getF32Type(),
                                  &materializeSinhApproximation));
    return success();
  }
};

// Converts chlo.top_k to HLO iota, sort, and slice ops.
//
// chlo.top_k sorts along last dimension of the input tensor and then returns
// the top K components' values and indices. This is translated into a few
// ops in HLO: first generating an integer sequence for the indices,
// then sort both the original input tensor and the indices together, and
// at last slice out the top K components.
//
// For example, for the following IR:
//
// %0:2 = "chlo.top_k"(%input, k=8): tensor<16x16xf32> ->
//                                   (tensor<16x8xf32>, tensor<16x8xi32>)
//
// We will get:
//
// %1 = "hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
// %2 = "hlo.sort"(%input, %1) ({
// ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
//      %arg3: tensor<i32>, %arg4: tensor<i32>):
//   %7 = "hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
//   "hlo.return"(%7) : (tensor<i1>) -> ()
// }) {dimension = 1 : i64, is_stable = true} : ...
// %3 = "hlo.get_tuple_element"(%2) {index = 0 : i32} : ...
// %4 = "hlo.get_tuple_element"(%2) {index = 1 : i32} : ...
// %5 = "hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
//                           start_indices dense<0> : tensor<2xi64>,
//                           strides = dense<1> : tensor<2xi64>} :
//                              (tensor<16x16xf32>) -> tensor<16x8xf32>
// %6 = "hlo.slice"(%4) ...
//
// TODO(b/284078162): Decide what to do with this pattern given that we now
// have mlir::stablehlo::TopKOp. No action needed for now given that
// mlir::stablehlo::TopKOp is currently categorized as
// `hasPrivateFeaturesNotInStablehlo`.
struct ConvertTopKOp final : OpConversionPattern<mlir::chlo::TopKOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::TopKOp op, OpAdaptor /*adaptor*/,
                  ConversionPatternRewriter &rewriter) const override {
    auto operandType = dyn_cast<RankedTensorType>(op.getOperand().getType());
    if (!operandType)
      return failure();
    int64_t operandRank = operandType.getRank();
    int64_t lastDimIndex = operandRank - 1;
    int64_t lastDimSize = operandType.getDimSize(lastDimIndex);
    int64_t lastDimResultSize =
        mlir::hlo::isDynamicDimSize(lastDimSize)
            ? static_cast<int64_t>(op.getK())
            : std::min(static_cast<int64_t>(op.getK()), lastDimSize);
    int64_t isDynamic = !operandType.hasStaticShape();
    auto i32Type = rewriter.getIntegerType(32);
    Value opShapeValue, resultShapeValue;
    if (isDynamic) {
      SmallVector<Value> sizesI32x1;
      for (auto i = 0; i < operandType.getRank(); ++i) {
        auto sizeI32 = rewriter.create<mlir::stablehlo::GetDimensionSizeOp>(
            op.getLoc(), op.getOperand(), i);
        auto sizeI32x1 = rewriter.create<mlir::stablehlo::ReshapeOp>(
            op.getLoc(), RankedTensorType::get({1}, i32Type), sizeI32);
        sizesI32x1.push_back(sizeI32x1);
      }
      opShapeValue = rewriter.create<mlir::stablehlo::ConcatenateOp>(
          op.getLoc(), sizesI32x1,
          /*dimension=*/0);
      auto lastDimI32 = rewriter.create<mlir::stablehlo::ConstantOp>(
          op.getLoc(),
          rewriter.getI32IntegerAttr(static_cast<int32_t>(lastDimResultSize)));
      auto lastDimI32x1 = rewriter.create<mlir::stablehlo::ReshapeOp>(
          op.getLoc(), RankedTensorType::get({1}, i32Type), lastDimI32);
      sizesI32x1.back() = lastDimI32x1;
      resultShapeValue = rewriter.create<mlir::stablehlo::ConcatenateOp>(
          op.getLoc(), sizesI32x1,
          /*dimension=*/0);
    }

    // Create an Iota op for indices.
    Type iotaType = RankedTensorType::get(operandType.getShape(), i32Type);
    Value iotaOp;
    if (isDynamic) {
      iotaOp = rewriter.create<mlir::stablehlo::DynamicIotaOp>(
          op.getLoc(), iotaType, opShapeValue,
          rewriter.getI64IntegerAttr(lastDimIndex));
    } else {
      iotaOp = rewriter.create<mlir::stablehlo::IotaOp>(
          op.getLoc(), iotaType, rewriter.getI64IntegerAttr(lastDimIndex));
    }

    // Create the sort op. It takes two inputs, one for the original input, the
    // other for the indices. Use TOTALORDER comparison type instead of the
    // default comparison if the element type is of type float.
    Type elementType = operandType.getElementType();
    mlir::stablehlo::SortOp sortOp =
        createSortOp(&rewriter, op.getLoc(), {op.getOperand(), iotaOp},
                     {elementType, i32Type}, lastDimIndex,
                     /*isStable=*/true,
                     /*direction=*/mlir::stablehlo::ComparisonDirection::GT);

    // Get the sorted input and index tuple element.
    Value tupleFirstElement = sortOp.getResult(0);
    Value tupleSecondElement = sortOp.getResult(1);

    SmallVector<int64_t> beginIndices(operandRank, 0);
    auto endIndices = llvm::to_vector(operandType.getShape());
    endIndices.back() = lastDimResultSize;
    SmallVector<int64_t> strides(operandRank, 1);

    // Get the slice for the top K elements.
    auto indicesTy = RankedTensorType::get(operandRank, rewriter.getI64Type());
    Value values, indices;
    if (isDynamic) {
      Value startIndices = rewriter.create<mlir::stablehlo::ConstantOp>(
          op.getLoc(), DenseIntElementsAttr::get(indicesTy, beginIndices));
      Value lastIndices = rewriter.create<mlir::stablehlo::ConvertOp>(
          op.getLoc(), resultShapeValue, rewriter.getI64Type());
      Value stridesOp = rewriter.create<mlir::stablehlo::ConstantOp>(
          op.getLoc(), DenseIntElementsAttr::get(indicesTy, strides));

      SmallVector<int64_t> resultShape =
          llvm::to_vector(operandType.getShape());
      resultShape.back() = lastDimResultSize;
      RankedTensorType resultType = RankedTensorType::get(
          resultShape, elementType, operandType.getEncoding());
      RankedTensorType indexResultType =
          RankedTensorType::get(resultShape, i32Type);

      values = rewriter.create<mlir::stablehlo::RealDynamicSliceOp>(
          op.getLoc(), resultType, tupleFirstElement, startIndices, lastIndices,
          stridesOp);
      indices = rewriter.create<mlir::stablehlo::RealDynamicSliceOp>(
          op.getLoc(), indexResultType, tupleSecondElement, startIndices,
          lastIndices, stridesOp);
    } else {
      values = rewriter.create<mlir::stablehlo::SliceOp>(
          op.getLoc(), tupleFirstElement,
          rewriter.getDenseI64ArrayAttr(beginIndices),
          rewriter.getDenseI64ArrayAttr(endIndices),
          rewriter.getDenseI64ArrayAttr(strides));
      indices = rewriter.create<mlir::stablehlo::SliceOp>(
          op.getLoc(), tupleSecondElement,
          rewriter.getDenseI64ArrayAttr(beginIndices),
          rewriter.getDenseI64ArrayAttr(endIndices),
          rewriter.getDenseI64ArrayAttr(strides));
    }

    rewriter.replaceOp(op, {values, indices});
    return success();
  }
};

struct ConvertZetaOp final : OpConversionPattern<mlir::chlo::ZetaOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::chlo::ZetaOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    FloatType minPrecisionTy = rewriter.getF32Type();
    rewriter.replaceOp(
        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
                                  minPrecisionTy, &materializeZeta));
    return success();
  }
};

//===----------------------------------------------------------------------===//
// Pass Definition.
//===----------------------------------------------------------------------===//

struct LegalizeChlo final : impl::LegalizeChloBase<LegalizeChlo> {
  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<mlir::scf::SCFDialect, mlir::shape::ShapeDialect,
                    mlir::stablehlo::StablehloDialect,
                    mlir::tensor::TensorDialect>();
  }

  void runOnOperation() override {
    MLIRContext *ctx = &getContext();
    {
      ConversionTarget conversionTarget(getContext());
      RewritePatternSet conversionPatterns(ctx);
      conversionTarget.addIllegalDialect<chlo::ChloDialect>();
      conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
      conversionTarget.addLegalDialect<
          mlir::stablehlo::StablehloDialect, mlir::arith::ArithDialect,
          mlir::shape::ShapeDialect, mlir::scf::SCFDialect,
          mlir::tensor::TensorDialect>();

      populateLegalizeChloPatterns(ctx, &conversionPatterns);
      if (failed(applyPartialConversion(getOperation(), conversionTarget,
                                        std::move(conversionPatterns)))) {
        return signalPassFailure();
      }
    }

    {
      // Add canonicalization patterns to simplify produced ops from other
      // dialects.
      RewritePatternSet patterns(ctx);
      populateCanonicalizationPatterns(ctx, &patterns);
      mlir::shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx);
      mlir::shape::ShapeOfOp::getCanonicalizationPatterns(patterns, ctx);
      mlir::shape::BroadcastOp::getCanonicalizationPatterns(patterns, ctx);
      mlir::shape::CstrBroadcastableOp::getCanonicalizationPatterns(patterns,
                                                                    ctx);
      mlir::tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
      if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                              std::move(patterns)))) {
        return signalPassFailure();
      }
    }
  }
};
} // namespace

namespace {
#include "compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.h.inc"
} // end anonymous namespace

namespace {
static void populateBroadcastingPatterns(MLIRContext *context,
                                         RewritePatternSet *patterns) {
  // Instantiate conversion templates for conforming binary elementwise ops
  // that do not have different dtypes between operands and results and do
  // not have special attributes that need to be preserved.
  populateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
      context, patterns, 10);
  populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
      context, patterns, 5);
  patterns
      ->add<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
          context);
}

static void populateDecompositionPatterns(MLIRContext *context,
                                          RewritePatternSet *patterns) {
  populateWithGenerated(*patterns);
  patterns->add<ConvertConstantOp, ConvertBesselI1eOp, ConvertCoshOp,
                ConvertDigammaOp, ConvertErfOp, ConvertErfcOp, ConvertErfInvOp,
                ConvertLgammaOp, ConvertNextAfterOp, ConvertPolygammaOp,
                ConvertSinhOp, ConvertTopKOp, ConvertZetaOp>(context);
}
} // namespace

void populateLegalizeChloPatterns(MLIRContext *context,
                                  RewritePatternSet *patterns) {
  populateBroadcastingPatterns(context, patterns);
  populateDecompositionPatterns(context, patterns);
}
} // namespace mlir::iree_compiler::stablehlo
