[StableHLO][CHLO]Refactor CHLO decompositions to follow upstream StableHLO (#21682)

Refactored the CHLO decompositions to align with the existing StableHLO
upstream
https://github.com/openxla/stablehlo/blob/main/stablehlo/transforms/ChloDecompositionPatternsMath.td
updated the test files accordingly by registering the StableHLO upstream
pass.

---------

Signed-off-by: LekkalaSravya3 <lekkala.sravya@multicorewareinc.com>
diff --git a/compiler/plugins/input/StableHLO/BUILD.bazel b/compiler/plugins/input/StableHLO/BUILD.bazel
index da1abae..3309373 100644
--- a/compiler/plugins/input/StableHLO/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/BUILD.bazel
@@ -32,6 +32,7 @@
         "@llvm-project//mlir:Transforms",
         "@stablehlo//:chlo_ops",
         "@stablehlo//:stablehlo_ops",
+        "@stablehlo//:stablehlo_passes",
         "@stablehlo//:vhlo_ops",
     ],
 )
diff --git a/compiler/plugins/input/StableHLO/CMakeLists.txt b/compiler/plugins/input/StableHLO/CMakeLists.txt
index 1875e08..c27c580 100644
--- a/compiler/plugins/input/StableHLO/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/CMakeLists.txt
@@ -33,6 +33,7 @@
     MLIRShapeDialect
     MLIRTransforms
     StablehloOps
+    StablehloPasses
     VhloOps
     iree::compiler::PluginAPI
     iree::compiler::plugins::input::StableHLO::Conversion
diff --git a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
index 391f5ce..7231138 100644
--- a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
@@ -42,28 +42,11 @@
     ],
 )
 
