| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #ifndef IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_MAP_STABLEHLO_TO_SCALAR_OP_H |
| #define IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_MAP_STABLEHLO_TO_SCALAR_OP_H |
| |
| #include <optional> |
| |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Complex/IR/Complex.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/ImplicitLocOpBuilder.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "stablehlo/dialect/StablehloOps.h" |
| |
| namespace mlir { |
| namespace stablehlo { |
| namespace impl { |
| |
| // A struct to map StableHloBinaryOpTy type to the corresponding floating-point |
| // and integer scalar operation types. |
| template <typename StableHloBinaryOpTy> |
| struct StableHloToScalarOp { |
| using FOp = void; |
| using IOp = void; |
| using UOp = void; |
| using COp = void; |
| }; |
| |
| template <> |
| struct StableHloToScalarOp<stablehlo::AddOp> { |
| using FOp = ::mlir::arith::AddFOp; |
| using IOp = ::mlir::arith::AddIOp; |
| using UOp = ::mlir::arith::AddIOp; |
| using COp = ::mlir::complex::AddOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::AndOp> { |
| using IOp = ::mlir::arith::AndIOp; |
| using UOp = ::mlir::arith::AndIOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::CbrtOp> { |
| using FOp = ::mlir::math::CbrtOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::CompareOp> { |
| using FOp = ::mlir::arith::CmpFOp; |
| using IOp = ::mlir::arith::CmpIOp; |
| using UOp = ::mlir::arith::CmpIOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::CeilOp> { |
| using FOp = ::mlir::math::CeilOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::ClzOp> { |
| using IOp = ::mlir::math::CountLeadingZerosOp; |
| using UOp = ::mlir::math::CountLeadingZerosOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::CosineOp> { |
| using FOp = ::mlir::math::CosOp; |
| using COp = ::mlir::complex::CosOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::ExpOp> { |
| using FOp = ::mlir::math::ExpOp; |
| using COp = ::mlir::complex::ExpOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::Expm1Op> { |
| using FOp = ::mlir::math::ExpM1Op; |
| using COp = ::mlir::complex::Expm1Op; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::FloorOp> { |
| using FOp = ::mlir::math::FloorOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::LogOp> { |
| using FOp = ::mlir::math::LogOp; |
| using COp = ::mlir::complex::LogOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::Log1pOp> { |
| using FOp = ::mlir::math::Log1pOp; |
| using COp = ::mlir::complex::Log1pOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::MulOp> { |
| using FOp = ::mlir::arith::MulFOp; |
| using IOp = ::mlir::arith::MulIOp; |
| using UOp = ::mlir::arith::MulIOp; |
| using COp = ::mlir::complex::MulOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::OrOp> { |
| using IOp = ::mlir::arith::OrIOp; |
| using UOp = ::mlir::arith::OrIOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::PopulationCountOp> { |
| using IOp = ::mlir::math::CtPopOp; |
| using UOp = ::mlir::math::CtPopOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::RsqrtOp> { |
| using FOp = ::mlir::math::RsqrtOp; |
| using COp = ::mlir::complex::RsqrtOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::RoundNearestEvenOp> { |
| using FOp = ::mlir::math::RoundEvenOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::RoundOp> { |
| using FOp = ::mlir::math::RoundOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::SubtractOp> { |
| using FOp = ::mlir::arith::SubFOp; |
| using IOp = ::mlir::arith::SubIOp; |
| using UOp = ::mlir::arith::SubIOp; |
| using COp = ::mlir::complex::SubOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::SqrtOp> { |
| using FOp = ::mlir::math::SqrtOp; |
| using COp = ::mlir::complex::SqrtOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::SineOp> { |
| using FOp = ::mlir::math::SinOp; |
| using COp = ::mlir::complex::SinOp; |
| }; |
| // FIXME(Jakub) |
| /* |
| template <> |
| struct StableHloToScalarOp<stablehlo::TanOp> { |
| using FOp = ::mlir::math::TanOp; |
| using COp = ::mlir::complex::TanOp; |
| }; |
| */ |
| template <> |
| struct StableHloToScalarOp<stablehlo::Atan2Op> { |
| using FOp = ::mlir::math::Atan2Op; |
| using COp = ::mlir::complex::Atan2Op; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::TanhOp> { |
| using FOp = ::mlir::math::TanhOp; |
| using COp = ::mlir::complex::TanhOp; |
| }; |
| template <> |
| struct StableHloToScalarOp<stablehlo::XorOp> { |
| using IOp = ::mlir::arith::XOrIOp; |
| using UOp = ::mlir::arith::XOrIOp; |
| }; |
| |
| // Alias for the map from StableHLO binary op type to STD floating-point op |
| // type. |
| template <typename StableHloOp> |
| using ScalarFOp = typename StableHloToScalarOp<StableHloOp>::FOp; |
| // Alias for the map from StableHLO binary op type to STD signed integer op |
| // type. |
| template <typename StableHloOp> |
| using ScalarIOp = typename StableHloToScalarOp<StableHloOp>::IOp; |
| // Alias for the map from StableHLO binary op type to STD unsigned integer op |
| // type. |
| template <typename StableHloOp> |
| using ScalarUOp = typename StableHloToScalarOp<StableHloOp>::UOp; |
| // Alias for the map from StableHLO binary op type to STD complex op type. |
| template <typename StableHloOp> |
| using ScalarCOp = typename StableHloToScalarOp<StableHloOp>::COp; |
| |
| template <typename... Args> |
| struct MapStableHloOpToScalarOpImpl { |
| Value operator()(Location /*loc*/, ArrayRef<Type> /*ResultTypes*/, |
| ArrayRef<Type> /*argTypes*/, ValueRange /*args*/, |
| OpBuilder * /*b*/) { |
| return nullptr; |
| } |
| }; |
| |
| template <typename StdScalarOp> |
| struct MapStableHloOpToScalarOpImpl<StdScalarOp> { |
| Value operator()(Location loc, ArrayRef<Type> resultTypes, |
| ArrayRef<Type> /*argTypes*/, ValueRange args, OpBuilder *b) { |
| return b->template create<StdScalarOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| }; |
| |
| template <typename SupportedType, typename StdScalarOp, typename... Args> |
| struct MapStableHloOpToScalarOpImpl<SupportedType, StdScalarOp, Args...> { |
| Value operator()(Location loc, ArrayRef<Type> resultTypes, |
| ArrayRef<Type> argTypes, ValueRange args, OpBuilder *b) { |
| Type elementType = getElementTypeOrSelf(argTypes.front()); |
| if (SupportedType{}(elementType)) { |
| return b->template create<StdScalarOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| return MapStableHloOpToScalarOpImpl<Args...>{}(loc, resultTypes, argTypes, |
| args, b); |
| } |
| }; |
| |
| template <typename SupportedType, typename... Args> |
| struct MapStableHloOpToScalarOpImpl<SupportedType, void, Args...> { |
| Value operator()(Location loc, ArrayRef<Type> resultTypes, |
| ArrayRef<Type> argTypes, ValueRange args, OpBuilder *b) { |
| return MapStableHloOpToScalarOpImpl<Args...>{}(loc, resultTypes, argTypes, |
| args, b); |
| } |
| }; |
| |
| struct IsAnyIntegerType { |
| bool operator()(Type t) { return t.isa<IntegerType>(); } |
| }; |
| |
| struct IsSignedIntegerType { |
| bool operator()(Type t) { |
| // Pretend that signless is signed. This will change eventually. |
| return t.isa<IntegerType>() && !t.isUnsignedInteger() && |
| !t.isSignlessInteger(1); |
| } |
| }; |
| |
| struct IsUnsignedIntegerType { |
| bool operator()(Type t) { |
| return t.isUnsignedInteger() || t.isSignlessInteger(1); |
| } |
| }; |
| |
| struct IsFloatType { |
| bool operator()(Type t) { return t.isa<FloatType>(); } |
| }; |
| |
| struct IsComplexType { |
| bool operator()(Type t) { return t.isa<ComplexType>(); } |
| }; |
| |
| template <template <typename T> class MapTy, typename OpTy, |
| typename PredTy = llvm::is_detected<MapTy, OpTy>> |
| struct MapableIf { |
| using type = void; |
| }; |
| template <template <typename T> class MapTy, typename OpTy> |
| struct MapableIf<MapTy, OpTy, std::true_type> { |
| using type = MapTy<OpTy>; |
| }; |
| |
| // Inserts the computation that corresponds to the body of the loop for lowered |
| // StableHLO unary/binary op. Returns the value for the result. |
| template <typename StableHloOpTy> |
| inline Value mapStableHloOpToStdScalarOp( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| typename StableHloOpTy::Adaptor adaptor, OpBuilder *b) { |
| using ScalarIOpOrVoid = typename MapableIf<ScalarIOp, StableHloOpTy>::type; |
| using ScalarUOpOrVoid = typename MapableIf<ScalarUOp, StableHloOpTy>::type; |
| using ScalarFOpOrVoid = typename MapableIf<ScalarFOp, StableHloOpTy>::type; |
| using ScalarCOpOrVoid = typename MapableIf<ScalarCOp, StableHloOpTy>::type; |
| return MapStableHloOpToScalarOpImpl<IsSignedIntegerType, ScalarIOpOrVoid, |
| IsUnsignedIntegerType, ScalarUOpOrVoid, |
| IsFloatType, ScalarFOpOrVoid, |
| IsComplexType, ScalarCOpOrVoid>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::AbsOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::AbsOp::Adaptor adaptor, OpBuilder *b) { |
| Type elementType = getElementTypeOrSelf(argTypes.front()); |
| if (elementType.isa<FloatType>()) { |
| return MapStableHloOpToScalarOpImpl<IsFloatType, ::mlir::math::AbsFOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| if (elementType.isa<ComplexType>()) { |
| return MapStableHloOpToScalarOpImpl<IsComplexType, |
| ::mlir::complex::AbsOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| if (elementType.isSignlessInteger() || elementType.isSignedInteger()) { |
| // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) |
| Value lhs = adaptor.getOperand(); |
| Value zeroIntval = |
| b->create<arith::ConstantOp>(loc, b->getZeroAttr(lhs.getType())); |
| auto lhsGtZero = b->create<ScalarIOp<CompareOp>>( |
| loc, arith::CmpIPredicate::sge, lhs, zeroIntval); |
| auto negVal = |
| b->create<ScalarIOp<stablehlo::SubtractOp>>(loc, zeroIntval, lhs); |
| return b->create<::mlir::arith::SelectOp>(loc, lhsGtZero, lhs, negVal); |
| } |
| return nullptr; |
| } |
| |
| // Return a constant for v of type t, splat if t is a vector type. |
| inline Value getConstantOrSplat(OpBuilder *b, Location loc, Type t, |
| Attribute v) { |
| if (VectorType vecType = t.dyn_cast<VectorType>()) { |
| v = SplatElementsAttr::get(vecType, v); |
| } |
| return b->create<arith::ConstantOp>(loc, t, cast<TypedAttr>(v)); |
| } |
| |
| template <typename PredicateType> |
| inline std::optional<PredicateType> |
| getCmpPredicate(stablehlo::ComparisonDirection, bool) { |
| return std::nullopt; |
| } |
| |
| template <> |
| inline std::optional<arith::CmpFPredicate> |
| getCmpPredicate<arith::CmpFPredicate>( |
| stablehlo::ComparisonDirection comparisonDirection, bool isSigned) { |
| assert(isSigned && "cannot have an unsigned float!"); |
| return llvm::StringSwitch<std::optional<arith::CmpFPredicate>>( |
| stringifyComparisonDirection(comparisonDirection)) |
| .Case("EQ", arith::CmpFPredicate::OEQ) |
| .Case("NE", arith::CmpFPredicate::UNE) |
| .Case("GE", arith::CmpFPredicate::OGE) |
| .Case("GT", arith::CmpFPredicate::OGT) |
| .Case("LE", arith::CmpFPredicate::OLE) |
| .Case("LT", arith::CmpFPredicate::OLT) |
| .Default(std::nullopt); |
| } |
| |
| template <> |
| inline std::optional<arith::CmpIPredicate> |
| getCmpPredicate<arith::CmpIPredicate>( |
| stablehlo::ComparisonDirection comparisonDirection, bool isSigned) { |
| return llvm::StringSwitch<std::optional<arith::CmpIPredicate>>( |
| stringifyComparisonDirection(comparisonDirection)) |
| .Case("EQ", arith::CmpIPredicate::eq) |
| .Case("NE", arith::CmpIPredicate::ne) |
| .Case("GE", |
| isSigned ? arith::CmpIPredicate::sge : arith::CmpIPredicate::uge) |
| .Case("GT", |
| isSigned ? arith::CmpIPredicate::sgt : arith::CmpIPredicate::ugt) |
| .Case("LE", |
| isSigned ? arith::CmpIPredicate::sle : arith::CmpIPredicate::ule) |
| .Case("LT", |
| isSigned ? arith::CmpIPredicate::slt : arith::CmpIPredicate::ult) |
| .Default(std::nullopt); |
| } |
| |
| inline Value cmpComplex(Location loc, Value lhs, Value rhs, |
| stablehlo::ComparisonDirection comparisonDirection, |
| OpBuilder *b) { |
| auto complexType = lhs.getType().cast<ComplexType>(); |
| if (complexType.getElementType().isa<FloatType>()) { |
| if (comparisonDirection == stablehlo::ComparisonDirection::EQ) { |
| return b->create<complex::EqualOp>(loc, lhs, rhs); |
| } |
| if (comparisonDirection == stablehlo::ComparisonDirection::NE) { |
| return b->create<complex::NotEqualOp>(loc, lhs, rhs); |
| } |
| |
| // Perform a lexicographical comparison for the (real, imaginary) pair. |
| Type complexFloatTy = complexType.getElementType(); |
| |
| Value lhsReal = b->create<complex::ReOp>(loc, complexFloatTy, lhs); |
| Value rhsReal = b->create<complex::ReOp>(loc, complexFloatTy, rhs); |
| |
| Value lhsImag = b->create<complex::ImOp>(loc, complexFloatTy, lhs); |
| Value rhsImag = b->create<complex::ImOp>(loc, complexFloatTy, rhs); |
| |
| auto predicate = getCmpPredicate<arith::CmpFPredicate>(comparisonDirection, |
| /*is_signed=*/true); |
| assert(predicate.has_value() && "expected valid comparison direction"); |
| |
| // if (lhsReal == rhsReal && lhsImag `predicate` rhsImag || |
| // lhsReal `predicate` rhsReal) |
| Value realsAreEq = b->create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, |
| lhsReal, rhsReal); |
| Value imagsAreOrdered = |
| b->create<arith::CmpFOp>(loc, *predicate, lhsImag, rhsImag); |
| Value realsAreOrdered = |
| b->create<arith::CmpFOp>(loc, *predicate, lhsReal, rhsReal); |
| |
| Value orLhs = b->create<arith::AndIOp>(loc, realsAreEq, imagsAreOrdered); |
| return b->create<arith::OrIOp>(loc, orLhs, realsAreOrdered); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::CompareOp>( |
| Location loc, ArrayRef<Type> /*resultTypes*/, ArrayRef<Type> argTypes, |
| stablehlo::CompareOp::Adaptor adaptor, OpBuilder *b) { |
| stablehlo::ComparisonDirection comparisonDirection = |
| adaptor.getComparisonDirection(); |
| const auto &lhs = adaptor.getLhs(); |
| const auto &rhs = adaptor.getRhs(); |
| Type elementType = getElementTypeOrSelf(argTypes.front()); |
| if (elementType.isa<IntegerType>()) { |
| bool isUnsigned = IsUnsignedIntegerType{}(elementType); |
| std::optional<arith::CmpIPredicate> predicate = |
| getCmpPredicate<arith::CmpIPredicate>(comparisonDirection, !isUnsigned); |
| assert(predicate.has_value() && "expected valid comparison direction"); |
| return b->create<ScalarIOp<stablehlo::CompareOp>>(loc, predicate.value(), |
| lhs, rhs); |
| } |
| if (auto floatType = elementType.dyn_cast<FloatType>()) { |
| if (adaptor.getCompareType() && |
| *adaptor.getCompareType() == stablehlo::ComparisonType::TOTALORDER) { |
| // The semantics of totalorder fp compare are |
| // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN |
| auto intType = b->getIntegerType(floatType.getWidth()); |
| auto zero = |
| b->create<arith::ConstantOp>(loc, intType, b->getZeroAttr(intType)); |
| auto max = b->create<arith::ConstantOp>( |
| loc, intType, |
| b->getIntegerAttr(intType, |
| APInt::getSignedMaxValue(floatType.getWidth()))); |
| // Switch from a floating point value to a integer value in such a way |
| // that when using the integer value to compare, we get the same result |
| // for normal values, and -NaN is treated as the smallest value, and NaN |
| // is treated as the largest value. |
| // If f is a float, and |
| // x = bit_cast<int32_t>(f); |
| // y = x < 0 ? numeric_limits<int32_t>::max() - x : x; |
| // then y is ordered as an int32_t such that finite values have the |
| // obvious order, -0 is ordered before 0, and -NaN and NaN appear at the |
| // beginning and end of the ordering. |
| auto toIntegral = [&](Value v) { |
| auto x = b->create<arith::BitcastOp>(loc, intType, v); |
| auto cmp = |
| b->create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, x, zero); |
| auto sub = b->create<arith::SubIOp>(loc, max, x); |
| return b->create<arith::SelectOp>(loc, cmp, sub, x); |
| }; |
| auto lhsInt = toIntegral(lhs); |
| auto rhsInt = toIntegral(rhs); |
| auto predicate = |
| getCmpPredicate<arith::CmpIPredicate>(comparisonDirection, |
| /*is_signed=*/true); |
| assert(predicate.has_value() && "expected valid comparison direction"); |
| return b->create<arith::CmpIOp>(loc, *predicate, lhsInt, rhsInt); |
| } |
| std::optional<arith::CmpFPredicate> predicate = |
| getCmpPredicate<arith::CmpFPredicate>(comparisonDirection, |
| /*is_signed=*/true); |
| assert(predicate.has_value() && "expected valid comparison direction"); |
| return b->create<ScalarFOp<stablehlo::CompareOp>>(loc, predicate.value(), |
| lhs, rhs); |
| } |
| if (auto complexType = elementType.dyn_cast<ComplexType>()) |
| return cmpComplex(loc, lhs, rhs, comparisonDirection, b); |
| return nullptr; |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::ReducePrecisionOp>( |
| Location loc, ArrayRef<Type> /*resultTypes*/, ArrayRef<Type> argTypes, |
| stablehlo::ReducePrecisionOp::Adaptor adaptor, OpBuilder *builder) { |
| using llvm::APInt; |
| mlir::ImplicitLocOpBuilder b(loc, *builder); |
| |
| // Integer and float types for casting and constant generation. |
| auto floatType = |
| argTypes.front().cast<TensorType>().getElementType().cast<FloatType>(); |
| int64_t nbits = floatType.getWidth(); |
| auto intType = mlir::IntegerType::get(loc.getContext(), floatType.getWidth()); |
| |
| Value xAsInt = b.create<arith::BitcastOp>(intType, adaptor.getOperand()); |
| |
| // SignificandWidth includes the implicit extra bit. |
| auto srcMantissaBits = floatType.getFPMantissaWidth() - 1; |
| int srcExponentBits = nbits - 1 - srcMantissaBits; |
| |
| // Clear the sign bit, it does not participate in rounding and we will restore |
| // it later. |
| APInt signBitMask(nbits, 1); |
| signBitMask <<= nbits - 1; |
| |
| APInt expBitsMask(nbits, 1); |
| expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits; |
| |
| auto createConstant = [&](const APInt &v) { |
| return b.create<arith::ConstantIntOp>(v.getZExtValue(), intType) |
| .getResult(); |
| }; |
| |
| Value xAbsBits = |
| b.create<arith::AndIOp>(xAsInt, createConstant(~signBitMask)); |
| Value xIsNan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, xAbsBits, |
| createConstant(expBitsMask)); |
| |
| int destMantissaBits = adaptor.getMantissaBits(); |
| if (destMantissaBits < static_cast<int>(srcMantissaBits)) { |
| // Last remaining mantissa bit. |
| APInt lastMantissaBitMask(nbits, 1); |
| lastMantissaBitMask <<= srcMantissaBits - destMantissaBits; |
| |
| // Compute rounding bias for round-to-nearest with ties to even. This is |
| // equal to a base value of 0111... plus one bit if the last remaining |
| // mantissa bit is 1. |
| APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1; |
| |
| Value mantissaDiff = b.create<arith::ConstantIntOp>( |
| srcMantissaBits - destMantissaBits, intType); |
| Value highestMantissaMaskVal = createConstant(lastMantissaBitMask); |
| Value baseRoundingBiasVal = createConstant(baseRoundingBias); |
| Value xLastMantissaBit = b.create<arith::ShRUIOp>( |
| b.create<arith::AndIOp>(xAsInt, highestMantissaMaskVal), mantissaDiff); |
| Value xRoundingBias = |
| b.create<arith::AddIOp>(xLastMantissaBit, baseRoundingBiasVal); |
| |
| // Add rounding bias, and mask out truncated bits. Note that the case |
| // where adding the rounding bias overflows into the exponent bits is |
| // correct; the non-masked mantissa bits will all be zero, and the |
| // exponent will be incremented by one. |
| APInt truncationMask = ~(lastMantissaBitMask - 1); |
| Value xRounded = b.create<arith::AddIOp>(xAsInt, xRoundingBias); |
| xAsInt = b.create<arith::AndIOp>(xRounded, createConstant(truncationMask)); |
| } |
| |
| int destExponentBits = adaptor.getExponentBits(); |
| if (destExponentBits < srcExponentBits) { |
| // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- |
| // significant bit -- is equal to 1.0f for all exponent sizes. Adding |
| // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- |
| // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' |
| // exponent (corresponding to 0.0f). |
| // |
| // Thus, the f32 exponent corresponding to the highest non-infinite |
| // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 |
| // exponent corresponding to the lowest exponent for a bit size of n is |
| // (2^7-1) - 2^(n-1)-1. |
| // |
| // Note that we have already checked that exponents_bits >= 1. |
| APInt exponentBias(nbits, 1); |
| exponentBias = (exponentBias << (srcExponentBits - 1)) - 1; |
| |
| APInt reducedExponentBias(nbits, 1); |
| reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1; |
| |
| APInt reducedMaxExponent = exponentBias + reducedExponentBias; |
| APInt reducedMinExponent = exponentBias - reducedExponentBias; |
| |
| // Do we overflow or underflow? |
| Value xExponent = |
| b.create<arith::AndIOp>(xAsInt, createConstant(expBitsMask)); |
| Value xOverflows = b.create<arith::CmpIOp>( |
| arith::CmpIPredicate::ugt, xExponent, |
| createConstant(reducedMaxExponent << srcMantissaBits)); |
| Value xUnderflows = b.create<arith::CmpIOp>( |
| arith::CmpIPredicate::ule, xExponent, |
| createConstant(reducedMinExponent << srcMantissaBits)); |
| |
| // Compute appropriately-signed values of zero and infinity. |
| Value xSignedZero = |
| b.create<arith::AndIOp>(xAsInt, createConstant(signBitMask)); |
| Value xSignedInf = |
| b.create<arith::OrIOp>(xSignedZero, createConstant(expBitsMask)); |
| |
| // Force to zero or infinity if overflow or underflow. (Note that this |
| // truncates all denormal values to zero, rather than rounding them.) |
| xAsInt = b.create<arith::SelectOp>(xOverflows, xSignedInf, xAsInt); |
| xAsInt = b.create<arith::SelectOp>(xUnderflows, xSignedZero, xAsInt); |
| } |
| |
| Value result = b.create<arith::BitcastOp>(floatType, xAsInt); |
| return b.create<arith::SelectOp>(xIsNan, adaptor.getOperand(), result); |
| } |
| |
| // FIXME(Jakub) |
| // template <> |
| // inline Value mapStableHloOpToStdScalarOp<stablehlo::CopyOp>( |
| // Location /*loc*/, ArrayRef<Type> /*ResultTypes*/, |
| // ArrayRef<Type> /*argTypes*/, stablehlo::CopyOp::Adaptor adaptor, |
| // OpBuilder* /*b*/) { |
| // return adaptor.getOperands().front(); |
| // } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::ComplexOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::ComplexOp::Adaptor adaptor, OpBuilder *b) { |
| return MapStableHloOpToScalarOpImpl<complex::CreateOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::MaxOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::MaxOp::Adaptor adaptor, OpBuilder *b) { |
| ValueRange operands = adaptor.getOperands(); |
| Value lhs = operands.front(); |
| Type complexTy = lhs.getType(); |
| |
| if (!complexTy.isa<ComplexType>()) |
| return MapStableHloOpToScalarOpImpl< |
| IsFloatType, arith::MaximumFOp, IsSignedIntegerType, arith::MaxSIOp, |
| IsUnsignedIntegerType, arith::MaxUIOp>{}(loc, resultTypes, argTypes, |
| adaptor.getOperands(), b); |
| |
| assert(resultTypes.size() == 1 && "MaxOp should return a single result"); |
| assert(operands.size() == 2 && "MaxOp should take exactly two arguments"); |
| |
| Value rhs = operands.back(); |
| // 'max' performs a lexicographical comparison for the (real, imaginary) pair. |
| Value cond = cmpComplex(loc, lhs, rhs, stablehlo::ComparisonDirection::GE, b); |
| |
| return b->create<arith::SelectOp>(loc, cond, lhs, rhs).getResult(); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::MinOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::MinOp::Adaptor adaptor, OpBuilder *b) { |
| ValueRange operands = adaptor.getOperands(); |
| Value lhs = operands.front(); |
| Type complexTy = lhs.getType(); |
| |
| if (!complexTy.isa<ComplexType>()) |
| return MapStableHloOpToScalarOpImpl< |
| IsFloatType, arith::MinimumFOp, IsSignedIntegerType, arith::MinSIOp, |
| IsUnsignedIntegerType, arith::MinUIOp>{}(loc, resultTypes, argTypes, |
| adaptor.getOperands(), b); |
| |
| assert(resultTypes.size() == 1 && "MinOp should return a single result"); |
| assert(operands.size() == 2 && "MinOp should take exactly two arguments"); |
| |
| Value rhs = operands.back(); |
| // 'min' performs a lexicographical comparison for the (real, imaginary) pair. |
| Value cond = cmpComplex(loc, lhs, rhs, stablehlo::ComparisonDirection::LE, b); |
| |
| return b->create<arith::SelectOp>(loc, cond, lhs, rhs).getResult(); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::RealOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::RealOp::Adaptor adaptor, OpBuilder *b) { |
| if (!adaptor.getOperand().getType().isa<ComplexType>()) |
| return adaptor.getOperand(); |
| return MapStableHloOpToScalarOpImpl<complex::ReOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::ImagOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::ImagOp::Adaptor adaptor, OpBuilder *b) { |
| if (!adaptor.getOperand().getType().isa<ComplexType>()) |
| return b->create<arith::ConstantOp>( |
| loc, b->getZeroAttr(adaptor.getOperand().getType())); |
| return MapStableHloOpToScalarOpImpl<complex::ImOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| |
| // 'target_types' is the unconverted type (signed or unsigned if integer), |
| // 'ResultTypes' is the converted type (signless if integer). |
| inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef<Type> targetTypes, |
| ArrayRef<Type> resultTypes, |
| ArrayRef<Type> argTypes, ValueRange args, |
| OpBuilder *b) { |
| assert(targetTypes.size() == 1 && "ConvertOp should return a single result"); |
| assert(resultTypes.size() == 1 && "ConvertOp should return a single result"); |
| assert(argTypes.size() == 1 && "ConvertOp should take a single argument"); |
| assert(args.size() == 1 && "ConvertOp should take a single argument"); |
| |
| Type sourceType = getElementTypeOrSelf(argTypes.front()); |
| Type targetType = getElementTypeOrSelf(targetTypes.front()); |
| Type convertedSourceType = getElementTypeOrSelf(args.front()); |
| |
| // A boolean value is considered to be unsigned when converting to |
| // floating-point. Otherwise, it will become `-1`. |
| if (IsUnsignedIntegerType{}(sourceType) && |
| mlir::arith::UIToFPOp::areCastCompatible(convertedSourceType, |
| targetType)) { |
| return b->create<mlir::arith::UIToFPOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| if (mlir::arith::SIToFPOp::areCastCompatible(sourceType, targetType)) { |
| return b->create<mlir::arith::SIToFPOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) { |
| auto src = sourceType.cast<FloatType>(); |
| auto res = targetType.cast<FloatType>(); |
| if (src.getWidth() > res.getWidth()) { |
| return b->create<mlir::arith::TruncFOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| if (src.getWidth() < res.getWidth()) { |
| return b->create<mlir::arith::ExtFOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| // There's no direct conversion between different 16 bit floating point |
| // types, so go through 32 bit float. |
| if (sourceType != targetType) { |
| assert(sourceType.isBF16() || targetType.isBF16()); |
| Value ext = b->create<arith::ExtFOp>(loc, b->getF32Type(), args); |
| return b->create<arith::TruncFOp>(loc, resultTypes, ext); |
| } |
| // No conversion is needed for identical float types. |
| return args.front(); |
| } |
| if (targetType.isInteger(/*width=*/1)) { |
| // When casting to bool, we need to compare whether the value is equal to |
| // zero. |
| if (sourceType.isSignlessInteger() || sourceType.isUnsignedInteger()) { |
| Value zeroIntval = b->create<arith::ConstantOp>( |
| loc, b->getZeroAttr(args.front().getType())); |
| return b->create<mlir::arith::CmpIOp>(loc, arith::CmpIPredicate::ne, |
| args.front(), zeroIntval); |
| } |
| if (sourceType.isa<FloatType>()) { |
| Value zero = b->create<arith::ConstantOp>( |
| loc, b->getZeroAttr(args.front().getType())); |
| return b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, |
| args.front(), zero); |
| } |
| } |
| if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) { |
| auto src = sourceType.cast<IntegerType>(); |
| auto res = targetType.cast<IntegerType>(); |
| if (src.getWidth() > res.getWidth()) { |
| return b->create<mlir::arith::TruncIOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| if (src.getWidth() < res.getWidth()) { |
| // Special case boolean values, so they get casted to `1` instead of `-1`. |
| if (IsUnsignedIntegerType{}(src)) { |
| return b->create<mlir::arith::ExtUIOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| return b->create<mlir::arith::ExtSIOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| // No conversion is needed for the same width integers |
| return args.front(); |
| } |
| if (targetType.isUnsignedInteger() && |
| mlir::arith::FPToUIOp::areCastCompatible(convertedSourceType, |
| targetType)) { |
| return b->create<mlir::arith::FPToUIOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType, |
| targetType)) { |
| return b->create<mlir::arith::FPToSIOp>(loc, resultTypes, args, |
| std::nullopt); |
| } |
| if (targetType.isa<ComplexType>()) { |
| Type targetElementType = targetType.cast<ComplexType>().getElementType(); |
| assert(!targetElementType.isa<ComplexType>() && |
| "elements of complex numbers should not be complex"); |
| Value targetReal; |
| Value targetImag; |
| if (sourceType.isa<ComplexType>()) { |
| // We are converting from complex type: convert real and imaginary parts |
| // separately. |
| Type sourceElementType = sourceType.cast<ComplexType>().getElementType(); |
| assert(!sourceElementType.isa<ComplexType>() && |
| "elements of complex numbers should not be complex"); |
| Value sourceReal = |
| b->create<mlir::complex::ReOp>(loc, sourceElementType, args.front()); |
| targetReal = |
| mapConvertOpToStdScalarOp(loc, targetElementType, targetElementType, |
| sourceElementType, sourceReal, b); |
| Value sourceImag = |
| b->create<mlir::complex::ImOp>(loc, sourceElementType, args.front()); |
| targetImag = |
| mapConvertOpToStdScalarOp(loc, targetElementType, targetElementType, |
| sourceElementType, sourceImag, b); |
| } else { |
| // We are converting from real (float, integer, etc.) type, convert the |
| // real part and set the imaginary part to 0. |
| targetReal = mapConvertOpToStdScalarOp( |
| loc, targetElementType, targetElementType, argTypes, args, b); |
| targetImag = b->create<mlir::arith::ConstantOp>( |
| loc, b->getFloatAttr(targetElementType, 0.0)); |
| } |
| return b->create<mlir::complex::CreateOp>(loc, targetType, targetReal, |
| targetImag); |
| } |
| if (auto sourceComplexType = sourceType.dyn_cast<ComplexType>()) { |
| auto sourceElementType = sourceComplexType.getElementType(); |
| // When converting from complex to a non-complex type, we take just the real |
| // part of the complex number. |
| Value sourceReal = |
| b->create<mlir::complex::ReOp>(loc, sourceElementType, args.front()); |
| return mapConvertOpToStdScalarOp(loc, targetTypes, resultTypes, |
| sourceElementType, sourceReal, b); |
| } |
| return nullptr; |
| } |
| |
| /// Lower bitcast operations where the input and resulting type are the same |
| /// bitwidth, thus implying that the operation is fully defined by parallel |
| /// loops and scalar operations without any shape dimension changes. |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::BitcastConvertOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::BitcastConvertOp::Adaptor adaptor, OpBuilder *b) { |
| Type argType = getElementTypeOrSelf(argTypes.front()); |
| Type resultType = getElementTypeOrSelf(resultTypes.front()); |
| |
| // Skip needless casts. |
| if (argType == resultType) |
| return adaptor.getOperand(); |
| |
| if (!isa<FloatType, IntegerType>(resultType) || |
| !isa<FloatType, IntegerType>(argType)) |
| return nullptr; |
| |
| if (resultType.getIntOrFloatBitWidth() != argType.getIntOrFloatBitWidth()) |
| return nullptr; |
| |
| return b->create<mlir::arith::BitcastOp>(loc, resultTypes, |
| adaptor.getOperands()); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::DotOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::DotOp::Adaptor adaptor, OpBuilder *b) { |
| // Dot Op converter from lhlo to affine only accepts float and integer types. |
| const auto &lhs = adaptor.getOperands()[0]; |
| const auto &rhs = adaptor.getOperands()[1]; |
| const auto &result = adaptor.getOperands()[2]; |
| Type elementType = lhs.getType(); |
| if (elementType.isa<FloatType>()) { |
| Value floatMul = |
| MapStableHloOpToScalarOpImpl<IsFloatType, ::mlir::arith::MulFOp>{}( |
| loc, resultTypes, argTypes, {lhs, rhs}, b); |
| return MapStableHloOpToScalarOpImpl<IsFloatType, ::mlir::arith::AddFOp>{}( |
| loc, resultTypes, argTypes, {floatMul, result}, b); |
| } |
| if (elementType.isa<IntegerType>()) { |
| Value intMul = |
| MapStableHloOpToScalarOpImpl<IsAnyIntegerType, ::mlir::arith::MulIOp>{}( |
| loc, resultTypes, argTypes, {lhs, rhs}, b); |
| return MapStableHloOpToScalarOpImpl<IsAnyIntegerType, |
| ::mlir::arith::AddIOp>{}( |
| loc, resultTypes, argTypes, {intMul, result}, b); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::IsFiniteOp>( |
| Location loc, ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/, |
| stablehlo::IsFiniteOp::Adaptor adaptor, OpBuilder *b) { |
| if (adaptor.getX().getType().isa<FloatType>()) { |
| auto posInf = APFloat::getInf( |
| adaptor.getX().getType().cast<FloatType>().getFloatSemantics()); |
| auto constPosInf = b->create<arith::ConstantOp>( |
| loc, b->getFloatAttr(adaptor.getX().getType(), posInf)); |
| Value absX = b->create<::mlir::math::AbsFOp>(loc, adaptor.getX()); |
| return b->create<::mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::ONE, |
| absX, constPosInf); |
| } |
| return nullptr; |
| } |
| |
| /// Implements the conversion of HLO op to scalar op (to use within region of a |
| /// linalg.generic op) for compare-select style operations like min/max. |
| template <typename... Args> |
| struct CompareSelectOpToStdScalarOp { |
| static Value map(Location /*loc*/, |
| stablehlo::ComparisonDirection /*comparison_direction*/, |
| ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/, |
| ValueRange /*args*/, OpBuilder * /*b*/) { |
| return nullptr; |
| } |
| }; |
| |
| /// Specialization which allows converting to a comparison operation in standard |
| /// dialect with a given predicate based on the element type of the operand. |
| template <typename SupportedType, typename StdCompareOp, typename Predicate, |
| typename... Args> |
| struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate, |
| Args...> { |
| static Value map(Location loc, |
| stablehlo::ComparisonDirection comparisonDirection, |
| ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| ValueRange args, OpBuilder *b) { |
| Type elementType = getElementTypeOrSelf(argTypes.front()); |
| if (elementType.isa<SupportedType>()) { |
| auto predicate = getCmpPredicate<Predicate>( |
| comparisonDirection, !elementType.isUnsignedInteger()); |
| assert(predicate.has_value() && "expected valid comparison direction"); |
| auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(), |
| args[0], args[1]); |
| return b->create<::mlir::arith::SelectOp>(loc, cmp, args[0], args[1]); |
| } |
| return CompareSelectOpToStdScalarOp<Args...>::map( |
| loc, comparisonDirection, resultTypes, argTypes, args, b); |
| } |
| }; |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::ClampOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::ClampOp::Adaptor op, OpBuilder *b) { |
| // clamp(lb, x, ub) = min(max(lb, x), ub) |
| Value maxLbX = mapStableHloOpToStdScalarOp<stablehlo::MaxOp>( |
| loc, resultTypes, argTypes, ValueRange{op.getMin(), op.getOperand()}, b); |
| return mapStableHloOpToStdScalarOp<stablehlo::MinOp>( |
| loc, resultTypes, argTypes, ValueRange{maxLbX, op.getMax()}, b); |
| } |
| |
| template <typename U, typename S> |
| inline Value makeSafeIntDiv(ImplicitLocOpBuilder &lb, Type originalType, |
| Value lhs, Value rhs, Value returnedOnZero, |
| Value returnedOnSignedOverflow) { |
| Type type = lhs.getType(); |
| auto elementType = getElementTypeOrSelf(type).cast<IntegerType>(); |
| Value zero = lb.create<arith::ConstantOp>(lb.getZeroAttr(type)); |
| auto makeConstant = [&](const APInt &i) { |
| return getConstantOrSplat(&lb, lb.getLoc(), type, |
| lb.getIntegerAttr(elementType, i)); |
| }; |
| Value one = makeConstant(APInt(elementType.getWidth(), 1)); |
| Value rhsIsZero = |
| lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, rhs, zero); |
| |
| // For unsigned just set the divisor to 1 when it would be 0. |
| if (originalType.isUnsignedInteger()) { |
| Value safeRhs = lb.create<arith::SelectOp>(rhsIsZero, one, rhs); |
| Value safeDiv = lb.create<U>(lhs, safeRhs); |
| return lb.create<arith::SelectOp>(rhsIsZero, returnedOnZero, safeDiv); |
| } |
| |
| // For signed also check for INT_MIN / -1. |
| Value smin = makeConstant(APInt::getSignedMinValue(elementType.getWidth())); |
| Value lhsIsSmin = |
| lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, lhs, smin); |
| Value minusOne = makeConstant(APInt::getAllOnes(elementType.getWidth())); |
| Value rhsIsMinusOne = |
| lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, rhs, minusOne); |
| Value hasIntMinOverflow = lb.create<arith::AndIOp>(lhsIsSmin, rhsIsMinusOne); |
| Value rhsIsUnsafe = lb.create<arith::OrIOp>(rhsIsZero, hasIntMinOverflow); |
| Value safeRhs = lb.create<arith::SelectOp>(rhsIsUnsafe, one, rhs); |
| Value safeDiv = lb.create<S>(lhs, safeRhs); |
| Value safeSmin = lb.create<arith::SelectOp>( |
| hasIntMinOverflow, returnedOnSignedOverflow, safeDiv); |
| return lb.create<arith::SelectOp>(rhsIsZero, returnedOnZero, safeSmin); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::DivOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::DivOp::Adaptor adaptor, OpBuilder *b) { |
| Type originalType = getElementTypeOrSelf(argTypes.front()); |
| if (originalType.isa<ComplexType, FloatType>()) { |
| return MapStableHloOpToScalarOpImpl<IsFloatType, arith::DivFOp, |
| IsComplexType, complex::DivOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| |
| // Integer division overflow behavior: |
| // |
| // X / 0 == -1 |
| // INT_SMIN /s -1 = INT_SMIN |
| ImplicitLocOpBuilder lb(loc, *b); |
| Type type = adaptor.getLhs().getType(); |
| auto elementType = getElementTypeOrSelf(type).cast<IntegerType>(); |
| auto makeConstant = [&](const APInt &i) { |
| return getConstantOrSplat(&lb, lb.getLoc(), type, |
| lb.getIntegerAttr(elementType, i)); |
| }; |
| Value minusOne = makeConstant(APInt::getAllOnes(elementType.getWidth())); |
| Value smin = makeConstant(APInt::getSignedMinValue(elementType.getWidth())); |
| return makeSafeIntDiv<arith::DivUIOp, arith::DivSIOp>( |
| lb, originalType, adaptor.getLhs(), adaptor.getRhs(), |
| /*returnedOnZero=*/minusOne, |
| /*returnedOnSignedOverflow=*/smin); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::RemOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::RemOp::Adaptor adaptor, OpBuilder *b) { |
| Type originalType = getElementTypeOrSelf(argTypes.front()); |
| if (originalType.isa<ComplexType, FloatType>()) { |
| return MapStableHloOpToScalarOpImpl<IsFloatType, arith::RemFOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| |
| // Integer remainder overflow behavior: |
| // |
| // X % 0 == X |
| // INT_SMIN %s -1 = 0 |
| ImplicitLocOpBuilder lb(loc, *b); |
| Type type = adaptor.getLhs().getType(); |
| Value zero = lb.create<arith::ConstantOp>(lb.getZeroAttr(type)); |
| return makeSafeIntDiv<arith::RemUIOp, arith::RemSIOp>( |
| lb, originalType, adaptor.getLhs(), adaptor.getRhs(), |
| /*returnedOnZero=*/adaptor.getLhs(), |
| /*returnedOnSignedOverflow=*/zero); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::NegOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::NegOp::Adaptor adaptor, OpBuilder *b) { |
| Type elementType = getElementTypeOrSelf(adaptor.getOperand().getType()); |
| if (elementType.isa<ComplexType, FloatType>()) { |
| return MapStableHloOpToScalarOpImpl<IsFloatType, ::mlir::arith::NegFOp, |
| IsComplexType, |
| ::mlir::complex::NegOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| if (elementType.isa<IntegerType>()) { |
| // lmhlo.neg(x, result) -> result = sub(0, x) |
| Value lhs = adaptor.getOperand(); |
| Value zeroIntval = |
| b->create<arith::ConstantOp>(loc, b->getZeroAttr(lhs.getType())); |
| return b->create<ScalarIOp<stablehlo::SubtractOp>>(loc, zeroIntval, lhs); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::NotOp>( |
| Location loc, ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/, |
| stablehlo::NotOp::Adaptor adaptor, OpBuilder *b) { |
| Type elementType = getElementTypeOrSelf(adaptor.getOperand().getType()); |
| if (auto integerType = elementType.dyn_cast<IntegerType>()) { |
| // lmhlo.not(x) -> x ^ -1 |
| Value allOnes = getConstantOrSplat( |
| b, loc, adaptor.getOperand().getType(), |
| b->getIntegerAttr(integerType, |
| APInt::getAllOnes(integerType.getWidth()))); |
| return b->create<::mlir::arith::XOrIOp>(loc, allOnes, adaptor.getOperand()); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::LogisticOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> /*argTypes*/, |
| stablehlo::LogisticOp::Adaptor adaptor, OpBuilder *b) { |
| // 1.0 / (1.0 - exp(-x)) |
| Value negX = mapStableHloOpToStdScalarOp<stablehlo::NegOp>( |
| loc, resultTypes, resultTypes, {adaptor.getOperand()}, b); |
| Value expNegX = mapStableHloOpToStdScalarOp<stablehlo::ExpOp>( |
| loc, resultTypes, resultTypes, {{negX}}, b); |
| |
| Value oneFloat = b->create<arith::ConstantOp>(loc, b->getF32FloatAttr(1.0)); |
| Value one = mapConvertOpToStdScalarOp(loc, resultTypes, resultTypes, |
| {oneFloat.getType()}, {{oneFloat}}, b); |
| Value oneAddExprNegX = mapStableHloOpToStdScalarOp<stablehlo::AddOp>( |
| loc, resultTypes, resultTypes, {{expNegX, one}}, b); |
| return mapStableHloOpToStdScalarOp<stablehlo::DivOp>( |
| loc, resultTypes, resultTypes, {{one, oneAddExprNegX}}, b); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::PowOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::PowOp::Adaptor adaptor, OpBuilder *b) { |
| auto lb = ImplicitLocOpBuilder(loc, *b); |
| // Floating point can use std::powf |
| auto resultType = resultTypes.front(); |
| if (resultType.isa<ComplexType, FloatType>()) { |
| return MapStableHloOpToScalarOpImpl<IsFloatType, math::PowFOp, |
| IsComplexType, complex::PowOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| |
| // Exponentiation by squaring: |
| // https://en.wikipedia.org/wiki/Exponentiation_by_squaring; |
| Value negOne = |
| lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, -1)); |
| Value zero = lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, 0)); |
| Value one = lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, 1)); |
| Value two = lb.create<arith::ConstantOp>(lb.getIntegerAttr(resultType, 2)); |
| Value step = lb.create<arith::ConstantIndexOp>(1); |
| Value lowerBound = lb.create<arith::ConstantIndexOp>(0); |
| // Everything else would overflow for any exponent > 1, as 2^64 |
| // is the larget possible exponent for a 64-bit integer, and |
| // that's 1 << 6. |
| Value upperBound = lb.create<arith::ConstantIndexOp>(6); |
| auto originalBase = adaptor.getLhs(); |
| auto originalExponent = adaptor.getRhs(); |
| |
| Value accum = |
| lb.create<scf::ForOp>( |
| lowerBound, upperBound, step, |
| SmallVector<Value>({one, originalBase, originalExponent}), |
| [&](OpBuilder &b, Location, Value /*v*/, ValueRange iters) { |
| Value accum = iters[0]; |
| Value base = iters[1]; |
| Value exponent = iters[2]; |
| |
| Value condition = b.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, |
| b.create<::mlir::arith::AndIOp>(loc, exponent, one), one); |
| Value multiplied = |
| b.create<::mlir::arith::MulIOp>(loc, accum, base); |
| accum = b.create<::mlir::arith::SelectOp>(loc, condition, |
| multiplied, accum); |
| base = b.create<::mlir::arith::MulIOp>(loc, base, base); |
| exponent = b.create<::mlir::arith::ShRUIOp>(loc, exponent, one); |
| b.create<scf::YieldOp>( |
| loc, SmallVector<Value>({accum, base, exponent})); |
| }) |
| .getResult(0); |
| |
| Value rhsIsEven = lb.create<arith::CmpIOp>( |
| arith::CmpIPredicate::eq, |
| lb.create<arith::RemSIOp>(adaptor.getRhs(), two), zero); |
| Value rhsIsNegative = lb.create<arith::CmpIOp>(arith::CmpIPredicate::slt, |
| adaptor.getRhs(), zero); |
| Value lhsIsOne = |
| lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, adaptor.getLhs(), one); |
| Value lhsIsNegOne = lb.create<arith::CmpIOp>(arith::CmpIPredicate::eq, |
| adaptor.getLhs(), negOne); |
| |
| // The accum is correct when the rhs is non-negative. When rhs is |
| // negative, we return 0 for integer, with the exception of lhs values of 1 |
| // and -1 which have integer results for negative exponents. Specifically, the |
| // calulation is the following: |
| // |
| // - Return accum if the rhs is not negative. |
| // - Return 1 or -1 depending on the parity of rhs when the lhs is -1. |
| // - Return 1 if lhs is 1. |
| // - Else return 0. |
| Value ifLhsIsOne = lb.create<::mlir::arith::SelectOp>(lhsIsOne, one, zero); |
| Value ifLhsIsNegOne = lb.create<::mlir::arith::SelectOp>( |
| lhsIsNegOne, lb.create<::mlir::arith::SelectOp>(rhsIsEven, one, negOne), |
| ifLhsIsOne); |
| return lb.create<::mlir::arith::SelectOp>(rhsIsNegative, ifLhsIsNegOne, |
| accum); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::SelectOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| stablehlo::SelectOp::Adaptor adaptor, OpBuilder *b) { |
| return MapStableHloOpToScalarOpImpl<::mlir::arith::SelectOp>{}( |
| loc, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::SignOp>( |
| Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> /*argTypes*/, |
| stablehlo::SignOp::Adaptor adaptor, OpBuilder *b) { |
| Value operand = adaptor.getOperand(); |
| Type elementType = getElementTypeOrSelf(operand.getType()); |
| if (auto floatType = elementType.dyn_cast<FloatType>()) { |
| Value zero = |
| b->create<arith::ConstantOp>(loc, b->getZeroAttr(operand.getType())); |
| Value ne0I1 = b->create<::mlir::arith::CmpFOp>( |
| loc, arith::CmpFPredicate::ONE, operand, zero); |
| Value ne0Float = |
| b->create<::mlir::arith::UIToFPOp>(loc, zero.getType(), ne0I1); |
| Value copySign = b->create<::mlir::math::CopySignOp>(loc, resultTypes, |
| ne0Float, operand); |
| auto isNan = b->create<::mlir::arith::CmpFOp>( |
| loc, arith::CmpFPredicate::UNO, operand, operand); |
| return b->create<::mlir::arith::SelectOp>(loc, isNan, operand, copySign); |
| } |
| if (auto integerType = elementType.dyn_cast<IntegerType>()) { |
| // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) |
| Value zero = |
| b->create<arith::ConstantOp>(loc, b->getZeroAttr(operand.getType())); |
| Value bitwidthMinusOne = getConstantOrSplat( |
| b, loc, operand.getType(), |
| b->getIntegerAttr(integerType, integerType.getWidth() - 1)); |
| Value one = getConstantOrSplat(b, loc, operand.getType(), |
| b->getIntegerAttr(integerType, 1)); |
| Value cmp = b->create<::mlir::arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| operand, zero); |
| Value ashr = |
| b->create<::mlir::arith::ShRSIOp>(loc, operand, bitwidthMinusOne); |
| Value orOp = b->create<::mlir::arith::OrIOp>(loc, ashr, one); |
| return b->create<::mlir::arith::SelectOp>(loc, cmp, zero, orOp); |
| } |
| if (elementType.isa<ComplexType>()) { |
| return b->create<::mlir::complex::SignOp>(loc, elementType, operand); |
| } |
| return nullptr; |
| } |
| |
| /// Construct operations to select the saturated value if the shift amount is |
| /// greater than the bitwidth of the type. |
| inline Value selectShiftedOrSaturated(ImplicitLocOpBuilder &lb, Value rhs, |
| Value shifted, Value saturated, |
| Type type) { |
| Type etype = getElementTypeOrSelf(type); |
| auto bitWidthInt = etype.getIntOrFloatBitWidth(); |
| Value bitWidth = getConstantOrSplat(&lb, lb.getLoc(), type, |
| lb.getIntegerAttr(etype, bitWidthInt)); |
| Value cmp = lb.create<mlir::arith::CmpIOp>(mlir::arith::CmpIPredicate::ugt, |
| bitWidth, rhs); |
| return lb.create<mlir::arith::SelectOp>(cmp, shifted, saturated); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::ShiftLeftOp>( |
| Location loc, ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/, |
| stablehlo::ShiftLeftOp::Adaptor adaptor, OpBuilder *b) { |
| ImplicitLocOpBuilder lb(loc, *b); |
| Value lhs = adaptor.getLhs(); |
| Value rhs = adaptor.getRhs(); |
| Type type = lhs.getType(); |
| |
| // "Saturate" if the shift is greater than the bitwidth of the type |
| Value zero = lb.create<arith::ConstantOp>(lb.getZeroAttr(type)); |
| Value shifted = lb.create<mlir::arith::ShLIOp>(lhs, rhs); |
| |
| return selectShiftedOrSaturated(lb, rhs, shifted, zero, type); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::ShiftRightLogicalOp>( |
| Location loc, ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/, |
| stablehlo::ShiftRightLogicalOp::Adaptor adaptor, OpBuilder *b) { |
| ImplicitLocOpBuilder lb(loc, *b); |
| Value lhs = adaptor.getLhs(); |
| Value rhs = adaptor.getRhs(); |
| Type type = lhs.getType(); |
| |
| // "Saturate" if the shift is greater than the bitwidth of the type |
| Value zero = lb.create<arith::ConstantOp>(b->getZeroAttr(type)); |
| Value shifted = lb.create<mlir::arith::ShRUIOp>(lhs, rhs); |
| |
| return selectShiftedOrSaturated(lb, rhs, shifted, zero, type); |
| } |
| |
| template <> |
| inline Value mapStableHloOpToStdScalarOp<stablehlo::ShiftRightArithmeticOp>( |
| Location loc, ArrayRef<Type> /*ResultTypes*/, ArrayRef<Type> /*argTypes*/, |
| stablehlo::ShiftRightArithmeticOp::Adaptor adaptor, OpBuilder *b) { |
| ImplicitLocOpBuilder lb(loc, *b); |
| Value lhs = adaptor.getLhs(); |
| Value rhs = adaptor.getRhs(); |
| Type type = lhs.getType(); |
| Type etype = getElementTypeOrSelf(type); |
| auto bitWidthInt = etype.getIntOrFloatBitWidth(); |
| |
| // "Saturate" if the shift is greater than the bitwidth of the type |
| Value maxShift = getConstantOrSplat( |
| b, loc, type, lb.getIntegerAttr(etype, bitWidthInt - 1)); |
| Value saturatedShifted = lb.create<mlir::arith::ShRSIOp>(lhs, maxShift); |
| Value shifted = lb.create<mlir::arith::ShRSIOp>(lhs, rhs); |
| |
| return selectShiftedOrSaturated(lb, rhs, shifted, saturatedShifted, type); |
| } |
| } // namespace impl |
| |
| struct StableHloOpToStdScalarOp { |
| // Converts stablehlo 'op' to linalg and arith ops. |
| template <typename StableHloOpTy> |
| static Value mapOp(StableHloOpTy op, ArrayRef<Type> resultTypes, |
| ValueRange args, OpBuilder *b) { |
| auto argTypes = llvm::to_vector(op->getOperandTypes()); |
| return mapOpWithArgTypes(op, resultTypes, argTypes, args, b); |
| } |
| |
| // Converts stablehlo 'op' to linalg and arith ops. The types of 'args' may |
| // already be converted, 'argTypes' are their original types. |
| template <typename StableHloOpTy> |
| static Value mapOpWithArgTypes(StableHloOpTy op, ArrayRef<Type> resultTypes, |
| ArrayRef<Type> argTypes, ValueRange args, |
| OpBuilder *b) { |
| static_assert(!std::is_same<StableHloOpTy, stablehlo::ConvertOp>::value); |
| return mapOpOfType<StableHloOpTy>( |
| op.getLoc(), resultTypes, argTypes, |
| typename StableHloOpTy::Adaptor(args, op->getAttrDictionary()), b); |
| } |
| // Overload for stablehlo::ConvertOp. |
| static Value mapOpWithArgTypes(stablehlo::ConvertOp op, |
| ArrayRef<Type> resultTypes, |
| ArrayRef<Type> argTypes, ValueRange args, |
| OpBuilder *b) { |
| return impl::mapConvertOpToStdScalarOp(op.getLoc(), op.getType(), |
| resultTypes, argTypes, args, b); |
| } |
| |
| // Converts stablehlo 'op' to linalg and arith ops. |
| template <typename StableHloOpTy> |
| static Value |
| mapOpOfType(Location loc, ArrayRef<Type> resultTypes, ArrayRef<Type> argTypes, |
| typename StableHloOpTy::Adaptor adaptor, OpBuilder *b) { |
| if (std::is_same<StableHloOpTy, stablehlo::ConvertOp>::value) { |
| // Note: this assumes that the caller is passing result/arg types with |
| // appropriate signedness. |
| return impl::mapConvertOpToStdScalarOp( |
| loc, resultTypes, resultTypes, argTypes, adaptor.getOperands(), b); |
| } |
| return impl::mapStableHloOpToStdScalarOp<StableHloOpTy>( |
| loc, resultTypes, argTypes, adaptor, b); |
| } |
| }; |
| |
| } // namespace stablehlo |
| } // namespace mlir |
| |
| #endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_MAP_STABLEHLO_TO_SCALAR_OP_H |