[StableHLO][CHLO]Refactor CHLO decompositions to follow upstream StableHLO (#21682)
Refactored the CHLO decompositions to align with the existing StableHLO
upstream
https://github.com/openxla/stablehlo/blob/main/stablehlo/transforms/ChloDecompositionPatternsMath.td
updated the test files accordingly by registering the StableHLO upstream
pass.
---------
Signed-off-by: LekkalaSravya3 <lekkala.sravya@multicorewareinc.com>
diff --git a/compiler/plugins/input/StableHLO/BUILD.bazel b/compiler/plugins/input/StableHLO/BUILD.bazel
index da1abae..3309373 100644
--- a/compiler/plugins/input/StableHLO/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/BUILD.bazel
@@ -32,6 +32,7 @@
"@llvm-project//mlir:Transforms",
"@stablehlo//:chlo_ops",
"@stablehlo//:stablehlo_ops",
+ "@stablehlo//:stablehlo_passes",
"@stablehlo//:vhlo_ops",
],
)
diff --git a/compiler/plugins/input/StableHLO/CMakeLists.txt b/compiler/plugins/input/StableHLO/CMakeLists.txt
index 1875e08..c27c580 100644
--- a/compiler/plugins/input/StableHLO/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/CMakeLists.txt
@@ -33,6 +33,7 @@
MLIRShapeDialect
MLIRTransforms
StablehloOps
+ StablehloPasses
VhloOps
iree::compiler::PluginAPI
iree::compiler::plugins::input::StableHLO::Conversion
diff --git a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
index 391f5ce..7231138 100644
--- a/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
+++ b/compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
@@ -42,28 +42,11 @@
],
)
-iree_gentbl_cc_library(
- name = "CHLODecompositionPatterns",
- tbl_outs = [
- (
- ["--gen-rewriters"],
- "CHLODecompositionPatterns.h.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "CHLODecompositionPatterns.td",
- deps = [
- "@stablehlo//:chlo_ops_td_files",
- "@stablehlo//:stablehlo_ops_td_files",
- ],
-)
-
iree_compiler_cc_library(
name = "StableHLOLegalization",
srcs = [
"CheckVHLOStableHloMixUsage.cpp",
"ConvertCollectives.cpp",
- "LegalizeCHLO.cpp",
"LegalizeControlFlow.cpp",
"LegalizeShapeComputations.cpp",
"LegalizeToLinalgUtils.cpp",
@@ -75,7 +58,6 @@
"VerifyCompilerInputLegality.cpp",
],
deps = [
- ":CHLODecompositionPatterns",
":PassHeaders",
"//compiler/plugins/input/StableHLO/Conversion/Preprocessing",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
diff --git a/compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.td b/compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.td
deleted file mode 100644
index 20afb72..0000000
--- a/compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.td
+++ /dev/null
@@ -1,405 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-// This is the legalization pattern definition file for CHLO to StableHLO.
-// These are included in the populateDecompositionPatterns factory
-// and should only include canonical expansions which are not actually
-// ambiguous/different for various backends. Avoid patterns that are actually
-// lowering to non-canonical forms.
-
-include "mlir/IR/OpBase.td"
-include "stablehlo/dialect/ChloOps.td"
-include "stablehlo/dialect/StablehloOps.td"
-
-class StableHLO_ComparisonDirectionValue<string enumStr> :
- ConstantAttr<StableHLO_ComparisonDirectionAttr,
- "::mlir::stablehlo::ComparisonDirection::" # enumStr>;
-
-class ConstantLike<string value> : NativeCodeCall<
- "::mlir::iree_compiler::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
-
-def ComplexElementType : Type<
- CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
- "Complex element type">;
-
-def NonComplexElementType : Type<
- CPred<"!isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
- "Non-complex element type">;
-
-def ConstantLikeMaxFiniteValue : NativeCodeCall<
- "::mlir::iree_compiler::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
-
-def ConstantLikePosInfValue : NativeCodeCall<
- "::mlir::iree_compiler::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">;
-
-def ConstantLikeNegInfValue : NativeCodeCall<
- "::mlir::iree_compiler::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/true)">;
-
-def STABLEHLO_DEFAULT_RESULT_ACCURACY :
- ConstantAttr<StableHLO_ResultAccuracyAttr, "::mlir::stablehlo::ResultAccuracyMode::DEFAULT">;
-
-//===----------------------------------------------------------------------===//
-// Unary op patterns.
-//===----------------------------------------------------------------------===//
-
-// Expand acos for non-complex arguments to MHLO dialect as follows:
-// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
-// = pi if x == -1
-//
-// TODO(b/237376133): Support operands with complex element types separately
-// using the following formula.
-// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
-def : Pat<(CHLO_AcosOp NonComplexElementType:$input),
- (StableHLO_SelectOp
- (StableHLO_CompareOp
- $input,
- (ConstantLike<"-1"> $input),
- StableHLO_ComparisonDirectionValue<"NE">,
- (STABLEHLO_DEFAULT_COMPARISON_TYPE)
- ),
- (StableHLO_MulOp
- (ConstantLike<"2"> $input),
- (StableHLO_Atan2Op
- (StableHLO_SqrtOp
- (StableHLO_SubtractOp
- (ConstantLike<"1"> $input),
- (StableHLO_MulOp $input, $input)
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- ),
- (StableHLO_AddOp
- (ConstantLike<"1"> $input),
- $input
- )
- )
- ),
- (ConstantLike<"M_PI"> $input)
- )>;
-
-// Expand acosh to MHLO dialect as follows:
-// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
-// = log(x + sqrt((x+1)*(x-1)))
-// acosh(x) = nan if x < -1
-//
-// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as
-// log(2*x) = log(2) + log(x). (Note this works because negative x never
-// overflows; x < -1 simply yields nan.
-def : Pat<(CHLO_AcoshOp NonComplexElementType:$input),
- (StableHLO_SelectOp
- (StableHLO_CompareOp
- $input,
- (ConstantLike<"-1"> $input),
- StableHLO_ComparisonDirectionValue<"LT">,
- (STABLEHLO_DEFAULT_COMPARISON_TYPE)
- ),
- (ConstantLike<"NAN"> $input),
- (StableHLO_SelectOp
- (StableHLO_CompareOp
- $input,
- (StableHLO_SqrtOp
- (ConstantLikeMaxFiniteValue $input),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- ),
- StableHLO_ComparisonDirectionValue<"GE">,
- (STABLEHLO_DEFAULT_COMPARISON_TYPE)
- ),
- (StableHLO_AddOp
- (StableHLO_LogOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
- (StableHLO_LogOp
- (ConstantLike<"2"> $input),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- ),
- (StableHLO_LogOp
- (StableHLO_AddOp
- $input,
- (StableHLO_SqrtOp
- (StableHLO_MulOp
- (StableHLO_AddOp
- (ConstantLike<"1"> $input),
- $input
- ),
- (StableHLO_AddOp
- (ConstantLike<"-1"> $input),
- $input
- )
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- )
- )>;
-
-// Expand acosh for complex arguments to MHLO dialect as
-// acosh(x) = log(x + sqrt((x+1)*(x-1)))
-//
-// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing:
-// "For now, we ignore the question of overflow if x is a
-// complex type, because we don't yet have exhaustive tests for complex trig
-// functions".
-def : Pat<(CHLO_AcoshOp ComplexElementType:$input),
- (StableHLO_LogOp
- (StableHLO_AddOp
- $input,
- (StableHLO_SqrtOp
- (StableHLO_MulOp
- (StableHLO_AddOp
- $input,
- (ConstantLike<"1"> $input)
- ),
- (StableHLO_SubtractOp
- $input,
- (ConstantLike<"1"> $input)
- )
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )>;
-
-
-// Expand asin to MHLO dialect as follows:
-// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
-def : Pat<(CHLO_AsinOp $input),
- (StableHLO_MulOp
- (ConstantLike<"2"> $input),
- (StableHLO_Atan2Op
- $input,
- (StableHLO_AddOp
- (ConstantLike<"1"> $input),
- (StableHLO_SqrtOp
- (StableHLO_SubtractOp
- (ConstantLike<"1"> $input),
- (StableHLO_MulOp $input, $input)
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- )
- )
- )>;
-
-// Expand asinh for non-complex arguments to MHLO dialect as
-// asinh(x) = log(x + sqrt(x^2 + 1))
-//
-// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
-// as 2*x and return log(2) + log(x).
-//
-// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point
-// arithmetic. However, we would like to retain the low order term of this,
-// which is around 0.5 * x^2 using a binomial expansion.
-// Let z = sqrt(a^2 + 1)
-// The following rewrite retains the lower order term.
-// log(a + sqrt(a^2 + 1))
-// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1)))
-// = log((a + a^2 + 1 + a * z + z) / (1 + z))
-// = log(1 + a + a^2 / (1 + z))
-// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1)))
-//
-// If x is negative, the above would give us some trouble; we can't approximate
-// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) =
-// -asinh(x).
-def : Pat<(CHLO_AsinhOp NonComplexElementType:$input),
- (StableHLO_MulOp
- (StableHLO_SignOp $input),
- (StableHLO_SelectOp
- (StableHLO_CompareOp
- (StableHLO_AbsOp $input),
- (StableHLO_SqrtOp
- (ConstantLikeMaxFiniteValue $input),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- ),
- StableHLO_ComparisonDirectionValue<"GE">,
- (STABLEHLO_DEFAULT_COMPARISON_TYPE)
- ),
- (StableHLO_AddOp
- (StableHLO_LogOp
- (StableHLO_AbsOp $input),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- ),
- (StableHLO_LogOp
- (ConstantLike<"2"> $input),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- ),
- (StableHLO_SelectOp
- (StableHLO_CompareOp
- (StableHLO_AbsOp $input),
- (ConstantLike<"1"> $input),
- StableHLO_ComparisonDirectionValue<"LE">,
- (STABLEHLO_DEFAULT_COMPARISON_TYPE)
- ),
- (StableHLO_Log1pOp
- (StableHLO_AddOp
- (StableHLO_AbsOp $input),
- (StableHLO_MulOp
- (StableHLO_AbsOp $input),
- (StableHLO_DivOp
- (StableHLO_AbsOp $input),
- (StableHLO_AddOp
- (ConstantLike<"1"> $input),
- (StableHLO_SqrtOp
- (StableHLO_AddOp
- (StableHLO_MulOp
- (StableHLO_AbsOp $input),
- (StableHLO_AbsOp $input)
- ),
- (ConstantLike<"1"> $input)
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- )
- )
- )
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- ),
- (StableHLO_LogOp
- (StableHLO_AddOp
- (StableHLO_AbsOp $input),
- (StableHLO_SqrtOp
- (StableHLO_AddOp
- (StableHLO_MulOp
- (StableHLO_AbsOp $input),
- (StableHLO_AbsOp $input)
- ),
- (ConstantLike<"1"> $input)
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- )
- )
- )>;
-
-// Expand asinh for complex arguments to MHLO dialect as
-// asinh(x) = log(x + sqrt(x^2 + 1))
-//
-// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing:
-// "For now, we ignore the question of overflow if x is a
-// complex type, because we don't yet have exhaustive tests for complex trig
-// functions".
-def : Pat<(CHLO_AsinhOp ComplexElementType:$input),
- (StableHLO_LogOp
- (StableHLO_AddOp
- $input,
- (StableHLO_SqrtOp
- (StableHLO_AddOp
- (StableHLO_MulOp $input, $input),
- (ConstantLike<"1"> $input)
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- ),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )>;
-
-// Express `atan` as
-// atan(x) = atan2(x, 1)
-def : Pat<(CHLO_AtanOp $input),
- (StableHLO_Atan2Op
- $input,
- (ConstantLike<"1"> $input)
- )>;
-
-// Express `atanh` for non-complex arguments as follows:
-// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
-// atanh(x) = nan otherwise
-def : Pat<(CHLO_AtanhOp NonComplexElementType:$input),
- (StableHLO_SelectOp
- (StableHLO_CompareOp
- (StableHLO_AbsOp $input),
- (ConstantLike<"1"> $input),
- StableHLO_ComparisonDirectionValue<"GT">,
- (STABLEHLO_DEFAULT_COMPARISON_TYPE)
- ),
- (ConstantLike<"NAN"> $input),
- (StableHLO_MulOp
- (StableHLO_SubtractOp
- (StableHLO_Log1pOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
- (StableHLO_Log1pOp
- (StableHLO_NegOp $input),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- ),
- (ConstantLike<"0.5"> $input)
- )
- )>;
-
-// Express `atanh` for complex arguments as follows:
-// atanh(x) = (log(1 + x) - log(1 + (-x))) * 0.5
-//
-// Per tensorflow/compiler/xla/client/lib/math.cc at the time of writing:
-// "For now, we ignore the nan edge case for complex inputs,
-// because we don't yet have exhaustive tests for complex trig functions".
-def : Pat<(CHLO_AtanhOp ComplexElementType:$input),
- (StableHLO_MulOp
- (StableHLO_SubtractOp
- (StableHLO_Log1pOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
- (StableHLO_Log1pOp
- (StableHLO_NegOp $input),
- STABLEHLO_DEFAULT_RESULT_ACCURACY
- )
- ),
- (ConstantLike<"0.5"> $input)
- )>;
-
-// Express `conj` as
-// conj(x) = (re(x), -im(x)).
-def : Pat<(CHLO_ConjOp $v),
- (StableHLO_ComplexOp (StableHLO_RealOp $v), (StableHLO_NegOp (StableHLO_ImagOp $v)))>;
-
-// Express `is_inf` as
-// is_inf(x) = is_pos_inf(|x|)
-def : Pat<(CHLO_IsInfOp NonComplexElementType:$input),
- (CHLO_IsPosInfOp
- (StableHLO_AbsOp $input)
- )>;
-
-// Express `is_pos_inf` as
-// is_pos_inf(x) = (x == +inf)
-def : Pat<(CHLO_IsPosInfOp NonComplexElementType:$input),
- (StableHLO_CompareOp
- $input,
- (ConstantLikePosInfValue $input),
- StableHLO_ComparisonDirectionValue<"EQ">,
- (STABLEHLO_DEFAULT_COMPARISON_TYPE)
- )>;
-
-// Express `is_neg_inf` as
-// is_neg_inf(x) = (x == -inf)
-def : Pat<(CHLO_IsNegInfOp NonComplexElementType:$input),
- (StableHLO_CompareOp
- $input,
- (ConstantLikeNegInfValue $input),
- StableHLO_ComparisonDirectionValue<"EQ">,
- (STABLEHLO_DEFAULT_COMPARISON_TYPE)
- )>;
-
-// Express `tan` as
-// sine(x) / cosine(x)
-def : Pat<(CHLO_TanOp NonComplexElementType:$input),
- (StableHLO_DivOp
- (StableHLO_SineOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY),
- (StableHLO_CosineOp $input, STABLEHLO_DEFAULT_RESULT_ACCURACY)
- )>;
-
-
-// Express `tan(a + bi)` as
-// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b))
-def : Pat<(CHLO_TanOp ComplexElementType:$input),
- (StableHLO_DivOp
- (StableHLO_ComplexOp
- (CHLO_TanOp:$tan (StableHLO_RealOp $input)),
- (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input), STABLEHLO_DEFAULT_RESULT_ACCURACY)),
- (StableHLO_ComplexOp
- (ConstantLike<"1.0"> $tan),
- (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh)))
- )>;
diff --git a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
index 7ee0ff6..018b437 100644
--- a/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
+++ b/compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
@@ -34,22 +34,12 @@
PUBLIC
)
-iree_tablegen_library(
- NAME
- CHLODecompositionPatterns
- TD_FILE
- "CHLODecompositionPatterns.td"
- OUTS
- --gen-rewriters CHLODecompositionPatterns.h.inc
-)
-
iree_cc_library(
NAME
StableHLOLegalization
SRCS
"CheckVHLOStableHloMixUsage.cpp"
"ConvertCollectives.cpp"
- "LegalizeCHLO.cpp"
"LegalizeControlFlow.cpp"
"LegalizeShapeComputations.cpp"
"LegalizeToLinalgUtils.cpp"
@@ -60,7 +50,6 @@
"StableHLOToLinalgExt.cpp"
"VerifyCompilerInputLegality.cpp"
DEPS
- ::CHLODecompositionPatterns
::PassHeaders
ChloOps
LLVMSupport
diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
deleted file mode 100644
index 713280c..0000000
--- a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
+++ /dev/null
@@ -1,2225 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-// Implements logic for lowering CHLO ops to StableHLO and Shape dialect ops,
-// taking care of CHLO's broadcasting semantics
-
-#include "compiler/plugins/input/StableHLO/Conversion/Passes.h"
-#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h"
-#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Shape/IR/Shape.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "stablehlo/dialect/BroadcastUtils.h"
-#include "stablehlo/dialect/ChloOps.h"
-#include "stablehlo/dialect/StablehloOps.h"
-
-namespace mlir::iree_compiler::stablehlo {
-
-#define GEN_PASS_DEF_LEGALIZECHLO
-#include "compiler/plugins/input/StableHLO/Conversion/Passes.h.inc"
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Helpers.
-//===----------------------------------------------------------------------===//
-
-template <typename FromOpTy, typename ToOpTy>
-struct HloNaryElementwiseAdaptor {
- static ToOpTy createOp(FromOpTy fromOp, Type resultType,
- ValueRange broadcastedOperands, OpBuilder &builder) {
- return builder.create<ToOpTy>(fromOp.getLoc(), resultType,
- broadcastedOperands);
- }
-};
-
-static std::optional<mlir::stablehlo::ComparisonDirection>
-toStableHloComparisonDirection(mlir::chlo::ComparisonDirection value) {
- switch (value) {
- case mlir::chlo::ComparisonDirection::EQ:
- return mlir::stablehlo::ComparisonDirection::EQ;
- case mlir::chlo::ComparisonDirection::NE:
- return mlir::stablehlo::ComparisonDirection::NE;
- case mlir::chlo::ComparisonDirection::GE:
- return mlir::stablehlo::ComparisonDirection::GE;
- case mlir::chlo::ComparisonDirection::GT:
- return mlir::stablehlo::ComparisonDirection::GT;
- case mlir::chlo::ComparisonDirection::LE:
- return mlir::stablehlo::ComparisonDirection::LE;
- case mlir::chlo::ComparisonDirection::LT:
- return mlir::stablehlo::ComparisonDirection::LT;
- }
- return {};
-}
-
-static std::optional<mlir::stablehlo::ComparisonType>
-toStableHloComparisonType(mlir::chlo::ComparisonType value) {
- switch (value) {
- case mlir::chlo::ComparisonType::NOTYPE:
- return mlir::stablehlo::ComparisonType::NOTYPE;
- case mlir::chlo::ComparisonType::FLOAT:
- return mlir::stablehlo::ComparisonType::FLOAT;
- case mlir::chlo::ComparisonType::TOTALORDER:
- return mlir::stablehlo::ComparisonType::TOTALORDER;
- case mlir::chlo::ComparisonType::SIGNED:
- return mlir::stablehlo::ComparisonType::SIGNED;
- case mlir::chlo::ComparisonType::UNSIGNED:
- return mlir::stablehlo::ComparisonType::UNSIGNED;
- }
- return {};
-}
-
-struct HloCompareAdaptor {
- static mlir::stablehlo::CompareOp
- createOp(mlir::chlo::BroadcastCompareOp fromOp, Type resultType,
- ValueRange broadcastedOperands, OpBuilder &builder) {
- auto chloDirection = fromOp.getComparisonDirection();
- auto hloDirection = toStableHloComparisonDirection(chloDirection);
- if (!hloDirection)
- return nullptr;
- auto chloType =
- fromOp.getCompareType().value_or(mlir::chlo::ComparisonType::NOTYPE);
- auto hloType = toStableHloComparisonType(chloType);
- if (!hloType)
- return nullptr;
- auto hloTypeAttr = fromOp.getCompareType()
- ? mlir::stablehlo::ComparisonTypeAttr::get(
- builder.getContext(), *hloType)
- : nullptr;
- return builder.create<mlir::stablehlo::CompareOp>(
- fromOp.getLoc(), resultType, broadcastedOperands[0],
- broadcastedOperands[1], *hloDirection, hloTypeAttr);
- }
-};
-
-// Populate a pattern for each Broadcasting Chlo op. This requires the pattern
-// to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values.
-template <template <typename, typename, typename> typename Pattern,
- typename... ConstructorArgs>
-static void populateForBroadcastingBinaryOp(MLIRContext *context,
- RewritePatternSet *patterns,
- ConstructorArgs &&...args) {
-#define POPULATE_BCAST(ChloOp, HloOp) \
- patterns \
- ->add<Pattern<ChloOp, HloOp, HloNaryElementwiseAdaptor<ChloOp, HloOp>>>( \
- context, args...);
-
- POPULATE_BCAST(mlir::chlo::BroadcastAddOp, mlir::stablehlo::AddOp);
- POPULATE_BCAST(mlir::chlo::BroadcastAndOp, mlir::stablehlo::AndOp);
- POPULATE_BCAST(mlir::chlo::BroadcastAtan2Op, mlir::stablehlo::Atan2Op);
- POPULATE_BCAST(mlir::chlo::BroadcastComplexOp, mlir::stablehlo::ComplexOp);
- POPULATE_BCAST(mlir::chlo::BroadcastDivOp, mlir::stablehlo::DivOp);
- POPULATE_BCAST(mlir::chlo::BroadcastMaxOp, mlir::stablehlo::MaxOp);
- POPULATE_BCAST(mlir::chlo::BroadcastMinOp, mlir::stablehlo::MinOp);
- POPULATE_BCAST(mlir::chlo::BroadcastMulOp, mlir::stablehlo::MulOp);
- POPULATE_BCAST(mlir::chlo::BroadcastNextAfterOp, mlir::chlo::NextAfterOp);
- POPULATE_BCAST(mlir::chlo::BroadcastOrOp, mlir::stablehlo::OrOp);
- POPULATE_BCAST(mlir::chlo::BroadcastPolygammaOp, mlir::chlo::PolygammaOp);
- POPULATE_BCAST(mlir::chlo::BroadcastPowOp, mlir::stablehlo::PowOp);
- POPULATE_BCAST(mlir::chlo::BroadcastRemOp, mlir::stablehlo::RemOp);
- POPULATE_BCAST(mlir::chlo::BroadcastShiftLeftOp,
- mlir::stablehlo::ShiftLeftOp);
- POPULATE_BCAST(mlir::chlo::BroadcastShiftRightArithmeticOp,
- mlir::stablehlo::ShiftRightArithmeticOp);
- POPULATE_BCAST(mlir::chlo::BroadcastShiftRightLogicalOp,
- mlir::stablehlo::ShiftRightLogicalOp);
- POPULATE_BCAST(mlir::chlo::BroadcastSubOp, mlir::stablehlo::SubtractOp);
- POPULATE_BCAST(mlir::chlo::BroadcastXorOp, mlir::stablehlo::XorOp);
- POPULATE_BCAST(mlir::chlo::BroadcastZetaOp, mlir::chlo::ZetaOp);
-
-#undef POPULATE_BCAST
-
- // Broadcasting ops requiring special construction.
- patterns->add<Pattern<mlir::chlo::BroadcastCompareOp,
- mlir::stablehlo::CompareOp, HloCompareAdaptor>>(
- context, args...);
-}
-
-template <typename T>
-static Value getConstantLike(OpBuilder &b, Location loc, T constant,
- Value val) {
- Type ty = getElementTypeOrSelf(val.getType());
- auto getAttr = [&]() -> Attribute {
- if (isa<IntegerType>(ty))
- return b.getIntegerAttr(ty, constant);
- if (isa<FloatType>(ty))
- return b.getFloatAttr(ty, constant);
- if (auto complexTy = dyn_cast<ComplexType>(ty)) {
- return complex::NumberAttr::get(complexTy, constant, 0);
- }
- llvm_unreachable("unhandled element type");
- };
- return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
- val);
-}
-
-static Value getConstantLike(OpBuilder &b, Location loc,
- const APFloat &constant, Value val) {
- Type ty = getElementTypeOrSelf(val.getType());
- return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
- val);
-}
-
-static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc,
- Value val) {
- auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
- return getConstantLike(
- b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
-}
-
-static Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val,
- bool negative) {
- auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
- return getConstantLike(
- b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
-}
-
-//===----------------------------------------------------------------------===//
-// Broadcasting Patterns.
-//===----------------------------------------------------------------------===//
-
-// Converts binary ops that statically are determined to not broadcast directly
-// to the corresponding stablehlo non-broadcasting op.
-template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
-struct ConvertTrivialNonBroadcastBinaryOp final
- : OpConversionPattern<ChloOpTy> {
- using OpConversionPattern<ChloOpTy>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Only rewrite for statically determinable non-broadcasting cases.
- auto lhsType = dyn_cast<RankedTensorType>(adaptor.getLhs().getType());
- auto rhsType = dyn_cast<RankedTensorType>(adaptor.getRhs().getType());
- if (!lhsType || !rhsType)
- return failure();
-
- // Requires rank broadcast.
- if (lhsType.getRank() != rhsType.getRank())
- return failure();
-
- // Any dynamic dimension may require broadcasting and requires more
- // analysis.
- if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) {
- return failure();
- }
-
- if (!llvm::equal(lhsType.getShape(), rhsType.getShape())) {
- return failure();
- }
-
- rewriter.replaceOp(
- op, ValueRange{Adaptor::createOp(op, op.getResult().getType(),
- adaptor.getOperands(), rewriter)});
- return success();
- }
-};
-
-// Converts a binary op with ranked broadcasting operands to explicitly
-// broadcast and invoke the corresponding stablehlo non-broadcasting op.
-// Note that dynamic broadcasting supported by this pattern is only valid for
-// "numpy" broadcasting semantics as defined here:
-// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
-// Specifically, this includes the following cases:
-// - Same rank broadcast (operands have the same static rank).
-// - Different-rank broadcast, either without a broadcast_dims attribute or
-// with the broadcast_dims attribute set to map to a prefix padding.
-// - Legal combinations of degenerate (1-dim) implicit broadcasting.
-// The restriction on broadcast_dims derives from the definition of the
-// `shape.broadcast` op, which only supports prefix-padding.
-template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
-struct ConvertRankedDynamicBroadcastBinaryOp final
- : OpConversionPattern<ChloOpTy> {
- using OpConversionPattern<ChloOpTy>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Only support ranked operands.
- Value lhs = adaptor.getLhs();
- Value rhs = adaptor.getRhs();
- auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
- auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
- auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
- if (!lhsType || !rhsType || !resultType)
- return failure();
-
- // Check for "numpy"-style rank broadcast.
- auto broadcastDimensions = op.getBroadcastDimensions();
- if (broadcastDimensions && !mlir::hlo::isLegalNumpyRankedBroadcast(
- lhs, rhs, *broadcastDimensions)) {
- // Note: It is unclear whether the general specification of explicit
- // broadcast_dimensions on binary ops is a feature we want to carry
- // forward. While it can technically be implemented for ranked-dynamic,
- // it is incompatible with unranked inputs. If this warning is emitted
- // in real programs, it is an indication that the feature should be
- // implemented versus just falling back on the more standard definition
- // of numpy-like prefix-padding.
- op.emitWarning() << "unsupported non prefix-padded dynamic rank "
- << "broadcast_dimensions = " << *broadcastDimensions;
- return failure();
- }
-
- // Compute result shape.
- Location loc = op.getLoc();
-
- // Insert a constraint on the shapes being broadcastable and insert all
- // future code into an assuming block reliant on the constraint.
- Value lhsShape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
- Value rhsShape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
- auto broadcastableCstr =
- rewriter.create<shape::CstrBroadcastableOp>(loc, lhsShape, rhsShape);
- auto assumingOp = rewriter.create<shape::AssumingOp>(
- loc, ArrayRef<Type>{resultType}, broadcastableCstr.getResult());
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.createBlock(&assumingOp.getDoRegion());
-
- int64_t resultRank = std::max(lhsType.getRank(), rhsType.getRank());
- Value resultExtents =
- hlo::computeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
- rewriter);
-
- // Note that we unconditionally emit DynamicBroadcastInDim ops and let
- // downstream canonicalizations fold them away if possible. This is
- // because, in the dynamic case, there are many corner cases regarding
- // when it is safe to omit, and some of them require analysis to prove
- // properly.
- auto lhsBroadcastDimensions = llvm::to_vector(
- llvm::seq<int64_t>(resultRank - lhsType.getRank(), resultRank));
- Value broadcastedLhs =
- rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
- loc,
- RankedTensorType::get(resultType.getShape(),
- lhsType.getElementType()),
- lhs, resultExtents,
- rewriter.getDenseI64ArrayAttr(lhsBroadcastDimensions));
- auto rhsBroadcastDimensions = llvm::to_vector(
- llvm::seq<int64_t>(resultRank - rhsType.getRank(), resultRank));
- Value broadcastedRhs =
- rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
- loc,
- RankedTensorType::get(resultType.getShape(),
- rhsType.getElementType()),
- rhs, resultExtents,
- rewriter.getDenseI64ArrayAttr(rhsBroadcastDimensions));
-
- // And generate the final non-broadcasted binary op.
- Value finalResult = Adaptor::createOp(
- op, resultType, {broadcastedLhs, broadcastedRhs}, rewriter);
- rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
- rewriter.replaceOp(op, {assumingOp.getResult(0)});
- return success();
- }
-};
-
-struct ConvertConstantLikeOp final
- : OpConversionPattern<mlir::chlo::ConstantLikeOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::ConstantLikeOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto resultTy = cast<ShapedType>(op.getType());
-
- // Unranked uses are not supported.
- if (!resultTy.hasRank())
- return failure();
-
- // Lower to HLO constant if statically shaped.
- if (resultTy.hasStaticShape()) {
- auto complexAttr = dyn_cast<mlir::complex::NumberAttr>(op.getValue());
- auto attr = DenseElementsAttr::get(resultTy, complexAttr ? complexAttr
- : op.getValue());
- rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, attr);
- return success();
- }
-
- // Lower to broadcasted constant.
- Location loc = op.getLoc();
- Value constant =
- rewriter.create<mlir::stablehlo::ConstantOp>(loc, op.getValue());
- Value shape = rewriter.create<shape::ShapeOfOp>(loc, adaptor.getOperand());
- rewriter.replaceOpWithNewOp<mlir::stablehlo::DynamicBroadcastInDimOp>(
- op, resultTy, constant, shape, rewriter.getDenseI64ArrayAttr({}));
- return success();
- }
-};
-
-struct ConvertSelectOp final
- : OpConversionPattern<mlir::chlo::BroadcastSelectOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::BroadcastSelectOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // Only support ranked operands.
- Value pred = adaptor.getPred();
- Value onTrue = adaptor.getOnTrue();
- Value onFalse = adaptor.getOnFalse();
- auto predType = dyn_cast<RankedTensorType>(pred.getType());
- auto onTrueType = dyn_cast<RankedTensorType>(onTrue.getType());
- auto onFalseType = dyn_cast<RankedTensorType>(onFalse.getType());
- auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
- if (!predType || !onTrueType || !onFalseType || !resultType) {
- return failure();
- }
-
- Location loc = op.getLoc();
- Value predShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
- Value onTrueShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onTrue);
- Value onFalseShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onFalse);
- int64_t resultRank = std::max(
- {predType.getRank(), onTrueType.getRank(), onFalseType.getRank()});
-
- Value broadcastableCstr = rewriter.createOrFold<shape::CstrBroadcastableOp>(
- loc, ValueRange{predShape, onTrueShape, onFalseShape});
- auto assumingOp = rewriter.create<shape::AssumingOp>(
- loc, ArrayRef<Type>{resultType}, broadcastableCstr);
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.createBlock(&assumingOp.getDoRegion());
-
- Value resultExtents = rewriter.createOrFold<shape::BroadcastOp>(
- loc, shape::getExtentTensorType(op.getContext()),
- ValueRange{predShape, onTrueShape, onFalseShape},
- /*error=*/nullptr);
- auto shapeType =
- RankedTensorType::get({resultRank}, rewriter.getIndexType());
- resultExtents =
- rewriter.createOrFold<tensor::CastOp>(loc, shapeType, resultExtents);
-
- Value broadcastedPred = pred;
- // Pred has an implicit broadcast for scalars, so use that when convenient.
- if (predType.getRank() > 0) {
- auto predBroadcastDimensions = llvm::to_vector(
- llvm::seq<int64_t>(resultRank - predType.getRank(), resultRank));
- broadcastedPred =
- rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
- loc,
- RankedTensorType::get(resultType.getShape(),
- predType.getElementType()),
- pred, resultExtents,
- rewriter.getDenseI64ArrayAttr(predBroadcastDimensions));
- }
- auto onTrueBroadcastDimensions = llvm::to_vector(
- llvm::seq<int64_t>(resultRank - onTrueType.getRank(), resultRank));
- Value broadcastedOnTrue =
- rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
- loc,
- RankedTensorType::get(resultType.getShape(),
- onTrueType.getElementType()),
- onTrue, resultExtents,
- rewriter.getDenseI64ArrayAttr(onTrueBroadcastDimensions));
- auto onFalseBroadcastDimensions = llvm::to_vector(
- llvm::seq<int64_t>(resultRank - onFalseType.getRank(), resultRank));
- Value broadcastedOnFalse =
- rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>(
- loc,
- RankedTensorType::get(resultType.getShape(),
- onFalseType.getElementType()),
- onFalse, resultExtents,
- rewriter.getDenseI64ArrayAttr(onFalseBroadcastDimensions));
-
- // And generate the final non-broadcasted ternary op.
- Value finalResult = rewriter.create<mlir::stablehlo::SelectOp>(
- loc, resultType, broadcastedPred, broadcastedOnTrue,
- broadcastedOnFalse);
- rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
- rewriter.replaceOp(op, {assumingOp.getResult(0)});
- return success();
- }
-};
-
-//===----------------------------------------------------------------------===//
-// Decomposition Patterns.
-//===----------------------------------------------------------------------===//
-
-struct ConvertConstantOp final : OpConversionPattern<mlir::chlo::ConstantOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::ConstantOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, op.getValue());
- return success();
- }
-};
-
-template <typename FTy>
-static Value
-materializeChebyshevPolynomialApproximation(ConversionPatternRewriter &rewriter,
- Location loc, Value x,
- ArrayRef<FTy> coefficients) {
- Value b0 = getConstantLike(rewriter, loc, 0.0, x);
- Value b1 = getConstantLike(rewriter, loc, 0.0, x);
- Value b2 = getConstantLike(rewriter, loc, 0.0, x);
- for (FTy c : coefficients) {
- b2 = b1;
- b1 = b0;
- b0 = rewriter.create<mlir::stablehlo::MulOp>(loc, x.getType(), x, b1);
- b0 = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x.getType(), b0, b2);
- b0 = rewriter.create<mlir::stablehlo::AddOp>(
- loc, x.getType(), b0, getConstantLike(rewriter, loc, c, x));
- }
- Value result =
- rewriter.create<mlir::stablehlo::SubtractOp>(loc, x.getType(), b0, b2);
- result = rewriter.create<mlir::stablehlo::MulOp>(
- loc, x.getType(), result, getConstantLike(rewriter, loc, 0.5, x));
- return result;
-}
-
-template <typename FTy>
-static Value materializeBesselI1eApproximation(
- ConversionPatternRewriter &rewriter, Location loc, Value x,
- ArrayRef<FTy> kI1eCoeffsA, ArrayRef<FTy> kI1eCoeffsB) {
- Value z = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
- Value half = getConstantLike(rewriter, loc, 0.5, x);
- Value two = getConstantLike(rewriter, loc, 2.0, x);
- Value thirtyTwo = getConstantLike(rewriter, loc, 32.0, x);
- Value eight = getConstantLike(rewriter, loc, 8.0, x);
-
- Value tmp = rewriter.create<mlir::stablehlo::MulOp>(loc, half, z);
- tmp = rewriter.create<mlir::stablehlo::SubtractOp>(loc, tmp, two);
-
- Value xLe8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
- kI1eCoeffsA);
- xLe8 = rewriter.create<mlir::stablehlo::MulOp>(loc, z, xLe8);
-
- tmp = rewriter.create<mlir::stablehlo::DivOp>(loc, thirtyTwo, z);
- tmp = rewriter.create<mlir::stablehlo::SubtractOp>(loc, tmp, two);
- Value xGt8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
- kI1eCoeffsB);
- xGt8 = rewriter.create<mlir::stablehlo::DivOp>(
- loc, xGt8, rewriter.create<mlir::stablehlo::SqrtOp>(loc, z));
-
- Value isLe8 = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, z, eight, mlir::stablehlo::ComparisonDirection::LE);
-
- Value select =
- rewriter.create<mlir::stablehlo::SelectOp>(loc, isLe8, xLe8, xGt8);
- return rewriter.create<mlir::stablehlo::MulOp>(
- loc, rewriter.create<mlir::stablehlo::SignOp>(loc, x), select);
-}
-
-Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
- "expect f32 element type");
- const float kI1eCoeffsA[] = {
- 9.38153738649577178388E-9f, -4.44505912879632808065E-8f,
- 2.00329475355213526229E-7f, -8.56872026469545474066E-7f,
- 3.47025130813767847674E-6f, -1.32731636560394358279E-5f,
- 4.78156510755005422638E-5f, -1.61760815825896745588E-4f,
- 5.12285956168575772895E-4f, -1.51357245063125314899E-3f,
- 4.15642294431288815669E-3f, -1.05640848946261981558E-2f,
- 2.47264490306265168283E-2f, -5.29459812080949914269E-2f,
- 1.02643658689847095384E-1f, -1.76416518357834055153E-1f,
- 2.52587186443633654823E-1f};
-
- const float kI1eCoeffsB[] = {
- -3.83538038596423702205E-9f, -2.63146884688951950684E-8f,
- -2.51223623787020892529E-7f, -3.88256480887769039346E-6f,
- -1.10588938762623716291E-4f, -9.76109749136146840777E-3f,
- 7.78576235018280120474E-1f};
-
- return materializeBesselI1eApproximation<float>(rewriter, loc, x, kI1eCoeffsA,
- kI1eCoeffsB);
-}
-
-static Value
-materializeBesselI1eApproximationF64(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
- "expect f64 element type");
-
- const double kI1eCoeffsA[] = {
- 2.77791411276104639959E-18, -2.11142121435816608115E-17,
- 1.55363195773620046921E-16, -1.10559694773538630805E-15,
- 7.60068429473540693410E-15, -5.04218550472791168711E-14,
- 3.22379336594557470981E-13, -1.98397439776494371520E-12,
- 1.17361862988909016308E-11, -6.66348972350202774223E-11,
- 3.62559028155211703701E-10, -1.88724975172282928790E-9,
- 9.38153738649577178388E-9, -4.44505912879632808065E-8,
- 2.00329475355213526229E-7, -8.56872026469545474066E-7,
- 3.47025130813767847674E-6, -1.32731636560394358279E-5,
- 4.78156510755005422638E-5, -1.61760815825896745588E-4,
- 5.12285956168575772895E-4, -1.51357245063125314899E-3,
- 4.15642294431288815669E-3, -1.05640848946261981558E-2,
- 2.47264490306265168283E-2, -5.29459812080949914269E-2,
- 1.02643658689847095384E-1, -1.76416518357834055153E-1,
- 2.52587186443633654823E-1};
-
- const double kI1eCoeffsB[] = {
- 7.51729631084210481353E-18, 4.41434832307170791151E-18,
- -4.65030536848935832153E-17, -3.20952592199342395980E-17,
- 2.96262899764595013876E-16, 3.30820231092092828324E-16,
- -1.88035477551078244854E-15, -3.81440307243700780478E-15,
- 1.04202769841288027642E-14, 4.27244001671195135429E-14,
- -2.10154184277266431302E-14, -4.08355111109219731823E-13,
- -7.19855177624590851209E-13, 2.03562854414708950722E-12,
- 1.41258074366137813316E-11, 3.25260358301548823856E-11,
- -1.89749581235054123450E-11, -5.58974346219658380687E-10,
- -3.83538038596423702205E-9, -2.63146884688951950684E-8,
- -2.51223623787020892529E-7, -3.88256480887769039346E-6,
- -1.10588938762623716291E-4, -9.76109749136146840777E-3,
- 7.78576235018280120474E-1};
-
- return materializeBesselI1eApproximation<double>(rewriter, loc, x,
- kI1eCoeffsA, kI1eCoeffsB);
-}
-
-static Value materializeWithUpcast(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args,
- FloatType minPrecisionTy,
- Value callback(ConversionPatternRewriter &,
- Location, ValueRange)) {
- Type originalTy = getElementTypeOrSelf(args.front().getType());
- auto floatOriginalTy = dyn_cast<FloatType>(originalTy);
- bool needsUpcast =
- floatOriginalTy && floatOriginalTy.getWidth() < minPrecisionTy.getWidth();
-
- // Upcast arguments if necessary.
- llvm::SmallVector<Value, 2> castedArgs;
- if (needsUpcast) {
- for (Value a : args) {
- castedArgs.push_back(
- rewriter.create<mlir::stablehlo::ConvertOp>(loc, a, minPrecisionTy));
- }
- args = castedArgs;
- }
-
- Value result = callback(rewriter, loc, args);
-
- // Cast back if necessary.
- if (needsUpcast) {
- result =
- rewriter.create<mlir::stablehlo::ConvertOp>(loc, result, originalTy);
- }
-
- return result;
-}
-
-struct ConvertBesselI1eOp final : OpConversionPattern<mlir::chlo::BesselI1eOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::BesselI1eOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- Value x = adaptor.getOperand();
- Type ty = cast<ShapedType>(x.getType()).getElementType();
-
- // For now, we support only f64, f32, f16 and bf16.
- // See https://www.tensorflow.org/api_docs/python/tf/math/bessel_i1e
- if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) {
- return failure();
- }
-
- if (ty.isF64()) {
- rewriter.replaceOp(
- op, materializeBesselI1eApproximationF64(rewriter, loc, x));
- return success();
- }
-
- rewriter.replaceOp(
- op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
- rewriter.getF32Type(),
- &materializeBesselI1eApproximationF32));
- return success();
- }
-};
-
-template <typename FTy>
-static Value
-materializePolynomialApproximation(ConversionPatternRewriter &rewriter,
- Location loc, Value x,
- ArrayRef<FTy> coefficients) {
- if (coefficients.empty())
- return getConstantLike(rewriter, loc, 0.0, x);
-
- Value poly = getConstantLike(rewriter, loc, coefficients[0], x);
- for (size_t i = 1, e = coefficients.size(); i < e; ++i) {
- poly = rewriter.create<mlir::stablehlo::MulOp>(loc, x.getType(), poly, x);
- poly = rewriter.create<mlir::stablehlo::AddOp>(
- loc, x.getType(), poly,
- getConstantLike(rewriter, loc, coefficients[i], x));
- }
- return poly;
-}
-
-// Precondition is |x| >= 1. Use erf approximation, otherwise.
-//
-// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
-// argument and derive the final approximation for all |x| >= 1.
-// This implementation is based on Cephes.
-static Value materializeErfcApproximationF64ForMagnituteGeOne(
- ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
- "expect f64 element type");
- const double kMaxlog = 7.09782712893383996843E2;
- const double kErfcPCoefficients[] = {
- 2.46196981473530512524E-10, 5.64189564831068821977E-1,
- 7.46321056442269912687E0, 4.86371970985681366614E1,
- 1.96520832956077098242E2, 5.26445194995477358631E2,
- 9.34528527171957607540E2, 1.02755188689515710272E3,
- 5.57535335369399327526E2};
- const double kErfcQCoefficients[] = {
- 1.00000000000000000000E0, 1.32281951154744992508E1,
- 8.67072140885989742329E1, 3.54937778887819891062E2,
- 9.75708501743205489753E2, 1.82390916687909736289E3,
- 2.24633760818710981792E3, 1.65666309194161350182E3,
- 5.57535340817727675546E2};
- const double kErfcRCoefficients[] = {
- 5.64189583547755073984E-1, 1.27536670759978104416E0,
- 5.01905042251180477414E0, 6.16021097993053585195E0,
- 7.40974269950448939160E0, 2.97886665372100240670E0};
- const double kErfcSCoefficients[] = {
- 1.00000000000000000000E0, 2.26052863220117276590E0,
- 9.39603524938001434673E0, 1.20489539808096656605E1,
- 1.70814450747565897222E1, 9.60896809063285878198E0,
- 3.36907645100081516050E0};
-
- // Let z = -x^2.
- Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
- Value z = rewriter.create<mlir::stablehlo::NegOp>(loc, xSq);
-
- // Materialize polynomial approximation for x in [1, 8) as
- // erfc(x) = exp(z) P(|x|) / Q(|x|).
- Value expZ = rewriter.create<mlir::stablehlo::ExpOp>(loc, z);
- Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
- Value polP = materializePolynomialApproximation(
- rewriter, loc, absX, llvm::ArrayRef(kErfcPCoefficients));
- Value expZMulPolyP = rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, polP);
- Value polQ = materializePolynomialApproximation(
- rewriter, loc, absX, llvm::ArrayRef(kErfcQCoefficients));
- Value erfcApprox18 =
- rewriter.create<mlir::stablehlo::DivOp>(loc, expZMulPolyP, polQ);
-
- // Materialize polynomial approximation for x in >= 8 as
- // erfc(x) exp(z) R(|x|) / S(|x|).
- Value polR = materializePolynomialApproximation(
- rewriter, loc, absX, llvm::ArrayRef(kErfcRCoefficients));
- Value expZMulPolyR = rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, polR);
- Value polS = materializePolynomialApproximation(
- rewriter, loc, absX, llvm::ArrayRef(kErfcSCoefficients));
- Value erfcApprox8Inf =
- rewriter.create<mlir::stablehlo::DivOp>(loc, expZMulPolyR, polS);
-
- // Combine polynomial approximations for x >= 1.
- Value eight = getConstantLike(rewriter, loc, 8.0, x);
- Value absXLt8 = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, absX, eight, mlir::stablehlo::ComparisonDirection::LT);
- Value erfcApprox = rewriter.create<mlir::stablehlo::SelectOp>(
- loc, absXLt8, erfcApprox18, erfcApprox8Inf);
-
- // Clamp to prevent overflow and materialize approximation for large x as
- // erfc(x) = 0.
- Value zLtNegMaxlog = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, z, getConstantLike(rewriter, loc, -kMaxlog, x),
- mlir::stablehlo::ComparisonDirection::LT);
- Value zero = getConstantLike(rewriter, loc, 0.0, x);
- Value erfcApproxClamped = rewriter.create<mlir::stablehlo::SelectOp>(
- loc, zLtNegMaxlog, zero, erfcApprox);
-
- // Derive approximation for x <= -1 as
- // erfc(x) = 2 - erfc(-x).
- // Reuse previously materialized approximations all of which take |x| as their
- // argument.
- Value xLtZero = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, zero, mlir::stablehlo::ComparisonDirection::LT);
- Value two = getConstantLike(rewriter, loc, 2.0, x);
- Value twoSubErfcApproxClamped =
- rewriter.create<mlir::stablehlo::SubtractOp>(loc, two, erfcApproxClamped);
- return rewriter.create<mlir::stablehlo::SelectOp>(
- loc, xLtZero, twoSubErfcApproxClamped, erfcApproxClamped);
-}
-
-// Precondition is |x| <= 1. Use erfc approximation, otherwise.
-// This implementation is based on Cephes.
-static Value materializeErfApproximationF64ForMagnituteLeOne(
- ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
- "expect f64 element type");
- const double kErfTCoefficients[] = {
- 9.60497373987051638749E0, 9.00260197203842689217E1,
- 2.23200534594684319226E3, 7.00332514112805075473E3,
- 5.55923013010394962768E4};
- const double kErfUCoefficients[] = {
- 1.00000000000000000000E0, 3.35617141647503099647E1,
- 5.21357949780152679795E2, 4.59432382970980127987E3,
- 2.26290000613890934246E4, 4.92673942608635921086E4};
-
- // Materialize polynomial approximation for |x| <= 1 as
- // erf(x) = x T(x^2) / U(x^2).
- Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
- Value polyT = materializePolynomialApproximation(
- rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients));
- Value xMulPolyT = rewriter.create<mlir::stablehlo::MulOp>(loc, x, polyT);
- Value polyU = materializePolynomialApproximation(
- rewriter, loc, xSq, llvm::ArrayRef(kErfUCoefficients));
- return rewriter.create<mlir::stablehlo::DivOp>(loc, xMulPolyT, polyU);
-}
-
-// This implementation is based on Cephes.
-static Value materializeErfApproximationF64(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
- "expect f64 element type");
-
- // Rely on erf approximation for |x| < 1
- // erf(x) = erf_approx(x)
- Value erfApprox =
- materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
-
- // Rely on erfc approximation for |x| >= 1 and materialize erf as
- // erf(x) = 1 - erfc_approx(x)
- Value one = getConstantLike(rewriter, loc, 1.0, x);
- Value erfcApprox =
- materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
- Value erfcBasedApprox =
- rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfcApprox);
-
- // Materialize approximation selection based on argument.
- Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
- Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
- return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne, erfApprox,
- erfcBasedApprox);
-}
-
-static Value
-materializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF64() &&
- "expect f64 element type");
-
- // Rely on erfc approximation for |x| >= 1
- // erfc(x) = erfc_approx(x)
- Value erfcApprox =
- materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
-
- // Rely on erf approximation for |x| < 1 and materialize erfc as
- // erfc(x) = 1 - erf_approx(x)
- Value one = getConstantLike(rewriter, loc, 1.0, x);
- Value erfApprox =
- materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
- Value erfBasedApprox =
- rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfApprox);
-
- // Materialize approximation selection based on argument.
- Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
- Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
- return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne,
- erfBasedApprox, erfcApprox);
-}
-
-// Precondition is |x| >= 1. Use erf approximation, otherwise.
-//
-// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
-// argument and derive the final approximation for all |x| >= 1.
-// This implementation is based on Cephes.
-static Value materializeErfcApproximationF32ForMagnitudeGeOne(
- ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
- "expect f32 element type");
- const double kMaxlog = 88.72283905206835;
- const float kErfcPCoefficients[] = {
- +2.326819970068386E-2f, -1.387039388740657E-1f, +3.687424674597105E-1f,
- -5.824733027278666E-1f, +6.210004621745983E-1f, -4.944515323274145E-1f,
- +3.404879937665872E-1f, -2.741127028184656E-1f, +5.638259427386472E-1f,
- };
- const float kErfcRCoefficients[] = {
- -1.047766399936249E+1f, +1.297719955372516E+1f, -7.495518717768503E+0f,
- +2.921019019210786E+0f, -1.015265279202700E+0f, +4.218463358204948E-1f,
- -2.820767439740514E-1f, +5.641895067754075E-1f,
- };
-
- // Let z = -x^2.
- Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
- Value z = rewriter.create<mlir::stablehlo::NegOp>(loc, xSq);
-
- // Materialize polynomial approximation for x >= 1 as
- // erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2)
- // erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2
- Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
- Value one = getConstantLike(rewriter, loc, 1.0, x);
- Value reciprocalXSq = rewriter.create<mlir::stablehlo::DivOp>(loc, one, xSq);
- Value expZ = rewriter.create<mlir::stablehlo::ExpOp>(loc, z);
- Value oneDivAbsX = rewriter.create<mlir::stablehlo::DivOp>(loc, one, absX);
- Value expZMulOneDivAbsX =
- rewriter.create<mlir::stablehlo::MulOp>(loc, expZ, oneDivAbsX);
- Value two = getConstantLike(rewriter, loc, 2.0, x);
- Value absXLtTwo = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, absX, two, mlir::stablehlo::ComparisonDirection::LT);
- Value polP = materializePolynomialApproximation(
- rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcPCoefficients));
- Value polR = materializePolynomialApproximation(
- rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcRCoefficients));
- Value poly =
- rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtTwo, polP, polR);
- Value erfcApprox =
- rewriter.create<mlir::stablehlo::MulOp>(loc, expZMulOneDivAbsX, poly);
-
- // Clamp to prevent overflow and materialize approximation for large x as
- // erfc(x) = 0.
- Value zLtNeqMaxlog = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, z, getConstantLike(rewriter, loc, -kMaxlog, x),
- mlir::stablehlo::ComparisonDirection::LT);
- Value zero = getConstantLike(rewriter, loc, 0.0, x);
- Value erfcApproxClamped = rewriter.create<mlir::stablehlo::SelectOp>(
- loc, zLtNeqMaxlog, zero, erfcApprox);
-
- // Derive approximation for x <= -1 as
- // erfc(x) = 2 - erfc(-x).
- // Reuse previously materialized approximations all of which take |x| as their
- // argument.
- Value xLtZero = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, zero, mlir::stablehlo::ComparisonDirection::LT);
- Value twoSubErfcApprox =
- rewriter.create<mlir::stablehlo::SubtractOp>(loc, two, erfcApproxClamped);
- return rewriter.create<mlir::stablehlo::SelectOp>(
- loc, xLtZero, twoSubErfcApprox, erfcApproxClamped);
-}
-
-// Precondition is |x| <= 1. Use erfc approximation, otherwise.
-// This implementation is based on Cephes.
-static Value materializeErfApproximationF32ForMagnitudeLeOne(
- ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
- "expect f32 element type");
- const float kErfTCoefficients[] = {
- +7.853861353153693E-5f, -8.010193625184903E-4f, +5.188327685732524E-3f,
- -2.685381193529856E-2f, +1.128358514861418E-1f, -3.761262582423300E-1f,
- +1.128379165726710E+0f,
- };
-
- // Materialize polynomial approximation for |x| <= 1 as
- // erf(x) = x T(x^2).
- Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
- Value polyT = materializePolynomialApproximation(
- rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients));
- return rewriter.create<mlir::stablehlo::MulOp>(loc, x, polyT);
-}
-
-// This is the same approximation as used in Eigen.
-static Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
- "expect f32 element type");
- const float kAlpha[] = {
- -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
- -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
- -1.60960333262415e-02f,
- };
- const float kBeta[] = {
- -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
- -7.37332916720468e-03f, -1.42647390514189e-02f,
- };
-
- // Clamp argument between -4 and 4.
- Value lb = getConstantLike(rewriter, loc, -4.0, x);
- Value ub = getConstantLike(rewriter, loc, 4.0, x);
- x = rewriter.create<mlir::stablehlo::ClampOp>(loc, x.getType(), lb, x, ub);
- Value xSq = rewriter.create<mlir::stablehlo::MulOp>(loc, x, x);
-
- // Materialize polynomial approximation for x in [-4, 4] as
- // erf(x) = x * Alpha(x^2) / Beta(x^2).
- Value alphaPoly = materializePolynomialApproximation(rewriter, loc, xSq,
- llvm::ArrayRef(kAlpha));
- Value betaPoly = materializePolynomialApproximation(rewriter, loc, xSq,
- llvm::ArrayRef(kBeta));
- Value xMulAlphaPoly =
- rewriter.create<mlir::stablehlo::MulOp>(loc, x, alphaPoly);
- Value erf =
- rewriter.create<mlir::stablehlo::DivOp>(loc, xMulAlphaPoly, betaPoly);
- Value lbErf = getConstantLike(rewriter, loc, -1.0, x);
- Value ubErf = getConstantLike(rewriter, loc, 1.0, x);
- return rewriter.create<mlir::stablehlo::ClampOp>(loc, erf.getType(), lbErf,
- erf, ubErf);
-}
-
-static Value
-materializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args) {
- Value x = args.front();
- assert(cast<ShapedType>(x.getType()).getElementType().isF32() &&
- "expect f32 element type");
-
- // Rely on erfc approximation for |x| >= 1
- // erfc(x) = erfc_approx(x)
- Value erfcApprox =
- materializeErfcApproximationF32ForMagnitudeGeOne(rewriter, loc, x);
-
- // Rely on erf approximation for |x| < 1 and materialize erfc as
- // erfc(x) = 1 - erf_approx(x)
- Value one = getConstantLike(rewriter, loc, 1.0, x);
- Value erfApprox =
- materializeErfApproximationF32ForMagnitudeLeOne(rewriter, loc, x);
- Value erfBasedApprox =
- rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, erfApprox);
-
- // Materialize approximation selection based on argument.
- Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
- Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
- return rewriter.create<mlir::stablehlo::SelectOp>(loc, absXLtOne,
- erfBasedApprox, erfcApprox);
-}
-
-struct ConvertErfOp final : OpConversionPattern<mlir::chlo::ErfOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::ErfOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- Value x = adaptor.getOperand();
- Type ty = cast<ShapedType>(x.getType()).getElementType();
-
- // For now, we support only f64, f32, f16 and bf16.
- if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) {
- return failure();
- }
-
- if (ty.isF64()) {
- rewriter.replaceOp(op, materializeErfApproximationF64(rewriter, loc, x));
- return success();
- }
-
- rewriter.replaceOp(
- op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
- rewriter.getF32Type(),
- &materializeErfApproximationF32));
- return success();
- }
-};
-
-struct ConvertErfcOp final : OpConversionPattern<mlir::chlo::ErfcOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::ErfcOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- Value x = adaptor.getOperand();
- Type ty = cast<ShapedType>(x.getType()).getElementType();
-
- // For now, we support only f64, f32, f16 and bf16.
- if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16()) {
- return failure();
- }
-
- if (ty.isF64()) {
- rewriter.replaceOp(op, materializeErfcApproximationF64(rewriter, loc, x));
- return success();
- }
-
- rewriter.replaceOp(
- op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
- rewriter.getF32Type(),
- &materializeErfcApproximationF32));
- return success();
- }
-};
-
-static Value erfInv32(ConversionPatternRewriter &b, Location loc,
- ValueRange args) {
- constexpr int kDegree = 9;
- constexpr std::array<float, 9> wLessThan5Constants = {
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- constexpr std::array<float, 9> wGreaterThan5Constants = {
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
-
- Value x = args[0];
- // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
- // log(1+arg) when arg is close to zero. For more details, see
- // https://en.cppreference.com/w/cpp/numeric/math/log1p
- Value minusXSquared = b.create<mlir::stablehlo::MulOp>(
- loc, x, b.create<mlir::stablehlo::NegOp>(loc, x));
- Value w = b.create<mlir::stablehlo::NegOp>(
- loc, b.create<mlir::stablehlo::Log1pOp>(loc, minusXSquared));
-
- Value lt = b.create<mlir::stablehlo::CompareOp>(
- loc, w, getConstantLike(b, loc, 5.0, x),
- mlir::stablehlo::ComparisonDirection::LT);
- auto coefficient = [&](int i) {
- return b.create<mlir::stablehlo::SelectOp>(
- loc, lt, getConstantLike(b, loc, wLessThan5Constants[i], x),
- getConstantLike(b, loc, wGreaterThan5Constants[i], x));
- };
- w = b.create<mlir::stablehlo::SelectOp>(
- loc, lt,
- b.create<mlir::stablehlo::SubtractOp>(loc, w,
- getConstantLike(b, loc, 2.5, x)),
- b.create<mlir::stablehlo::SubtractOp>(
- loc, b.create<mlir::stablehlo::SqrtOp>(loc, w),
- getConstantLike(b, loc, 3.0, x)));
- Value p = coefficient(0);
- for (int i = 1; i < kDegree; ++i) {
- p = b.create<mlir::stablehlo::AddOp>(
- loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w));
- }
-
- // Result modulo edge cases.
- Value result = b.create<mlir::stablehlo::MulOp>(loc, p, x);
-
- // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
- // indeterminate, and can give nan or -/+inf.)
- return b.create<mlir::stablehlo::SelectOp>(
- loc,
- b.create<mlir::stablehlo::CompareOp>(
- loc, b.create<mlir::stablehlo::AbsOp>(loc, x),
- getConstantLike(b, loc, 1, x),
- mlir::stablehlo::ComparisonDirection::EQ),
- b.create<mlir::stablehlo::MulOp>(
- loc, x, getConstantLikeInfValue(b, loc, x, false)),
- result);
-}
-
-static Value erfInv64(ConversionPatternRewriter &b, Location loc,
- ValueRange args) {
- constexpr std::array<double, 23> wLessThan625Constants = {
- -3.6444120640178196996e-21, -1.685059138182016589e-19,
- 1.2858480715256400167e-18, 1.115787767802518096e-17,
- -1.333171662854620906e-16, 2.0972767875968561637e-17,
- 6.6376381343583238325e-15, -4.0545662729752068639e-14,
- -8.1519341976054721522e-14, 2.6335093153082322977e-12,
- -1.2975133253453532498e-11, -5.4154120542946279317e-11,
- 1.051212273321532285e-09, -4.1126339803469836976e-09,
- -2.9070369957882005086e-08, 4.2347877827932403518e-07,
- -1.3654692000834678645e-06, -1.3882523362786468719e-05,
- 0.0001867342080340571352, -0.00074070253416626697512,
- -0.0060336708714301490533, 0.24015818242558961693,
- 1.6536545626831027356};
- constexpr std::array<double, 19> wLessThan16Constants = {
- 2.2137376921775787049e-09, 9.0756561938885390979e-08,
- -2.7517406297064545428e-07, 1.8239629214389227755e-08,
- 1.5027403968909827627e-06, -4.013867526981545969e-06,
- 2.9234449089955446044e-06, 1.2475304481671778723e-05,
- -4.7318229009055733981e-05, 6.8284851459573175448e-05,
- 2.4031110387097893999e-05, -0.0003550375203628474796,
- 0.00095328937973738049703, -0.0016882755560235047313,
- 0.0024914420961078508066, -0.0037512085075692412107,
- 0.005370914553590063617, 1.0052589676941592334,
- 3.0838856104922207635,
- };
- constexpr std::array<double, 17> wGreaterThan16Constants = {
- -2.7109920616438573243e-11, -2.5556418169965252055e-10,
- 1.5076572693500548083e-09, -3.7894654401267369937e-09,
- 7.6157012080783393804e-09, -1.4960026627149240478e-08,
- 2.9147953450901080826e-08, -6.7711997758452339498e-08,
- 2.2900482228026654717e-07, -9.9298272942317002539e-07,
- 4.5260625972231537039e-06, -1.9681778105531670567e-05,
- 7.5995277030017761139e-05, -0.00021503011930044477347,
- -0.00013871931833623122026, 1.0103004648645343977,
- 4.8499064014085844221,
- };
-
- Value x = args[0];
- // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
- // log(1+arg) when arg is close to zero. For more details, see
- // https://en.cppreference.com/w/cpp/numeric/math/log1p
- Value minusXSquared = b.create<mlir::stablehlo::MulOp>(
- loc, x, b.create<mlir::stablehlo::NegOp>(loc, x));
- Value w = b.create<mlir::stablehlo::NegOp>(
- loc, b.create<mlir::stablehlo::Log1pOp>(loc, minusXSquared));
-
- Value lt625 = b.create<mlir::stablehlo::CompareOp>(
- loc, w, getConstantLike(b, loc, 6.25, x),
- mlir::stablehlo::ComparisonDirection::LT);
- Value lt16 = b.create<mlir::stablehlo::CompareOp>(
- loc, w, getConstantLike(b, loc, 16, x),
- mlir::stablehlo::ComparisonDirection::LT);
-
- auto coefficient = [&](int i) {
- Value c = getConstantLike(b, loc, wLessThan625Constants[i], x);
- if (i < 19) {
- c = b.create<mlir::stablehlo::SelectOp>(
- loc, lt625, c, getConstantLike(b, loc, wLessThan16Constants[i], x));
- }
- if (i < 17) {
- c = b.create<mlir::stablehlo::SelectOp>(
- loc, lt16, c, getConstantLike(b, loc, wGreaterThan16Constants[i], x));
- }
- return c;
- };
-
- Value sqrtW = b.create<mlir::stablehlo::SqrtOp>(loc, w);
- Value wMinus3125 = b.create<mlir::stablehlo::SubtractOp>(
- loc, w, getConstantLike(b, loc, 3.125, x));
- Value select2 = b.create<mlir::stablehlo::SelectOp>(
- loc, lt16, getConstantLike(b, loc, 3.25, w),
- getConstantLike(b, loc, 5.0, w));
- Value select2Result =
- b.create<mlir::stablehlo::SubtractOp>(loc, sqrtW, select2);
- w = b.create<mlir::stablehlo::SelectOp>(loc, lt625, wMinus3125,
- select2Result);
-
- Value p = coefficient(0);
- for (int i = 1; i < 17; ++i) {
- p = b.create<mlir::stablehlo::AddOp>(
- loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w));
- }
- for (int i = 17; i < 19; ++i) {
- p = b.create<mlir::stablehlo::SelectOp>(
- loc, lt16,
- b.create<mlir::stablehlo::AddOp>(
- loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)),
- p);
- }
- for (int i = 19; i < 23; ++i) {
- p = b.create<mlir::stablehlo::SelectOp>(
- loc, lt625,
- b.create<mlir::stablehlo::AddOp>(
- loc, coefficient(i), b.create<mlir::stablehlo::MulOp>(loc, p, w)),
- p);
- }
-
- // Result modulo edge cases.
- Value result = b.create<mlir::stablehlo::MulOp>(loc, p, x);
-
- // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
- // indeterminate, and can give nan or -/+inf.)
- return b.create<mlir::stablehlo::SelectOp>(
- loc,
- b.create<mlir::stablehlo::CompareOp>(
- loc, b.create<mlir::stablehlo::AbsOp>(loc, x),
- getConstantLike(b, loc, 1, x),
- mlir::stablehlo::ComparisonDirection::EQ),
- b.create<mlir::stablehlo::MulOp>(
- loc, x, getConstantLikeInfValue(b, loc, x, false)),
- result);
-}
-
-struct ConvertErfInvOp final : OpConversionPattern<mlir::chlo::ErfInvOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::ErfInvOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- if (op.getResult().getType().getElementType().isF64()) {
- rewriter.replaceOp(op, erfInv64(rewriter, loc, adaptor.getOperands()));
- return success();
- }
- FloatType minPrecisionTy = rewriter.getF32Type();
- rewriter.replaceOp(op, materializeWithUpcast(rewriter, loc,
- adaptor.getOperands(),
- minPrecisionTy, &erfInv32));
- return success();
- }
-};
-
-// Coefficients for the Lanczos approximation of the gamma function. The
-// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
-// and kLanczosCoefficients.size() + 1). The coefficients below correspond to
-// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
-// [7, 9] seemed to be the least sensitive to the quality of the log function.
-// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
-// for a particularly inaccurate log function.
-constexpr double kLanczosGamma = 7; // aka g
-constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
-constexpr std::array<double, 8> kLanczosCoefficients = {
- 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
- 771.3234287776530788486528258894, -176.61502916214059906584551354,
- 12.507343278686904814458936853, -0.13857109526572011689554707,
- 9.984369578019570859563e-6, 1.50563273514931155834e-7};
-
-// Compute the Lgamma function using Lanczos' approximation from "A Precision
-// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
-// series B. Vol. 1:
-// lgamma(z + 1) = (log(2) + log(pi)) / 2
-// + (z + 1/2) * log(t(z))
-// - t(z) + log(a(z))
-// with t(z) = z + kLanczosGamma + 1/2
-// a(z) = kBaseLanczosCoeff
-// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
-static Value materializeLgamma(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args) {
- // If the input is less than 0.5 use Euler's reflection formula.
- // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
- // Let z be
- // z = -x if x < 1/2
- // z = x - 1 otheriwse
- Value x = args.front();
- Value half = getConstantLike(rewriter, loc, 0.5, x);
- Value needToReflect = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, half, mlir::stablehlo::ComparisonDirection::LT);
- Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
- Value one = getConstantLike(rewriter, loc, 1, x);
- Value xSubOne = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, one);
- Value z = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, negX,
- xSubOne);
-
- // Materialize
- // a(z) = kBaseLanczosCoeff
- // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
- Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
- for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
- Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
- Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
- Value quotient = rewriter.create<mlir::stablehlo::DivOp>(
- loc, coeff,
- rewriter.create<mlir::stablehlo::AddOp>(loc, z, oneBasedIndex));
- a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, quotient);
- }
-
- // To improve accuracy on platforms with less-precise log implementations,
- // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
- // device.
- // Materialize as
- // log(t) = log(kLanczosGamma + 1/2 + z)
- // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
- Value lanczosPlusHalf =
- getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
- Value t = rewriter.create<mlir::stablehlo::AddOp>(loc, lanczosPlusHalf, z);
- Value logTerm =
- getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
- Value log1pTerm = rewriter.create<mlir::stablehlo::Log1pOp>(
- loc, rewriter.create<mlir::stablehlo::DivOp>(loc, z, lanczosPlusHalf));
- Value logT = rewriter.create<mlir::stablehlo::AddOp>(loc, logTerm, log1pTerm);
-
- // Note that t(z) may be large and we need to be careful not to overflow to
- // infinity in the relevant term
- // r = (z + 1/2) * log(t(z)) - t(z).
- // Therefore, we compute this as
- // r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
- Value tDivLogT = rewriter.create<mlir::stablehlo::DivOp>(loc, t, logT);
- Value sum = rewriter.create<mlir::stablehlo::SubtractOp>(
- loc, rewriter.create<mlir::stablehlo::AddOp>(loc, z, half), tDivLogT);
- Value r = rewriter.create<mlir::stablehlo::MulOp>(loc, sum, logT);
-
- // Compute the final result (modulo reflection) as
- // lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
- Value logA = rewriter.create<mlir::stablehlo::LogOp>(loc, a);
- Value lgamma = rewriter.create<mlir::stablehlo::AddOp>(
- loc,
- rewriter.create<mlir::stablehlo::AddOp>(
- loc,
- getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
- r),
- logA);
-
- // Compute the reflected value for x < 0.5 as
- // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
- //
- // The abs is needed because lgamma is the log of the absolute value of the
- // gamma function.
- //
- // We have to be careful when computing the final term above. gamma(x) goes
- // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
- // term. The slope is large, so precision is particularly important.
- //
- // Because abs(sin(pi * x)) has period of 1 we can equivalently use
- // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
- // more numerically accurate: It doesn't overflow to inf like pi * x would and
- // if x is an integer it evaluates to exactly 0 which is important because we
- // then take the log of this value, and log(0) is inf.
- //
- // We don't have a frac(x) primitive in HLO and computing it is tricky, but
- // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
- // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
- //
- // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
- // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
- // [0, 1] is symmetric across the line Y=0.5.
- //
-
- // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
- // pi * abs_frac for values of abs_frac close to 1.
- Value abs = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
- Value absFrac = rewriter.create<mlir::stablehlo::SubtractOp>(
- loc, abs, rewriter.create<mlir::stablehlo::FloorOp>(loc, abs));
- Value reduceAbsFrac = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, half, absFrac, mlir::stablehlo::ComparisonDirection::LT);
- absFrac = rewriter.create<mlir::stablehlo::SelectOp>(
- loc, reduceAbsFrac,
- rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, absFrac), absFrac);
-
- // Materialize reflection.
- Value reflectionDenom = rewriter.create<mlir::stablehlo::LogOp>(
- loc,
- rewriter.create<mlir::stablehlo::SineOp>(
- loc, rewriter.create<mlir::stablehlo::MulOp>(
- loc, getConstantLike(rewriter, loc, M_PI, x), absFrac)));
- Value lgammaReflection = rewriter.create<mlir::stablehlo::SubtractOp>(
- loc,
- rewriter.create<mlir::stablehlo::SubtractOp>(
- loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
- reflectionDenom),
- lgamma);
-
- // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
- // then it "wins" and the result is +/-inf.
- Value finiteReflectionDenom =
- rewriter.create<mlir::stablehlo::IsFiniteOp>(loc, reflectionDenom);
- Value negReflectionDenom =
- rewriter.create<mlir::stablehlo::NegOp>(loc, reflectionDenom);
- lgammaReflection = rewriter.create<mlir::stablehlo::SelectOp>(
- loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom);
-
- // Select whether or not to rely on the reflection.
- lgamma = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect,
- lgammaReflection, lgamma);
-
- // Materialize +/-inf behavior as
- // lgamma(+/-inf) = +inf.
- Value xIsInf = rewriter.create<chlo::IsInfOp>(loc, x);
- return rewriter.create<mlir::stablehlo::SelectOp>(
- loc, xIsInf,
- getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), lgamma);
-}
-
-// Express `cosh` as
-// cosh(x) = (e^x + e^-x) / 2
-// = e^(x + log(1/2)) + e^(-x + log(1/2))
-//
-// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
-//
-// This incorrectly overflows to inf for two f32 input values, namely
-// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
-// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
-// we deem this acceptable.
-static Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange operands) {
- mlir::chlo::CoshOp::Adaptor transformed(operands);
- Value x = transformed.getOperand();
-
- Value logOneHalf = rewriter.create<mlir::stablehlo::LogOp>(
- loc, getConstantLike(rewriter, loc, 0.5, x));
- Value expAdd = rewriter.create<mlir::stablehlo::ExpOp>(
- loc, rewriter.create<mlir::stablehlo::AddOp>(loc, x, logOneHalf));
- Value expSub = rewriter.create<mlir::stablehlo::ExpOp>(
- loc, rewriter.create<mlir::stablehlo::SubtractOp>(loc, logOneHalf, x));
- return rewriter.create<mlir::stablehlo::AddOp>(loc, expAdd, expSub);
-}
-
-struct ConvertCoshOp final : OpConversionPattern<mlir::chlo::CoshOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::CoshOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOp(
- op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
- rewriter.getF32Type(),
- &materializeCoshApproximation));
- return success();
- }
-};
-
-// Compute the Digamma function using Lanczos' approximation from "A Precision
-// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
-// series B. Vol. 1:
-// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
-// with t(z) = z + kLanczosGamma + 1/2
-// a(z) = kBaseLanczosCoeff
-// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
-// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
-static Value materializeDigamma(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args) {
- // If the input is less than 0.5 use Euler's reflection formula.
- // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
- // Let z be
- // z = -x if x < 1/2
- // z = x - 1 otheriwse
- Value x = args.front();
- Value half = getConstantLike(rewriter, loc, 0.5, x);
- Value needToReflect = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, half, mlir::stablehlo::ComparisonDirection::LT);
- Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
- Value one = getConstantLike(rewriter, loc, 1, x);
- Value xSubOne = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, one);
- Value z = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, negX,
- xSubOne);
-
- // Materialize
- // a(z) = kBaseLanczosCoeff
- // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
- // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
- Value zero = getConstantLike(rewriter, loc, 0.0, x);
- Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
- Value aPrime = zero;
- for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
- Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
- Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
- Value zTerm =
- rewriter.create<mlir::stablehlo::AddOp>(loc, z, oneBasedIndex);
- aPrime = rewriter.create<mlir::stablehlo::SubtractOp>(
- loc, aPrime,
- rewriter.create<mlir::stablehlo::DivOp>(
- loc, coeff,
- rewriter.create<mlir::stablehlo::MulOp>(loc, zTerm, zTerm)));
- a = rewriter.create<mlir::stablehlo::AddOp>(
- loc, a, rewriter.create<mlir::stablehlo::DivOp>(loc, coeff, zTerm));
- }
-
- // To improve accuracy on platforms with less-precise log implementations,
- // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
- // device.
- // Materialize as
- // log(t) = log(kLanczosGamma + 1/2 + z)
- // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
- Value lanczosPlusHalf =
- getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
- Value t = rewriter.create<mlir::stablehlo::AddOp>(loc, lanczosPlusHalf, z);
- Value logTerm =
- getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
- Value log1pTerm = rewriter.create<mlir::stablehlo::Log1pOp>(
- loc, rewriter.create<mlir::stablehlo::DivOp>(loc, z, lanczosPlusHalf));
- Value logT = rewriter.create<mlir::stablehlo::AddOp>(loc, logTerm, log1pTerm);
-
- // Materialize the final result (modulo reflection) as
- // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
- Value aPrimeDivA = rewriter.create<mlir::stablehlo::DivOp>(loc, aPrime, a);
- Value lanczosGammaDivT = rewriter.create<mlir::stablehlo::DivOp>(
- loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
- Value digamma = rewriter.create<mlir::stablehlo::SubtractOp>(
- loc, rewriter.create<mlir::stablehlo::AddOp>(loc, logT, aPrimeDivA),
- lanczosGammaDivT);
-
- // We need to be careful how we compute cot(pi * input) below: For
- // near-integral arguments, pi * input can lose precision.
- //
- // Input is already known to be less than 0.5 (otherwise we don't have to
- // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
- // increase precision of pi * x and the resulting cotangent.
- Value reducedX = rewriter.create<mlir::stablehlo::AddOp>(
- loc, x,
- rewriter.create<mlir::stablehlo::AbsOp>(
- loc, rewriter.create<mlir::stablehlo::FloorOp>(
- loc, rewriter.create<mlir::stablehlo::AddOp>(
- loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
-
- // Materialize reflection for inputs less than 0.5 as
- // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
- // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
- Value pi = getConstantLike(rewriter, loc, M_PI, x);
- Value piMulReducedX =
- rewriter.create<mlir::stablehlo::MulOp>(loc, pi, reducedX);
- Value cos = rewriter.create<mlir::stablehlo::CosineOp>(loc, piMulReducedX);
- Value sin = rewriter.create<mlir::stablehlo::SineOp>(loc, piMulReducedX);
- Value reflection = rewriter.create<mlir::stablehlo::SubtractOp>(
- loc, digamma,
- rewriter.create<mlir::stablehlo::DivOp>(
- loc, rewriter.create<mlir::stablehlo::MulOp>(loc, pi, cos), sin));
-
- // Select whether or not to rely on the reflection.
- digamma = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect,
- reflection, digamma);
-
- // Digamma has poles at negative integers and zero; return nan for those.
- Value isLeZero = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, zero, mlir::stablehlo::ComparisonDirection::LE);
- Value isInt = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x),
- mlir::stablehlo::ComparisonDirection::EQ);
- Value isPole = rewriter.create<mlir::stablehlo::AndOp>(loc, isLeZero, isInt);
- return rewriter.create<mlir::stablehlo::SelectOp>(
- loc, isPole,
- getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
- x),
- digamma);
-}
-
-static Value getConstantLikeSmallestFiniteValue(OpBuilder &b, Location loc,
- Value val) {
- auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
- return getConstantLike(
- b, loc, llvm::APFloat::getSmallest(ty.getFloatSemantics()), val);
-}
-
-static Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
- ValueRange args) {
- // Code should match StableHLO's materializeZeta
- assert(args.size() == 2);
- Value x = args[0];
- Value q = args[1];
- static const std::array<double, 12> kZetaCoeffs{
- -7.1661652561756670113e18,
- 1.8152105401943546773e17,
- -4.5979787224074726105e15,
- 1.1646782814350067249e14,
- -2.950130727918164224e12,
- 7.47242496e10,
- -1.8924375803183791606e9,
- 47900160.0,
- -1209600.0,
- 30240.0,
- -720.0,
- 12.0,
- };
-
- // For speed we'll always use 9 iterations for the initial series estimate,
- // and a 12 term expansion for the Euler-Maclaurin formula.
- Value a = q;
- Value zero = getConstantLike(rewriter, loc, 0.0, a);
- Value negPower = zero;
- Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
- Value initialSum = rewriter.create<mlir::stablehlo::PowOp>(loc, q, negX);
- Value one = getConstantLike(rewriter, loc, 1.0, a);
- for (int i = 0; i < 9; ++i) {
- a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, one);
- negPower = rewriter.create<mlir::stablehlo::PowOp>(loc, a, negX);
- initialSum =
- rewriter.create<mlir::stablehlo::AddOp>(loc, initialSum, negPower);
- }
-
- a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, one);
- negPower = rewriter.create<mlir::stablehlo::PowOp>(loc, a, negX);
- Value oneLikeX = getConstantLike(rewriter, loc, 1.0, x);
- Value xMinusOne =
- rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, oneLikeX);
- Value negPowerMulA =
- rewriter.create<mlir::stablehlo::MulOp>(loc, negPower, a);
- Value negPowerMulADivXMinusOne =
- rewriter.create<mlir::stablehlo::DivOp>(loc, negPowerMulA, xMinusOne);
- Value s = rewriter.create<mlir::stablehlo::AddOp>(loc, initialSum,
- negPowerMulADivXMinusOne);
- Value aInverseSquare = rewriter.create<mlir::stablehlo::DivOp>(
- loc, one, rewriter.create<mlir::stablehlo::MulOp>(loc, a, a));
-
- Value hornerSum = zero;
- Value factor = one;
- // Use Horner's rule for this.
- // Note this differs from Cephes which does a 'naive' polynomial evaluation.
- // Using Horner's rule allows to avoid some NaN's and Infs from happening,
- // resulting in more numerically stable code.
- for (int i = 0; i < 11; ++i) {
- Value factorLhs = rewriter.create<mlir::stablehlo::AddOp>(
- loc, x, getConstantLike(rewriter, loc, 22 - 2 * i, x));
- Value factorRhs = rewriter.create<mlir::stablehlo::AddOp>(
- loc, x, getConstantLike(rewriter, loc, 21 - 2 * i, x));
- factor = rewriter.create<mlir::stablehlo::MulOp>(loc, factorLhs, factorRhs);
- hornerSum = rewriter.create<mlir::stablehlo::MulOp>(
- loc, factor,
- rewriter.create<mlir::stablehlo::MulOp>(
- loc, aInverseSquare,
- rewriter.create<mlir::stablehlo::AddOp>(
- loc, hornerSum,
- getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
- }
- Value zeroPointFiveLikeNegPower =
- getConstantLike(rewriter, loc, .5, negPower);
- Value xDivA = rewriter.create<mlir::stablehlo::DivOp>(loc, x, a);
- s = rewriter.create<mlir::stablehlo::AddOp>(
- loc, s,
- rewriter.create<mlir::stablehlo::MulOp>(
- loc, negPower,
- rewriter.create<mlir::stablehlo::AddOp>(
- loc, zeroPointFiveLikeNegPower,
- rewriter.create<mlir::stablehlo::MulOp>(
- loc, xDivA,
- rewriter.create<mlir::stablehlo::AddOp>(
- loc,
- getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], a),
- hornerSum)))));
-
- // Use the initial zeta sum without the correction term coming
- // from Euler-Maclaurin if it is accurate enough.
- Value absNegPower = rewriter.create<mlir::stablehlo::AbsOp>(loc, negPower);
- Value absInitialSum =
- rewriter.create<mlir::stablehlo::AbsOp>(loc, initialSum);
- Value output = rewriter.create<mlir::stablehlo::SelectOp>(
- loc,
- rewriter.create<mlir::stablehlo::CompareOp>(
- loc, absNegPower,
- rewriter.create<mlir::stablehlo::MulOp>(
- loc, absInitialSum,
- getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
- mlir::stablehlo::ComparisonDirection::LT),
- initialSum, s);
-
- // Function is not defined for x < 1.
- Value nan = getConstantLike(rewriter, loc,
- std::numeric_limits<double>::quiet_NaN(), x);
- output = rewriter.create<mlir::stablehlo::SelectOp>(
- loc,
- rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, oneLikeX, mlir::stablehlo::ComparisonDirection::LT),
- nan, output);
-
- // For q <= 0, x must be an integer.
- Value qLeZero = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, q, zero, mlir::stablehlo::ComparisonDirection::LE);
- Value xNotInt = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x),
- mlir::stablehlo::ComparisonDirection::NE);
- Value xDomainError =
- rewriter.create<mlir::stablehlo::AndOp>(loc, qLeZero, xNotInt);
- output = rewriter.create<mlir::stablehlo::SelectOp>(loc, xDomainError, nan,
- output);
-
- // For all integer q <= 0, zeta has a pole. The limit is only defined as
- // +inf if x is and even integer.
- Value inf = getConstantLike(rewriter, loc,
- std::numeric_limits<double>::infinity(), x);
- Value qIsInt = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, q, rewriter.create<mlir::stablehlo::FloorOp>(loc, q),
- mlir::stablehlo::ComparisonDirection::EQ);
- Value atPole = rewriter.create<mlir::stablehlo::AndOp>(loc, qLeZero, qIsInt);
- Value two = getConstantLike(rewriter, loc, 2.0, x);
- Value xIsInt = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, rewriter.create<mlir::stablehlo::FloorOp>(loc, x),
- mlir::stablehlo::ComparisonDirection::EQ);
- Value xIsEven = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, rewriter.create<mlir::stablehlo::RemOp>(loc, x, two), zero,
- mlir::stablehlo::ComparisonDirection::EQ);
- Value xIsEvenInt =
- rewriter.create<mlir::stablehlo::AndOp>(loc, xIsInt, xIsEven);
- output = rewriter.create<mlir::stablehlo::SelectOp>(
- loc, atPole,
- rewriter.create<mlir::stablehlo::SelectOp>(loc, xIsEvenInt, inf, nan),
- output);
-
- // For x = 1, this is the harmonic series and diverges.
- output = rewriter.create<mlir::stablehlo::SelectOp>(
- loc,
- rewriter.create<mlir::stablehlo::CompareOp>(
- loc, x, one, mlir::stablehlo::ComparisonDirection::EQ),
- inf, output);
-
- return output;
-}
-
-static Value materializePolygamma(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange args) {
- mlir::chlo::PolygammaOp::Adaptor transformed(args);
- Value n = transformed.getN();
- Value x = transformed.getX();
-
- // Handle integer n > 0.
- Value one = getConstantLike(rewriter, loc, 1.0, x);
- Value two = getConstantLike(rewriter, loc, 2.0, x);
- Value sign = rewriter.create<mlir::stablehlo::SubtractOp>(
- loc,
- rewriter.create<mlir::stablehlo::MulOp>(
- loc, two, rewriter.create<mlir::stablehlo::RemOp>(loc, n, two)),
- one);
- Value nPlusOne = rewriter.create<mlir::stablehlo::AddOp>(loc, n, one);
- Value expLgammaNp1 = rewriter.create<mlir::stablehlo::ExpOp>(
- loc, rewriter.create<chlo::LgammaOp>(loc, nPlusOne));
- Value zeta = rewriter.create<chlo::ZetaOp>(loc, nPlusOne, x);
- Value result = rewriter.create<mlir::stablehlo::MulOp>(
- loc, rewriter.create<mlir::stablehlo::MulOp>(loc, sign, expLgammaNp1),
- zeta);
-
- // Handle n = 0.
- Value zero = getConstantLike(rewriter, loc, 0.0, x);
- Value nEqZero = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, n, zero, mlir::stablehlo::ComparisonDirection::EQ);
- result = rewriter.create<mlir::stablehlo::SelectOp>(
- loc, nEqZero, rewriter.create<chlo::DigammaOp>(loc, x), result);
-
- // Check that n is a natural number. Return nan, otherwise.
- Value nonInt = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, n, rewriter.create<mlir::stablehlo::FloorOp>(loc, n),
- mlir::stablehlo::ComparisonDirection::NE);
- Value negative = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, n, zero, mlir::stablehlo::ComparisonDirection::LT);
- Value nonNatural =
- rewriter.create<mlir::stablehlo::OrOp>(loc, nonInt, negative);
- return rewriter.create<mlir::stablehlo::SelectOp>(
- loc, nonNatural,
- getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
- x),
- result);
-}
-
-struct ConvertLgammaOp final : OpConversionPattern<mlir::chlo::LgammaOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::LgammaOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- FloatType minPrecisionTy = rewriter.getF32Type();
- rewriter.replaceOp(
- op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
- minPrecisionTy, &materializeLgamma));
- return success();
- }
-};
-
-struct ConvertDigammaOp final : OpConversionPattern<mlir::chlo::DigammaOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::DigammaOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- FloatType minPrecisionTy = rewriter.getF32Type();
- rewriter.replaceOp(
- op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
- minPrecisionTy, &materializeDigamma));
- return success();
- }
-};
-
-static Value materializeNextAfter(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange operands) {
- mlir::chlo::NextAfterOp::Adaptor transformed(operands);
- Value x = transformed.getX();
- Value y = transformed.getY();
- auto resultTy = cast<ShapedType>(x.getType());
- auto bitwidth = resultTy.getElementType().getIntOrFloatBitWidth();
- mlir::ImplicitLocOpBuilder b(loc, rewriter);
- Type intTy = resultTy.clone(b.getIntegerType(bitwidth));
- auto xAsInt = b.create<mlir::stablehlo::BitcastConvertOp>(intTy, x);
- auto yAsInt = b.create<mlir::stablehlo::BitcastConvertOp>(intTy, y);
-
- // The result is NaN if either "x" or "y" are NaN.
- auto xIsNan = b.create<mlir::stablehlo::CompareOp>(
- x, x, mlir::stablehlo::ComparisonDirection::NE);
- auto yIsNan = b.create<mlir::stablehlo::CompareOp>(
- y, y, mlir::stablehlo::ComparisonDirection::NE);
- auto nanInput = b.create<mlir::stablehlo::OrOp>(xIsNan, yIsNan);
- auto resultForNan = getConstantLike(
- rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
- auto resultForNanAsInt =
- b.create<mlir::stablehlo::BitcastConvertOp>(intTy, resultForNan);
-
- // The sign bit is the MSB.
- const int64_t signBit = int64_t{1} << (bitwidth - 1);
- // Discard the sign bit to make the result non-negative.
- Value signMask = getConstantLike(rewriter, loc, signBit, xAsInt);
- Value negatedSignMask = getConstantLike(rewriter, loc, ~signBit, xAsInt);
- auto xAbs = b.create<mlir::stablehlo::AndOp>(xAsInt, negatedSignMask);
- auto yAbs = b.create<mlir::stablehlo::AndOp>(yAsInt, negatedSignMask);
-
- // When both "x" and "y" are equal, the result is "y".
- auto xAndYAreEqual = b.create<mlir::stablehlo::CompareOp>(
- x, y, mlir::stablehlo::ComparisonDirection::EQ);
- auto resultForEqual = yAsInt;
-
- // When both "x" and "y" are 0, the result is "y". This is a separate case
- // from above because "x" and "y" might have a different sign.
- Value zero = getConstantLike(rewriter, loc, 0, xAsInt);
- auto xIsZero = b.create<mlir::stablehlo::CompareOp>(
- xAbs, zero, mlir::stablehlo::ComparisonDirection::EQ);
- auto yIsZero = b.create<mlir::stablehlo::CompareOp>(
- yAbs, zero, mlir::stablehlo::ComparisonDirection::EQ);
- auto resultForBothZero = yAsInt;
-
- auto xSign = b.create<mlir::stablehlo::AndOp>(xAsInt, signMask);
- auto ySign = b.create<mlir::stablehlo::AndOp>(yAsInt, signMask);
-
- // If from == 0 && to != 0, we need to return the smallest subnormal number
- // signed like "to".
- Value one = getConstantLike(rewriter, loc, 1, xAsInt);
- auto resultForXZeroYNonZero = b.create<mlir::stablehlo::OrOp>(ySign, one);
-
- // If the sign of "x" and "y" disagree:
- // - we need to make the magnitude of "from" smaller so that it is closer to
- // zero.
- //
- // Otherwise the signs agree:
- // - "x" with a magnitude larger than "y" means we need to make the magnitude
- // smaller.
- // - "x" with a magnitude smaller than "y" means we need to make the magnitude
- // larger.
- auto signsDisagree = b.create<mlir::stablehlo::CompareOp>(
- xSign, ySign, mlir::stablehlo::ComparisonDirection::NE);
- auto xMagnitudeLargerThanY = b.create<mlir::stablehlo::CompareOp>(
- xAbs, yAbs, mlir::stablehlo::ComparisonDirection::GT);
- auto resultHasSmallerMagnitude =
- b.create<mlir::stablehlo::OrOp>(xMagnitudeLargerThanY, signsDisagree);
- auto minusOne = getConstantLike(rewriter, loc, -1, xAsInt);
- auto magnitudeAdjustment = b.create<mlir::stablehlo::SelectOp>(
- resultHasSmallerMagnitude, minusOne, one);
- Value result = b.create<mlir::stablehlo::AddOp>(xAsInt, magnitudeAdjustment);
- // Handle from == +-0.
- result = b.create<mlir::stablehlo::SelectOp>(
- xIsZero,
- b.create<mlir::stablehlo::SelectOp>(yIsZero, resultForBothZero,
- resultForXZeroYNonZero),
- result);
- // Handle from == to.
- result = b.create<mlir::stablehlo::SelectOp>(xAndYAreEqual, resultForEqual,
- result);
- // Handle isnan(x) || isnan(y).
- result =
- b.create<mlir::stablehlo::SelectOp>(nanInput, resultForNanAsInt, result);
-
- // Cast back to the original type.
- return b.create<mlir::stablehlo::BitcastConvertOp>(resultTy, result);
-}
-
-struct ConvertNextAfterOp final : OpConversionPattern<mlir::chlo::NextAfterOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::NextAfterOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOp(
- op, materializeNextAfter(rewriter, op.getLoc(), adaptor.getOperands()));
- return success();
- }
-};
-
-struct ConvertPolygammaOp final : OpConversionPattern<mlir::chlo::PolygammaOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::PolygammaOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- FloatType minPrecisionTy = rewriter.getF32Type();
- rewriter.replaceOp(
- op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
- minPrecisionTy, materializePolygamma));
- return success();
- }
-};
-
-// Sinh(x) = (e^x - e^-x) / 2
-// = e^(x + log(1/2)) - e^(-x + log(1/2)).
-//
-// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
-// inf.
-//
-// This incorrectly overflows to +/-inf for two f32 input values, namely
-// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
-// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
-// we deem this acceptable.
-static Value
-materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange operands) {
- mlir::chlo::SinhOp::Adaptor transformed(operands);
- Value x = transformed.getOperand();
-
- Value logOneHalf = rewriter.create<mlir::stablehlo::LogOp>(
- loc, getConstantLike(rewriter, loc, 0.5, x));
- Value expAdd = rewriter.create<mlir::stablehlo::ExpOp>(
- loc, rewriter.create<mlir::stablehlo::AddOp>(loc, x, logOneHalf));
- Value expSub = rewriter.create<mlir::stablehlo::ExpOp>(
- loc, rewriter.create<mlir::stablehlo::SubtractOp>(loc, logOneHalf, x));
- return rewriter.create<mlir::stablehlo::SubtractOp>(loc, expAdd, expSub);
-}
-
-// Express `sinh` as
-// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
-// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
-static Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
- Location loc, ValueRange operands) {
- Value largeSinhResult =
- materializeSinhApproximationForLargeX(rewriter, loc, operands);
-
- mlir::chlo::SinhOp::Adaptor transformed(operands);
- Value x = transformed.getOperand();
-
- // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
- // 0.
- // Rewrite this to avoid that. We use expm1(x) because that preserves the
- // first order term of the taylor series of e^x.
- // (e^(x) - e^(-x)) / 2. =
- // (e^(x) - 1 + 1 - e^(-x)) / 2.
- // (expm1(x) + (e^(x) - 1) / e^x) / 2.
- // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
- Value expm1 = rewriter.create<mlir::stablehlo::Expm1Op>(loc, x);
- Value one = getConstantLike(rewriter, loc, 1.0, x);
- Value oneHalf = getConstantLike(rewriter, loc, 0.5, x);
- Value expm1PlusOne = rewriter.create<mlir::stablehlo::AddOp>(loc, expm1, one);
- Value ratio =
- rewriter.create<mlir::stablehlo::DivOp>(loc, expm1, expm1PlusOne);
- Value sum = rewriter.create<mlir::stablehlo::AddOp>(loc, expm1, ratio);
- Value smallSinhResult =
- rewriter.create<mlir::stablehlo::MulOp>(loc, oneHalf, sum);
-
- Value absX = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
- Value absXLtOne = rewriter.create<mlir::stablehlo::CompareOp>(
- loc, absX, one, mlir::stablehlo::ComparisonDirection::LT);
- return rewriter.create<mlir::stablehlo::SelectOp>(
- loc, absXLtOne, smallSinhResult, largeSinhResult);
-}
-
-struct ConvertSinhOp final : OpConversionPattern<mlir::chlo::SinhOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::SinhOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Value x = adaptor.getOperand();
- if (isa<ComplexType>(cast<ShapedType>(x.getType()).getElementType())) {
- rewriter.replaceOp(op, materializeSinhApproximationForLargeX(
- rewriter, op.getLoc(), adaptor.getOperands()));
- return success();
- }
- rewriter.replaceOp(
- op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
- rewriter.getF32Type(),
- &materializeSinhApproximation));
- return success();
- }
-};
-
-// Converts chlo.top_k to HLO iota, sort, and slice ops.
-//
-// chlo.top_k sorts along last dimension of the input tensor and then returns
-// the top K components' values and indices. This is translated into a few
-// ops in HLO: first generating an integer sequence for the indices,
-// then sort both the original input tensor and the indices together, and
-// at last slice out the top K components.
-//
-// For example, for the following IR:
-//
-// %0:2 = "chlo.top_k"(%input, k=8): tensor<16x16xf32> ->
-// (tensor<16x8xf32>, tensor<16x8xi32>)
-//
-// We will get:
-//
-// %1 = "hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
-// %2 = "hlo.sort"(%input, %1) ({
-// ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
-// %arg3: tensor<i32>, %arg4: tensor<i32>):
-// %7 = "hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
-// "hlo.return"(%7) : (tensor<i1>) -> ()
-// }) {dimension = 1 : i64, is_stable = true} : ...
-// %3 = "hlo.get_tuple_element"(%2) {index = 0 : i32} : ...
-// %4 = "hlo.get_tuple_element"(%2) {index = 1 : i32} : ...
-// %5 = "hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
-// start_indices dense<0> : tensor<2xi64>,
-// strides = dense<1> : tensor<2xi64>} :
-// (tensor<16x16xf32>) -> tensor<16x8xf32>
-// %6 = "hlo.slice"(%4) ...
-//
-// TODO(b/284078162): Decide what to do with this pattern given that we now
-// have mlir::stablehlo::TopKOp. No action needed for now given that
-// mlir::stablehlo::TopKOp is currently categorized as
-// `hasPrivateFeaturesNotInStablehlo`.
-struct ConvertTopKOp final : OpConversionPattern<mlir::chlo::TopKOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::TopKOp op, OpAdaptor /*adaptor*/,
- ConversionPatternRewriter &rewriter) const override {
- auto operandType = dyn_cast<RankedTensorType>(op.getOperand().getType());
- if (!operandType)
- return failure();
- int64_t operandRank = operandType.getRank();
- int64_t lastDimIndex = operandRank - 1;
- int64_t lastDimSize = operandType.getDimSize(lastDimIndex);
- int64_t lastDimResultSize =
- mlir::hlo::isDynamicDimSize(lastDimSize)
- ? static_cast<int64_t>(op.getK())
- : std::min(static_cast<int64_t>(op.getK()), lastDimSize);
- int64_t isDynamic = !operandType.hasStaticShape();
- auto i32Type = rewriter.getIntegerType(32);
- Value opShapeValue, resultShapeValue;
- if (isDynamic) {
- SmallVector<Value> sizesI32x1;
- for (auto i = 0; i < operandType.getRank(); ++i) {
- auto sizeI32 = rewriter.create<mlir::stablehlo::GetDimensionSizeOp>(
- op.getLoc(), op.getOperand(), i);
- auto sizeI32x1 = rewriter.create<mlir::stablehlo::ReshapeOp>(
- op.getLoc(), RankedTensorType::get({1}, i32Type), sizeI32);
- sizesI32x1.push_back(sizeI32x1);
- }
- opShapeValue = rewriter.create<mlir::stablehlo::ConcatenateOp>(
- op.getLoc(), sizesI32x1,
- /*dimension=*/0);
- auto lastDimI32 = rewriter.create<mlir::stablehlo::ConstantOp>(
- op.getLoc(),
- rewriter.getI32IntegerAttr(static_cast<int32_t>(lastDimResultSize)));
- auto lastDimI32x1 = rewriter.create<mlir::stablehlo::ReshapeOp>(
- op.getLoc(), RankedTensorType::get({1}, i32Type), lastDimI32);
- sizesI32x1.back() = lastDimI32x1;
- resultShapeValue = rewriter.create<mlir::stablehlo::ConcatenateOp>(
- op.getLoc(), sizesI32x1,
- /*dimension=*/0);
- }
-
- // Create an Iota op for indices.
- Type iotaType = RankedTensorType::get(operandType.getShape(), i32Type);
- Value iotaOp;
- if (isDynamic) {
- iotaOp = rewriter.create<mlir::stablehlo::DynamicIotaOp>(
- op.getLoc(), iotaType, opShapeValue,
- rewriter.getI64IntegerAttr(lastDimIndex));
- } else {
- iotaOp = rewriter.create<mlir::stablehlo::IotaOp>(
- op.getLoc(), iotaType, rewriter.getI64IntegerAttr(lastDimIndex));
- }
-
- // Create the sort op. It takes two inputs, one for the original input, the
- // other for the indices. Use TOTALORDER comparison type instead of the
- // default comparison if the element type is of type float.
- Type elementType = operandType.getElementType();
- mlir::stablehlo::SortOp sortOp =
- createSortOp(&rewriter, op.getLoc(), {op.getOperand(), iotaOp},
- {elementType, i32Type}, lastDimIndex,
- /*isStable=*/true,
- /*direction=*/mlir::stablehlo::ComparisonDirection::GT);
-
- // Get the sorted input and index tuple element.
- Value tupleFirstElement = sortOp.getResult(0);
- Value tupleSecondElement = sortOp.getResult(1);
-
- SmallVector<int64_t> beginIndices(operandRank, 0);
- auto endIndices = llvm::to_vector(operandType.getShape());
- endIndices.back() = lastDimResultSize;
- SmallVector<int64_t> strides(operandRank, 1);
-
- // Get the slice for the top K elements.
- auto indicesTy = RankedTensorType::get(operandRank, rewriter.getI64Type());
- Value values, indices;
- if (isDynamic) {
- Value startIndices = rewriter.create<mlir::stablehlo::ConstantOp>(
- op.getLoc(), DenseIntElementsAttr::get(indicesTy, beginIndices));
- Value lastIndices = rewriter.create<mlir::stablehlo::ConvertOp>(
- op.getLoc(), resultShapeValue, rewriter.getI64Type());
- Value stridesOp = rewriter.create<mlir::stablehlo::ConstantOp>(
- op.getLoc(), DenseIntElementsAttr::get(indicesTy, strides));
-
- SmallVector<int64_t> resultShape =
- llvm::to_vector(operandType.getShape());
- resultShape.back() = lastDimResultSize;
- RankedTensorType resultType = RankedTensorType::get(
- resultShape, elementType, operandType.getEncoding());
- RankedTensorType indexResultType =
- RankedTensorType::get(resultShape, i32Type);
-
- values = rewriter.create<mlir::stablehlo::RealDynamicSliceOp>(
- op.getLoc(), resultType, tupleFirstElement, startIndices, lastIndices,
- stridesOp);
- indices = rewriter.create<mlir::stablehlo::RealDynamicSliceOp>(
- op.getLoc(), indexResultType, tupleSecondElement, startIndices,
- lastIndices, stridesOp);
- } else {
- values = rewriter.create<mlir::stablehlo::SliceOp>(
- op.getLoc(), tupleFirstElement,
- rewriter.getDenseI64ArrayAttr(beginIndices),
- rewriter.getDenseI64ArrayAttr(endIndices),
- rewriter.getDenseI64ArrayAttr(strides));
- indices = rewriter.create<mlir::stablehlo::SliceOp>(
- op.getLoc(), tupleSecondElement,
- rewriter.getDenseI64ArrayAttr(beginIndices),
- rewriter.getDenseI64ArrayAttr(endIndices),
- rewriter.getDenseI64ArrayAttr(strides));
- }
-
- rewriter.replaceOp(op, {values, indices});
- return success();
- }
-};
-
-struct ConvertZetaOp final : OpConversionPattern<mlir::chlo::ZetaOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(mlir::chlo::ZetaOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- FloatType minPrecisionTy = rewriter.getF32Type();
- rewriter.replaceOp(
- op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
- minPrecisionTy, &materializeZeta));
- return success();
- }
-};
-
-//===----------------------------------------------------------------------===//
-// Pass Definition.
-//===----------------------------------------------------------------------===//
-
-struct LegalizeChlo final : impl::LegalizeChloBase<LegalizeChlo> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<mlir::scf::SCFDialect, mlir::shape::ShapeDialect,
- mlir::stablehlo::StablehloDialect,
- mlir::tensor::TensorDialect>();
- }
-
- void runOnOperation() override {
- MLIRContext *ctx = &getContext();
- {
- ConversionTarget conversionTarget(getContext());
- RewritePatternSet conversionPatterns(ctx);
- conversionTarget.addIllegalDialect<chlo::ChloDialect>();
- conversionTarget.addLegalDialect<
- mlir::stablehlo::StablehloDialect, mlir::arith::ArithDialect,
- mlir::shape::ShapeDialect, mlir::scf::SCFDialect,
- mlir::tensor::TensorDialect>();
-
- populateLegalizeChloPatterns(ctx, &conversionPatterns);
- if (failed(applyPartialConversion(getOperation(), conversionTarget,
- std::move(conversionPatterns)))) {
- return signalPassFailure();
- }
- }
-
- {
- // Add canonicalization patterns to simplify produced ops from other
- // dialects.
- RewritePatternSet patterns(ctx);
- populateCanonicalizationPatterns(ctx, &patterns);
- mlir::shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx);
- mlir::shape::ShapeOfOp::getCanonicalizationPatterns(patterns, ctx);
- mlir::shape::BroadcastOp::getCanonicalizationPatterns(patterns, ctx);
- mlir::shape::CstrBroadcastableOp::getCanonicalizationPatterns(patterns,
- ctx);
- mlir::tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
- return signalPassFailure();
- }
- }
- }
-};
-} // namespace
-
-namespace {
-#include "compiler/plugins/input/StableHLO/Conversion/CHLODecompositionPatterns.h.inc"
-} // end anonymous namespace
-
-namespace {
-static void populateBroadcastingPatterns(MLIRContext *context,
- RewritePatternSet *patterns) {
- // Instantiate conversion templates for conforming binary elementwise ops
- // that do not have different dtypes between operands and results and do
- // not have special attributes that need to be preserved.
- populateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
- context, patterns, 10);
- populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
- context, patterns, 5);
- patterns->add<ConvertConstantLikeOp, ConvertSelectOp>(context);
-}
-
-static void populateDecompositionPatterns(MLIRContext *context,
- RewritePatternSet *patterns) {
- populateWithGenerated(*patterns);
- patterns->add<ConvertConstantOp, ConvertBesselI1eOp, ConvertCoshOp,
- ConvertDigammaOp, ConvertErfOp, ConvertErfcOp, ConvertErfInvOp,
- ConvertLgammaOp, ConvertNextAfterOp, ConvertPolygammaOp,
- ConvertSinhOp, ConvertTopKOp, ConvertZetaOp>(context);
-}
-} // namespace
-
-void populateLegalizeChloPatterns(MLIRContext *context,
- RewritePatternSet *patterns) {
- populateBroadcastingPatterns(context, patterns);
- populateDecompositionPatterns(context, patterns);
-}
-} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
index d5bad85..7887e2d 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/Passes.cpp
@@ -85,7 +85,8 @@
stablehlo::createLegalizeShapeComputations());
passManager.addNestedPass<func::FuncOp>(
stablehlo::createConvertStableHloToLinalgExt());
- passManager.addNestedPass<func::FuncOp>(stablehlo::createLegalizeChlo());
+ passManager.addNestedPass<func::FuncOp>(
+ mlir::stablehlo::createChloLegalizeToStablehloPass());
passManager.addPass(createConvertStableHloToIreeInputDialects());
passManager.addPass(createReconcileUnrealizedCastsPass());
diff --git a/compiler/plugins/input/StableHLO/Conversion/Passes.td b/compiler/plugins/input/StableHLO/Conversion/Passes.td
index 7852112..915a487 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Passes.td
+++ b/compiler/plugins/input/StableHLO/Conversion/Passes.td
@@ -37,11 +37,6 @@
let summary = "Legalizes from StableHLO control flow to SCF control flow";
}
-def LegalizeChlo :
- InterfacePass<"iree-stablehlo-legalize-chlo", "mlir::FunctionOpInterface"> {
- let summary = "Legalizes from CHLO ops flow to StableHLO and Shape ops";
-}
-
def LegalizeStableHLOCustomCalls :
InterfacePass<"iree-stablehlo-legalize-custom-calls", "mlir::FunctionOpInterface"> {
let summary = "Legalizes specialized custom calls to decomposed implementations";
diff --git a/compiler/plugins/input/StableHLO/Conversion/Rewriters.h b/compiler/plugins/input/StableHLO/Conversion/Rewriters.h
index 1a328ff..956a3a4 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Rewriters.h
+++ b/compiler/plugins/input/StableHLO/Conversion/Rewriters.h
@@ -15,11 +15,6 @@
// General StableHLO/CHLO lowering patterns.
//===----------------------------------------------------------------------===//
-/// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and
-/// Shape ops.
-void populateLegalizeChloPatterns(MLIRContext *context,
- RewritePatternSet *patterns);
-
/// Collection of rewrite patterns for lowering of StableHLO ops to SCF control
/// flow ops.
void populateLegalizeControlFlowPatterns(MLIRContext *context,
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_decomposition.mlir b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_decomposition.mlir
index d2f9353..851d002 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_decomposition.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_decomposition.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-stablehlo-legalize-chlo))" \
-// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: iree-opt --chlo-legalize-to-stablehlo \
+// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: func.func @asin_bf16(
func.func @asin_bf16(%arg : tensor<bf16>) -> tensor<bf16> {
@@ -508,6 +508,8 @@
// -----
+// CHECK-LABEL: @complex_tan
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1xf32>, %[[ARG1:.+]]: tensor<1xf32>
func.func @complex_tan(%arg0 : tensor<1xf32>, %arg1 : tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) {
%0 = stablehlo.complex %arg0, %arg1 : tensor<1xcomplex<f32>>
%1 = chlo.tan %0 : tensor<1xcomplex<f32>> -> tensor<1xcomplex<f32>>
@@ -516,18 +518,21 @@
func.return %2, %3 : tensor<1xf32>, tensor<1xf32>
}
-// CHECK-LABEL: @complex_tan
-// CHECK-SAME: %[[ARG0:.+]]: tensor<1xf32>, %[[ARG1:.+]]: tensor<1xf32>
-// CHECK: %[[ONE:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<1xf32>
-// CHECK: %[[SINE:.+]] = stablehlo.sine %[[ARG0]]
-// CHECK: %[[COS:.+]] = stablehlo.cosine %[[ARG0]]
-// CHECK: %[[TAN:.+]] = stablehlo.divide %[[SINE]], %[[COS]]
-// CHECK: %[[TANH:.+]] = stablehlo.tanh %[[ARG1]]
-// CHECK: %[[NUM:.+]] = stablehlo.complex %[[TAN]], %[[TANH]]
-// CHECK: %[[MUL:.+]] = stablehlo.multiply %[[TAN]], %[[TANH]]
-// CHECK: %[[NEG:.+]] = stablehlo.negate %[[MUL]]
-// CHECK: %[[DEN:.+]] = stablehlo.complex %[[ONE]], %[[NEG]]
-// CHECK: %[[RES:.+]] = stablehlo.divide %[[NUM]], %[[DEN]]
-// CHECK: %[[REAL:.+]] = stablehlo.real %[[RES]]
-// CHECK: %[[IMAG:.+]] = stablehlo.imag %[[RES]]
-// CHECK: return %[[REAL]], %[[IMAG]]
+
+// -----
+
+// CHECK-LABEL: func.func @acos_complex_f32
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<complex<f32>>) -> tensor<complex<f32>>
+func.func @acos_complex_f32(%arg : tensor<complex<f32>>) -> tensor<complex<f32>> {
+ %result = "chlo.acos"(%arg) : (tensor<complex<f32>>) -> tensor<complex<f32>>
+ func.return %result : tensor<complex<f32>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @acos_complex_f64_dynamic
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>>
+func.func @acos_complex_f64_dynamic(%arg : tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>> {
+ %result = "chlo.acos"(%arg) : (tensor<?xcomplex<f64>>) -> tensor<?xcomplex<f64>>
+ func.return %result : tensor<?xcomplex<f64>>
+}
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_no_broadcast.mlir b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_no_broadcast.mlir
index 3a6fbd1..2cc1506 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_no_broadcast.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_no_broadcast.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-stablehlo-legalize-chlo),cse)" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(chlo-legalize-to-stablehlo), cse)" \
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a
@@ -17,17 +17,10 @@
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
func.func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
- // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
- // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
- // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
- // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-DAG: %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1]
- // CHECK-DAG: %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1]
- // CHECK-NEXT: %[[RESULT:.+]] = stablehlo.add %[[ARG0_B]], %[[ARG1_B]]
- // CHECK-NEXT: shape.assuming_yield %[[RESULT]]
- // CHECK-NEXT: }
- // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
+
+ // CHECK:stablehlo.dynamic_broadcast_in_dim
+ // CHECK:stablehlo.add
+
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}
@@ -38,17 +31,9 @@
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
func.func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
- // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
- // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
- // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
- // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-DAG: %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1] : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
- // CHECK-DAG: %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
- // CHECK-NEXT: %[[RESULT:.+]] = stablehlo.complex %[[ARG0_B]], %[[ARG1_B]] : tensor<?x?xcomplex<f32>>
- // CHECK-NEXT: shape.assuming_yield %[[RESULT]]
- // CHECK-NEXT: }
- // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
+
+ // CHECK:stablehlo.dynamic_broadcast_in_dim
+ // CHECK:stablehlo.complex
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
func.return %0 : tensor<?x?xcomplex<f32>>
}
@@ -59,17 +44,8 @@
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
func.func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
- // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
- // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
- // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
- // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
- // CHECK: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-DAG: %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1] : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
- // CHECK-DAG: %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
- // CHECK: %[[RESULT:.+]] = stablehlo.compare EQ, %[[ARG0_B]], %[[ARG1_B]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
- // CHECK: shape.assuming_yield %[[RESULT]]
- // CHECK-NEXT: }
- // CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
+ // CHECK:stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.compare
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo<comparison_direction EQ>} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
func.return %0 : tensor<?x?xi1>
}
@@ -78,77 +54,76 @@
// CHECK-LABEL: func @selectv2
func.func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
- // CHECK-NEXT: stablehlo.select %arg0, %arg1, %arg2
+ // CHECK: tensor.cast
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
func.return %0: tensor<2xi32>
}
// CHECK-LABEL: func @selectv2_pred_scalar
func.func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
- // CHECK-NEXT: stablehlo.select %arg0, %arg1, %arg2
+ // CHECK: stablehlo.select
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
func.return %0: tensor<2xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_then
func.func @selectv2_broadcast_then(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
- // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1, 2] : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
- // CHECK-NEXT: stablehlo.select %arg0, %[[BROADCAST]], %arg2
+ // CHECK: tensor.cast
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
func.return %0: tensor<2x8x8xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_else
func.func @selectv2_broadcast_else(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
- // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg2, dims = [1, 2] : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
- // CHECK-NEXT: stablehlo.select %arg0, %arg1, %[[BROADCAST]]
+ // CHECK : stablehlo.select
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
func.return %0: tensor<2x8x8xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_pred
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<1xi1>, %[[ARG1:.*]]: tensor<2x8x8xi32>, %[[ARG2:.*]]: tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
func.func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
- // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [2] : (tensor<1xi1>) -> tensor<2x8x8xi1>
- // CHECK-NEXT: stablehlo.select %[[BROADCAST]], %arg1, %arg2
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
func.return %0: tensor<2x8x8xi32>
+
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
+
}
// CHECK-LABEL: func @selectv2_broadcast_tensor_pred
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xi1>, %[[ARG1:.*]]: tensor<2x3xf16>, %[[ARG2:.*]]: tensor<2x3xf16>)
func.func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
- // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1>
- // CHECK-NEXT: stablehlo.select %[[BROADCAST]], %arg1, %arg2
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
func.return %0: tensor<2x3xf16>
+
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
+
}
// CHECK-LABEL: func @selectv2_broadcast_all
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x1x1xi1>, %[[ARG1:.*]]: tensor<1x8x1xi32>, %[[ARG2:.*]]: tensor<1x1x8xi32>)
func.func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
- // CHECK-DAG: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1>
- // CHECK-DAG: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [0, 1, 2] : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32>
- // CHECK-DAG: %[[BROADCAST_2:.*]] = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2] : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
- // CHECK: stablehlo.select %[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]]
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
func.return %0: tensor<8x8x8xi32>
+
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
}
// CHECK-LABEL: func @selectv2_dynamic_ranked
func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
- // CHECK-DAG: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor<1xindex>
- // CHECK-DAG: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor<3xindex>
- // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor<3xindex>
- // CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<3xindex>, tensor<1xindex>, tensor<3xindex>
- // CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) {
- // CHECK-NEXT: %[[BCST:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
- // CHECK-NEXT: %[[BCST0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, %[[BCST]], dims = [2] : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1>
- // CHECK-NEXT: %[[BCST1:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[BCST]], dims = [0, 1, 2] : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
- // CHECK-NEXT: %[[BCST2:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[BCST]], dims = [0, 1, 2] : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
- // CHECK-NEXT: %[[SELECT:.*]] = stablehlo.select %[[BCST0]], %[[BCST1]], %[[BCST2]] : tensor<2x?x8xi1>, tensor<2x?x8xi32>
- // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32>
- // CHECK-NEXT: }
- // CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32>
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
func.return %0: tensor<2x?x8xi32>
+
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
+
}
// -----
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_with_broadcast.mlir b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_with_broadcast.mlir
index 6ee69ca..df5970c 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_with_broadcast.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/legalize_chlo_with_broadcast.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-stablehlo-legalize-chlo),cse)" \
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(chlo-legalize-to-stablehlo), cse)" \
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a
@@ -16,17 +16,8 @@
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
func.func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
- // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
- // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
- // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
- // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-DAG: %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1]
- // CHECK-DAG: %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1]
- // CHECK-NEXT: %[[RESULT:.+]] = stablehlo.add %[[ARG0_B]], %[[ARG1_B]]
- // CHECK-NEXT: shape.assuming_yield %[[RESULT]]
- // CHECK-NEXT: }
- // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.add
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}
@@ -36,18 +27,9 @@
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
func.func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
- // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
- // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
- // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
- // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-DAG: %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1] : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
- // CHECK-DAG: %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
- // CHECK-NEXT: %[[RESULT:.+]] = stablehlo.complex %[[ARG0_B]], %[[ARG1_B]] : tensor<?x?xcomplex<f32>>
- // CHECK-NEXT: shape.assuming_yield %[[RESULT]]
- // CHECK-NEXT: }
- // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
- %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.complex
+ %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
func.return %0 : tensor<?x?xcomplex<f32>>
}
@@ -56,17 +38,8 @@
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
func.func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
- // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
- // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
- // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
- // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
- // CHECK: %[[RESULT_EXTENTS:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
- // CHECK-DAG: %[[ARG0_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG0]], %[[RESULT_EXTENTS]], dims = [1] : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
- // CHECK-DAG: %[[ARG1_B:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG1]], %[[RESULT_EXTENTS]], dims = [0, 1] : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
- // CHECK: %[[RESULT:.+]] = stablehlo.compare EQ, %[[ARG0_B]], %[[ARG1_B]] : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
- // CHECK: shape.assuming_yield %[[RESULT]]
- // CHECK-NEXT: }
- // CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
+ // CHECK:stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.compare
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo<comparison_direction EQ>} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
func.return %0 : tensor<?x?xi1>
}
@@ -75,75 +48,74 @@
// CHECK-LABEL: func @selectv2
func.func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
- // CHECK-NEXT: stablehlo.select %arg0, %arg1, %arg2
+ // CHECK: tensor.cast
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
+
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
func.return %0: tensor<2xi32>
}
// CHECK-LABEL: func @selectv2_pred_scalar
func.func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
- // CHECK-NEXT: stablehlo.select %arg0, %arg1, %arg2
+ // CHECK: stablehlo.select
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
func.return %0: tensor<2xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_then
func.func @selectv2_broadcast_then(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
- // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1, 2] : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
- // CHECK-NEXT: stablehlo.select %arg0, %[[BROADCAST]], %arg2
+
+ // CHECK: tensor.cast
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
+
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
func.return %0: tensor<2x8x8xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_else
func.func @selectv2_broadcast_else(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
- // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg2, dims = [1, 2] : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
- // CHECK-NEXT: stablehlo.select %arg0, %arg1, %[[BROADCAST]]
+ // CHECK: stablehlo.select
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
func.return %0: tensor<2x8x8xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_pred
func.func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
- // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [2] : (tensor<1xi1>) -> tensor<2x8x8xi1>
- // CHECK-NEXT: stablehlo.select %[[BROADCAST]], %arg1, %arg2
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
func.return %0: tensor<2x8x8xi32>
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
}
// CHECK-LABEL: func @selectv2_broadcast_tensor_pred
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xi1>, %[[ARG1:.*]]: tensor<2x3xf16>, %[[ARG2:.*]]: tensor<2x3xf16>)
func.func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
- // CHECK-NEXT: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1>
- // CHECK-NEXT: stablehlo.select %[[BROADCAST]], %arg1, %arg2
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
func.return %0: tensor<2x3xf16>
+
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
+
}
// CHECK-LABEL: func @selectv2_broadcast_all
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x1x1xi1>, %[[ARG1:.*]]: tensor<1x8x1xi32>, %[[ARG2:.*]]: tensor<1x1x8xi32>)
func.func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
- // CHECK-DAG: %[[BROADCAST_0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1>
- // CHECK-DAG: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [0, 1, 2] : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32>
- // CHECK-DAG: %[[BROADCAST_2:.*]] = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2] : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
- // CHECK: stablehlo.select %[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]]
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
func.return %0: tensor<8x8x8xi32>
+
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
+
}
// CHECK-LABEL: func @selectv2_dynamic_ranked
func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
- // CHECK-DAG: %[[SHAPE0:.*]] = shape.const_shape [1] : tensor<1xindex>
- // CHECK-DAG: %[[SHAPE2:.*]] = shape.const_shape [2, 8, 8] : tensor<3xindex>
- // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<2x?x8xi32> -> tensor<3xindex>
- // CHECK-NEXT: %[[CSTR:.*]] = shape.cstr_broadcastable %[[SHAPE1]], %[[SHAPE0]], %[[SHAPE2]] : tensor<3xindex>, tensor<1xindex>, tensor<3xindex>
- // CHECK-NEXT: %[[ASSUME:.*]] = shape.assuming %[[CSTR]] -> (tensor<2x?x8xi32>) {
- // CHECK-NEXT: %[[BCST:.*]] = shape.broadcast %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
- // CHECK-NEXT: %[[BCST0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, %[[BCST]], dims = [2] : (tensor<1xi1>, tensor<3xindex>) -> tensor<2x?x8xi1>
- // CHECK-NEXT: %[[BCST1:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[BCST]], dims = [0, 1, 2] : (tensor<2x?x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
- // CHECK-NEXT: %[[BCST2:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[BCST]], dims = [0, 1, 2] : (tensor<2x8x8xi32>, tensor<3xindex>) -> tensor<2x?x8xi32>
- // CHECK-NEXT: %[[SELECT:.*]] = stablehlo.select %[[BCST0]], %[[BCST1]], %[[BCST2]] : tensor<2x?x8xi1>, tensor<2x?x8xi32>
- // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<2x?x8xi32>
- // CHECK-NEXT: }
- // CHECK-NEXT: return %[[ASSUME]] : tensor<2x?x8xi32>
+ // CHECK: stablehlo.dynamic_broadcast_in_dim
+ // CHECK: stablehlo.select
+
%0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
func.return %0: tensor<2x?x8xi32>
}
diff --git a/compiler/plugins/input/StableHLO/PluginRegistration.cpp b/compiler/plugins/input/StableHLO/PluginRegistration.cpp
index df9f1a4..db4dead 100644
--- a/compiler/plugins/input/StableHLO/PluginRegistration.cpp
+++ b/compiler/plugins/input/StableHLO/PluginRegistration.cpp
@@ -13,6 +13,7 @@
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/VhloOps.h"
+#include "stablehlo/transforms/Passes.h"
namespace mlir::iree_compiler::stablehlo {
@@ -61,6 +62,7 @@
static void registerPasses() {
// TODO(scotttodd): register other StableHLO passes?
registerStableHLOConversionPasses();
+ mlir::stablehlo::registerChloLegalizeToStablehloPass();
}
void onRegisterDialects(DialectRegistry ®istry) override {