-iree_gentbl_cc_library(
-    name = "CHLODecompositionPatterns",
-    tbl_outs = [
-        (
-            ["--gen-rewriters"],
-            "CHLODecompositionPatterns.h.inc",
-        ),
-    ],
-    tblgen = "@llvm-project//mlir:mlir-tblgen",
-    td_file = "CHLODecompositionPatterns.td",
-    deps = [
-        "@stablehlo//:chlo_ops_td_files",
-        "@stablehlo//:stablehlo_ops_td_files",
-    ],
-)
-
 iree_compiler_cc_library(
     name = "StableHLOLegalization",
     srcs = [
         "CheckVHLOStableHloMixUsage.cpp",
         "ConvertCollectives.cpp",
-        "LegalizeCHLO.cpp",
         "LegalizeControlFlow.cpp",
         "LegalizeShapeComputations.cpp",
         "LegalizeToLinalgUtils.cpp",
@@ -75,7 +58,6 @@
         "VerifyCompilerInputLegality.cpp",
     ],
     deps = [
-        ":CHLODecompositionPatterns",
         ":PassHeaders",
         "//compiler/plugins/input/StableHLO/Conversion/Preprocessing",
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
diff --git a/compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.td b/compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.td
deleted file mode 100644
index 20afb72..0000000
--- a/compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.td
+++ /dev/null
@@ -1,405 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-// This is the legalization pattern definition file for CHLO to StableHLO.
-// These are included in the populateDecompositionPatterns factory
-// and should only include canonical expansions which are not actually
-// ambiguous/different for various backends. Avoid patterns that are actually
-// lowering to non-canonical forms.
-
-include "mlir/IR/OpBase.td"
-include "stablehlo/dialect/ChloOps.td"
-include "stablehlo/dialect/StablehloOps.td"
-
-class StableHLO_ComparisonDirectionValue<string enumStr> :
-  ConstantAttr<StableHLO_ComparisonDirectionAttr,
-               "::mlir::stablehlo::ComparisonDirection::" # enumStr>;
-
-class ConstantLike<string value> : NativeCodeCall<
-    "::mlir::iree_compiler::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
-
-def ComplexElementType : Type<
-  CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
-  "Complex element type">;
-
-def NonComplexElementType : Type<
-  CPred<"!isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
-  "Non-complex element type">;
-
-def ConstantLikeMaxFiniteValue : NativeCodeCall<
-    "::mlir::iree_compiler::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
-
-def ConstantLikePosInfValue : NativeCodeCall<
-    "::mlir::iree_compiler::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">;
-
-def ConstantLikeNegInfValue : NativeCodeCall<
-    "::mlir::iree_compiler::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/true)">;
-
-def STABLEHLO_DEFAULT_RESULT_ACCURACY :
-  ConstantAttr<StableHLO_ResultAccuracyAttr, "::mlir::stablehlo::ResultAccuracyMode::DEFAULT">;
-
-//===----------------------------------------------------------------------===//
-// Unary op patterns.
-//===----------------------------------------------------------------------===//
-
-// Expand acos for non-complex arguments to MHLO dialect as follows:
-//   acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x))  if x != -1
-//           = pi                                 if x == -1
-//
-// TODO(b/237376133): Support operands with complex element types separately
-// using the following formula.
-//   acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
-def : Pat<(CHLO_AcosOp NonComplexElementType:$input),
-  (StableHLO_SelectOp
-    (StableHLO_CompareOp
-      $input,
-      (ConstantLike<"-1"> $input),
-      StableHLO_ComparisonDirectionValue<"NE">,
-      (STABLEHLO_DEFAULT_COMPARISON_TYPE)
-    ),
-    (StableHLO_MulOp
-      (ConstantLike<"2"> $input),
-      (StableHLO_Atan2Op
-        (StableHLO_SqrtOp
-          (StableHLO_SubtractOp
-            (ConstantLike<"1"> $input),
-            (StableHLO_MulOp $input, $input)
-          ),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        ),
-        (StableHLO_AddOp
-          (ConstantLike<"1"> $input),
-          $input
-        )
-      )
-    ),
-    (ConstantLike<"M_PI"> $input)
-  )>;
-
-// Expand acosh to MHLO dialect as follows:
-//   acosh(x) = log(x + sqrt(x^2 - 1))      if x >= -1
-//            = log(x + sqrt((x+1)*(x-1)))
-//   acosh(x) = nan                         if x < -1
-//
-// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as
-// log(2*x) = log(2) + log(x).  (Note this works because negative x never
-// overflows; x < -1 simply yields nan.
-def : Pat<(CHLO_AcoshOp NonComplexElementType:$input),
-  (StableHLO_SelectOp
-    (StableHLO_CompareOp
-      $input,
-      (ConstantLike<"-1"> $input),
-      StableHLO_ComparisonDirectionValue<"LT">,
-      (STABLEHLO_DEFAULT_COMPARISON_TYPE)
-    ),
-    (ConstantLike<"NAN"> $input),
-    (StableHLO_SelectOp
-      (StableHLO_CompareOp
-        $input,
-        (StableHLO_SqrtOp
-          (ConstantLikeMaxFiniteValue $input),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        ),
-        StableHLO_ComparisonDirectionValue<"GE">,
-        (STABLEHLO_DEFAULT_COMPARISON_TYPE)
-      ),
-      (StableHLO_AddOp
-        (StableHLO_LogOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
-        (StableHLO_LogOp
-          (ConstantLike<"2"> $input),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        )
-      ),
-      (StableHLO_LogOp
-        (StableHLO_AddOp
-          $input,
-          (StableHLO_SqrtOp
-            (StableHLO_MulOp
-              (StableHLO_AddOp
-                (ConstantLike<"1"> $input),
-                $input
-              ),
-              (StableHLO_AddOp
-                (ConstantLike<"-1"> $input),
-                $input
-              )
-            ),
-            STABLEHLO_DEFAULT_RESULT_ACCURACY
-          )
-        ),
-        STABLEHLO_DEFAULT_RESULT_ACCURACY
-      )
-    )
-  )>;
-
-// Expand acosh for complex arguments to MHLO dialect as
-//   acosh(x) = log(x + sqrt((x+1)*(x-1)))
-//
-// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing:
-// "For now, we ignore the question of overflow if x is a
-// complex type, because we don't yet have exhaustive tests for complex trig
-// functions".
-def : Pat<(CHLO_AcoshOp ComplexElementType:$input),
-  (StableHLO_LogOp
-    (StableHLO_AddOp
-      $input,
-      (StableHLO_SqrtOp
-        (StableHLO_MulOp
-          (StableHLO_AddOp
-            $input,
-            (ConstantLike<"1"> $input)
-          ),
-          (StableHLO_SubtractOp
-            $input,
-            (ConstantLike<"1"> $input)
-          )
-        ),
-        STABLEHLO_DEFAULT_RESULT_ACCURACY
-      )
-    ),
-    STABLEHLO_DEFAULT_RESULT_ACCURACY
-  )>;
-
-
-// Expand asin to MHLO dialect as follows:
-//   asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
-def : Pat<(CHLO_AsinOp $input),
-  (StableHLO_MulOp
-    (ConstantLike<"2"> $input),
-    (StableHLO_Atan2Op
-      $input,
-      (StableHLO_AddOp
-        (ConstantLike<"1"> $input),
-        (StableHLO_SqrtOp
-          (StableHLO_SubtractOp
-            (ConstantLike<"1"> $input),
-            (StableHLO_MulOp $input, $input)
-          ),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        )
-      )
-    )
-  )>;
-
-// Expand asinh for non-complex arguments to MHLO dialect as
-//   asinh(x) = log(x + sqrt(x^2 + 1))
-//
-// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
-// as 2*x and return log(2) + log(x).
-//
-// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point
-// arithmetic. However, we would like to retain the low order term of this,
-// which is around 0.5 * x^2 using a binomial expansion.
-// Let z = sqrt(a^2 + 1)
-// The following rewrite retains the lower order term.
-// log(a + sqrt(a^2 + 1))
-//   = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1)))
-//   = log((a + a^2 + 1 + a * z + z) / (1 + z))
-//   = log(1 + a + a^2 / (1 + z))
-//   = log(1 + a + a^2 / (1 + sqrt(a^2 + 1)))
-//
-// If x is negative, the above would give us some trouble; we can't approximate
-// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) =
-// -asinh(x).
-def : Pat<(CHLO_AsinhOp NonComplexElementType:$input),
-  (StableHLO_MulOp
-    (StableHLO_SignOp $input),
-    (StableHLO_SelectOp
-      (StableHLO_CompareOp
-        (StableHLO_AbsOp $input),
-        (StableHLO_SqrtOp
-          (ConstantLikeMaxFiniteValue $input),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        ),
-        StableHLO_ComparisonDirectionValue<"GE">,
-        (STABLEHLO_DEFAULT_COMPARISON_TYPE)
-      ),
-      (StableHLO_AddOp
-        (StableHLO_LogOp
-          (StableHLO_AbsOp $input),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        ),
-        (StableHLO_LogOp
-          (ConstantLike<"2"> $input),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        )
-      ),
-      (StableHLO_SelectOp
-        (StableHLO_CompareOp
-          (StableHLO_AbsOp $input),
-          (ConstantLike<"1"> $input),
-          StableHLO_ComparisonDirectionValue<"LE">,
-          (STABLEHLO_DEFAULT_COMPARISON_TYPE)
-        ),
-        (StableHLO_Log1pOp
-          (StableHLO_AddOp
-            (StableHLO_AbsOp $input),
-            (StableHLO_MulOp
-              (StableHLO_AbsOp $input),
-              (StableHLO_DivOp
-                (StableHLO_AbsOp $input),
-                (StableHLO_AddOp
-                  (ConstantLike<"1"> $input),
-                  (StableHLO_SqrtOp
-                    (StableHLO_AddOp
-                      (StableHLO_MulOp
-                        (StableHLO_AbsOp $input),
-                        (StableHLO_AbsOp $input)
-                      ),
-                      (ConstantLike<"1"> $input)
-                    ),
-                    STABLEHLO_DEFAULT_RESULT_ACCURACY
-                  )
-                )
-              )
-            )
-          ),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        ),
-        (StableHLO_LogOp
-          (StableHLO_AddOp
-            (StableHLO_AbsOp $input),
-            (StableHLO_SqrtOp
-              (StableHLO_AddOp
-                (StableHLO_MulOp
-                  (StableHLO_AbsOp $input),
-                  (StableHLO_AbsOp $input)
-                ),
-                (ConstantLike<"1"> $input)
-              ),
-              STABLEHLO_DEFAULT_RESULT_ACCURACY
-            )
-          ),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        )
-      )
-    )
-  )>;
-
-// Expand asinh for complex arguments to MHLO dialect as
-//   asinh(x) = log(x + sqrt(x^2 + 1))
-//
-// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing:
-// "For now, we ignore the question of overflow if x is a
-// complex type, because we don't yet have exhaustive tests for complex trig
-// functions".
-def : Pat<(CHLO_AsinhOp ComplexElementType:$input),
-  (StableHLO_LogOp
-    (StableHLO_AddOp
-      $input,
-      (StableHLO_SqrtOp
-        (StableHLO_AddOp
-          (StableHLO_MulOp $input, $input),
-          (ConstantLike<"1"> $input)
-        ),
-        STABLEHLO_DEFAULT_RESULT_ACCURACY
-      )
-    ),
-    STABLEHLO_DEFAULT_RESULT_ACCURACY
-  )>;
-
-// Express `atan` as
-//   atan(x) = atan2(x, 1)
-def : Pat<(CHLO_AtanOp $input),
-  (StableHLO_Atan2Op
-    $input,
-    (ConstantLike<"1"> $input)
-  )>;
-
-// Express `atanh` for non-complex arguments as follows:
-//   atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
-//   atanh(x) = nan                          otherwise
-def : Pat<(CHLO_AtanhOp NonComplexElementType:$input),
-  (StableHLO_SelectOp
-    (StableHLO_CompareOp
-      (StableHLO_AbsOp $input),
-      (ConstantLike<"1"> $input),
-      StableHLO_ComparisonDirectionValue<"GT">,
-      (STABLEHLO_DEFAULT_COMPARISON_TYPE)
-    ),
-    (ConstantLike<"NAN"> $input),
-    (StableHLO_MulOp
-      (StableHLO_SubtractOp
-        (StableHLO_Log1pOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
-        (StableHLO_Log1pOp
-          (StableHLO_NegOp $input),
-          STABLEHLO_DEFAULT_RESULT_ACCURACY
-        )
-      ),
-      (ConstantLike<"0.5"> $input)
-    )
-  )>;
-
-// Express `atanh` for complex arguments as follows:
-//   atanh(x) = (log(1 + x) - log(1 + (-x))) * 0.5
-//
-// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing:
-// "For now, we ignore the nan edge case for complex inputs,
-// because we don't yet have exhaustive tests for complex trig functions".
-def : Pat<(CHLO_AtanhOp ComplexElementType:$input),
-  (StableHLO_MulOp
-    (StableHLO_SubtractOp
-      (StableHLO_Log1pOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
-      (StableHLO_Log1pOp
-        (StableHLO_NegOp $input),
-        STABLEHLO_DEFAULT_RESULT_ACCURACY
-      )
-    ),
-    (ConstantLike<"0.5"> $input)
-  )>;
-
-// Express `conj` as
-//   conj(x) = (re(x), -im(x)).
-def : Pat<(CHLO_ConjOp $v),
-          (StableHLO_ComplexOp (StableHLO_RealOp $v), (StableHLO_NegOp (StableHLO_ImagOp $v)))>;
-
-// Express `is_inf` as
-//   is_inf(x) = is_pos_inf(|x|)
-def : Pat<(CHLO_IsInfOp NonComplexElementType:$input),
-  (CHLO_IsPosInfOp
-    (StableHLO_AbsOp $input)
-  )>;
-
-// Express `is_pos_inf` as
-//   is_pos_inf(x) = (x == +inf)
-def : Pat<(CHLO_IsPosInfOp NonComplexElementType:$input),
-  (StableHLO_CompareOp
-    $input,
-    (ConstantLikePosInfValue $input),
-    StableHLO_ComparisonDirectionValue<"EQ">,
-    (STABLEHLO_DEFAULT_COMPARISON_TYPE)
-  )>;
-
-// Express `is_neg_inf` as
-//   is_neg_inf(x) = (x == -inf)
-def : Pat<(CHLO_IsNegInfOp NonComplexElementType:$input),
-  (StableHLO_CompareOp
-    $input,
-    (ConstantLikeNegInfValue $input),
-    StableHLO_ComparisonDirectionValue<"EQ">,
-    (STABLEHLO_DEFAULT_COMPARISON_TYPE)
-  )>;
-
-// Express `tan` as
-//   sine(x) / cosine(x)
-def : Pat<(CHLO_TanOp NonComplexElementType:$input),
-  (StableHLO_DivOp
-    (StableHLO_SineOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
-    (StableHLO_CosineOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY)
-  )>;
-
-
-// Express `tan(a + bi)` as
-//   (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b))
-def : Pat<(CHLO_TanOp ComplexElementType:$input),
-  (StableHLO_DivOp
-    (StableHLO_ComplexOp
-      (CHLO_TanOp:$tan (StableHLO_RealOp $input)),
-      (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input), STABLEHLO_DEFAULT_RESULT_ACCURACY)),
-    (StableHLO_ComplexOp
-      (ConstantLike<"1.0"> $tan),
-      (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))
-  )>;
diff --git a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
index 7ee0ff6..018b437 100644
--- a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
@@ -34,22 +34,12 @@
   PUBLIC
 )
 
-iree_tablegen_library(
-  NAME
-    CHLODecompositionPatterns
-  TD_FILE
-    "CHLODecompositionPatterns.td"
-  OUTS
-    --gen-rewriters CHLODecompositionPatterns.h.inc
-)
-
 iree_cc_library(
   NAME
     StableHLOLegalization
   SRCS
     "CheckVHLOStableHloMixUsage.cpp"
     "ConvertCollectives.cpp"
-    "LegalizeCHLO.cpp"
     "LegalizeControlFlow.cpp"
     "LegalizeShapeComputations.cpp"
     "LegalizeToLinalgUtils.cpp"
@@ -60,7 +50,6 @@
     "StableHLOToLinalgExt.cpp"
     "VerifyCompilerInputLegality.cpp"
   DEPS
-    ::CHLODecompositionPatterns
     ::PassHeaders
     ChloOps
     LLVMSupport
diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
deleted file mode 100644
index 713280c..0000000
--- a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
+++ /dev/null
@@ -1,2225 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-// Implements logic for lowering CHLO ops to StableHLO and Shape dialect ops,
-// taking care of CHLO's broadcasting semantics
-
-#include "compiler/plugins/input/StableHLO/Conversion/Passes.h"
-#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h"
-#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Shape/IR/Shape.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "stablehlo/dialect/BroadcastUtils.h"
-#include "stablehlo/dialect/ChloOps.h"
-#include "stablehlo/dialect/StablehloOps.h"
-
-namespace mlir::iree_compiler::stablehlo {
-
-#define GEN_PASS_DEF_LEGALIZECHLO
-#include "compiler/plugins/input/StableHLO/Conversion/Passes.h.inc"
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Helpers.
-//===----------------------------------------------------------------------===//
-
-template <typename FromOpTy, typename ToOpTy>
-struct HloNaryElementwiseAdaptor {
-  static ToOpTy createOp(FromOpTy fromOp, Type resultType,
-                         ValueRange broadcastedOperands, OpBuilder &builder) {
-    return builder.create<ToOpTy>(fromOp.getLoc(), resultType,
-                                  broadcastedOperands);
-  }
-};
-
-static std::optional<mlir::stablehlo::ComparisonDirection>
-toStableHloComparisonDirection(mlir::chlo::ComparisonDirection value) {
-  switch (value) {
-  case mlir::chlo::ComparisonDirection::EQ:
-    return mlir::stablehlo::ComparisonDirection::EQ;
-  case mlir::chlo::ComparisonDirection::NE:
-    return mlir::stablehlo::ComparisonDirection::NE;
-  case mlir::chlo::ComparisonDirection::GE:
-    return mlir::stablehlo::ComparisonDirection::GE;
-  case mlir::chlo::ComparisonDirection::GT:
-    return mlir::stablehlo::ComparisonDirection::GT;
-  case mlir::chlo::ComparisonDirection::LE:
-    return mlir::stablehlo::ComparisonDirection::LE;
-  case mlir::chlo::ComparisonDirection::LT:
-    return mlir::stablehlo::ComparisonDirection::LT;
-  }
-  return {};
-}
-
-static std::optional<mlir::stablehlo::ComparisonType>
-toStableHloComparisonType(mlir::chlo::ComparisonType value) {
-  switch (value) {
-  case mlir::chlo::ComparisonType::NOTYPE:
-    return mlir::stablehlo::ComparisonType::NOTYPE;
-  case mlir::chlo::ComparisonType::FLOAT:
-    return mlir::stablehlo::ComparisonType::FLOAT;
-  case mlir::chlo::ComparisonType::TOTALORDER:
-    return mlir::stablehlo::ComparisonType::TOTALORDER;
-  case mlir::chlo::ComparisonType::SIGNED:
-    return mlir::stablehlo::ComparisonType::SIGNED;
-  case mlir::chlo::ComparisonType::UNSIGNED:
-    return mlir::stablehlo::ComparisonType::UNSIGNED;
-  }
-  return {};
-}
-
-struct HloCompareAdaptor {
-  static mlir::stablehlo::CompareOp
-  createOp(mlir::chlo::BroadcastCompareOp fromOp, Type resultType,
-           ValueRange broadcastedOperands, OpBuilder &builder) {
-    auto chloDirection = fromOp.getComparisonDirection();
-    auto hloDirection = toStableHloComparisonDirection(chloDirection);
-    if (!hloDirection)
-      return nullptr;
-    auto chloType =
-        fromOp.getCompareType().value_or(mlir::chlo::ComparisonType::NOTYPE);
-    auto hloType = toStableHloComparisonType(chloType);
-    if (!hloType)
-      return nullptr;
-    auto hloTypeAttr = fromOp.getCompareType()
-                           ? mlir::stablehlo::ComparisonTypeAttr::get(
-                                 builder.getContext(), *hloType)
-                           : nullptr;
-    return builder.create<mlir::stablehlo::CompareOp>(
-        fromOp.getLoc(), resultType, broadcastedOperands[0],
-        broadcastedOperands[1], *hloDirection, hloTypeAttr);
-  }
-};
-
-// Populate a pattern for each Broadcasting Chlo op. This requires the pattern
-// to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values.
-template <template <typename, typename, typename> typename Pattern,
-          typename... ConstructorArgs>
-static void populateForBroadcastingBinaryOp(MLIRContext *context,
-                                            RewritePatternSet *patterns,
-                                            ConstructorArgs &&...args) {
-#define POPULATE_BCAST(ChloOp, HloOp)                                          \
-  patterns                                                                     \
-      ->add<Pattern<ChloOp, HloOp, HloNaryElementwiseAdaptor<ChloOp, HloOp>>>( \
-          context, args...);
-
-  POPULATE_BCAST(mlir::chlo::BroadcastAddOp, mlir::stablehlo::AddOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastAndOp, mlir::stablehlo::AndOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastAtan2Op, mlir::stablehlo::Atan2Op);
-  POPULATE_BCAST(mlir::chlo::BroadcastComplexOp, mlir::stablehlo::ComplexOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastDivOp, mlir::stablehlo::DivOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastMaxOp, mlir::stablehlo::MaxOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastMinOp, mlir::stablehlo::MinOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastMulOp, mlir::stablehlo::MulOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastNextAfterOp, mlir::chlo::NextAfterOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastOrOp, mlir::stablehlo::OrOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastPolygammaOp, mlir::chlo::PolygammaOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastPowOp, mlir::stablehlo::PowOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastRemOp, mlir::stablehlo::RemOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastShiftLeftOp,
-                 mlir::stablehlo::ShiftLeftOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastShiftRightArithmeticOp,
-                 mlir::stablehlo::ShiftRightArithmeticOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastShiftRightLogicalOp,
-                 mlir::stablehlo::ShiftRightLogicalOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastSubOp, mlir::stablehlo::SubtractOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastXorOp, mlir::stablehlo::XorOp);
-  POPULATE_BCAST(mlir::chlo::BroadcastZetaOp, mlir::chlo::ZetaOp);
-
-#undef POPULATE_BCAST
-
-  // Broadcasting ops requiring special construction.
-  patterns->add<Pattern<mlir::chlo::BroadcastCompareOp,
-                        mlir::stablehlo::CompareOp, HloCompareAdaptor>>(
-      context, args...);
-}
-
-template <typename T>
-static Value getConstantLike(OpBuilder &b, Location loc, T constant,
-                             Value val) {
-  Type ty = getElementTypeOrSelf(val.getType());
-  auto getAttr = [&]() -> Attribute {
-    if (isa<IntegerType>(ty))
-      return b.getIntegerAttr(ty, constant);
-    if (isa<FloatType>(ty))
-      return b.getFloatAttr(ty, constant);
-    if (auto complexTy = dyn_cast<ComplexType>(ty)) {
-      return complex::NumberAttr::get(complexTy, constant, 0);
-    }
-    llvm_unreachable("unhandled element type");
-  };
-  return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
-                                              val);
-}
-
-static Value getConstantLike(OpBuilder &b, Location loc,
-                             const APFloat &constant, Value val) {
-  Type ty = getElementTypeOrSelf(val.getType());
-  return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
-                                              val);
-}
-
-static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc,
-                                           Value val) {
-  auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
-  return getConstantLike(
-      b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
-}
-
-static Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val,
-                                     bool negative) {
-  auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
-  return getConstantLike(
-      b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
-}
-
-//===----------------------------------------------------------------------===//
-// Broadcasting Patterns.
-//===----------------------------------------------------------------------===//
-
-// Converts binary ops that statically are determined to not broadcast directly
-// to the corresponding stablehlo non-broadcasting op.
-template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
-struct ConvertTrivialNonBroadcastBinaryOp final
-    : OpConversionPattern<ChloOpTy> {
-  using OpConversionPattern<ChloOpTy>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Only rewrite for statically determinable non-broadcasting cases.
-    auto lhsType = dyn_cast<RankedTensorType>(adaptor.getLhs().getType());
-    auto rhsType = dyn_cast<RankedTensorType>(adaptor.getRhs().getType());
-    if (!lhsType || !rhsType)
-      return failure();
-
-    // Requires rank broadcast.
-    if (lhsType.getRank() != rhsType.getRank())
-      return failure();
-
-    // Any dynamic dimension may require broadcasting and requires more
-    // analysis.
-    if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) {
-      return failure();
-    }
-
-    if (!llvm::equal(lhsType.getShape(), rhsType.getShape())) {
-      return failure();
-    }
-
-    rewriter.replaceOp(
-        op, ValueRange{Adaptor::createOp(op, op.getResult().getType(),
-                                         adaptor.getOperands(), rewriter)});
-    return success();
-  }
-};
-
-// Converts a binary op with ranked broadcasting operands to explicitly
-// broadcast and invoke the corresponding stablehlo non-broadcasting op.
-// Note that dynamic broadcasting supported by this pattern is only valid for
-// "numpy" broadcasting semantics as defined here:
-//   https://docs.scipy.org/doc/numpy/reference/ufuncs.html
-// Specifically, this includes the following cases:
-//   - Same rank broadcast (operands have the same static rank).
-//   - Different-rank broadcast, either without a broadcast_dims attribute or
-//     with the broadcast_dims attribute set to map to a prefix padding.
-//   - Legal combinations of degenerate (1-dim) implicit broadcasting.
-// The restriction on broadcast_dims derives from the definition of the
-// `shape.broadcast` op, which only supports prefix-padding.
-template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
-struct ConvertRankedDynamicBroadcastBinaryOp final
-    : OpConversionPattern<ChloOpTy> {
-  using OpConversionPattern<ChloOpTy>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Only support ranked operands.
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
-    auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
-    auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
-    auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
-    if (!lhsType || !rhsType || !resultType)
-      return failure();
-
-    // Check for "numpy"-style rank broadcast.
-    auto broadcastDimensions = op.getBroadcastDimensions();
-    if (broadcastDimensions && !mlir::hlo::isLegalNumpyRankedBroadcast(
-                                   lhs, rhs, *broadcastDimensions)) {
-      // Note: It is unclear whether the general specification of explicit
-      // broadcast_dimensions on binary ops is a feature we want to carry
-      // forward. While it can technically be implemented for ranked-dynamic,
-      // it is incompatible with unranked inputs. If this warning is emitted
-      // in real programs, it is an indication that the feature should be
-      // implemented versus just falling back on the more standard definition
-      // of numpy-like prefix-padding.
-      op.emitWarning() << "unsupported non prefix-padded dynamic rank "
-                       << "broadcast_dimensions = " << *broadcastDimensions;
-      return failure();
-    }
-
-    // Compute result shape.
-    Location loc = op.getLoc();
-
-    // Insert a constraint on the shapes being broadcastable and insert all
-    // future code into an assuming block reliant on the constraint.
-    Value lhsShape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
-    Value rhsShape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
-    auto broadcastableCstr =
-        rewriter.create<shape::CstrBroadcastableOp>(loc, lhsShape, rhsShape);
-    auto assumingOp = rewriter.create<shape::AssumingOp>(
-        loc, ArrayRef<Type>{resultType}, broadcastableCstr.getResult());
-
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.createBlock(&assumingOp.getDoRegion());
-
-    int64_t resultRank = std::max(lhsType.getRank(), rhsType.getRank());
-    Value resultExtents =
-        hlo::computeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
-                                                               rewriter);
-
-    // Note that we unconditionally emit DynamicBroadcastInDim ops and let
-    // downstream canonicalizations fold them away if possible. This is
-    // because, in the dynamic case, there are many corner cases regarding
-    // when it is safe to omit, and some of them require analysis to prove
-    // properly.
-    auto lhsBroadcastDimensions = llvm::to_vector(
-        llvm::seq<int64_t>(resultRank - lhsType.getRank(), resultRank));
-    Value broadcastedLhs =
-        rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
-            loc,
-            RankedTensorType::get(resultType.getShape(),
-                                  lhsType.getElementType()),
-            lhs, resultExtents,
-            rewriter.getDenseI64ArrayAttr(lhsBroadcastDimensions));
-    auto rhsBroadcastDimensions = llvm::to_vector(
-        llvm::seq<int64_t>(resultRank - rhsType.getRank(), resultRank));
-    Value broadcastedRhs =
-        rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
-            loc,
-            RankedTensorType::get(resultType.getShape(),
-                                  rhsType.getElementType()),
-            rhs, resultExtents,
-            rewriter.getDenseI64ArrayAttr(rhsBroadcastDimensions));
-
-    // And generate the final non-broadcasted binary op.
-    Value finalResult = Adaptor::createOp(
-        op, resultType, {broadcastedLhs, broadcastedRhs}, rewriter);
-    rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
-    rewriter.replaceOp(op, {assumingOp.getResult(0)});
-    return success();
-  }
-};
-
-struct ConvertConstantLikeOp final
-    : OpConversionPattern<mlir::chlo::ConstantLikeOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::ConstantLikeOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto resultTy = cast<ShapedType>(op.getType());
-
-    // Unranked uses are not supported.
-    if (!resultTy.hasRank())
-      return failure();
-
-    // Lower to HLO constant if statically shaped.
-    if (resultTy.hasStaticShape()) {
-      auto complexAttr = dyn_cast<mlir::complex::NumberAttr>(op.getValue());
-      auto attr = DenseElementsAttr::get(resultTy, complexAttr ? complexAttr
-                                                               : op.getValue());
-      rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, attr);
-      return success();
-    }
-
-    // Lower to broadcasted constant.
-    Location loc = op.getLoc();
-    Value constant =
-        rewriter.create<mlir::stablehlo::ConstantOp>(loc, op.getValue());
-    Value shape = rewriter.create<shape::ShapeOfOp>(loc, adaptor.getOperand());
-    rewriter.replaceOpWithNewOp<mlir::stablehlo::DynamicBroadcastInDimOp>(
-        op, resultTy, constant, shape, rewriter.getDenseI64ArrayAttr({}));
-    return success();
-  }
-};
-
-struct ConvertSelectOp final
-    : OpConversionPattern<mlir::chlo::BroadcastSelectOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::BroadcastSelectOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Only support ranked operands.
-    Value pred = adaptor.getPred();
-    Value onTrue = adaptor.getOnTrue();
-    Value onFalse = adaptor.getOnFalse();
-    auto predType = dyn_cast<RankedTensorType>(pred.getType());
-    auto onTrueType = dyn_cast<RankedTensorType>(onTrue.getType());
-    auto onFalseType = dyn_cast<RankedTensorType>(onFalse.getType());
-    auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
-    if (!predType || !onTrueType || !onFalseType || !resultType) {
-      return failure();
-    }
-
-    Location loc = op.getLoc();
-    Value predShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
-    Value onTrueShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onTrue);
-    Value onFalseShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onFalse);
-    int64_t resultRank = std::max(
-        {predType.getRank(), onTrueType.getRank(), onFalseType.getRank()});
-
-    Value broadcastableCstr = rewriter.createOrFold<shape::CstrBroadcastableOp>(
-        loc, ValueRange{predShape, onTrueShape, onFalseShape});
-    auto assumingOp = rewriter.create<shape::AssumingOp>(
-        loc, ArrayRef<Type>{resultType}, broadcastableCstr);
-
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.createBlock(&assumingOp.getDoRegion());
-
-    Value resultExtents = rewriter.createOrFold<shape::BroadcastOp>(
-        loc, shape::getExtentTensorType(op.getContext()),
-        ValueRange{predShape, onTrueShape, onFalseShape},
-        /*error=*/nullptr);
-    auto shapeType =
-        RankedTensorType::get({resultRank}, rewriter.getIndexType());
-    resultExtents =
-        rewriter.createOrFold<tensor::CastOp>(loc, shapeType, resultExtents);
-
-    Value broadcastedPred = pred;
-    // Pred has an implicit broadcast for scalars, so use that when convenient.
-    if (predType.getRank() > 0) {
-      auto predBroadcastDimensions = llvm::to_vector(
-          llvm::seq<int64_t>(resultRank - predType.getRank(), resultRank));
-      broadcastedPred =
-          rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
-              loc,
-              RankedTensorType::get(resultType.getShape(),
-                                    predType.getElementType()),
-              pred, resultExtents,
-              rewriter.getDenseI64ArrayAttr(predBroadcastDimensions));
-    }
-    auto onTrueBroadcastDimensions = llvm::to_vector(
-        llvm::seq<int64_t>(resultRank - onTrueType.getRank(), resultRank));
-    Value broadcastedOnTrue =
-        rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
-            loc,
-            RankedTensorType::get(resultType.getShape(),
-                                  onTrueType.getElementType()),
-            onTrue, resultExtents,
-            rewriter.getDenseI64ArrayAttr(onTrueBroadcastDimensions));
-    auto onFalseBroadcastDimensions = llvm::to_vector(
-        llvm::seq<int64_t>(resultRank - onFalseType.getRank(), resultRank));
-    Value broadcastedOnFalse =
-        rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
-            loc,
-            RankedTensorType::get(resultType.getShape(),
-                                  onFalseType.getElementType()),
-            onFalse, resultExtents,
-            rewriter.getDenseI64ArrayAttr(onFalseBroadcastDimensions));
-
-    // And generate the final non-broadcasted ternary op.
-    Value finalResult = rewriter.create<mlir::stablehlo::SelectOp>(
-        loc, resultType, broadcastedPred, broadcastedOnTrue,
-        broadcastedOnFalse);
-    rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
-    rewriter.replaceOp(op, {assumingOp.getResult(0)});
-    return success();
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// Decomposition Patterns.
-//===----------------------------------------------------------------------===//
-
-struct ConvertConstantOp final : OpConversionPattern<mlir::chlo::ConstantOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::ConstantOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, op.getValue());
-    return success();
-  }
-};
-
-template <typename FTy>
-static Value
-materializeChebyshevPolynomialApproximation(ConversionPatternRewriter &rewriter,
-                                            Location loc, Value x,
-                                            ArrayRef<FTy> coefficients) {
-  Value b0 = getConstantLike(rewriter, loc, 0.0, x);
-  Value b1 = getConstantLike(rewriter, loc, 0.0, x);
-  Value b2 = getConstantLike(rewriter, loc, 0.0, x);
-  for (FTy c : coefficients) {
-    b2 = b1;
-    b1 = b0;
-    b0 = rewriter.create<mlir::stablehlo::MulOp>(loc, x.getType(), x, b1);
-    b0 = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x.getType(), b0, b2);
-    b0 = rewriter.create<mlir::stablehlo::AddOp>(
-        loc, x.getType(), b0, getConstantLike(rewriter, loc, c, x));
-  }
-  Value result =
-      rewriter.create<mlir::stablehlo::SubtractOp>(loc, x.getType(), b0, b2);
-  result = rewriter.create<mlir::stablehlo::MulOp>(
-      loc, x.getType(), result, getConstantLike(rewriter, loc, 0.5, x));
-  return result;
-}
-
-template <typename FTy>
-static Value materializeBesselI1eApproximation(
-    ConversionPatternRewriter &rewriter, Location loc, Value x,
-    ArrayRef<FTy> kI1eCoeffsA, ArrayRef<FTy> kI1eCoeffsB) {
-  Value z = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
-  Value half = getConstantLike(rewriter, loc, 0.5, x);
-  Value two = getConstantLike(rewriter, loc, 2.0, x);
-  Value thirtyTwo = getConstantLike(rewriter, loc, 32.0, x);
-  Value eight = getConstantLike(rewriter, loc, 8.0, x);
-
-  Value tmp = rewriter.create<mlir::stablehlo::MulOp>(loc, half, z);
-  tmp = rewriter.create<mlir::stablehlo::SubtractOp>(loc, tmp, two);
-
-  Value xLe8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
-                                                           kI1eCoeffsA);
-  xLe8 = rewriter.create<mlir::stablehlo::MulOp>(loc, z, xLe8);
-
-  tmp = rewriter.create<mlir::stablehlo::DivOp>(loc, thirtyTwo, z);
-  tmp = rewriter.create<mlir::stablehlo::SubtractOp>(loc, tmp, two);
-  Value xGt8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
-                                                           kI1eCoeffsB);
-  xGt8 = rewriter.create<mlir::stablehlo::DivOp>(
-      loc, xGt8, rewriter.create<mlir::stablehlo::SqrtOp>(loc, z));
-
-  Value isLe8 = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, z, eight, mlir::stablehlo::ComparisonDirection::LE);
-
-  Value select =
-      rewriter.create<mlir::stablehlo::SelectOp>(loc, isLe8, xLe8, xGt8);
-  return rewriter.create<mlir::stablehlo::MulOp>(
-      loc, rewriter.create<mlir::stablehlo::SignOp>(loc, x), select);
-}
-
-Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
-                                           Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
-         "expect f32 element type");
-  const float kI1eCoeffsA[] = {
-      9.38153738649577178388E-9f, -4.44505912879632808065E-8f,
-      2.00329475355213526229E-7f, -8.56872026469545474066E-7f,
-      3.47025130813767847674E-6f, -1.32731636560394358279E-5f,
-      4.78156510755005422638E-5f, -1.61760815825896745588E-4f,
-      5.12285956168575772895E-4f, -1.51357245063125314899E-3f,
-      4.15642294431288815669E-3f, -1.05640848946261981558E-2f,
-      2.47264490306265168283E-2f, -5.29459812080949914269E-2f,
-      1.02643658689847095384E-1f, -1.76416518357834055153E-1f,
-      2.52587186443633654823E-1f};
-
-  const float kI1eCoeffsB[] = {
-      -3.83538038596423702205E-9f, -2.63146884688951950684E-8f,
-      -2.51223623787020892529E-7f, -3.88256480887769039346E-6f,
-      -1.10588938762623716291E-4f, -9.76109749136146840777E-3f,
-      7.78576235018280120474E-1f};
-
-  return materializeBesselI1eApproximation<float>(rewriter, loc, x, kI1eCoeffsA,
-                                                  kI1eCoeffsB);
-}
-
-static Value
-materializeBesselI1eApproximationF64(ConversionPatternRewriter &rewriter,
-                                     Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
-         "expect f64 element type");
-
-  const double kI1eCoeffsA[] = {
-      2.77791411276104639959E-18, -2.11142121435816608115E-17,
-      1.55363195773620046921E-16, -1.10559694773538630805E-15,
-      7.60068429473540693410E-15, -5.04218550472791168711E-14,
-      3.22379336594557470981E-13, -1.98397439776494371520E-12,
-      1.17361862988909016308E-11, -6.66348972350202774223E-11,
-      3.62559028155211703701E-10, -1.88724975172282928790E-9,
-      9.38153738649577178388E-9,  -4.44505912879632808065E-8,
-      2.00329475355213526229E-7,  -8.56872026469545474066E-7,
-      3.47025130813767847674E-6,  -1.32731636560394358279E-5,
-      4.78156510755005422638E-5,  -1.61760815825896745588E-4,
-      5.12285956168575772895E-4,  -1.51357245063125314899E-3,
-      4.15642294431288815669E-3,  -1.05640848946261981558E-2,
-      2.47264490306265168283E-2,  -5.29459812080949914269E-2,
-      1.02643658689847095384E-1,  -1.76416518357834055153E-1,
-      2.52587186443633654823E-1};
-
-  const double kI1eCoeffsB[] = {
-      7.51729631084210481353E-18,  4.41434832307170791151E-18,
-      -4.65030536848935832153E-17, -3.20952592199342395980E-17,
-      2.96262899764595013876E-16,  3.30820231092092828324E-16,
-      -1.88035477551078244854E-15, -3.81440307243700780478E-15,
-      1.04202769841288027642E-14,  4.27244001671195135429E-14,
-      -2.10154184277266431302E-14, -4.08355111109219731823E-13,
-      -7.19855177624590851209E-13, 2.03562854414708950722E-12,
-      1.41258074366137813316E-11,  3.25260358301548823856E-11,
-      -1.89749581235054123450E-11, -5.58974346219658380687E-10,
-      -3.83538038596423702205E-9,  -2.63146884688951950684E-8,
-      -2.51223623787020892529E-7,  -3.88256480887769039346E-6,
-      -1.10588938762623716291E-4,  -9.76109749136146840777E-3,
-      7.78576235018280120474E-1};
-
-  return materializeBesselI1eApproximation<double>(rewriter, loc, x,
-                                                   kI1eCoeffsA, kI1eCoeffsB);
-}
-
-static Value materializeWithUpcast(ConversionPatternRewriter &rewriter,
-                                   Location loc, ValueRange args,
-                                   FloatType minPrecisionTy,
-                                   Value callback(ConversionPatternRewriter &,
-                                                  Location, ValueRange)) {
-  Type originalTy = getElementTypeOrSelf(args.front().getType());
-  auto floatOriginalTy = dyn_cast<FloatType>(originalTy);
-  bool needsUpcast =
-      floatOriginalTy && floatOriginalTy.getWidth() < minPrecisionTy.getWidth();
-
-  // Upcast arguments if necessary.
-  llvm::SmallVector<Value, 2> castedArgs;
-  if (needsUpcast) {
-    for (Value a : args) {
-      castedArgs.push_back(
-          rewriter.create<mlir::stablehlo::ConvertOp>(loc, a, minPrecisionTy));
-    }
-    args = castedArgs;
-  }
-
-  Value result = callback(rewriter, loc, args);
-
-  // Cast back if necessary.
-  if (needsUpcast) {
-    result =
-        rewriter.create<mlir::stablehlo::ConvertOp>(loc, result, originalTy);
-  }
-
-  return result;
-}
-
-struct ConvertBesselI1eOp final : OpConversionPattern<mlir::chlo::BesselI1eOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::BesselI1eOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    Value x = adaptor.getOperand();
-    Type ty = cast<ShapedType>(x.getType()).getElementType();
-
-    // For now, we support only f64, f32, f16 and bf16.
-    // See https://www.tensorflow.org/api_docs/python/tf/math/bessel_i1e
-    if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) {
-      return failure();
-    }
-
-    if (ty.isF64()) {
-      rewriter.replaceOp(
-          op, materializeBesselI1eApproximationF64(rewriter, loc, x));
-      return success();
-    }
-
-    rewriter.replaceOp(
-        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
-                                  rewriter.getF32Type(),
-                                  &materializeBesselI1eApproximationF32));
-    return success();
-  }
-};
-
-template <typename FTy>
-static Value
-materializePolynomialApproximation(ConversionPatternRewriter &rewriter,
-                                   Location loc, Value x,
-                                   ArrayRef<FTy> coefficients) {
-  if (coefficients.empty())
-    return getConstantLike(rewriter, loc, 0.0, x);
-
-  Value poly = getConstantLike(rewriter, loc, coefficients[0], x);
-  for (size_t i = 1, e = coefficients.size(); i < e; ++i) {
-    poly = rewriter.create<mlir::stablehlo::MulOp>(loc, x.getType(), poly, x);
-    poly = rewriter.create<mlir::stablehlo::AddOp>(
-        loc, x.getType(), poly,
-        getConstantLike(rewriter, loc, coefficients[i], x));
-  }
-  return poly;
-}
-
-// Precondition is |x| >= 1. Use erf approximation, otherwise.
-//
-// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
-// argument and derive the final approximation for all |x| >= 1.
-// This implementation is based on Cephes.
-static Value materializeErfcApproximationF64ForMagnituteGeOne(
-    ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
-         "expect f64 element type");
-  const double kMaxlog = 7.09782712893383996843E2;
-  const double kErfcPCoefficients[] = {
-      2.46196981473530512524E-10, 5.64189564831068821977E-1,
-      7.46321056442269912687E0,   4.86371970985681366614E1,
-      1.96520832956077098242E2,   5.26445194995477358631E2,
-      9.34528527171957607540E2,   1.02755188689515710272E3,
-      5.57535335369399327526E2};
-  const double kErfcQCoefficients[] = {
-      1.00000000000000000000E0, 1.32281951154744992508E1,
-      8.67072140885989742329E1, 3.54937778887819891062E2,
-      9.75708501743205489753E2, 1.82390916687909736289E3,
-      2.24633760818710981792E3, 1.65666309194161350182E3,
-      5.57535340817727675546E2};
-  const double kErfcRCoefficients[] = {
-      5.64189583547755073984E-1, 1.27536670759978104416E0,
-      5.01905042251180477414E0,  6.16021097993053585195E0,
-      7.40974269950448939160E0,  2.97886665372100240670E0};
-  const double kErfcSCoefficients[] = {
-      1.00000000000000000000E0, 2.26052863220117276590E0,
-      9.39603524938001434673E0, 1.20489539808096656605E1,
-      1.70814450747565897222E1, 9.60896809063285878198E0,
-      3.36907645100081516050E0};
-
-  // Let z = -x^2.
-  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
-  Value z = rewriter.create<mlir::stablehlo::NegOp>(loc, xSq);
-
-  // Materialize polynomial approximation for x in [1, 8) as
-  //   erfc(x) = exp(z) P(|x|) / Q(|x|).
-  Value expZ = rewriter.create<mlir::stablehlo::ExpOp>(loc, z);
-  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
-  Value polP = materializePolynomialApproximation(
-      rewriter, loc, absX, llvm::ArrayRef(kErfcPCoefficients));
-  Value expZMulPolyP = rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, polP);
-  Value polQ = materializePolynomialApproximation(
-      rewriter, loc, absX, llvm::ArrayRef(kErfcQCoefficients));
-  Value erfcApprox18 =
-      rewriter.create<mlir::stablehlo::DivOp>(loc, expZMulPolyP, polQ);
-
-  // Materialize polynomial approximation for x in >= 8 as
-  //   erfc(x) exp(z) R(|x|) / S(|x|).
-  Value polR = materializePolynomialApproximation(
-      rewriter, loc, absX, llvm::ArrayRef(kErfcRCoefficients));
-  Value expZMulPolyR = rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, polR);
-  Value polS = materializePolynomialApproximation(
-      rewriter, loc, absX, llvm::ArrayRef(kErfcSCoefficients));
-  Value erfcApprox8Inf =
-      rewriter.create<mlir::stablehlo::DivOp>(loc, expZMulPolyR, polS);
-
-  // Combine polynomial approximations for x >= 1.
-  Value eight = getConstantLike(rewriter, loc, 8.0, x);
-  Value absXLt8 = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, absX, eight, mlir::stablehlo::ComparisonDirection::LT);
-  Value erfcApprox = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, absXLt8, erfcApprox18, erfcApprox8Inf);
-
-  // Clamp to prevent overflow and materialize approximation for large x as
-  //   erfc(x) = 0.
-  Value zLtNegMaxlog = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, z, getConstantLike(rewriter, loc, -kMaxlog, x),
-      mlir::stablehlo::ComparisonDirection::LT);
-  Value zero = getConstantLike(rewriter, loc, 0.0, x);
-  Value erfcApproxClamped = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, zLtNegMaxlog, zero, erfcApprox);
-
-  // Derive approximation for x <= -1 as
-  //   erfc(x) = 2 - erfc(-x).
-  // Reuse previously materialized approximations all of which take |x| as their
-  // argument.
-  Value xLtZero = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, x, zero, mlir::stablehlo::ComparisonDirection::LT);
-  Value two = getConstantLike(rewriter, loc, 2.0, x);
-  Value twoSubErfcApproxClamped =
-      rewriter.create<mlir::stablehlo::SubtractOp>(loc, two, erfcApproxClamped);
-  return rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, xLtZero, twoSubErfcApproxClamped, erfcApproxClamped);
-}
-
-// Precondition is |x| <= 1. Use erfc approximation, otherwise.
-// This implementation is based on Cephes.
-static Value materializeErfApproximationF64ForMagnituteLeOne(
-    ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
-         "expect f64 element type");
-  const double kErfTCoefficients[] = {
-      9.60497373987051638749E0, 9.00260197203842689217E1,
-      2.23200534594684319226E3, 7.00332514112805075473E3,
-      5.55923013010394962768E4};
-  const double kErfUCoefficients[] = {
-      1.00000000000000000000E0, 3.35617141647503099647E1,
-      5.21357949780152679795E2, 4.59432382970980127987E3,
-      2.26290000613890934246E4, 4.92673942608635921086E4};
-
-  // Materialize polynomial approximation for |x| <= 1 as
-  //   erf(x) = x T(x^2) / U(x^2).
-  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
-  Value polyT = materializePolynomialApproximation(
-      rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients));
-  Value xMulPolyT = rewriter.create<mlir::stablehlo::MulOp>(loc, x, polyT);
-  Value polyU = materializePolynomialApproximation(
-      rewriter, loc, xSq, llvm::ArrayRef(kErfUCoefficients));
-  return rewriter.create<mlir::stablehlo::DivOp>(loc, xMulPolyT, polyU);
-}
-
-// This implementation is based on Cephes.
-static Value materializeErfApproximationF64(ConversionPatternRewriter &rewriter,
-                                            Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
-         "expect f64 element type");
-
-  // Rely on erf approximation for |x| < 1
-  //   erf(x) = erf_approx(x)
-  Value erfApprox =
-      materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
-
-  // Rely on erfc approximation for |x| >= 1 and materialize erf as
-  //   erf(x) = 1 - erfc_approx(x)
-  Value one = getConstantLike(rewriter, loc, 1.0, x);
-  Value erfcApprox =
-      materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
-  Value erfcBasedApprox =
-      rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfcApprox);
-
-  // Materialize approximation selection based on argument.
-  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
-  Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
-  return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne, erfApprox,
-                                                    erfcBasedApprox);
-}
-
-static Value
-materializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
-                                Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
-         "expect f64 element type");
-
-  // Rely on erfc approximation for |x| >= 1
-  //   erfc(x) = erfc_approx(x)
-  Value erfcApprox =
-      materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
-
-  // Rely on erf approximation for |x| < 1 and materialize erfc as
-  //   erfc(x) = 1 - erf_approx(x)
-  Value one = getConstantLike(rewriter, loc, 1.0, x);
-  Value erfApprox =
-      materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
-  Value erfBasedApprox =
-      rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfApprox);
-
-  // Materialize approximation selection based on argument.
-  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
-  Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
-  return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne,
-                                                    erfBasedApprox, erfcApprox);
-}
-
-// Precondition is |x| >= 1. Use erf approximation, otherwise.
-//
-// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
-// argument and derive the final approximation for all |x| >= 1.
-// This implementation is based on Cephes.
-static Value materializeErfcApproximationF32ForMagnitudeGeOne(
-    ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
-         "expect f32 element type");
-  const double kMaxlog = 88.72283905206835;
-  const float kErfcPCoefficients[] = {
-      +2.326819970068386E-2f, -1.387039388740657E-1f, +3.687424674597105E-1f,
-      -5.824733027278666E-1f, +6.210004621745983E-1f, -4.944515323274145E-1f,
-      +3.404879937665872E-1f, -2.741127028184656E-1f, +5.638259427386472E-1f,
-  };
-  const float kErfcRCoefficients[] = {
-      -1.047766399936249E+1f, +1.297719955372516E+1f, -7.495518717768503E+0f,
-      +2.921019019210786E+0f, -1.015265279202700E+0f, +4.218463358204948E-1f,
-      -2.820767439740514E-1f, +5.641895067754075E-1f,
-  };
-
-  // Let z = -x^2.
-  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
-  Value z = rewriter.create<mlir::stablehlo::NegOp>(loc, xSq);
-
-  // Materialize polynomial approximation for x >= 1 as
-  //   erfc(x) = exp(z) 1/x P(1/x^2)   if x in [1, 2)
-  //   erfc(x) = exp(z) 1/x R(1/x^2)   if x >= 2
-  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
-  Value one = getConstantLike(rewriter, loc, 1.0, x);
-  Value reciprocalXSq = rewriter.create<mlir::stablehlo::DivOp>(loc, one, xSq);
-  Value expZ = rewriter.create<mlir::stablehlo::ExpOp>(loc, z);
-  Value oneDivAbsX = rewriter.create<mlir::stablehlo::DivOp>(loc, one, absX);
-  Value expZMulOneDivAbsX =
-      rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, oneDivAbsX);
-  Value two = getConstantLike(rewriter, loc, 2.0, x);
-  Value absXLtTwo = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, absX, two, mlir::stablehlo::ComparisonDirection::LT);
-  Value polP = materializePolynomialApproximation(
-      rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcPCoefficients));
-  Value polR = materializePolynomialApproximation(
-      rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcRCoefficients));
-  Value poly =
-      rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtTwo, polP, polR);
-  Value erfcApprox =
-      rewriter.create<mlir::stablehlo::MulOp>(loc, expZMulOneDivAbsX, poly);
-
-  // Clamp to prevent overflow and materialize approximation for large x as
-  //   erfc(x) = 0.
-  Value zLtNeqMaxlog = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, z, getConstantLike(rewriter, loc, -kMaxlog, x),
-      mlir::stablehlo::ComparisonDirection::LT);
-  Value zero = getConstantLike(rewriter, loc, 0.0, x);
-  Value erfcApproxClamped = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, zLtNeqMaxlog, zero, erfcApprox);
-
-  // Derive approximation for x <= -1 as
-  //   erfc(x) = 2 - erfc(-x).
-  // Reuse previously materialized approximations all of which take |x| as their
-  // argument.
-  Value xLtZero = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, x, zero, mlir::stablehlo::ComparisonDirection::LT);
-  Value twoSubErfcApprox =
-      rewriter.create<mlir::stablehlo::SubtractOp>(loc, two, erfcApproxClamped);
-  return rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, xLtZero, twoSubErfcApprox, erfcApproxClamped);
-}
-
-// Precondition is |x| <= 1. Use erfc approximation, otherwise.
-// This implementation is based on Cephes.
-static Value materializeErfApproximationF32ForMagnitudeLeOne(
-    ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
-         "expect f32 element type");
-  const float kErfTCoefficients[] = {
-      +7.853861353153693E-5f, -8.010193625184903E-4f, +5.188327685732524E-3f,
-      -2.685381193529856E-2f, +1.128358514861418E-1f, -3.761262582423300E-1f,
-      +1.128379165726710E+0f,
-  };
-
-  // Materialize polynomial approximation for |x| <= 1 as
-  //   erf(x) = x T(x^2).
-  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
-  Value polyT = materializePolynomialApproximation(
-      rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients));
-  return rewriter.create<mlir::stablehlo::MulOp>(loc, x, polyT);
-}
-
-// This is the same approximation as used in Eigen.
-static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
-                                            Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
-         "expect f32 element type");
-  const float kAlpha[] = {
-      -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f,
-      -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
-      -1.60960333262415e-02f,
-  };
-  const float kBeta[] = {
-      -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
-      -7.37332916720468e-03f, -1.42647390514189e-02f,
-  };
-
-  // Clamp argument between -4 and 4.
-  Value lb = getConstantLike(rewriter, loc, -4.0, x);
-  Value ub = getConstantLike(rewriter, loc, 4.0, x);
-  x = rewriter.create<mlir::stablehlo::ClampOp>(loc, x.getType(), lb, x, ub);
-  Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
-
-  // Materialize polynomial approximation for x in [-4, 4] as
-  //   erf(x) = x * Alpha(x^2) / Beta(x^2).
-  Value alphaPoly = materializePolynomialApproximation(rewriter, loc, xSq,
-                                                       llvm::ArrayRef(kAlpha));
-  Value betaPoly = materializePolynomialApproximation(rewriter, loc, xSq,
-                                                      llvm::ArrayRef(kBeta));
-  Value xMulAlphaPoly =
-      rewriter.create<mlir::stablehlo::MulOp>(loc, x, alphaPoly);
-  Value erf =
-      rewriter.create<mlir::stablehlo::DivOp>(loc, xMulAlphaPoly, betaPoly);
-  Value lbErf = getConstantLike(rewriter, loc, -1.0, x);
-  Value ubErf = getConstantLike(rewriter, loc, 1.0, x);
-  return rewriter.create<mlir::stablehlo::ClampOp>(loc, erf.getType(), lbErf,
-                                                   erf, ubErf);
-}
-
-static Value
-materializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
-                                Location loc, ValueRange args) {
-  Value x = args.front();
-  assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
-         "expect f32 element type");
-
-  // Rely on erfc approximation for |x| >= 1
-  //   erfc(x) = erfc_approx(x)
-  Value erfcApprox =
-      materializeErfcApproximationF32ForMagnitudeGeOne(rewriter, loc, x);
-
-  // Rely on erf approximation for |x| < 1 and materialize erfc as
-  //   erfc(x) = 1 - erf_approx(x)
-  Value one = getConstantLike(rewriter, loc, 1.0, x);
-  Value erfApprox =
-      materializeErfApproximationF32ForMagnitudeLeOne(rewriter, loc, x);
-  Value erfBasedApprox =
-      rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfApprox);
-
-  // Materialize approximation selection based on argument.
-  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
-  Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
-  return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne,
-                                                    erfBasedApprox, erfcApprox);
-}
-
-struct ConvertErfOp final : OpConversionPattern<mlir::chlo::ErfOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::ErfOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    Value x = adaptor.getOperand();
-    Type ty = cast<ShapedType>(x.getType()).getElementType();
-
-    // For now, we support only f64, f32, f16 and bf16.
-    if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) {
-      return failure();
-    }
-
-    if (ty.isF64()) {
-      rewriter.replaceOp(op, materializeErfApproximationF64(rewriter, loc, x));
-      return success();
-    }
-
-    rewriter.replaceOp(
-        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
-                                  rewriter.getF32Type(),
-                                  &materializeErfApproximationF32));
-    return success();
-  }
-};
-
-struct ConvertErfcOp final : OpConversionPattern<mlir::chlo::ErfcOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::ErfcOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    Value x = adaptor.getOperand();
-    Type ty = cast<ShapedType>(x.getType()).getElementType();
-
-    // For now, we support only f64, f32, f16 and bf16.
-    if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) {
-      return failure();
-    }
-
-    if (ty.isF64()) {
-      rewriter.replaceOp(op, materializeErfcApproximationF64(rewriter, loc, x));
-      return success();
-    }
-
-    rewriter.replaceOp(
-        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
-                                  rewriter.getF32Type(),
-                                  &materializeErfcApproximationF32));
-    return success();
-  }
-};
-
-static Value erfInv32(ConversionPatternRewriter &b, Location loc,
-                      ValueRange args) {
-  constexpr int kDegree = 9;
-  constexpr std::array<float, 9> wLessThan5Constants = {
-      2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
-      -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
-      -0.00417768164f,  0.246640727f,    1.50140941f};
-  constexpr std::array<float, 9> wGreaterThan5Constants = {
-      -0.000200214257f, 0.000100950558f, 0.00134934322f,
-      -0.00367342844f,  0.00573950773f,  -0.0076224613f,
-      0.00943887047f,   1.00167406f,     2.83297682f};
-
-  Value x = args[0];
-  // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
-  // log(1+arg) when arg is close to zero. For more details, see
-  // https://en.cppreference.com/w/cpp/numeric/math/log1p
-  Value minusXSquared = b.create<mlir::stablehlo::MulOp>(
-      loc, x, b.create<mlir::stablehlo::NegOp>(loc, x));
-  Value w = b.create<mlir::stablehlo::NegOp>(
-      loc, b.create<mlir::stablehlo::Log1pOp>(loc, minusXSquared));
-
-  Value lt = b.create<mlir::stablehlo::CompareOp>(
-      loc, w, getConstantLike(b, loc, 5.0, x),
-      mlir::stablehlo::ComparisonDirection::LT);
-  auto coefficient = [&](int i) {
-    return b.create<mlir::stablehlo::SelectOp>(
-        loc, lt, getConstantLike(b, loc, wLessThan5Constants[i], x),
-        getConstantLike(b, loc, wGreaterThan5Constants[i], x));
-  };
-  w = b.create<mlir::stablehlo::SelectOp>(
-      loc, lt,
-      b.create<mlir::stablehlo::SubtractOp>(loc, w,
-                                            getConstantLike(b, loc, 2.5, x)),
-      b.create<mlir::stablehlo::SubtractOp>(
-          loc, b.create<mlir::stablehlo::SqrtOp>(loc, w),
-          getConstantLike(b, loc, 3.0, x)));
-  Value p = coefficient(0);
-  for (int i = 1; i < kDegree; ++i) {
-    p = b.create<mlir::stablehlo::AddOp>(
-        loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w));
-  }
-
-  // Result modulo edge cases.
-  Value result = b.create<mlir::stablehlo::MulOp>(loc, p, x);
-
-  // Handle edge cases, namely erfinv(+/-1) = +/-inf.  (The above computation is
-  // indeterminate, and can give nan or -/+inf.)
-  return b.create<mlir::stablehlo::SelectOp>(
-      loc,
-      b.create<mlir::stablehlo::CompareOp>(
-          loc, b.create<mlir::stablehlo::AbsOp>(loc, x),
-          getConstantLike(b, loc, 1, x),
-          mlir::stablehlo::ComparisonDirection::EQ),
-      b.create<mlir::stablehlo::MulOp>(
-          loc, x, getConstantLikeInfValue(b, loc, x, false)),
-      result);
-}
-
-static Value erfInv64(ConversionPatternRewriter &b, Location loc,
-                      ValueRange args) {
-  constexpr std::array<double, 23> wLessThan625Constants = {
-      -3.6444120640178196996e-21, -1.685059138182016589e-19,
-      1.2858480715256400167e-18,  1.115787767802518096e-17,
-      -1.333171662854620906e-16,  2.0972767875968561637e-17,
-      6.6376381343583238325e-15,  -4.0545662729752068639e-14,
-      -8.1519341976054721522e-14, 2.6335093153082322977e-12,
-      -1.2975133253453532498e-11, -5.4154120542946279317e-11,
-      1.051212273321532285e-09,   -4.1126339803469836976e-09,
-      -2.9070369957882005086e-08, 4.2347877827932403518e-07,
-      -1.3654692000834678645e-06, -1.3882523362786468719e-05,
-      0.0001867342080340571352,   -0.00074070253416626697512,
-      -0.0060336708714301490533,  0.24015818242558961693,
-      1.6536545626831027356};
-  constexpr std::array<double, 19> wLessThan16Constants = {
-      2.2137376921775787049e-09,  9.0756561938885390979e-08,
-      -2.7517406297064545428e-07, 1.8239629214389227755e-08,
-      1.5027403968909827627e-06,  -4.013867526981545969e-06,
-      2.9234449089955446044e-06,  1.2475304481671778723e-05,
-      -4.7318229009055733981e-05, 6.8284851459573175448e-05,
-      2.4031110387097893999e-05,  -0.0003550375203628474796,
-      0.00095328937973738049703,  -0.0016882755560235047313,
-      0.0024914420961078508066,   -0.0037512085075692412107,
-      0.005370914553590063617,    1.0052589676941592334,
-      3.0838856104922207635,
-  };
-  constexpr std::array<double, 17> wGreaterThan16Constants = {
-      -2.7109920616438573243e-11, -2.5556418169965252055e-10,
-      1.5076572693500548083e-09,  -3.7894654401267369937e-09,
-      7.6157012080783393804e-09,  -1.4960026627149240478e-08,
-      2.9147953450901080826e-08,  -6.7711997758452339498e-08,
-      2.2900482228026654717e-07,  -9.9298272942317002539e-07,
-      4.5260625972231537039e-06,  -1.9681778105531670567e-05,
-      7.5995277030017761139e-05,  -0.00021503011930044477347,
-      -0.00013871931833623122026, 1.0103004648645343977,
-      4.8499064014085844221,
-  };
-
-  Value x = args[0];
-  // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
-  // log(1+arg) when arg is close to zero. For more details, see
-  // https://en.cppreference.com/w/cpp/numeric/math/log1p
-  Value minusXSquared = b.create<mlir::stablehlo::MulOp>(
-      loc, x, b.create<mlir::stablehlo::NegOp>(loc, x));
-  Value w = b.create<mlir::stablehlo::NegOp>(
-      loc, b.create<mlir::stablehlo::Log1pOp>(loc, minusXSquared));
-
-  Value lt625 = b.create<mlir::stablehlo::CompareOp>(
-      loc, w, getConstantLike(b, loc, 6.25, x),
-      mlir::stablehlo::ComparisonDirection::LT);
-  Value lt16 = b.create<mlir::stablehlo::CompareOp>(
-      loc, w, getConstantLike(b, loc, 16, x),
-      mlir::stablehlo::ComparisonDirection::LT);
-
-  auto coefficient = [&](int i) {
-    Value c = getConstantLike(b, loc, wLessThan625Constants[i], x);
-    if (i < 19) {
-      c = b.create<mlir::stablehlo::SelectOp>(
-          loc, lt625, c, getConstantLike(b, loc, wLessThan16Constants[i], x));
-    }
-    if (i < 17) {
-      c = b.create<mlir::stablehlo::SelectOp>(
-          loc, lt16, c, getConstantLike(b, loc, wGreaterThan16Constants[i], x));
-    }
-    return c;
-  };
-
-  Value sqrtW = b.create<mlir::stablehlo::SqrtOp>(loc, w);
-  Value wMinus3125 = b.create<mlir::stablehlo::SubtractOp>(
-      loc, w, getConstantLike(b, loc, 3.125, x));
-  Value select2 = b.create<mlir::stablehlo::SelectOp>(
-      loc, lt16, getConstantLike(b, loc, 3.25, w),
-      getConstantLike(b, loc, 5.0, w));
-  Value select2Result =
-      b.create<mlir::stablehlo::SubtractOp>(loc, sqrtW, select2);
-  w = b.create<mlir::stablehlo::SelectOp>(loc, lt625, wMinus3125,
-                                          select2Result);
-
-  Value p = coefficient(0);
-  for (int i = 1; i < 17; ++i) {
-    p = b.create<mlir::stablehlo::AddOp>(
-        loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w));
-  }
-  for (int i = 17; i < 19; ++i) {
-    p = b.create<mlir::stablehlo::SelectOp>(
-        loc, lt16,
-        b.create<mlir::stablehlo::AddOp>(
-            loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)),
-        p);
-  }
-  for (int i = 19; i < 23; ++i) {
-    p = b.create<mlir::stablehlo::SelectOp>(
-        loc, lt625,
-        b.create<mlir::stablehlo::AddOp>(
-            loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)),
-        p);
-  }
-
-  // Result modulo edge cases.
-  Value result = b.create<mlir::stablehlo::MulOp>(loc, p, x);
-
-  // Handle edge cases, namely erfinv(+/-1) = +/-inf.  (The above computation is
-  // indeterminate, and can give nan or -/+inf.)
-  return b.create<mlir::stablehlo::SelectOp>(
-      loc,
-      b.create<mlir::stablehlo::CompareOp>(
-          loc, b.create<mlir::stablehlo::AbsOp>(loc, x),
-          getConstantLike(b, loc, 1, x),
-          mlir::stablehlo::ComparisonDirection::EQ),
-      b.create<mlir::stablehlo::MulOp>(
-          loc, x, getConstantLikeInfValue(b, loc, x, false)),
-      result);
-}
-
-struct ConvertErfInvOp final : OpConversionPattern<mlir::chlo::ErfInvOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::ErfInvOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    if (op.getResult().getType().getElementType().isF64()) {
-      rewriter.replaceOp(op, erfInv64(rewriter, loc, adaptor.getOperands()));
-      return success();
-    }
-    FloatType minPrecisionTy = rewriter.getF32Type();
-    rewriter.replaceOp(op, materializeWithUpcast(rewriter, loc,
-                                                 adaptor.getOperands(),
-                                                 minPrecisionTy, &erfInv32));
-    return success();
-  }
-};
-
-// Coefficients for the Lanczos approximation of the gamma function. The
-// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
-// and kLanczosCoefficients.size() + 1). The coefficients below correspond to
-// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
-// [7, 9] seemed to be the least sensitive to the quality of the log function.
-// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
-// for a particularly inaccurate log function.
-constexpr double kLanczosGamma = 7; // aka g
-constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
-constexpr std::array<double, 8> kLanczosCoefficients = {
-    676.520368121885098567009190444019, -1259.13921672240287047156078755283,
-    771.3234287776530788486528258894,   -176.61502916214059906584551354,
-    12.507343278686904814458936853,     -0.13857109526572011689554707,
-    9.984369578019570859563e-6,         1.50563273514931155834e-7};
-
-// Compute the Lgamma function using Lanczos' approximation from "A Precision
-// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
-// series B. Vol. 1:
-//   lgamma(z + 1) = (log(2) + log(pi)) / 2
-//                     + (z + 1/2) * log(t(z))
-//                     - t(z) + log(a(z))
-//   with   t(z) = z + kLanczosGamma + 1/2
-//          a(z) = kBaseLanczosCoeff
-//                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
-static Value materializeLgamma(ConversionPatternRewriter &rewriter,
-                               Location loc, ValueRange args) {
-  // If the input is less than 0.5 use Euler's reflection formula.
-  //   gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
-  // Let z be
-  //   z = -x      if x < 1/2
-  //   z = x - 1   otheriwse
-  Value x = args.front();
-  Value half = getConstantLike(rewriter, loc, 0.5, x);
-  Value needToReflect = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, x, half, mlir::stablehlo::ComparisonDirection::LT);
-  Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
-  Value one = getConstantLike(rewriter, loc, 1, x);
-  Value xSubOne = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, one);
-  Value z = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, negX,
-                                                       xSubOne);
-
-  // Materialize
-  //   a(z) = kBaseLanczosCoeff
-  //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
-  Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
-  for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
-    Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
-    Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
-    Value quotient = rewriter.create<mlir::stablehlo::DivOp>(
-        loc, coeff,
-        rewriter.create<mlir::stablehlo::AddOp>(loc, z, oneBasedIndex));
-    a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, quotient);
-  }
-
-  // To improve accuracy on platforms with less-precise log implementations,
-  // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
-  // device.
-  // Materialize as
-  //   log(t) = log(kLanczosGamma + 1/2 + z)
-  //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
-  Value lanczosPlusHalf =
-      getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
-  Value t = rewriter.create<mlir::stablehlo::AddOp>(loc, lanczosPlusHalf, z);
-  Value logTerm =
-      getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
-  Value log1pTerm = rewriter.create<mlir::stablehlo::Log1pOp>(
-      loc, rewriter.create<mlir::stablehlo::DivOp>(loc, z, lanczosPlusHalf));
-  Value logT = rewriter.create<mlir::stablehlo::AddOp>(loc, logTerm, log1pTerm);
-
-  // Note that t(z) may be large and we need to be careful not to overflow to
-  // infinity in the relevant term
-  //   r = (z + 1/2) * log(t(z)) - t(z).
-  // Therefore, we compute this as
-  //   r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
-  Value tDivLogT = rewriter.create<mlir::stablehlo::DivOp>(loc, t, logT);
-  Value sum = rewriter.create<mlir::stablehlo::SubtractOp>(
-      loc, rewriter.create<mlir::stablehlo::AddOp>(loc, z, half), tDivLogT);
-  Value r = rewriter.create<mlir::stablehlo::MulOp>(loc, sum, logT);
-
-  // Compute the final result (modulo reflection) as
-  //   lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
-  Value logA = rewriter.create<mlir::stablehlo::LogOp>(loc, a);
-  Value lgamma = rewriter.create<mlir::stablehlo::AddOp>(
-      loc,
-      rewriter.create<mlir::stablehlo::AddOp>(
-          loc,
-          getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
-          r),
-      logA);
-
-  // Compute the reflected value for x < 0.5 as
-  //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
-  //
-  // The abs is needed because lgamma is the log of the absolute value of the
-  // gamma function.
-  //
-  // We have to be careful when computing the final term above. gamma(x) goes
-  // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
-  // term. The slope is large, so precision is particularly important.
-  //
-  // Because abs(sin(pi * x)) has period of 1 we can equivalently use
-  // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
-  // more numerically accurate: It doesn't overflow to inf like pi * x would and
-  // if x is an integer it evaluates to exactly 0 which is important because we
-  // then take the log of this value, and log(0) is inf.
-  //
-  // We don't have a frac(x) primitive in HLO and computing it is tricky, but
-  // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
-  // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
-  //
-  // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
-  // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
-  // [0, 1] is symmetric across the line Y=0.5.
-  //
-
-  // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
-  // pi * abs_frac for values of abs_frac close to 1.
-  Value abs = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
-  Value absFrac = rewriter.create<mlir::stablehlo::SubtractOp>(
-      loc, abs, rewriter.create<mlir::stablehlo::FloorOp>(loc, abs));
-  Value reduceAbsFrac = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, half, absFrac, mlir::stablehlo::ComparisonDirection::LT);
-  absFrac = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, reduceAbsFrac,
-      rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, absFrac), absFrac);
-
-  // Materialize reflection.
-  Value reflectionDenom = rewriter.create<mlir::stablehlo::LogOp>(
-      loc,
-      rewriter.create<mlir::stablehlo::SineOp>(
-          loc, rewriter.create<mlir::stablehlo::MulOp>(
-                   loc, getConstantLike(rewriter, loc, M_PI, x), absFrac)));
-  Value lgammaReflection = rewriter.create<mlir::stablehlo::SubtractOp>(
-      loc,
-      rewriter.create<mlir::stablehlo::SubtractOp>(
-          loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
-          reflectionDenom),
-      lgamma);
-
-  // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
-  // then it "wins" and the result is +/-inf.
-  Value finiteReflectionDenom =
-      rewriter.create<mlir::stablehlo::IsFiniteOp>(loc, reflectionDenom);
-  Value negReflectionDenom =
-      rewriter.create<mlir::stablehlo::NegOp>(loc, reflectionDenom);
-  lgammaReflection = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom);
-
-  // Select whether or not to rely on the reflection.
-  lgamma = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect,
-                                                      lgammaReflection, lgamma);
-
-  // Materialize +/-inf behavior as
-  //   lgamma(+/-inf) = +inf.
-  Value xIsInf = rewriter.create<chlo::IsInfOp>(loc, x);
-  return rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, xIsInf,
-      getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), lgamma);
-}
-
-// Express `cosh` as
-//   cosh(x) = (e^x + e^-x) / 2
-//           = e^(x + log(1/2)) + e^(-x + log(1/2))
-//
-// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
-//
-// This incorrectly overflows to inf for two f32 input values, namely
-// +/-89.4159851, due to rounding error when computing x +/- log(1/2).  The
-// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
-// we deem this acceptable.
-static Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
-                                          Location loc, ValueRange operands) {
-  mlir::chlo::CoshOp::Adaptor transformed(operands);
-  Value x = transformed.getOperand();
-
-  Value logOneHalf = rewriter.create<mlir::stablehlo::LogOp>(
-      loc, getConstantLike(rewriter, loc, 0.5, x));
-  Value expAdd = rewriter.create<mlir::stablehlo::ExpOp>(
-      loc, rewriter.create<mlir::stablehlo::AddOp>(loc, x, logOneHalf));
-  Value expSub = rewriter.create<mlir::stablehlo::ExpOp>(
-      loc, rewriter.create<mlir::stablehlo::SubtractOp>(loc, logOneHalf, x));
-  return rewriter.create<mlir::stablehlo::AddOp>(loc, expAdd, expSub);
-}
-
-struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::CoshOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOp(
-        op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
-                                  rewriter.getF32Type(),
-                                  &materializeCoshApproximation));
-    return success();
-  }
-};
-
-// Compute the Digamma function using Lanczos' approximation from "A Precision
-// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
-// series B. Vol. 1:
-//   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
-//   with   t(z) = z + kLanczosGamma + 1/2
-//          a(z) = kBaseLanczosCoeff
-//                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
-//          a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
-static Value materializeDigamma(ConversionPatternRewriter &rewriter,
-                                Location loc, ValueRange args) {
-  // If the input is less than 0.5 use Euler's reflection formula.
-  //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
-  // Let z be
-  //   z = -x      if x < 1/2
-  //   z = x - 1   otheriwse
-  Value x = args.front();
-  Value half = getConstantLike(rewriter, loc, 0.5, x);
-  Value needToReflect = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, x, half, mlir::stablehlo::ComparisonDirection::LT);
-  Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
-  Value one = getConstantLike(rewriter, loc, 1, x);
-  Value xSubOne = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, one);
-  Value z = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, negX,
-                                                       xSubOne);
-
-  // Materialize
-  //   a(z) = kBaseLanczosCoeff
-  //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
-  //   a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
-  Value zero = getConstantLike(rewriter, loc, 0.0, x);
-  Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
-  Value aPrime = zero;
-  for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
-    Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
-    Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
-    Value zTerm =
-        rewriter.create<mlir::stablehlo::AddOp>(loc, z, oneBasedIndex);
-    aPrime = rewriter.create<mlir::stablehlo::SubtractOp>(
-        loc, aPrime,
-        rewriter.create<mlir::stablehlo::DivOp>(
-            loc, coeff,
-            rewriter.create<mlir::stablehlo::MulOp>(loc, zTerm, zTerm)));
-    a = rewriter.create<mlir::stablehlo::AddOp>(
-        loc, a, rewriter.create<mlir::stablehlo::DivOp>(loc, coeff, zTerm));
-  }
-
-  // To improve accuracy on platforms with less-precise log implementations,
-  // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
-  // device.
-  // Materialize as
-  //   log(t) = log(kLanczosGamma + 1/2 + z)
-  //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
-  Value lanczosPlusHalf =
-      getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
-  Value t = rewriter.create<mlir::stablehlo::AddOp>(loc, lanczosPlusHalf, z);
-  Value logTerm =
-      getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
-  Value log1pTerm = rewriter.create<mlir::stablehlo::Log1pOp>(
-      loc, rewriter.create<mlir::stablehlo::DivOp>(loc, z, lanczosPlusHalf));
-  Value logT = rewriter.create<mlir::stablehlo::AddOp>(loc, logTerm, log1pTerm);
-
-  // Materialize the final result (modulo reflection) as
-  //   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
-  Value aPrimeDivA = rewriter.create<mlir::stablehlo::DivOp>(loc, aPrime, a);
-  Value lanczosGammaDivT = rewriter.create<mlir::stablehlo::DivOp>(
-      loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
-  Value digamma = rewriter.create<mlir::stablehlo::SubtractOp>(
-      loc, rewriter.create<mlir::stablehlo::AddOp>(loc, logT, aPrimeDivA),
-      lanczosGammaDivT);
-
-  // We need to be careful how we compute cot(pi * input) below: For
-  // near-integral arguments, pi * input can lose precision.
-  //
-  // Input is already known to be less than 0.5 (otherwise we don't have to
-  // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
-  // increase precision of pi * x and the resulting cotangent.
-  Value reducedX = rewriter.create<mlir::stablehlo::AddOp>(
-      loc, x,
-      rewriter.create<mlir::stablehlo::AbsOp>(
-          loc, rewriter.create<mlir::stablehlo::FloorOp>(
-                   loc, rewriter.create<mlir::stablehlo::AddOp>(
-                            loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
-
-  // Materialize reflection for inputs less than 0.5 as
-  //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
-  //              = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
-  Value pi = getConstantLike(rewriter, loc, M_PI, x);
-  Value piMulReducedX =
-      rewriter.create<mlir::stablehlo::MulOp>(loc, pi, reducedX);
-  Value cos = rewriter.create<mlir::stablehlo::CosineOp>(loc, piMulReducedX);
-  Value sin = rewriter.create<mlir::stablehlo::SineOp>(loc, piMulReducedX);
-  Value reflection = rewriter.create<mlir::stablehlo::SubtractOp>(
-      loc, digamma,
-      rewriter.create<mlir::stablehlo::DivOp>(
-          loc, rewriter.create<mlir::stablehlo::MulOp>(loc, pi, cos), sin));
-
-  // Select whether or not to rely on the reflection.
-  digamma = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect,
-                                                       reflection, digamma);
-
-  // Digamma has poles at negative integers and zero; return nan for those.
-  Value isLeZero = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, x, zero, mlir::stablehlo::ComparisonDirection::LE);
-  Value isInt = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x),
-      mlir::stablehlo::ComparisonDirection::EQ);
-  Value isPole = rewriter.create<mlir::stablehlo::AndOp>(loc, isLeZero, isInt);
-  return rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, isPole,
-      getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
-                      x),
-      digamma);
-}
-
-static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc,
-                                                Value val) {
-  auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
-  return getConstantLike(
-      b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
-}
-
-static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
-                             ValueRange args) {
-  // Code should match StableHLO's materializeZeta
-  assert(args.size() == 2);
-  Value x = args[0];
-  Value q = args[1];
-  static const std::array<double, 12> kZetaCoeffs{
-      -7.1661652561756670113e18,
-      1.8152105401943546773e17,
-      -4.5979787224074726105e15,
-      1.1646782814350067249e14,
-      -2.950130727918164224e12,
-      7.47242496e10,
-      -1.8924375803183791606e9,
-      47900160.0,
-      -1209600.0,
-      30240.0,
-      -720.0,
-      12.0,
-  };
-
-  // For speed we'll always use 9 iterations for the initial series estimate,
-  // and a 12 term expansion for the Euler-Maclaurin formula.
-  Value a = q;
-  Value zero = getConstantLike(rewriter, loc, 0.0, a);
-  Value negPower = zero;
-  Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
-  Value initialSum = rewriter.create<mlir::stablehlo::PowOp>(loc, q, negX);
-  Value one = getConstantLike(rewriter, loc, 1.0, a);
-  for (int i = 0; i < 9; ++i) {
-    a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, one);
-    negPower = rewriter.create<mlir::stablehlo::PowOp>(loc, a, negX);
-    initialSum =
-        rewriter.create<mlir::stablehlo::AddOp>(loc, initialSum, negPower);
-  }
-
-  a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, one);
-  negPower = rewriter.create<mlir::stablehlo::PowOp>(loc, a, negX);
-  Value oneLikeX = getConstantLike(rewriter, loc, 1.0, x);
-  Value xMinusOne =
-      rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, oneLikeX);
-  Value negPowerMulA =
-      rewriter.create<mlir::stablehlo::MulOp>(loc, negPower, a);
-  Value negPowerMulADivXMinusOne =
-      rewriter.create<mlir::stablehlo::DivOp>(loc, negPowerMulA, xMinusOne);
-  Value s = rewriter.create<mlir::stablehlo::AddOp>(loc, initialSum,
-                                                    negPowerMulADivXMinusOne);
-  Value aInverseSquare = rewriter.create<mlir::stablehlo::DivOp>(
-      loc, one, rewriter.create<mlir::stablehlo::MulOp>(loc, a, a));
-
-  Value hornerSum = zero;
-  Value factor = one;
-  // Use Horner's rule for this.
-  // Note this differs from Cephes which does a 'naive' polynomial evaluation.
-  // Using Horner's rule allows to avoid some NaN's and Infs from happening,
-  // resulting in more numerically stable code.
-  for (int i = 0; i < 11; ++i) {
-    Value factorLhs = rewriter.create<mlir::stablehlo::AddOp>(
-        loc, x, getConstantLike(rewriter, loc, 22 - 2 * i, x));
-    Value factorRhs = rewriter.create<mlir::stablehlo::AddOp>(
-        loc, x, getConstantLike(rewriter, loc, 21 - 2 * i, x));
-    factor = rewriter.create<mlir::stablehlo::MulOp>(loc, factorLhs, factorRhs);
-    hornerSum = rewriter.create<mlir::stablehlo::MulOp>(
-        loc, factor,
-        rewriter.create<mlir::stablehlo::MulOp>(
-            loc, aInverseSquare,
-            rewriter.create<mlir::stablehlo::AddOp>(
-                loc, hornerSum,
-                getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
-  }
-  Value zeroPointFiveLikeNegPower =
-      getConstantLike(rewriter, loc, .5, negPower);
-  Value xDivA = rewriter.create<mlir::stablehlo::DivOp>(loc, x, a);
-  s = rewriter.create<mlir::stablehlo::AddOp>(
-      loc, s,
-      rewriter.create<mlir::stablehlo::MulOp>(
-          loc, negPower,
-          rewriter.create<mlir::stablehlo::AddOp>(
-              loc, zeroPointFiveLikeNegPower,
-              rewriter.create<mlir::stablehlo::MulOp>(
-                  loc, xDivA,
-                  rewriter.create<mlir::stablehlo::AddOp>(
-                      loc,
-                      getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], a),
-                      hornerSum)))));
-
-  // Use the initial zeta sum without the correction term coming
-  // from Euler-Maclaurin if it is accurate enough.
-  Value absNegPower = rewriter.create<mlir::stablehlo::AbsOp>(loc, negPower);
-  Value absInitialSum =
-      rewriter.create<mlir::stablehlo::AbsOp>(loc, initialSum);
-  Value output = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc,
-      rewriter.create<mlir::stablehlo::CompareOp>(
-          loc, absNegPower,
-          rewriter.create<mlir::stablehlo::MulOp>(
-              loc, absInitialSum,
-              getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
-          mlir::stablehlo::ComparisonDirection::LT),
-      initialSum, s);
-
-  // Function is not defined for x < 1.
-  Value nan = getConstantLike(rewriter, loc,
-                              std::numeric_limits<double>::quiet_NaN(), x);
-  output = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc,
-      rewriter.create<mlir::stablehlo::CompareOp>(
-          loc, x, oneLikeX, mlir::stablehlo::ComparisonDirection::LT),
-      nan, output);
-
-  // For q <= 0, x must be an integer.
-  Value qLeZero = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, q, zero, mlir::stablehlo::ComparisonDirection::LE);
-  Value xNotInt = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x),
-      mlir::stablehlo::ComparisonDirection::NE);
-  Value xDomainError =
-      rewriter.create<mlir::stablehlo::AndOp>(loc, qLeZero, xNotInt);
-  output = rewriter.create<mlir::stablehlo::SelectOp>(loc, xDomainError, nan,
-                                                      output);
-
-  // For all integer q <= 0, zeta has a pole. The limit is only defined as
-  // +inf if x is and even integer.
-  Value inf = getConstantLike(rewriter, loc,
-                              std::numeric_limits<double>::infinity(), x);
-  Value qIsInt = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, q, rewriter.create<mlir::stablehlo::FloorOp>(loc, q),
-      mlir::stablehlo::ComparisonDirection::EQ);
-  Value atPole = rewriter.create<mlir::stablehlo::AndOp>(loc, qLeZero, qIsInt);
-  Value two = getConstantLike(rewriter, loc, 2.0, x);
-  Value xIsInt = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x),
-      mlir::stablehlo::ComparisonDirection::EQ);
-  Value xIsEven = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, rewriter.create<mlir::stablehlo::RemOp>(loc, x, two), zero,
-      mlir::stablehlo::ComparisonDirection::EQ);
-  Value xIsEvenInt =
-      rewriter.create<mlir::stablehlo::AndOp>(loc, xIsInt, xIsEven);
-  output = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, atPole,
-      rewriter.create<mlir::stablehlo::SelectOp>(loc, xIsEvenInt, inf, nan),
-      output);
-
-  // For x = 1, this is the harmonic series and diverges.
-  output = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc,
-      rewriter.create<mlir::stablehlo::CompareOp>(
-          loc, x, one, mlir::stablehlo::ComparisonDirection::EQ),
-      inf, output);
-
-  return output;
-}
-
-static Value materializePolygamma(ConversionPatternRewriter &rewriter,
-                                  Location loc, ValueRange args) {
-  mlir::chlo::PolygammaOp::Adaptor transformed(args);
-  Value n = transformed.getN();
-  Value x = transformed.getX();
-
-  // Handle integer n > 0.
-  Value one = getConstantLike(rewriter, loc, 1.0, x);
-  Value two = getConstantLike(rewriter, loc, 2.0, x);
-  Value sign = rewriter.create<mlir::stablehlo::SubtractOp>(
-      loc,
-      rewriter.create<mlir::stablehlo::MulOp>(
-          loc, two, rewriter.create<mlir::stablehlo::RemOp>(loc, n, two)),
-      one);
-  Value nPlusOne = rewriter.create<mlir::stablehlo::AddOp>(loc, n, one);
-  Value expLgammaNp1 = rewriter.create<mlir::stablehlo::ExpOp>(
-      loc, rewriter.create<chlo::LgammaOp>(loc, nPlusOne));
-  Value zeta = rewriter.create<chlo::ZetaOp>(loc, nPlusOne, x);
-  Value result = rewriter.create<mlir::stablehlo::MulOp>(
-      loc, rewriter.create<mlir::stablehlo::MulOp>(loc, sign, expLgammaNp1),
-      zeta);
-
-  // Handle n = 0.
-  Value zero = getConstantLike(rewriter, loc, 0.0, x);
-  Value nEqZero = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, n, zero, mlir::stablehlo::ComparisonDirection::EQ);
-  result = rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, nEqZero, rewriter.create<chlo::DigammaOp>(loc, x), result);
-
-  // Check that n is a natural number. Return nan, otherwise.
-  Value nonInt = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, n, rewriter.create<mlir::stablehlo::FloorOp>(loc, n),
-      mlir::stablehlo::ComparisonDirection::NE);
-  Value negative = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, n, zero, mlir::stablehlo::ComparisonDirection::LT);
-  Value nonNatural =
-      rewriter.create<mlir::stablehlo::OrOp>(loc, nonInt, negative);
-  return rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, nonNatural,
-      getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
-                      x),
-      result);
-}
-
-struct ConvertLgammaOp final : OpConversionPattern<mlir::chlo::LgammaOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::LgammaOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    FloatType minPrecisionTy = rewriter.getF32Type();
-    rewriter.replaceOp(
-        op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
-                                  minPrecisionTy, &materializeLgamma));
-    return success();
-  }
-};
-
-struct ConvertDigammaOp final : OpConversionPattern<mlir::chlo::DigammaOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::DigammaOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    FloatType minPrecisionTy = rewriter.getF32Type();
-    rewriter.replaceOp(
-        op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
-                                  minPrecisionTy, &materializeDigamma));
-    return success();
-  }
-};
-
-static Value materializeNextAfter(ConversionPatternRewriter &rewriter,
-                                  Location loc, ValueRange operands) {
-  mlir::chlo::NextAfterOp::Adaptor transformed(operands);
-  Value x = transformed.getX();
-  Value y = transformed.getY();
-  auto resultTy = cast<ShapedType>(x.getType());
-  auto bitwidth = resultTy.getElementType().getIntOrFloatBitWidth();
-  mlir::ImplicitLocOpBuilder b(loc, rewriter);
-  Type intTy = resultTy.clone(b.getIntegerType(bitwidth));
-  auto xAsInt = b.create<mlir::stablehlo::BitcastConvertOp>(intTy, x);
-  auto yAsInt = b.create<mlir::stablehlo::BitcastConvertOp>(intTy, y);
-
-  // The result is NaN if either "x" or "y" are NaN.
-  auto xIsNan = b.create<mlir::stablehlo::CompareOp>(
-      x, x, mlir::stablehlo::ComparisonDirection::NE);
-  auto yIsNan = b.create<mlir::stablehlo::CompareOp>(
-      y, y, mlir::stablehlo::ComparisonDirection::NE);
-  auto nanInput = b.create<mlir::stablehlo::OrOp>(xIsNan, yIsNan);
-  auto resultForNan = getConstantLike(
-      rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
-  auto resultForNanAsInt =
-      b.create<mlir::stablehlo::BitcastConvertOp>(intTy, resultForNan);
-
-  // The sign bit is the MSB.
-  const int64_t signBit = int64_t{1} << (bitwidth - 1);
-  // Discard the sign bit to make the result non-negative.
-  Value signMask = getConstantLike(rewriter, loc, signBit, xAsInt);
-  Value negatedSignMask = getConstantLike(rewriter, loc, ~signBit, xAsInt);
-  auto xAbs = b.create<mlir::stablehlo::AndOp>(xAsInt, negatedSignMask);
-  auto yAbs = b.create<mlir::stablehlo::AndOp>(yAsInt, negatedSignMask);
-
-  // When both "x" and "y" are equal, the result is "y".
-  auto xAndYAreEqual = b.create<mlir::stablehlo::CompareOp>(
-      x, y, mlir::stablehlo::ComparisonDirection::EQ);
-  auto resultForEqual = yAsInt;
-
-  // When both "x" and "y" are 0, the result is "y". This is a separate case
-  // from above because "x" and "y" might have a different sign.
-  Value zero = getConstantLike(rewriter, loc, 0, xAsInt);
-  auto xIsZero = b.create<mlir::stablehlo::CompareOp>(
-      xAbs, zero, mlir::stablehlo::ComparisonDirection::EQ);
-  auto yIsZero = b.create<mlir::stablehlo::CompareOp>(
-      yAbs, zero, mlir::stablehlo::ComparisonDirection::EQ);
-  auto resultForBothZero = yAsInt;
-
-  auto xSign = b.create<mlir::stablehlo::AndOp>(xAsInt, signMask);
-  auto ySign = b.create<mlir::stablehlo::AndOp>(yAsInt, signMask);
-
-  // If from == 0 && to != 0, we need to return the smallest subnormal number
-  // signed like "to".
-  Value one = getConstantLike(rewriter, loc, 1, xAsInt);
-  auto resultForXZeroYNonZero = b.create<mlir::stablehlo::OrOp>(ySign, one);
-
-  // If the sign of "x" and "y" disagree:
-  // - we need to make the magnitude of "from" smaller so that it is closer to
-  //   zero.
-  //
-  // Otherwise the signs agree:
-  // - "x" with a magnitude larger than "y" means we need to make the magnitude
-  //   smaller.
-  // - "x" with a magnitude smaller than "y" means we need to make the magnitude
-  //   larger.
-  auto signsDisagree = b.create<mlir::stablehlo::CompareOp>(
-      xSign, ySign, mlir::stablehlo::ComparisonDirection::NE);
-  auto xMagnitudeLargerThanY = b.create<mlir::stablehlo::CompareOp>(
-      xAbs, yAbs, mlir::stablehlo::ComparisonDirection::GT);
-  auto resultHasSmallerMagnitude =
-      b.create<mlir::stablehlo::OrOp>(xMagnitudeLargerThanY, signsDisagree);
-  auto minusOne = getConstantLike(rewriter, loc, -1, xAsInt);
-  auto magnitudeAdjustment = b.create<mlir::stablehlo::SelectOp>(
-      resultHasSmallerMagnitude, minusOne, one);
-  Value result = b.create<mlir::stablehlo::AddOp>(xAsInt, magnitudeAdjustment);
-  // Handle from == +-0.
-  result = b.create<mlir::stablehlo::SelectOp>(
-      xIsZero,
-      b.create<mlir::stablehlo::SelectOp>(yIsZero, resultForBothZero,
-                                          resultForXZeroYNonZero),
-      result);
-  // Handle from == to.
-  result = b.create<mlir::stablehlo::SelectOp>(xAndYAreEqual, resultForEqual,
-                                               result);
-  // Handle isnan(x) || isnan(y).
-  result =
-      b.create<mlir::stablehlo::SelectOp>(nanInput, resultForNanAsInt, result);
-
-  // Cast back to the original type.
-  return b.create<mlir::stablehlo::BitcastConvertOp>(resultTy, result);
-}
-
-struct ConvertNextAfterOp final : OpConversionPattern<mlir::chlo::NextAfterOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::NextAfterOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOp(
-        op, materializeNextAfter(rewriter, op.getLoc(), adaptor.getOperands()));
-    return success();
-  }
-};
-
-struct ConvertPolygammaOp final : OpConversionPattern<mlir::chlo::PolygammaOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::PolygammaOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    FloatType minPrecisionTy = rewriter.getF32Type();
-    rewriter.replaceOp(
-        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
-                                  minPrecisionTy, materializePolygamma));
-    return success();
-  }
-};
-
-// Sinh(x) = (e^x - e^-x) / 2
-//         = e^(x + log(1/2)) - e^(-x + log(1/2)).
-//
-// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
-// inf.
-//
-// This incorrectly overflows to +/-inf for two f32 input values, namely
-// +/-89.4159851, due to rounding error when computing x +/- log(1/2).  The
-// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
-// we deem this acceptable.
-static Value
-materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
-                                      Location loc, ValueRange operands) {
-  mlir::chlo::SinhOp::Adaptor transformed(operands);
-  Value x = transformed.getOperand();
-
-  Value logOneHalf = rewriter.create<mlir::stablehlo::LogOp>(
-      loc, getConstantLike(rewriter, loc, 0.5, x));
-  Value expAdd = rewriter.create<mlir::stablehlo::ExpOp>(
-      loc, rewriter.create<mlir::stablehlo::AddOp>(loc, x, logOneHalf));
-  Value expSub = rewriter.create<mlir::stablehlo::ExpOp>(
-      loc, rewriter.create<mlir::stablehlo::SubtractOp>(loc, logOneHalf, x));
-  return rewriter.create<mlir::stablehlo::SubtractOp>(loc, expAdd, expSub);
-}
-
-// Express `sinh` as
-//   sinh(x) = (e^x - e^-x) / 2                     if |x| < 1
-//           = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
-static Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
-                                          Location loc, ValueRange operands) {
-  Value largeSinhResult =
-      materializeSinhApproximationForLargeX(rewriter, loc, operands);
-
-  mlir::chlo::SinhOp::Adaptor transformed(operands);
-  Value x = transformed.getOperand();
-
-  // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
-  // 0.
-  // Rewrite this to avoid that. We use expm1(x) because that preserves the
-  // first order term of the taylor series of e^x.
-  // (e^(x) - e^(-x)) / 2. =
-  // (e^(x) - 1 + 1 - e^(-x)) / 2.
-  // (expm1(x) + (e^(x) - 1) / e^x) / 2.
-  // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
-  Value expm1 = rewriter.create<mlir::stablehlo::Expm1Op>(loc, x);
-  Value one = getConstantLike(rewriter, loc, 1.0, x);
-  Value oneHalf = getConstantLike(rewriter, loc, 0.5, x);
-  Value expm1PlusOne = rewriter.create<mlir::stablehlo::AddOp>(loc, expm1, one);
-  Value ratio =
-      rewriter.create<mlir::stablehlo::DivOp>(loc, expm1, expm1PlusOne);
-  Value sum = rewriter.create<mlir::stablehlo::AddOp>(loc, expm1, ratio);
-  Value smallSinhResult =
-      rewriter.create<mlir::stablehlo::MulOp>(loc, oneHalf, sum);
-
-  Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
-  Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
-      loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
-  return rewriter.create<mlir::stablehlo::SelectOp>(
-      loc, absXLtOne, smallSinhResult, largeSinhResult);
-}
-
-struct ConvertSinhOp final : OpConversionPattern<mlir::chlo::SinhOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::SinhOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Value x = adaptor.getOperand();
-    if (isa<ComplexType>(cast<ShapedType>(x.getType()).getElementType())) {
-      rewriter.replaceOp(op, materializeSinhApproximationForLargeX(
-                                 rewriter, op.getLoc(), adaptor.getOperands()));
-      return success();
-    }
-    rewriter.replaceOp(
-        op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
-                                  rewriter.getF32Type(),
-                                  &materializeSinhApproximation));
-    return success();
-  }
-};
-
-// Converts chlo.top_k to HLO iota, sort, and slice ops.
-//
-// chlo.top_k sorts along last dimension of the input tensor and then returns
-// the top K components' values and indices. This is translated into a few
-// ops in HLO: first generating an integer sequence for the indices,
-// then sort both the original input tensor and the indices together, and
-// at last slice out the top K components.
-//
-// For example, for the following IR:
-//
-// %0:2 = "chlo.top_k"(%input, k=8): tensor<16x16xf32> ->
-//                                   (tensor<16x8xf32>, tensor<16x8xi32>)
-//
-// We will get:
-//
-// %1 = "hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
-// %2 = "hlo.sort"(%input, %1) ({
-// ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
-//      %arg3: tensor<i32>, %arg4: tensor<i32>):
-//   %7 = "hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
-//   "hlo.return"(%7) : (tensor<i1>) -> ()
-// }) {dimension = 1 : i64, is_stable = true} : ...
-// %3 = "hlo.get_tuple_element"(%2) {index = 0 : i32} : ...
-// %4 = "hlo.get_tuple_element"(%2) {index = 1 : i32} : ...
-// %5 = "hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
-//                           start_indices dense<0> : tensor<2xi64>,
-//                           strides = dense<1> : tensor<2xi64>} :
-//                              (tensor<16x16xf32>) -> tensor<16x8xf32>
-// %6 = "hlo.slice"(%4) ...
-//
-// TODO(b/284078162): Decide what to do with this pattern given that we now
-// have mlir::stablehlo::TopKOp. No action needed for now given that
-// mlir::stablehlo::TopKOp is currently categorized as
-// `hasPrivateFeaturesNotInStablehlo`.
-struct ConvertTopKOp final : OpConversionPattern<mlir::chlo::TopKOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::TopKOp op, OpAdaptor /*adaptor*/,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto operandType = dyn_cast<RankedTensorType>(op.getOperand().getType());
-    if (!operandType)
-      return failure();
-    int64_t operandRank = operandType.getRank();
-    int64_t lastDimIndex = operandRank - 1;
-    int64_t lastDimSize = operandType.getDimSize(lastDimIndex);
-    int64_t lastDimResultSize =
-        mlir::hlo::isDynamicDimSize(lastDimSize)
-            ? static_cast<int64_t>(op.getK())
-            : std::min(static_cast<int64_t>(op.getK()), lastDimSize);
-    int64_t isDynamic = !operandType.hasStaticShape();
-    auto i32Type = rewriter.getIntegerType(32);
-    Value opShapeValue, resultShapeValue;
-    if (isDynamic) {
-      SmallVector<Value> sizesI32x1;
-      for (auto i = 0; i < operandType.getRank(); ++i) {
-        auto sizeI32 = rewriter.create<mlir::stablehlo::GetDimensionSizeOp>(
-            op.getLoc(), op.getOperand(), i);
-        auto sizeI32x1 = rewriter.create<mlir::stablehlo::ReshapeOp>(
-            op.getLoc(), RankedTensorType::get({1}, i32Type), sizeI32);
-        sizesI32x1.push_back(sizeI32x1);
-      }
-      opShapeValue = rewriter.create<mlir::stablehlo::ConcatenateOp>(
-          op.getLoc(), sizesI32x1,
-          /*dimension=*/0);
-      auto lastDimI32 = rewriter.create<mlir::stablehlo::ConstantOp>(
-          op.getLoc(),
-          rewriter.getI32IntegerAttr(static_cast<int32_t>(lastDimResultSize)));
-      auto lastDimI32x1 = rewriter.create<mlir::stablehlo::ReshapeOp>(
-          op.getLoc(), RankedTensorType::get({1}, i32Type), lastDimI32);
-      sizesI32x1.back() = lastDimI32x1;
-      resultShapeValue = rewriter.create<mlir::stablehlo::ConcatenateOp>(
-          op.getLoc(), sizesI32x1,
-          /*dimension=*/0);
-    }
-
-    // Create an Iota op for indices.
-    Type iotaType = RankedTensorType::get(operandType.getShape(), i32Type);
-    Value iotaOp;
-    if (isDynamic) {
-      iotaOp = rewriter.create<mlir::stablehlo::DynamicIotaOp>(
-          op.getLoc(), iotaType, opShapeValue,
-          rewriter.getI64IntegerAttr(lastDimIndex));
-    } else {
-      iotaOp = rewriter.create<mlir::stablehlo::IotaOp>(
-          op.getLoc(), iotaType, rewriter.getI64IntegerAttr(lastDimIndex));
-    }
-
-    // Create the sort op. It takes two inputs, one for the original input, the
-    // other for the indices. Use TOTALORDER comparison type instead of the
-    // default comparison if the element type is of type float.
-    Type elementType = operandType.getElementType();
-    mlir::stablehlo::SortOp sortOp =
-        createSortOp(&rewriter, op.getLoc(), {op.getOperand(), iotaOp},
-                     {elementType, i32Type}, lastDimIndex,
-                     /*isStable=*/true,
-                     /*direction=*/mlir::stablehlo::ComparisonDirection::GT);
-
-    // Get the sorted input and index tuple element.
-    Value tupleFirstElement = sortOp.getResult(0);
-    Value tupleSecondElement = sortOp.getResult(1);
-
-    SmallVector<int64_t> beginIndices(operandRank, 0);
-    auto endIndices = llvm::to_vector(operandType.getShape());
-    endIndices.back() = lastDimResultSize;
-    SmallVector<int64_t> strides(operandRank, 1);
-
-    // Get the slice for the top K elements.
-    auto indicesTy = RankedTensorType::get(operandRank, rewriter.getI64Type());
-    Value values, indices;
-    if (isDynamic) {
-      Value startIndices = rewriter.create<mlir::stablehlo::ConstantOp>(
-          op.getLoc(), DenseIntElementsAttr::get(indicesTy, beginIndices));
-      Value lastIndices = rewriter.create<mlir::stablehlo::ConvertOp>(
-          op.getLoc(), resultShapeValue, rewriter.getI64Type());
-      Value stridesOp = rewriter.create<mlir::stablehlo::ConstantOp>(
-          op.getLoc(), DenseIntElementsAttr::get(indicesTy, strides));
-
-      SmallVector<int64_t> resultShape =
-          llvm::to_vector(operandType.getShape());
-      resultShape.back() = lastDimResultSize;
-      RankedTensorType resultType = RankedTensorType::get(
-          resultShape, elementType, operandType.getEncoding());
-      RankedTensorType indexResultType =
-          RankedTensorType::get(resultShape, i32Type);
-
-      values = rewriter.create<mlir::stablehlo::RealDynamicSliceOp>(
-          op.getLoc(), resultType, tupleFirstElement, startIndices, lastIndices,
-          stridesOp);
-      indices = rewriter.create<mlir::stablehlo::RealDynamicSliceOp>(
-          op.getLoc(), indexResultType, tupleSecondElement, startIndices,
-          lastIndices, stridesOp);
-    } else {
-      values = rewriter.create<mlir::stablehlo::SliceOp>(
-          op.getLoc(), tupleFirstElement,
-          rewriter.getDenseI64ArrayAttr(beginIndices),
-          rewriter.getDenseI64ArrayAttr(endIndices),
-          rewriter.getDenseI64ArrayAttr(strides));
-      indices = rewriter.create<mlir::stablehlo::SliceOp>(
-          op.getLoc(), tupleSecondElement,
-          rewriter.getDenseI64ArrayAttr(beginIndices),
-          rewriter.getDenseI64ArrayAttr(endIndices),
-          rewriter.getDenseI64ArrayAttr(strides));
-    }
-
-    rewriter.replaceOp(op, {values, indices});
-    return success();
-  }
-};
-
-struct ConvertZetaOp final : OpConversionPattern<mlir::chlo::ZetaOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(mlir::chlo::ZetaOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    FloatType minPrecisionTy = rewriter.getF32Type();
-    rewriter.replaceOp(
-        op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
-                                  minPrecisionTy, &materializeZeta));
-    return success();
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// Pass Definition.
-//===----------------------------------------------------------------------===//
-
-struct LegalizeChlo final : impl::LegalizeChloBase<LegalizeChlo> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<mlir::scf::SCFDialect, mlir::shape::ShapeDialect,
-                    mlir::stablehlo::StablehloDialect,
-                    mlir::tensor::TensorDialect>();
-  }
-
-  void runOnOperation() override {
-    MLIRContext *ctx = &getContext();
-    {
-      ConversionTarget conversionTarget(getContext());
-      RewritePatternSet conversionPatterns(ctx);
-      conversionTarget.addIllegalDialect<chlo::ChloDialect>();
-      conversionTarget.addLegalDialect<
-          mlir::stablehlo::StablehloDialect, mlir::arith::ArithDialect,
-          mlir::shape::ShapeDialect, mlir::scf::SCFDialect,
-          mlir::tensor::TensorDialect>();
-
-      populateLegalizeChloPatterns(ctx, &conversionPatterns);
-      if (failed(applyPartialConversion(getOperation(), conversionTarget,
-                                        std::move(conversionPatterns)))) {
-        return signalPassFailure();
-      }
-    }
-
-    {
-      // Add canonicalization patterns to simplify produced ops from other
-      // dialects.
-      RewritePatternSet patterns(ctx);
-      populateCanonicalizationPatterns(ctx, &patterns);
-      mlir::shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx);
-      mlir::shape::ShapeOfOp::getCanonicalizationPatterns(patterns, ctx);
-      mlir::shape::BroadcastOp::getCanonicalizationPatterns(patterns, ctx);
-      mlir::shape::CstrBroadcastableOp::getCanonicalizationPatterns(patterns,
-                                                                    ctx);
-      mlir::tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
-      if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-        return signalPassFailure();
-      }
-    }
-  }
-};
-} // namespace
-
-namespace {
-#include "compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.h.inc"
-} // end anonymous namespace
-
-namespace {
-static void populateBroadcastingPatterns(MLIRContext *context,
-                                         RewritePatternSet *patterns) {
-  // Instantiate conversion templates for conforming binary elementwise ops
-  // that do not have different dtypes between operands and results and do
-  // not have special attributes that need to be preserved.
-  populateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
-      context, patterns, 10);
-  populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
-      context, patterns, 5);
-  patterns->add<ConvertConstantLikeOp, ConvertSelectOp>(context);
-}
-
-static void populateDecompositionPatterns(MLIRContext *context,
-                                          RewritePatternSet *patterns) {
-  populateWithGenerated(*patterns);
-  patterns->add<ConvertConstantOp, ConvertBesselI1eOp, ConvertCoshOp,
-                ConvertDigammaOp, ConvertErfOp, ConvertErfcOp, ConvertErfInvOp,
-                ConvertLgammaOp, ConvertNextAfterOp, ConvertPolygammaOp,
-                ConvertSinhOp, ConvertTopKOp, ConvertZetaOp>(context);
-}
-} // namespace
-
-void populateLegalizeChloPatterns(MLIRContext *context,
-                                  RewritePatternSet *patterns) {
-  populateBroadcastingPatterns(context, patterns);
-  populateDecompositionPatterns(context, patterns);
-}
-} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
index d5bad85..7887e2d 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
@@ -85,7 +85,8 @@
       stablehlo::createLegalizeShapeComputations());
   passManager.addNestedPass<func::FuncOp>(
       stablehlo::createConvertStableHloToLinalgExt());
