| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Implements optional canonicalization patterns for StableHLO ops. |
| |
| #include <cassert> |
| #include <functional> |
| #include <numeric> |
| |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h" |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h" |
| #include "llvm/ADT/APFloat.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/Support/Casting.h" |
| #include "mlir/Dialect/CommonFolders.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/BuiltinAttributeInterfaces.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "stablehlo/dialect/StablehloOps.h" |
| |
| namespace mlir::iree_compiler::stablehlo { |
| |
| #define GEN_PASS_DEF_STABLEHLOCANONICALIZE |
| #include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h.inc" |
| |
| namespace { |
| |
| // This is an upper limit on how many elements canonicalization patterns are |
| // allowed to materialize as new constants. |
| constexpr int64_t kFoldOpEltLimit = 65536; |
| |
| static bool isIotaRange(ArrayRef<int64_t> dims) { |
| for (auto [idx, value] : llvm::enumerate(dims)) { |
| if (idx != value) { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| static bool isIotaRange(ElementsAttr attr) { |
| auto elems = attr.tryGetValues<APInt>(); |
| if (!elems) |
| return false; |
| |
| for (auto [idx, value] : llvm::enumerate(*elems)) { |
| if (idx != value) { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| /// Matches when either of the submatchers match. |
| template <typename MatcherA, typename MatcherB> |
| struct m_AnyOf { |
| m_AnyOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {} |
| |
| bool match(Operation *op) { return matcherA.match(op) || matcherB.match(op); } |
| |
| MatcherA matcherA; |
| MatcherB matcherB; |
| }; |
| |
| template <typename MatcherA, typename MatcherB> |
| m_AnyOf(MatcherA, MatcherB) -> m_AnyOf<MatcherA, MatcherB>; |
| |
| /// Binary constant folder that used a generic folder function to handle both |
| /// ints and floats. |
| template <typename Fn> |
| static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs, |
| Fn &&folder) { |
| Attribute operands[2] = {lhs, rhs}; |
| Type elemTy = getElementTypeOrSelf(cast<TypedAttr>(lhs).getType()); |
| |
| if (isa<IntegerType>(elemTy)) { |
| if (Attribute res = |
| constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>( |
| operands, [&folder](const APInt &lhs, const APInt &rhs) { |
| return folder(lhs, rhs); |
| })) { |
| return cast<TypedAttr>(res); |
| } |
| return nullptr; |
| } |
| |
| if (isa<FloatType>(elemTy)) { |
| if (Attribute res = |
| constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void>( |
| operands, [&folder](const APFloat &lhs, const APFloat &rhs) { |
| return folder(lhs, rhs); |
| })) { |
| return cast<TypedAttr>(res); |
| } |
| return nullptr; |
| } |
| |
| return nullptr; |
| } |
| |
| struct AddOpCanon final : OpRewritePattern<mlir::stablehlo::AddOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, |
| PatternRewriter &rewriter) const override { |
| auto type = dyn_cast<RankedTensorType>(op.getType()); |
| if (!type) |
| return failure(); |
| |
| Value lhs = op.getLhs(); |
| Value rhs = op.getRhs(); |
| |
| if (matchPattern(lhs, m_Zero())) { |
| rewriter.replaceOp(op, rhs); |
| return success(); |
| } |
| |
| if (matchPattern(rhs, m_AnyOf(m_Zero(), m_NegZeroFloat()))) { |
| rewriter.replaceOp(op, lhs); |
| return success(); |
| } |
| |
| TypedAttr lhsAttr; |
| matchPattern(lhs, m_Constant(&lhsAttr)); |
| |
| TypedAttr rhsAttr; |
| matchPattern(rhs, m_Constant(&rhsAttr)); |
| |
| // The canonical form has the constant operand as the RHS. |
| if (isa<IntegerType>(type.getElementType()) && lhsAttr && !rhsAttr) { |
| rewriter.modifyOpInPlace( |
| op, [op, lhs, rhs] { op->setOperands(ValueRange{rhs, lhs}); }); |
| return success(); |
| } |
| |
| if (lhsAttr && rhsAttr) { |
| if (TypedAttr res = |
| foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::plus<>{})) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res); |
| return success(); |
| } |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| struct SubtractOpCanon final : OpRewritePattern<mlir::stablehlo::SubtractOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, |
| PatternRewriter &rewriter) const override { |
| auto type = dyn_cast<RankedTensorType>(op.getType()); |
| if (!type) |
| return failure(); |
| |
| Value lhs = op.getLhs(); |
| Value rhs = op.getRhs(); |
| |
| if (isa<IntegerType>(type.getElementType()) && lhs == rhs) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>( |
| op, rewriter.getZeroAttr(op.getType())); |
| return success(); |
| } |
| |
| // Subtraction of 0. |
| if (matchPattern(rhs, m_AnyOf(m_Zero(), m_PosZeroFloat()))) { |
| rewriter.replaceOp(op, lhs); |
| return success(); |
| } |
| |
| TypedAttr lhsAttr; |
| matchPattern(lhs, m_Constant(&lhsAttr)); |
| |
| TypedAttr rhsAttr; |
| matchPattern(rhs, m_Constant(&rhsAttr)); |
| |
| if (lhsAttr && rhsAttr) { |
| if (TypedAttr res = |
| foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::minus<>{})) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res); |
| return success(); |
| } |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| struct MulOpCanon final : OpRewritePattern<mlir::stablehlo::MulOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, |
| PatternRewriter &rewriter) const override { |
| auto type = dyn_cast<RankedTensorType>(op.getType()); |
| if (!type) |
| return failure(); |
| |
| Value lhs = op.getLhs(); |
| Value rhs = op.getRhs(); |
| |
| // Multiplication by 0. This fold is not trivial for floats in presence of |
| // NaN values. |
| if (matchPattern(lhs, m_Zero())) { |
| rewriter.replaceOp(op, lhs); |
| return success(); |
| } |
| if (matchPattern(rhs, m_Zero())) { |
| rewriter.replaceOp(op, rhs); |
| return success(); |
| } |
| |
| // Multiplication by 1. |
| if (matchPattern(rhs, m_One())) { |
| rewriter.replaceOp(op, lhs); |
| return success(); |
| } |
| |
| TypedAttr lhsAttr; |
| matchPattern(lhs, m_Constant(&lhsAttr)); |
| |
| TypedAttr rhsAttr; |
| matchPattern(rhs, m_Constant(&rhsAttr)); |
| |
| // The canonical form has the constant operand as the RHS. |
| if (isa<IntegerType>(type.getElementType()) && lhsAttr && !rhsAttr) { |
| rewriter.modifyOpInPlace( |
| op, [op, lhs, rhs] { op->setOperands(ValueRange{rhs, lhs}); }); |
| return success(); |
| } |
| |
| if (lhsAttr && rhsAttr) { |
| if (TypedAttr res = |
| foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{})) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res); |
| return success(); |
| } |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| static mlir::stablehlo::ComparisonDirection |
| invertDirection(mlir::stablehlo::ComparisonDirection direction) { |
| using mlir::stablehlo::ComparisonDirection; |
| |
| switch (direction) { |
| case ComparisonDirection::EQ: |
| return ComparisonDirection::EQ; |
| case ComparisonDirection::GE: |
| return ComparisonDirection::LE; |
| case ComparisonDirection::LE: |
| return ComparisonDirection::GE; |
| case ComparisonDirection::GT: |
| return ComparisonDirection::LT; |
| case ComparisonDirection::LT: |
| return ComparisonDirection::GT; |
| case ComparisonDirection::NE: |
| return ComparisonDirection::NE; |
| } |
| |
| llvm_unreachable("Unhandled case"); |
| } |
| |
| static APInt calculateComp(mlir::stablehlo::ComparisonType kind, |
| mlir::stablehlo::ComparisonDirection direction, |
| const APInt &lhs, const APInt &rhs) { |
| using mlir::stablehlo::ComparisonDirection; |
| using mlir::stablehlo::ComparisonType; |
| assert(llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED}, |
| kind) && |
| "Not an integer comparison"); |
| |
| auto asBit = [](bool value) { |
| return value ? APInt::getAllOnes(1) : APInt::getZero(1); |
| }; |
| |
| // Signed comparison. |
| if (kind == ComparisonType::SIGNED) { |
| switch (direction) { |
| case ComparisonDirection::EQ: |
| return asBit(lhs == rhs); |
| case ComparisonDirection::GE: |
| return asBit(lhs.sge(rhs)); |
| case ComparisonDirection::GT: |
| return asBit(lhs.sgt(rhs)); |
| case ComparisonDirection::LE: |
| return asBit(lhs.sle(rhs)); |
| case ComparisonDirection::LT: |
| return asBit(lhs.slt(rhs)); |
| case ComparisonDirection::NE: |
| return asBit(lhs != rhs); |
| } |
| } |
| |
| // Unsigned comparison. |
| switch (direction) { |
| case ComparisonDirection::EQ: |
| return asBit(lhs == rhs); |
| case ComparisonDirection::GE: |
| return asBit(lhs.uge(rhs)); |
| case ComparisonDirection::GT: |
| return asBit(lhs.ugt(rhs)); |
| case ComparisonDirection::LE: |
| return asBit(lhs.ule(rhs)); |
| case ComparisonDirection::LT: |
| return asBit(lhs.ult(rhs)); |
| case ComparisonDirection::NE: |
| return asBit(lhs != rhs); |
| } |
| |
| llvm_unreachable("Unhandled case"); |
| } |
| |
| struct CompareOpCanon final : OpRewritePattern<mlir::stablehlo::CompareOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::CompareOp op, |
| PatternRewriter &rewriter) const override { |
| auto type = dyn_cast<RankedTensorType>(op.getType()); |
| if (!type) |
| return failure(); |
| |
| // Bail out on non-integer comparison. |
| // TODO: Support more comparison types. |
| using mlir::stablehlo::ComparisonType; |
| std::optional<ComparisonType> compType = op.getCompareType(); |
| if (!compType || |
| !llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED}, |
| *compType)) { |
| return failure(); |
| } |
| |
| using mlir::stablehlo::ComparisonDirection; |
| ComparisonDirection direction = op.getComparisonDirection(); |
| Value lhs = op.getLhs(); |
| Value rhs = op.getRhs(); |
| |
| if (lhs == rhs) { |
| switch (direction) { |
| case ComparisonDirection::EQ: |
| case ComparisonDirection::GE: |
| case ComparisonDirection::LE: { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>( |
| op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true))); |
| return success(); |
| } |
| case ComparisonDirection::GT: |
| case ComparisonDirection::LT: |
| case ComparisonDirection::NE: { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>( |
| op, rewriter.getZeroAttr(type)); |
| return success(); |
| } |
| } |
| llvm_unreachable("Unhandled case"); |
| } |
| |
| TypedAttr lhsAttr; |
| matchPattern(lhs, m_Constant(&lhsAttr)); |
| |
| TypedAttr rhsAttr; |
| matchPattern(rhs, m_Constant(&rhsAttr)); |
| |
| // The canonical form has the constant operand as the RHS. |
| if (lhsAttr && !rhsAttr) { |
| rewriter.modifyOpInPlace(op, [&op, direction, lhs, rhs] { |
| op.setComparisonDirection(invertDirection(direction)); |
| op->setOperands(ValueRange{rhs, lhs}); |
| }); |
| return success(); |
| } |
| |
| if (lhsAttr && rhsAttr) { |
| if (Attribute res = |
| constFoldBinaryOp<IntegerAttr, IntegerAttr::ValueType, void>( |
| ArrayRef<Attribute>({lhsAttr, rhsAttr}), op.getType(), |
| [direction, kind = *compType](const APInt &a, |
| const APInt &b) { |
| return calculateComp(kind, direction, a, b); |
| })) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res); |
| return success(); |
| } |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| struct SelectOpCanon final : OpRewritePattern<mlir::stablehlo::SelectOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op, |
| PatternRewriter &rewriter) const override { |
| auto type = dyn_cast<RankedTensorType>(op.getType()); |
| if (!type) |
| return failure(); |
| |
| Value trueVal = op.getOnTrue(); |
| Value falseVal = op.getOnFalse(); |
| |
| // Eliminate select with two identical outcomes. |
| if (trueVal == falseVal) { |
| rewriter.replaceOp(op, trueVal); |
| return success(); |
| } |
| |
| // Simplify when the condition is a constant. |
| Value pred = op.getPred(); |
| ElementsAttr cond; |
| if (!matchPattern(pred, m_Constant(&cond))) { |
| return failure(); |
| } |
| |
| // Handle splat predicate and select either `trueVal` or `falseVal`. |
| if (cond.isSplat()) { |
| rewriter.replaceOp(op, cond.getSplatValue<bool>() ? trueVal : falseVal); |
| return success(); |
| } |
| |
| // Handle elementwise selection when both outcomes are also constants. This |
| // will create a new, likely non-splat constant. |
| if (cond.getNumElements() > kFoldOpEltLimit) |
| return failure(); |
| |
| ElementsAttr trueAttr; |
| if (!matchPattern(trueVal, m_Constant(&trueAttr))) |
| return failure(); |
| |
| ElementsAttr falseAttr; |
| if (!matchPattern(falseVal, m_Constant(&falseAttr))) |
| return failure(); |
| |
| SmallVector<Attribute> newValues; |
| newValues.reserve(cond.getNumElements()); |
| for (auto [condElem, trueElem, falseElem] : llvm::zip_equal( |
| cond.getValues<bool>(), trueAttr.getValues<Attribute>(), |
| falseAttr.getValues<Attribute>())) { |
| newValues.push_back(condElem ? trueElem : falseElem); |
| } |
| |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>( |
| op, DenseElementsAttr::get(type, newValues)); |
| return success(); |
| } |
| }; |
| |
| struct BroadcastInDimOpCanon final |
| : OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, |
| PatternRewriter &rewriter) const override { |
| auto type = dyn_cast<RankedTensorType>(op.getType()); |
| if (!type) |
| return failure(); |
| |
| Value operand = op.getOperand(); |
| auto operandTy = dyn_cast<RankedTensorType>(operand.getType()); |
| if (!operandTy) |
| return failure(); |
| |
| // Fold when broadcast is a noop. |
| auto dims = op.getBroadcastDimensions(); |
| bool isDimsIota = isIotaRange(dims); |
| if (type == operandTy && isDimsIota) { |
| rewriter.replaceOp(op, operand); |
| return success(); |
| } |
| |
| // Handle splat broadcasts. |
| if (SplatElementsAttr cstAttr; |
| matchPattern(operand, m_Constant(&cstAttr))) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>( |
| op, SplatElementsAttr::get(op.getType(), |
| cstAttr.getSplatValue<Attribute>())); |
| return success(); |
| } |
| |
| auto bsDimIndices = dims; |
| if (operandTy.hasStaticShape() && type.hasStaticShape() && |
| type.getNumElements() == operandTy.getNumElements()) { |
| // BroadcastInDim equivalent to reshape. |
| if (isDimsIota) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, type, |
| operand); |
| return success(); |
| } |
| // BroadcastInDim equivalent to transpose. |
| if (type.getRank() == operandTy.getRank()) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::TransposeOp>( |
| op, type, operand, dims); |
| return success(); |
| } |
| } |
| |
| // Eliminate redundant nested BroadcastInDim. |
| if (auto broadcastInDimOp = |
| operand.getDefiningOp<mlir::stablehlo::BroadcastInDimOp>()) { |
| auto newIndices = |
| rewriter.getDenseI64ArrayAttr(llvm::to_vector(llvm::map_range( |
| broadcastInDimOp.getBroadcastDimensions(), |
| [&bsDimIndices](int64_t dim) { return bsDimIndices[dim]; }))); |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>( |
| op, type, broadcastInDimOp.getOperand(), newIndices); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| struct ConcatenateOpCanon final |
| : OpRewritePattern<mlir::stablehlo::ConcatenateOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, |
| PatternRewriter &rewriter) const override { |
| auto type = dyn_cast<RankedTensorType>(op.getType()); |
| if (!type || !type.hasStaticShape()) |
| return failure(); |
| |
| size_t numElems = type.getNumElements(); |
| if (numElems > kFoldOpEltLimit) |
| return failure(); |
| |
| // Fold concatenate when all inputs are constants. |
| OperandRange inputs = op.getInputs(); |
| SmallVector<DenseElementsAttr> constants(inputs.size()); |
| for (auto [input, constant] : llvm::zip_equal(inputs, constants)) { |
| if (!matchPattern(input, m_Constant(&constant))) { |
| return failure(); |
| } |
| } |
| |
| uint64_t axis = op.getDimension(); |
| ArrayRef<int64_t> shape = type.getShape(); |
| int64_t topSize = std::accumulate(shape.begin(), shape.begin() + axis, |
| int64_t{1}, std::multiplies<>{}); |
| |
| SmallVector<Attribute> newElems; |
| newElems.reserve(numElems); |
| |
| for (int64_t i = 0; i != topSize; ++i) { |
| for (ElementsAttr attr : constants) { |
| size_t bottomSize = attr.getNumElements() / topSize; |
| auto begin = attr.value_begin<Attribute>() + (i * bottomSize); |
| newElems.append(begin, begin + bottomSize); |
| } |
| } |
| |
| assert(newElems.size() == numElems); |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>( |
| op, DenseElementsAttr::get(op.getType(), newElems)); |
| return success(); |
| } |
| }; |
| |
| struct ConvertOpCanon final : OpRewritePattern<mlir::stablehlo::ConvertOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::ConvertOp op, |
| PatternRewriter &rewriter) const override { |
| // Check if this convert is a noop. |
| if (op.getOperand().getType() != op.getType()) |
| return failure(); |
| |
| rewriter.replaceOp(op, op.getOperand()); |
| return success(); |
| } |
| }; |
| |
| /// Does the same as PatternRewriter::replaceOpWithNewOp, but with a twist. |
| /// |
| /// Sometimes, we want to replace an op with a new op and simultaneously refine |
| /// the result type from a dynamically-shaped type to a statically-shaped type. |
| /// (Search for usages of this function for examples). |
| // |
| /// Oftentimes, this works just fine because HLO is designed to accommodate |
| /// this kind of type refinements. But sometimes, this doesn't work - when |
| /// the op is used outside of the HLO dialect (e.g. in func.return). In these |
| /// cases, we insert a tensor.cast to smooth things out. |
| template <typename OpTy, typename... Args> |
| static OpTy refineOpWithNewOp(PatternRewriter &rewriter, Operation *op, |
| Args &&...args) { |
| auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...); |
| |
| llvm::SmallVector<Value> replacementResults; |
| assert(op->getNumResults() == newOp->getNumResults() && |
| "replacement op doesn't match results of original op"); |
| for (auto [opResult, newOpResult] : |
| llvm::zip(op->getResults(), newOp->getResults())) { |
| Value replacementResult = newOpResult; |
| if (llvm::any_of(opResult.getUsers(), [&](Operation *user) { |
| return user->getDialect() != op->getDialect(); |
| })) { |
| replacementResult = rewriter.create<mlir::tensor::CastOp>( |
| op->getLoc(), opResult.getType(), newOpResult); |
| } |
| replacementResults.push_back(replacementResult); |
| } |
| |
| rewriter.replaceOp(op, replacementResults); |
| return newOp; |
| } |
| |
| /// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary |
| /// BroadcastInDimOp. |
| struct DynamicBroadcastInDimOpNotActuallyDynamic final |
| : OpRewritePattern<mlir::stablehlo::DynamicBroadcastInDimOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp op, |
| PatternRewriter &rewriter) const override { |
| auto type = dyn_cast<RankedTensorType>(op.getType()); |
| auto operandType = dyn_cast<RankedTensorType>(op.getOperand().getType()); |
| if (!type || !operandType || !operandType.hasStaticShape()) { |
| return rewriter.notifyMatchFailure(op, "requires operand static shape"); |
| } |
| |
| // output has static shape, replace with broadcast_in_dim |
| if (type.hasStaticShape()) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>( |
| op, type, op.getOperand(), op.getBroadcastDimensionsAttr()); |
| return success(); |
| } |
| |
| // output_dimensions are constant, set output shape with output_dimensions, |
| // then replace with broadcast_in_dim |
| auto *outputDimOp = op.getOutputDimensions().getDefiningOp(); |
| if (outputDimOp && outputDimOp->hasTrait<mlir::OpTrait::ConstantLike>()) { |
| DenseIntElementsAttr shapeAttr; |
| if (matchPattern(outputDimOp, m_Constant(&shapeAttr))) { |
| SmallVector<int64_t> outputShape; |
| for (APInt shape : shapeAttr.getValues<APInt>()) { |
| outputShape.push_back(shape.getZExtValue()); |
| } |
| refineOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>( |
| rewriter, op, |
| RankedTensorType::get(outputShape, type.getElementType()), |
| op.getOperand(), op.getBroadcastDimensionsAttr()); |
| return success(); |
| } |
| } |
| return rewriter.notifyMatchFailure( |
| op, "requires output static shape or constant broadcast dimensions"); |
| } |
| }; |
| |
| struct ChainedDynamicBroadcastInDimCanonicalization final |
| : OpRewritePattern<mlir::stablehlo::DynamicBroadcastInDimOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp bcast, |
| PatternRewriter &rewriter) const override { |
| auto precedingBcast = |
| bcast.getOperand() |
| .getDefiningOp<mlir::stablehlo::DynamicBroadcastInDimOp>(); |
| if (!precedingBcast) |
| return failure(); |
| |
| // Compose broadcast dimensions. |
| SmallVector<int64_t> composition; |
| for (int64_t precedingDim : precedingBcast.getBroadcastDimensions()) { |
| composition.push_back(bcast.getBroadcastDimensions()[precedingDim]); |
| } |
| auto composedBcastDims = rewriter.getDenseI64ArrayAttr(composition); |
| |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::DynamicBroadcastInDimOp>( |
| bcast, bcast.getType(), precedingBcast.getOperand(), |
| bcast.getOutputDimensions(), composedBcastDims); |
| return success(); |
| } |
| }; |
| |
| // If all dimensions are known to be nonexpanding from the attribute, replace |
| // the dynamic broadcast with a cast. |
| struct DynamicBroadcastInDimAllDimsNonExpanding final |
| : OpRewritePattern<mlir::stablehlo::DynamicBroadcastInDimOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::DynamicBroadcastInDimOp op, |
| PatternRewriter &rewriter) const override { |
| auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType()); |
| if (!resultType) { |
| return rewriter.notifyMatchFailure(op, "requires ranked result type"); |
| } |
| |
| if (!op.getKnownNonexpandingDimensions() || |
| op.getKnownNonexpandingDimensions()->size() != resultType.getRank()) { |
| return rewriter.notifyMatchFailure( |
| op, "known_nonexpanding_dimensions don't cover all output dims"); |
| } |
| |
| auto cast = rewriter.createOrFold<tensor::CastOp>(op.getLoc(), resultType, |
| op.getOperand()); |
| rewriter.replaceOp(op, cast); |
| return success(); |
| } |
| }; |
| |
| struct NoopReduceOpCanon final : OpRewritePattern<mlir::stablehlo::ReduceOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op, |
| PatternRewriter &rewriter) const override { |
| // No dimensions to reduce. |
| if (op.getDimensions().empty()) { |
| rewriter.replaceOp(op, op.getInputs()); |
| return success(); |
| } |
| |
| // If all returned values in the ReduceOp region exists outside the |
| // region, replace the ReduceOp with those values. |
| if (auto retOp = dyn_cast<mlir::stablehlo::ReturnOp>( |
| op.getBody().front().getTerminator())) { |
| Region *retRegion = retOp->getParentRegion(); |
| if (llvm::any_of(retOp.getResults(), [retRegion](Value result) { |
| return result.getParentRegion() == retRegion; |
| })) { |
| return failure(); |
| } |
| |
| rewriter.replaceOp(op, retOp.getResults()); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| struct EmptyReduceOpCanon final : OpRewritePattern<mlir::stablehlo::ReduceOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op, |
| PatternRewriter &rewriter) const override { |
| // We require all reduce shapes to be the same, up to the element types, so |
| // we can just the first operand and the first result as a representative. |
| auto elemTy = dyn_cast<RankedTensorType>(op.getInputs().getType().front()); |
| if (!elemTy) { |
| return rewriter.notifyMatchFailure(op.getLoc(), |
| "unranked input unsupported"); |
| } |
| |
| if (!llvm::is_contained(elemTy.getShape(), 0)) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| DenseI64ArrayAttr empty = rewriter.getDenseI64ArrayAttr({}); |
| if (elemTy.hasStaticShape()) { |
| SmallVector<Value> broadcasts(op.getNumResults()); |
| for (auto [bcast, init, outTy] : llvm::zip_equal( |
| broadcasts, op.getInitValues(), op.getResultTypes())) { |
| bcast = rewriter.create<mlir::stablehlo::BroadcastInDimOp>(loc, outTy, |
| init, empty); |
| } |
| rewriter.replaceOp(op, broadcasts); |
| return success(); |
| } |
| |
| SmallVector<Value> shapes; |
| if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), shapes))) { |
| return failure(); |
| } |
| |
| SmallVector<Value> broadcasts(op.getNumResults()); |
| for (auto [bcast, init, shape, outTy] : llvm::zip_equal( |
| broadcasts, op.getInitValues(), shapes, op.getResultTypes())) { |
| bcast = rewriter.create<mlir::stablehlo::DynamicBroadcastInDimOp>( |
| loc, outTy, init, shape, empty); |
| } |
| rewriter.replaceOp(op, broadcasts); |
| return success(); |
| } |
| }; |
| |
| struct DynamicReshapeOpCanon final |
| : OpRewritePattern<mlir::stablehlo::DynamicReshapeOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::DynamicReshapeOp op, |
| PatternRewriter &rewriter) const override { |
| // This is a noop when the output type is already a static shape. |
| auto type = dyn_cast<RankedTensorType>(op.getType()); |
| if (!type || !type.hasStaticShape()) |
| return failure(); |
| |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, type, |
| op.getOperand()); |
| return success(); |
| } |
| }; |
| |
| struct GetTupleElementOpCanon final |
| : OpRewritePattern<mlir::stablehlo::GetTupleElementOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::GetTupleElementOp op, |
| PatternRewriter &rewriter) const override { |
| auto constructor = |
| op.getOperand().getDefiningOp<mlir::stablehlo::TupleOp>(); |
| if (!constructor) |
| return failure(); |
| |
| Value result = constructor.getOperand(op.getIndex()); |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| struct RealOpCanon final : OpRewritePattern<mlir::stablehlo::RealOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::RealOp op, |
| PatternRewriter &rewriter) const override { |
| auto complex = op.getOperand().getDefiningOp<mlir::stablehlo::ComplexOp>(); |
| if (!complex) |
| return failure(); |
| |
| rewriter.replaceOp(op, complex.getLhs()); |
| return success(); |
| } |
| }; |
| |
| struct ImagOpCanon final : OpRewritePattern<mlir::stablehlo::ImagOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::ImagOp op, |
| PatternRewriter &rewriter) const override { |
| auto complex = op.getOperand().getDefiningOp<mlir::stablehlo::ComplexOp>(); |
| if (!complex) |
| return failure(); |
| |
| rewriter.replaceOp(op, complex.getRhs()); |
| return success(); |
| } |
| }; |
| |
| struct GetDimensionSizeOpCanon final |
| : OpRewritePattern<mlir::stablehlo::GetDimensionSizeOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::GetDimensionSizeOp op, |
| PatternRewriter &rewriter) const override { |
| // Fold get_dimension_size when the queried dim is statically known. |
| auto tensorTy = dyn_cast<RankedTensorType>(op.getOperand().getType()); |
| if (!tensorTy) |
| return failure(); |
| |
| int64_t dimSize = tensorTy.getDimSize(op.getDimension()); |
| if (dimSize < 0) |
| return failure(); |
| |
| auto elemTy = cast<IntegerType>(op.getType().getElementType()); |
| IntegerAttr elemVal = rewriter.getIntegerAttr(elemTy, dimSize); |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>( |
| op, DenseElementsAttr::get(op.getType(), elemVal)); |
| return success(); |
| } |
| }; |
| |
| /// Converts gather ops to slice ops in case we have a single set of constant |
| /// indices. |
| struct GatherOpCanon final : OpRewritePattern<mlir::stablehlo::GatherOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::GatherOp gather, |
| PatternRewriter &rewriter) const override { |
| DenseIntElementsAttr index; |
| if (!matchPattern(gather.getStartIndices(), m_Constant(&index))) { |
| return failure(); |
| } |
| |
| mlir::stablehlo::GatherDimensionNumbersAttr dnums = |
| gather.getDimensionNumbers(); |
| if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1) { |
| return failure(); |
| } |
| |
| // TODO: Remove when the verifier catches this case what is |
| // invalid if all previous condition holds. |
| if (index.getNumElements() != |
| static_cast<int64_t>(dnums.getStartIndexMap().size())) { |
| return failure(); |
| } |
| |
| auto operandType = |
| dyn_cast<RankedTensorType>(gather->getOperand(0).getType()); |
| if (!operandType || !operandType.hasStaticShape()) |
| return failure(); |
| |
| auto sliceEnd = llvm::to_vector(gather.getSliceSizes()); |
| SmallVector<int64_t> sliceStart(sliceEnd.size(), 0); |
| for (auto [mapIndex, value] : |
| llvm::zip_equal(dnums.getStartIndexMap(), index.getValues<APInt>())) { |
| // Clamp the indices within bounds to faithfully mirror gather semantics. |
| int64_t offset = |
| std::clamp(value.getSExtValue(), static_cast<int64_t>(0), |
| operandType.getDimSize(mapIndex) - sliceEnd[mapIndex]); |
| sliceStart[mapIndex] += offset; |
| sliceEnd[mapIndex] += offset; |
| } |
| |
| SmallVector<int64_t> sliceStride(sliceEnd.size(), 1); |
| SmallVector<int64_t> sliceShape(sliceEnd.size()); |
| for (auto [shapeElem, startElem, endElem] : |
| llvm::zip_equal(sliceShape, sliceStart, sliceEnd)) { |
| shapeElem = endElem - startElem; |
| } |
| |
| Type elementType = gather.getType().getElementType(); |
| auto sliceType = RankedTensorType::get(sliceShape, elementType); |
| Value result = rewriter.create<mlir::stablehlo::SliceOp>( |
| gather.getLoc(), sliceType, gather.getOperand(), |
| rewriter.getDenseI64ArrayAttr(sliceStart), |
| rewriter.getDenseI64ArrayAttr(sliceEnd), |
| rewriter.getDenseI64ArrayAttr(sliceStride)); |
| |
| ArrayRef<int64_t> collapsedSliceDims = dnums.getCollapsedSliceDims(); |
| if (!collapsedSliceDims.empty()) { |
| llvm::SmallVector<int64_t> reshapeShape; |
| for (auto [idx, dim] : llvm::enumerate(sliceShape)) { |
| if (!llvm::is_contained(collapsedSliceDims, idx)) { |
| reshapeShape.push_back(dim); |
| } |
| } |
| auto reshapeType = RankedTensorType::get(reshapeShape, elementType); |
| result = rewriter.create<mlir::stablehlo::ReshapeOp>(gather.getLoc(), |
| reshapeType, result); |
| } |
| |
| result.setType(gather.getType()); |
| rewriter.replaceOp(gather, result); |
| return success(); |
| } |
| }; |
| |
| struct ReshapeOpCanon final : OpRewritePattern<mlir::stablehlo::ReshapeOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op, |
| PatternRewriter &rewriter) const override { |
| // Fold noop reshape. |
| if (op.getType() == op.getOperand().getType()) { |
| rewriter.replaceOp(op, op.getOperand()); |
| return success(); |
| } |
| |
| // Fold reshape of a constant. |
| ElementsAttr cstAttr; |
| if (!matchPattern(op.getOperand(), m_Constant(&cstAttr))) { |
| return failure(); |
| } |
| |
| if (auto splat = dyn_cast<SplatElementsAttr>(cstAttr)) { |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>( |
| op, SplatElementsAttr::get(op.getType(), |
| splat.getSplatValue<Attribute>())); |
| return success(); |
| } |
| |
| auto elements = |
| llvm::to_vector_of<Attribute>(cstAttr.getValues<Attribute>()); |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>( |
| op, DenseElementsAttr::get(op.getType(), elements)); |
| return success(); |
| } |
| }; |
| |
| struct MergeConsecutiveReshapes final |
| : OpRewritePattern<mlir::stablehlo::ReshapeOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::ReshapeOp op, |
| PatternRewriter &rewriter) const override { |
| // Fold noop reshape. |
| auto operand = op.getOperand(); |
| if (op.getType() == operand.getType()) { |
| rewriter.replaceOp(op, op.getOperand()); |
| return success(); |
| } |
| |
| // Fold reshape(reshape(x)). |
| auto reshapeOp = operand.getDefiningOp<mlir::stablehlo::ReshapeOp>(); |
| if (!reshapeOp) { |
| return rewriter.notifyMatchFailure( |
| op, "requires defining op of operand to be Reshape"); |
| } |
| |
| op.setOperand(reshapeOp->getOperand(0)); |
| return success(); |
| } |
| }; |
| |
| struct TransposeIsReshape final |
| : OpRewritePattern<mlir::stablehlo::TransposeOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op, |
| PatternRewriter &rewriter) const override { |
| auto input = op.getOperand(); |
| auto permutation = op.getPermutation(); |
| |
| if (isIotaRange(permutation)) { |
| rewriter.replaceOp(op, op.getOperand()); |
| return success(); |
| } |
| |
| auto inputTy = dyn_cast<RankedTensorType>(input.getType()); |
| if (!inputTy || !inputTy.hasStaticShape() || |
| !op.getType().hasStaticShape()) { |
| return rewriter.notifyMatchFailure( |
| op, "requires input/output to be of a statically-shaped ranked " |
| "tensor type"); |
| } |
| |
| SmallVector<int64_t> permValues(permutation); |
| |
| SmallVector<int64_t> nonZeroPerms; |
| nonZeroPerms.reserve(permValues.size()); |
| for (auto idx : permValues) { |
| auto sz = inputTy.getDimSize(idx); |
| if (sz != 1) |
| nonZeroPerms.push_back(idx); |
| } |
| |
| for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) |
| if (nonZeroPerms[i - 1] > nonZeroPerms[i]) |
| return rewriter.notifyMatchFailure(op, "memory layout change"); |
| |
| rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, op.getType(), |
| input); |
| return success(); |
| } |
| }; |
| |
| /// Check if a `t` is a tensor with zero extents. |
| static std::optional<RankedTensorType> isZeroExtent(Type t) { |
| auto type = dyn_cast<RankedTensorType>(t); |
| if (type && type.hasStaticShape() && |
| llvm::any_of(type.getShape(), [](int64_t s) { return s == 0; })) { |
| return type; |
| } |
| return std::nullopt; |
| } |
| |
| // Replace instances of zero extent tensors with empty tensors of the same |
| // type. |
| struct ZeroExtentTensorCanon final : RewritePattern { |
| ZeroExtentTensorCanon(MLIRContext *context, PatternBenefit benefit) |
| : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| auto loc = op->getLoc(); |
| |
| if (!isa_and_present<mlir::stablehlo::StablehloDialect>(op->getDialect())) { |
| return rewriter.notifyMatchFailure(op, "not stablehlo"); |
| } |
| |
| // If the result is a zero-extent tensor, replace the whole op with an empty |
| // tensor. |
| bool didUpdate = false; |
| for (auto result : op->getResults()) { |
| auto resultType = isZeroExtent(result.getType()); |
| if (!resultType || result.use_empty()) { |
| continue; |
| } |
| rewriter.replaceAllUsesWith(result, rewriter.create<tensor::EmptyOp>( |
| loc, resultType->getShape(), |
| resultType->getElementType())); |
| didUpdate = true; |
| } |
| |
| // If one of the operands is a zero-extent tensor, replace the operand with |
| // an empty tensor. |
| for (OpOperand &operand : op->getOpOperands()) { |
| auto operandType = isZeroExtent(operand.get().getType()); |
| if (!operandType || operand.get().getDefiningOp<tensor::EmptyOp>()) { |
| continue; |
| } |
| Operation *owner = operand.getOwner(); |
| int operandNum = operand.getOperandNumber(); |
| auto emptyTensorOp = rewriter.create<tensor::EmptyOp>( |
| loc, operandType->getShape(), operandType->getElementType()); |
| rewriter.modifyOpInPlace( |
| owner, [&]() { owner->setOperand(operandNum, emptyTensorOp); }); |
| didUpdate = true; |
| } |
| return success(didUpdate); |
| } |
| }; |
| |
| struct ReorderElementwiseAndShapeOp final |
| : OpTraitRewritePattern<OpTrait::Elementwise> { |
| using OpTraitRewritePattern::OpTraitRewritePattern; |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| if (op->getOperands().size() != 1) { |
| return rewriter.notifyMatchFailure(op, "expected to be unary"); |
| } |
| |
| auto definingOp = op->getOperand(0).getDefiningOp(); |
| if (!definingOp) { |
| return rewriter.notifyMatchFailure( |
| op, "expected to have an op before elementise op"); |
| } |
| |
| if (!isa<mlir::stablehlo::ReshapeOp>(definingOp) && |
| !isa<mlir::stablehlo::TransposeOp>(definingOp) && |
| !isa<mlir::stablehlo::BroadcastOp>(definingOp)) { |
| return rewriter.notifyMatchFailure( |
| op, "defining operation of unexpected type"); |
| } |
| |
| // Only reorder if the defining op has no other uses. |
| if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) { |
| return rewriter.notifyMatchFailure(op, "operation has more than one use"); |
| } |
| |
| Value input = definingOp->getOperand(0); |
| Value result = op->getResult(0); |
| auto intermediateType = cast<ShapedType>(input.getType()) |
| .clone(getElementTypeOrSelf(result.getType())); |
| |
| // Reorder the operation and rewire the inputs/outputs. |
| op->moveBefore(definingOp); |
| definingOp->getResult(0).setType(result.getType()); |
| rewriter.replaceAllUsesWith(result, definingOp->getResult(0)); |
| result.setType(intermediateType); |
| op->setOperands(input); |
| definingOp->setOperands(result); |
| return success(); |
| } |
| }; |
| |
| struct StableHLOCanonicalize final |
| : impl::StableHLOCanonicalizeBase<StableHLOCanonicalize> { |
| void runOnOperation() override { |
| MLIRContext *ctx = &getContext(); |
| RewritePatternSet patterns(ctx); |
| populateCanonicalizationPatterns(ctx, &patterns); |
| if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| std::move(patterns)))) { |
| signalPassFailure(); |
| } |
| } |
| |
| void getDependentDialects(DialectRegistry ®istry) const final { |
| registry.insert<tensor::TensorDialect>(); |
| } |
| }; |
| |
| } // namespace |
| void populateCanonicalizationPatterns(MLIRContext *context, |
| RewritePatternSet *patterns, |
| PatternBenefit benefit) { |
| patterns->add< |
| // Arithmetic ops. |
| AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon, SelectOpCanon, |
| // Complex ops. |
| RealOpCanon, ImagOpCanon, |
| // Query ops. |
| GetDimensionSizeOpCanon, GetTupleElementOpCanon, |
| // Broadcast ops. |
| BroadcastInDimOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic, |
| ChainedDynamicBroadcastInDimCanonicalization, |
| DynamicBroadcastInDimAllDimsNonExpanding, |
| // Reduce op. |
| NoopReduceOpCanon, EmptyReduceOpCanon, |
| // Shape manipulation(-ish) ops. |
| ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon, GatherOpCanon, |
| ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape, |
| // Types. |
| ZeroExtentTensorCanon>(context, benefit); |
| patterns->add<ReorderElementwiseAndShapeOp>(context); |
| } |
| } // namespace mlir::iree_compiler::stablehlo |