-  passManager.addNestedPass<func::FuncOp>(stablehlo::createLegalizeChlo());
+  passManager.addNestedPass<func::FuncOp>(
+      mlir::stablehlo::createChloLegalizeToStablehloPass());
   passManager.addPass(createConvertStableHloToIreeInputDialects());
   passManager.addPass(createReconcileUnrealizedCastsPass());
 
diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.td b/compiler/plugins/input/StableHLO/Conversion/Passes.td
index 7852112..915a487 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Passes.td
+++ b/compiler/plugins/input/StableHLO/Conversion/Passes.td
@@ -37,11 +37,6 @@
   let summary = "Legalizes from StableHLO control flow to SCF control flow";
 }
 
-def LegalizeChlo :
-    InterfacePass<"iree-stablehlo-legalize-chlo", "mlir::FunctionOpInterface"> {
-  let summary = "Legalizes from CHLO ops flow to StableHLO and Shape ops";
-}
-
 def LegalizeStableHLOCustomCalls :
     InterfacePass<"iree-stablehlo-legalize-custom-calls", "mlir::FunctionOpInterface"> {
   let summary = "Legalizes specialized custom calls to decomposed implementations";
diff --git a/compiler/plugins/input/StableHLO/Conversion/Rewriters.h b/compiler/plugins/input/StableHLO/Conversion/Rewriters.h
index 1a328ff..956a3a4 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Rewriters.h
+++ b/compiler/plugins/input/StableHLO/Conversion/Rewriters.h
@@ -15,11 +15,6 @@
 // General StableHLO/CHLO lowering patterns.
 //===----------------------------------------------------------------------===//
 
-/// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and
-/// Shape ops.
-void populateLegalizeChloPatterns(MLIRContext *context,
-                                  RewritePatternSet *patterns);
-
 /// Collection of rewrite patterns for lowering of StableHLO ops to SCF control
 /// flow ops.
 void populateLegalizeControlFlowPatterns(MLIRContext *context,
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_decomposition.mlir b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_decomposition.mlir
index d2f9353..851d002 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_decomposition.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_decomposition.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-stablehlo-legalize-chlo))" \
-// RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: iree-opt --chlo-legalize-to-stablehlo \
+// RUN:     --split-input-file --verify-diagnostics %s | FileCheck %s
 
 // CHECK-LABEL: func.func @asin_bf16(
 func.func @asin_bf16(%arg : tensor<bf16>) -> tensor<bf16> {
@@ -508,6 +508,8 @@
 
 // -----
 
+// CHECK-LABEL: @complex_tan
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1xf32>, %[[ARG1:.+]]: tensor<1xf32>
 func.func @complex_tan(%arg0 : tensor<1xf32>, %arg1 : tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) {
   %0 = stablehlo.complex %arg0, %arg1 : tensor<1xcomplex<f32>>
   %1 = chlo.tan %0 : tensor<1xcomplex<f32>> -> tensor<1xcomplex<f32>>
@@ -516,18 +518,21 @@
   func.return %2, %3 : tensor<1xf32>, tensor<1xf32>
 }
 
-// CHECK-LABEL: @complex_tan
-// CHECK-SAME: %[[ARG0:.+]]: tensor<1xf32>, %[[ARG1:.+]]: tensor<1xf32>
-// CHECK: %[[ONE:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
-// CHECK: %[[SINE:.+]] = stablehlo.sine %[[ARG0]]
-// CHECK: %[[COS:.+]] = stablehlo.cosine %[[ARG0]]
-// CHECK: %[[TAN:.+]] = stablehlo.divide %[[SINE]], %[[COS]]
-// CHECK: %[[TANH:.+]] = stablehlo.tanh %[[ARG1]]
-// CHECK: %[[NUM:.+]] = stablehlo.complex %[[TAN]], %[[TANH]]
-// CHECK: %[[MUL:.+]] = stablehlo.multiply %[[TAN]], %[[TANH]]
-// CHECK: %[[NEG:.+]] = stablehlo.negate %[[MUL]]
-// CHECK: %[[DEN:.+]] = stablehlo.complex %[[ONE]], %[[NEG]]
-// CHECK: %[[RES:.+]] = stablehlo.divide %[[NUM]], %[[DEN]]
-// CHECK: %[[REAL:.+]] = stablehlo.real %[[RES]]
-// CHECK: %[[IMAG:.+]] = stablehlo.imag %[[RES]]
-// CHECK: return %[[REAL]], %[[IMAG]]
+
+// -----
+
+// CHECK-LABEL:  func.func @acos_complex_f32
+// CHECK-SAME:    (%[[ARG0:.+]]: tensor<complex<f32>>) -> tensor<complex<f32>>
+func.func @acos_complex_f32(%arg : tensor<complex<f32>>) -> tensor<complex<f32>> {
+  %result = "chlo.acos"(%arg) : (tensor<complex<f32>>) -> tensor<complex<f32>>
+  func.return %result : tensor<complex<f32>>
+}
+
+// -----
+
+// CHECK-LABEL:  func.func @acos_complex_f64_dynamic
+// CHECK-SAME:    (%[[ARG0:.+]]: tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>>
+func.func @acos_complex_f64_dynamic(%arg : tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>> {
+  %result = "chlo.acos"(%arg) : (tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>>
+  func.return %result : tensor<?xcomplex<f64>>
+}
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_no_broadcast.mlir b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_no_broadcast.mlir
index 3a6fbd1..2cc1506 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_no_broadcast.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_no_broadcast.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-stablehlo-legalize-chlo),cse)" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(chlo-legalize-to-stablehlo), cse)" \
 // RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
 
 // Check the non-broadcast case for each registered op, then just check a
@@ -17,17 +17,10 @@
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
 func.func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // CHECK-DAG:  %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
-  // CHECK-DAG:  %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
-  // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
-  // CHECK-DAG:    %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-DAG:    %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1]
-  // CHECK-DAG:    %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1]
-  // CHECK-NEXT:   %[[RESULT:.+]] = stablehlo.add %[[ARG0_B]], %[[ARG1_B]]
-  // CHECK-NEXT:   shape.assuming_yield %[[RESULT]]
-  // CHECK-NEXT: }
-  // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
+
+  // CHECK:stablehlo.dynamic_broadcast_in_dim
+  // CHECK:stablehlo.add
+
   %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
   func.return %0 : tensor<?x?xf32>
 }
@@ -38,17 +31,9 @@
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
 func.func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
-  // CHECK-DAG:  %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
-  // CHECK-DAG:  %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
-  // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
-  // CHECK-NEXT:   %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-DAG:    %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1] : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  // CHECK-DAG:    %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  // CHECK-NEXT:   %[[RESULT:.+]] = stablehlo.complex %[[ARG0_B]], %[[ARG1_B]] : tensor<?x?xcomplex<f32>>
-  // CHECK-NEXT:   shape.assuming_yield %[[RESULT]]
-  // CHECK-NEXT: }
-  // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
+
+  // CHECK:stablehlo.dynamic_broadcast_in_dim
+  // CHECK:stablehlo.complex
   %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
   func.return %0 : tensor<?x?xcomplex<f32>>
 }
@@ -59,17 +44,8 @@
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
 func.func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
-  // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
-  // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
-  // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
-  // CHECK: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-DAG: %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1] : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  // CHECK-DAG: %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  // CHECK: %[[RESULT:.+]] = stablehlo.compare EQ, %[[ARG0_B]], %[[ARG1_B]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
-  // CHECK: shape.assuming_yield %[[RESULT]]
-  // CHECK-NEXT: }
-  // CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
+  // CHECK:stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.compare
   %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo<comparison_direction EQ>} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
   func.return %0 : tensor<?x?xi1>
 }
@@ -78,77 +54,76 @@
 
 // CHECK-LABEL: func @selectv2
 func.func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT: stablehlo.select %arg0, %arg1, %arg2
+  // CHECK: tensor.cast
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   func.return %0: tensor<2xi32>
 }
 
 // CHECK-LABEL: func @selectv2_pred_scalar
 func.func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT: stablehlo.select %arg0, %arg1, %arg2
+  // CHECK: stablehlo.select
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   func.return %0: tensor<2xi32>
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_then
 func.func @selectv2_broadcast_then(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
-  // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1, 2] : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
-  // CHECK-NEXT: stablehlo.select %arg0, %[[BROADCAST]], %arg2
+  // CHECK:  tensor.cast
+  // CHECK:  stablehlo.dynamic_broadcast_in_dim
+  // CHECK:  stablehlo.select
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
   func.return %0: tensor<2x8x8xi32>
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_else
 func.func @selectv2_broadcast_else(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
-  // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg2, dims = [1, 2] : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
-  // CHECK-NEXT: stablehlo.select %arg0, %arg1, %[[BROADCAST]]
+ // CHECK : stablehlo.select
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
   func.return %0: tensor<2x8x8xi32>
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_pred
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<1xi1>, %[[ARG1:.*]]: tensor<2x8x8xi32>, %[[ARG2:.*]]: tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
 func.func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
-  // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [2] : (tensor<1xi1>) -> tensor<2x8x8xi1>
-  // CHECK-NEXT: stablehlo.select %[[BROADCAST]], %arg1, %arg2
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
   func.return %0: tensor<2x8x8xi32>
+
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
+
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_tensor_pred
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xi1>, %[[ARG1:.*]]: tensor<2x3xf16>, %[[ARG2:.*]]: tensor<2x3xf16>)
 func.func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
-  // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1>
-  // CHECK-NEXT: stablehlo.select %[[BROADCAST]], %arg1, %arg2
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
   func.return %0: tensor<2x3xf16>
+
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
+
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_all
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x1x1xi1>, %[[ARG1:.*]]: tensor<1x8x1xi32>, %[[ARG2:.*]]: tensor<1x1x8xi32>)
 func.func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
-  // CHECK-DAG: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1>
-  // CHECK-DAG: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [0, 1, 2] : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32>
-  // CHECK-DAG: %[[BROADCAST_2:.*]] = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2] : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
-  // CHECK: stablehlo.select %[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]]
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
   func.return %0: tensor<8x8x8xi32>
+
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
 }
 
 // CHECK-LABEL: func @selectv2_dynamic_ranked
 func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
-  // CHECK-DAG: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor<1xindex>
-  // CHECK-DAG: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor<3xindex>
-  // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor<3xindex>
-  // CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<3xindex>, tensor<1xindex>, tensor<3xindex>
-  // CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) {
-  // CHECK-NEXT:   %[[BCST:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
-  // CHECK-NEXT:   %[[BCST0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, %[[BCST]], dims = [2] : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1>
-  // CHECK-NEXT:   %[[BCST1:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[BCST]], dims = [0, 1, 2] : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
-  // CHECK-NEXT:   %[[BCST2:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[BCST]], dims = [0, 1, 2] : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
-  // CHECK-NEXT:   %[[SELECT:.*]] = stablehlo.select %[[BCST0]], %[[BCST1]], %[[BCST2]] : tensor<2x?x8xi1>, tensor<2x?x8xi32>
-  // CHECK-NEXT:   shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32>
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
   func.return %0: tensor<2x?x8xi32>
+
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
+
 }
 
 // -----
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_with_broadcast.mlir b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_with_broadcast.mlir
index 6ee69ca..df5970c 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_with_broadcast.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_with_broadcast.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-stablehlo-legalize-chlo),cse)" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(chlo-legalize-to-stablehlo), cse)" \
 // RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
 
 // Check the non-broadcast case for each registered op, then just check a
@@ -16,17 +16,8 @@
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
 func.func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // CHECK-DAG:  %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
-  // CHECK-DAG:  %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
-  // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
-  // CHECK-DAG:    %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-DAG:    %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1]
-  // CHECK-DAG:    %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1]
-  // CHECK-NEXT:   %[[RESULT:.+]] = stablehlo.add %[[ARG0_B]], %[[ARG1_B]]
-  // CHECK-NEXT:   shape.assuming_yield %[[RESULT]]
-  // CHECK-NEXT: }
-  // CHECK-NEXT:      return %[[FINAL_RESULT]] : tensor<?x?xf32>
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.add
   %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
   func.return %0 : tensor<?x?xf32>
 }
@@ -36,18 +27,9 @@
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
 func.func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
-  // CHECK-DAG:  %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
-  // CHECK-DAG:  %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
-  // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
-  // CHECK-NEXT:   %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-DAG:    %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1] : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  // CHECK-DAG:    %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  // CHECK-NEXT:   %[[RESULT:.+]] = stablehlo.complex %[[ARG0_B]], %[[ARG1_B]] : tensor<?x?xcomplex<f32>>
-  // CHECK-NEXT:   shape.assuming_yield %[[RESULT]]
-  // CHECK-NEXT: }
-  // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
-  %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.complex
+   %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
   func.return %0 : tensor<?x?xcomplex<f32>>
 }
 
@@ -56,17 +38,8 @@
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
 // CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
 func.func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
-  // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
-  // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
-  // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
-  // CHECK: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
-  // CHECK-DAG: %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1] : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  // CHECK-DAG: %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
-  // CHECK: %[[RESULT:.+]] = stablehlo.compare EQ, %[[ARG0_B]], %[[ARG1_B]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
-  // CHECK: shape.assuming_yield %[[RESULT]]
-  // CHECK-NEXT: }
-  // CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
+  // CHECK:stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.compare
   %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo<comparison_direction EQ>} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
   func.return %0 : tensor<?x?xi1>
 }
@@ -75,75 +48,74 @@
 
 // CHECK-LABEL: func @selectv2
 func.func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT: stablehlo.select %arg0, %arg1, %arg2
+  // CHECK: tensor.cast
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
+
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   func.return %0: tensor<2xi32>
 }
 
 // CHECK-LABEL: func @selectv2_pred_scalar
 func.func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
-  // CHECK-NEXT: stablehlo.select %arg0, %arg1, %arg2
+  // CHECK: stablehlo.select
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
   func.return %0: tensor<2xi32>
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_then
 func.func @selectv2_broadcast_then(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
-  // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1, 2] : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
-  // CHECK-NEXT: stablehlo.select %arg0, %[[BROADCAST]], %arg2
+
+  // CHECK:  tensor.cast
+  // CHECK:  stablehlo.dynamic_broadcast_in_dim
+  // CHECK:  stablehlo.select
+
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
   func.return %0: tensor<2x8x8xi32>
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_else
 func.func @selectv2_broadcast_else(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
-  // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg2, dims = [1, 2] : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
-  // CHECK-NEXT: stablehlo.select %arg0, %arg1, %[[BROADCAST]]
+  // CHECK: stablehlo.select
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
   func.return %0: tensor<2x8x8xi32>
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_pred
 func.func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
-  // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [2] : (tensor<1xi1>) -> tensor<2x8x8xi1>
-  // CHECK-NEXT: stablehlo.select %[[BROADCAST]], %arg1, %arg2
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
   func.return %0: tensor<2x8x8xi32>
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_tensor_pred
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xi1>, %[[ARG1:.*]]: tensor<2x3xf16>, %[[ARG2:.*]]: tensor<2x3xf16>)
 func.func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
-  // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1>
-  // CHECK-NEXT: stablehlo.select %[[BROADCAST]], %arg1, %arg2
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
   func.return %0: tensor<2x3xf16>
+
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
+
 }
 
 // CHECK-LABEL: func @selectv2_broadcast_all
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x1x1xi1>, %[[ARG1:.*]]: tensor<1x8x1xi32>, %[[ARG2:.*]]: tensor<1x1x8xi32>)
 func.func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
-  // CHECK-DAG: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1>
-  // CHECK-DAG: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [0, 1, 2] : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32>
-  // CHECK-DAG: %[[BROADCAST_2:.*]] = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2] : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
-  // CHECK: stablehlo.select %[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]]
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
   func.return %0: tensor<8x8x8xi32>
+
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
+
 }
 
 // CHECK-LABEL: func @selectv2_dynamic_ranked
 func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
-  // CHECK-DAG: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor<1xindex>
-  // CHECK-DAG: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor<3xindex>
-  // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor<3xindex>
-  // CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<3xindex>, tensor<1xindex>, tensor<3xindex>
-  // CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) {
-  // CHECK-NEXT:   %[[BCST:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
-  // CHECK-NEXT:   %[[BCST0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, %[[BCST]], dims = [2] : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1>
-  // CHECK-NEXT:   %[[BCST1:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[BCST]], dims = [0, 1, 2] : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
-  // CHECK-NEXT:   %[[BCST2:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[BCST]], dims = [0, 1, 2] : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
-  // CHECK-NEXT:   %[[SELECT:.*]] = stablehlo.select %[[BCST0]], %[[BCST1]], %[[BCST2]] : tensor<2x?x8xi1>, tensor<2x?x8xi32>
-  // CHECK-NEXT:   shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32>
-  // CHECK-NEXT: }
-  // CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32>
+  // CHECK: stablehlo.dynamic_broadcast_in_dim
+  // CHECK: stablehlo.select
+
   %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
   func.return %0: tensor<2x?x8xi32>
 }
diff --git a/compiler/plugins/input/StableHLO/PluginRegistration.cpp b/compiler/plugins/input/StableHLO/PluginRegistration.cpp
index df9f1a4..db4dead 100644
--- a/compiler/plugins/input/StableHLO/PluginRegistration.cpp
+++ b/compiler/plugins/input/StableHLO/PluginRegistration.cpp
@@ -13,6 +13,7 @@
 #include "stablehlo/dialect/ChloOps.h"
 #include "stablehlo/dialect/StablehloOps.h"
 #include "stablehlo/dialect/VhloOps.h"
+#include "stablehlo/transforms/Passes.h"
 
 namespace mlir::iree_compiler::stablehlo {
 
@@ -61,6 +62,7 @@
   static void registerPasses() {
     // TODO(scotttodd): register other StableHLO passes?
     registerStableHLOConversionPasses();
+    mlir::stablehlo::registerChloLegalizeToStablehloPass();
   }
 
   void onRegisterDialects(DialectRegistry &registry) override {