| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include <algorithm> |
| |
| #include "iree/compiler/Dialect/VM/IR/VMDialect.h" |
| #include "iree/compiler/Dialect/VM/IR/VMOps.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/SymbolTable.h" |
| #include "mlir/Support/LogicalResult.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace VM { |
| |
| //===----------------------------------------------------------------------===// |
| // Utilities |
| //===----------------------------------------------------------------------===// |
| // TODO(benvanik): share these among dialects? |
| |
| namespace { |
| |
| /// Creates a constant zero attribute matching the given type. |
| Attribute zeroOfType(Type type) { |
| return Builder(type.getContext()).getZeroAttr(type); |
| } |
| |
| /// Creates a constant one attribute matching the given type. |
| Attribute oneOfType(Type type) { |
| Builder builder(type.getContext()); |
| if (type.isa<FloatType>()) return builder.getFloatAttr(type, 1.0); |
| if (auto integerTy = type.dyn_cast<IntegerType>()) |
| return builder.getIntegerAttr(integerTy, APInt(integerTy.getWidth(), 1)); |
| if (type.isa<RankedTensorType, VectorType>()) { |
| auto vtType = type.cast<ShapedType>(); |
| auto element = oneOfType(vtType.getElementType()); |
| if (!element) return {}; |
| return DenseElementsAttr::get(vtType, element); |
| } |
| return {}; |
| } |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Structural ops |
| //===----------------------------------------------------------------------===// |
| |
| //===----------------------------------------------------------------------===// |
| // Globals |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// Converts global initializer functions that evaluate to a constant to a |
| /// specified initial value. |
| template <typename T> |
| struct InlineConstGlobalOpInitializer : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(T op, |
| PatternRewriter &rewriter) const override { |
| if (!op.initializer()) return failure(); |
| auto initializer = dyn_cast_or_null<FuncOp>( |
| SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue())); |
| if (!initializer) return failure(); |
| if (initializer.getBlocks().size() == 1 && |
| initializer.getBlocks().front().getOperations().size() == 2 && |
| isa<ReturnOp>(initializer.getBlocks().front().getOperations().back())) { |
| auto &primaryOp = initializer.getBlocks().front().getOperations().front(); |
| Attribute constResult; |
| if (matchPattern(primaryOp.getResult(0), m_Constant(&constResult))) { |
| rewriter.replaceOpWithNewOp<T>(op, op.sym_name(), op.is_mutable(), |
| op.type(), constResult); |
| return success(); |
| } |
| } |
| return failure(); |
| } |
| }; |
| |
| /// Drops initial_values from globals where the value is 0, as by default all |
| /// globals are zero-initialized upon module load. |
| template <typename T> |
| struct DropDefaultConstGlobalOpInitializer : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| LogicalResult matchAndRewrite(T op, |
| PatternRewriter &rewriter) const override { |
| if (!op.initial_value().hasValue()) return failure(); |
| auto value = op.initial_valueAttr().template cast<IntegerAttr>(); |
| if (value.getValue() != 0) return failure(); |
| rewriter.replaceOpWithNewOp<T>(op, op.sym_name(), op.is_mutable(), |
| op.type(), |
| llvm::to_vector<4>(op->getDialectAttrs())); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void GlobalI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<InlineConstGlobalOpInitializer<GlobalI32Op>, |
| DropDefaultConstGlobalOpInitializer<GlobalI32Op>>(context); |
| } |
| |
| void GlobalI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<InlineConstGlobalOpInitializer<GlobalI64Op>, |
| DropDefaultConstGlobalOpInitializer<GlobalI64Op>>(context); |
| } |
| |
| void GlobalRefOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<InlineConstGlobalOpInitializer<GlobalRefOp>>(context); |
| } |
| |
| namespace { |
| |
| /// Inlines immutable global constants into their loads. |
| template <typename LOAD_OP, typename GLOBAL_OP, typename CONST_OP, |
| typename CONST_ZERO_OP> |
| struct InlineConstGlobalLoadIntegerOp : public OpRewritePattern<LOAD_OP> { |
| using OpRewritePattern<LOAD_OP>::OpRewritePattern; |
| LogicalResult matchAndRewrite(LOAD_OP op, |
| PatternRewriter &rewriter) const override { |
| auto globalAttr = op->template getAttrOfType<FlatSymbolRefAttr>("global"); |
| auto globalOp = |
| op->template getParentOfType<VM::ModuleOp>() |
| .template lookupSymbol<GLOBAL_OP>(globalAttr.getValue()); |
| if (!globalOp) return failure(); |
| if (globalOp.is_mutable()) return failure(); |
| if (globalOp.initial_value()) { |
| rewriter.replaceOpWithNewOp<CONST_OP>( |
| op, globalOp.initial_value().getValue()); |
| } else { |
| rewriter.replaceOpWithNewOp<CONST_ZERO_OP>(op); |
| } |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void GlobalLoadI32Op::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<InlineConstGlobalLoadIntegerOp<GlobalLoadI32Op, GlobalI32Op, |
| ConstI32Op, ConstI32ZeroOp>>( |
| context); |
| } |
| |
| void GlobalLoadI64Op::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<InlineConstGlobalLoadIntegerOp<GlobalLoadI64Op, GlobalI64Op, |
| ConstI64Op, ConstI64ZeroOp>>( |
| context); |
| } |
| |
| namespace { |
| |
| /// Inlines immutable global constants into their loads. |
| struct InlineConstGlobalLoadRefOp : public OpRewritePattern<GlobalLoadRefOp> { |
| using OpRewritePattern<GlobalLoadRefOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(GlobalLoadRefOp op, |
| PatternRewriter &rewriter) const override { |
| auto globalAttr = op->getAttrOfType<FlatSymbolRefAttr>("global"); |
| auto globalOp = |
| op->getParentOfType<VM::ModuleOp>().lookupSymbol<GlobalRefOp>( |
| globalAttr.getValue()); |
| if (!globalOp) return failure(); |
| if (globalOp.is_mutable()) return failure(); |
| rewriter.replaceOpWithNewOp<ConstRefZeroOp>(op, op.getType()); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void GlobalLoadRefOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<InlineConstGlobalLoadRefOp>(context); |
| } |
| |
| namespace { |
| |
| template <typename INDIRECT, typename DIRECT> |
| struct PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> { |
| using OpRewritePattern<INDIRECT>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(INDIRECT op, |
| PatternRewriter &rewriter) const override { |
| if (auto addressOp = |
| dyn_cast_or_null<GlobalAddressOp>(op.global().getDefiningOp())) { |
| rewriter.replaceOpWithNewOp<DIRECT>(op, op.value().getType(), |
| addressOp.global()); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| |
| } // namespace |
| |
| void GlobalLoadIndirectI32Op::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert< |
| PropagateGlobalLoadAddress<GlobalLoadIndirectI32Op, GlobalLoadI32Op>>( |
| context); |
| } |
| |
| void GlobalLoadIndirectI64Op::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert< |
| PropagateGlobalLoadAddress<GlobalLoadIndirectI64Op, GlobalLoadI64Op>>( |
| context); |
| } |
| |
| void GlobalLoadIndirectRefOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert< |
| PropagateGlobalLoadAddress<GlobalLoadIndirectRefOp, GlobalLoadRefOp>>( |
| context); |
| } |
| |
| namespace { |
| |
| template <typename INDIRECT, typename DIRECT> |
| struct PropagateGlobalStoreAddress : public OpRewritePattern<INDIRECT> { |
| using OpRewritePattern<INDIRECT>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(INDIRECT op, |
| PatternRewriter &rewriter) const override { |
| if (auto addressOp = |
| dyn_cast_or_null<GlobalAddressOp>(op.global().getDefiningOp())) { |
| rewriter.replaceOpWithNewOp<DIRECT>(op, op.value(), addressOp.global()); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| |
| } // namespace |
| |
| void GlobalStoreIndirectI32Op::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert< |
| PropagateGlobalStoreAddress<GlobalStoreIndirectI32Op, GlobalStoreI32Op>>( |
| context); |
| } |
| |
| void GlobalStoreIndirectI64Op::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert< |
| PropagateGlobalStoreAddress<GlobalStoreIndirectI64Op, GlobalStoreI64Op>>( |
| context); |
| } |
| |
| void GlobalStoreIndirectRefOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert< |
| PropagateGlobalStoreAddress<GlobalStoreIndirectRefOp, GlobalStoreRefOp>>( |
| context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Constants |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult ConstI32Op::fold(ArrayRef<Attribute> operands) { return value(); } |
| |
| OpFoldResult ConstI64Op::fold(ArrayRef<Attribute> operands) { return value(); } |
| |
| OpFoldResult ConstI32ZeroOp::fold(ArrayRef<Attribute> operands) { |
| return IntegerAttr::get(getResult().getType(), 0); |
| } |
| |
| OpFoldResult ConstI64ZeroOp::fold(ArrayRef<Attribute> operands) { |
| return IntegerAttr::get(getResult().getType(), 0); |
| } |
| |
| OpFoldResult ConstRefZeroOp::fold(ArrayRef<Attribute> operands) { |
| // TODO(b/144027097): relace unit attr with a proper null ref_ptr attr. |
| return UnitAttr::get(getContext()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ref_ptr operations |
| //===----------------------------------------------------------------------===// |
| |
| //===----------------------------------------------------------------------===// |
| // Conditional assignment |
| //===----------------------------------------------------------------------===// |
| |
| template <typename T> |
| static OpFoldResult foldSelectOp(T op) { |
| if (matchPattern(op.condition(), m_Zero())) { |
| // 0 ? x : y = y |
| return op.false_value(); |
| } else if (matchPattern(op.condition(), m_NonZero())) { |
| // !0 ? x : y = x |
| return op.true_value(); |
| } else if (op.true_value() == op.false_value()) { |
| // c ? x : x = x |
| return op.true_value(); |
| } |
| return {}; |
| } |
| |
| OpFoldResult SelectI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldSelectOp(*this); |
| } |
| |
| OpFoldResult SelectI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldSelectOp(*this); |
| } |
| |
| OpFoldResult SelectRefOp::fold(ArrayRef<Attribute> operands) { |
| return foldSelectOp(*this); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldSwitchOp(T op) { |
| APInt indexValue; |
| if (matchPattern(op.index(), m_ConstantInt(&indexValue))) { |
| // Index is constant and we can resolve immediately. |
| int64_t index = indexValue.getSExtValue(); |
| if (index < 0 || index >= op.values().size()) { |
| return op.default_value(); |
| } |
| return op.values()[index]; |
| } |
| |
| bool allValuesMatch = true; |
| for (auto value : op.values()) { |
| if (value != op.default_value()) { |
| allValuesMatch = false; |
| break; |
| } |
| } |
| if (allValuesMatch) { |
| // All values (and the default) are the same so just return it regardless of |
| // the provided index. |
| return op.default_value(); |
| } |
| |
| return {}; |
| } |
| |
| OpFoldResult SwitchI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldSwitchOp(*this); |
| } |
| |
| OpFoldResult SwitchI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldSwitchOp(*this); |
| } |
| |
| OpFoldResult SwitchRefOp::fold(ArrayRef<Attribute> operands) { |
| return foldSwitchOp(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Native integer arithmetic |
| //===----------------------------------------------------------------------===// |
| |
| /// Performs const folding `calculate` with element-wise behavior on the given |
| /// attribute in `operands` and returns the result if possible. |
| template <class AttrElementT, |
| class ElementValueT = typename AttrElementT::ValueType, |
| class CalculationT = std::function<ElementValueT(ElementValueT)>> |
| static Attribute constFoldUnaryOp(ArrayRef<Attribute> operands, |
| const CalculationT &calculate) { |
| assert(operands.size() == 1 && "unary op takes one operand"); |
| if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) { |
| return AttrElementT::get(operand.getType(), calculate(operand.getValue())); |
| } else if (auto operand = operands[0].dyn_cast_or_null<SplatElementsAttr>()) { |
| auto elementResult = |
| constFoldUnaryOp<AttrElementT>({operand.getSplatValue()}, calculate); |
| if (!elementResult) return {}; |
| return DenseElementsAttr::get(operand.getType(), elementResult); |
| } else if (auto operand = operands[0].dyn_cast_or_null<ElementsAttr>()) { |
| return operand.mapValues( |
| operand.getType().getElementType(), |
| llvm::function_ref<ElementValueT(const ElementValueT &)>( |
| [&](const ElementValueT &value) { return calculate(value); })); |
| } |
| return {}; |
| } |
| |
| /// Performs const folding `calculate` with element-wise behavior on the two |
| /// attributes in `operands` and returns the result if possible. |
| template <class AttrElementT, |
| class ElementValueT = typename AttrElementT::ValueType, |
| class CalculationT = |
| std::function<ElementValueT(ElementValueT, ElementValueT)>> |
| static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, |
| const CalculationT &calculate) { |
| assert(operands.size() == 2 && "binary op takes two operands"); |
| if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) { |
| auto rhs = operands[1].dyn_cast_or_null<AttrElementT>(); |
| if (!rhs || lhs.getType() != rhs.getType()) return {}; |
| return AttrElementT::get(lhs.getType(), |
| calculate(lhs.getValue(), rhs.getValue())); |
| } else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) { |
| // TODO(benvanik): handle splat/otherwise. |
| auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>(); |
| if (!rhs || lhs.getType() != rhs.getType()) return {}; |
| auto elementResult = constFoldBinaryOp<AttrElementT>( |
| {lhs.getSplatValue(), rhs.getSplatValue()}, calculate); |
| if (!elementResult) return {}; |
| return DenseElementsAttr::get(lhs.getType(), elementResult); |
| } else if (auto lhs = operands[0].dyn_cast_or_null<ElementsAttr>()) { |
| auto rhs = operands[1].dyn_cast_or_null<ElementsAttr>(); |
| if (!rhs || lhs.getType() != rhs.getType()) return {}; |
| auto lhsIt = lhs.getValues<AttrElementT>().begin(); |
| auto rhsIt = rhs.getValues<AttrElementT>().begin(); |
| SmallVector<Attribute, 4> resultAttrs(lhs.getNumElements()); |
| for (int64_t i = 0; i < lhs.getNumElements(); ++i) { |
| resultAttrs[i] = |
| constFoldBinaryOp<AttrElementT>({*lhsIt, *rhsIt}, calculate); |
| if (!resultAttrs[i]) return {}; |
| ++lhsIt; |
| ++rhsIt; |
| } |
| return DenseElementsAttr::get(lhs.getType(), resultAttrs); |
| } |
| return {}; |
| } |
| |
| template <typename ADD, typename SUB> |
| static OpFoldResult foldAddOp(ADD op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| // x + 0 = x or 0 + y = y (commutative) |
| return op.lhs(); |
| } |
| if (auto subOp = dyn_cast_or_null<SUB>(op.lhs().getDefiningOp())) { |
| if (subOp.lhs() == op.rhs()) return subOp.rhs(); |
| if (subOp.rhs() == op.rhs()) return subOp.lhs(); |
| } else if (auto subOp = dyn_cast_or_null<SUB>(op.rhs().getDefiningOp())) { |
| if (subOp.lhs() == op.lhs()) return subOp.rhs(); |
| if (subOp.rhs() == op.lhs()) return subOp.lhs(); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a + b; }); |
| } |
| |
| OpFoldResult AddI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldAddOp<AddI32Op, SubI32Op>(*this, operands); |
| } |
| |
| OpFoldResult AddI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldAddOp<AddI64Op, SubI64Op>(*this, operands); |
| } |
| |
| template <typename SUB, typename ADD> |
| static OpFoldResult foldSubOp(SUB op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| // x - 0 = x |
| return op.lhs(); |
| } |
| if (auto addOp = dyn_cast_or_null<ADD>(op.lhs().getDefiningOp())) { |
| if (addOp.lhs() == op.rhs()) return addOp.rhs(); |
| if (addOp.rhs() == op.rhs()) return addOp.lhs(); |
| } else if (auto addOp = dyn_cast_or_null<ADD>(op.rhs().getDefiningOp())) { |
| if (addOp.lhs() == op.lhs()) return addOp.rhs(); |
| if (addOp.rhs() == op.lhs()) return addOp.lhs(); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a - b; }); |
| } |
| |
| OpFoldResult SubI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldSubOp<SubI32Op, AddI32Op>(*this, operands); |
| } |
| |
| OpFoldResult SubI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldSubOp<SubI64Op, AddI64Op>(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldMulOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| // x * 0 = 0 or 0 * y = 0 (commutative) |
| return zeroOfType(op.getType()); |
| } else if (matchPattern(op.rhs(), m_One())) { |
| // x * 1 = x or 1 * y = y (commutative) |
| return op.lhs(); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a * b; }); |
| } |
| |
| template <typename T, typename CONST_OP> |
| struct FoldConstantMulOperand : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(T op, |
| PatternRewriter &rewriter) const override { |
| IntegerAttr c1, c2; |
| if (!matchPattern(op.rhs(), m_Constant(&c1))) return failure(); |
| if (auto mulOp = dyn_cast_or_null<T>(op.lhs().getDefiningOp())) { |
| if (matchPattern(mulOp.rhs(), m_Constant(&c2))) { |
| auto c = rewriter.createOrFold<CONST_OP>( |
| FusedLoc::get({mulOp.getLoc(), op.getLoc()}, rewriter.getContext()), |
| constFoldBinaryOp<IntegerAttr>( |
| {c1, c2}, |
| [](const APInt &a, const APInt &b) { return a * b; })); |
| rewriter.replaceOpWithNewOp<T>(op, op.getType(), mulOp.lhs(), c); |
| return success(); |
| } |
| } |
| return failure(); |
| } |
| }; |
| |
| OpFoldResult MulI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldMulOp(*this, operands); |
| } |
| |
| void MulI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<FoldConstantMulOperand<MulI32Op, ConstI32Op>>(context); |
| } |
| |
| OpFoldResult MulI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldMulOp(*this, operands); |
| } |
| |
| void MulI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<FoldConstantMulOperand<MulI64Op, ConstI64Op>>(context); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldDivSOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| // x / 0 = death |
| op.emitOpError() << "is a divide by constant zero"; |
| return {}; |
| } else if (matchPattern(op.lhs(), m_Zero())) { |
| // 0 / y = 0 |
| return zeroOfType(op.getType()); |
| } else if (matchPattern(op.rhs(), m_One())) { |
| // x / 1 = x |
| return op.lhs(); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a.sdiv(b); }); |
| } |
| |
| OpFoldResult DivI32SOp::fold(ArrayRef<Attribute> operands) { |
| return foldDivSOp(*this, operands); |
| } |
| |
| OpFoldResult DivI64SOp::fold(ArrayRef<Attribute> operands) { |
| return foldDivSOp(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldDivUOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| // x / 0 = death |
| op.emitOpError() << "is a divide by constant zero"; |
| return {}; |
| } else if (matchPattern(op.lhs(), m_Zero())) { |
| // 0 / y = 0 |
| return zeroOfType(op.getType()); |
| } else if (matchPattern(op.rhs(), m_One())) { |
| // x / 1 = x |
| return op.lhs(); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a.udiv(b); }); |
| } |
| |
| OpFoldResult DivI32UOp::fold(ArrayRef<Attribute> operands) { |
| return foldDivUOp(*this, operands); |
| } |
| |
| OpFoldResult DivI64UOp::fold(ArrayRef<Attribute> operands) { |
| return foldDivUOp(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldRemSOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| // x % 0 = death |
| op.emitOpError() << "is a remainder by constant zero"; |
| return {}; |
| } else if (matchPattern(op.lhs(), m_Zero()) || |
| matchPattern(op.rhs(), m_One())) { |
| // x % 1 = 0 |
| // 0 % y = 0 |
| return zeroOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a.srem(b); }); |
| } |
| |
| OpFoldResult RemI32SOp::fold(ArrayRef<Attribute> operands) { |
| return foldRemSOp(*this, operands); |
| } |
| |
| OpFoldResult RemI64SOp::fold(ArrayRef<Attribute> operands) { |
| return foldRemSOp(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldRemUOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.lhs(), m_Zero()) || matchPattern(op.rhs(), m_One())) { |
| // x % 1 = 0 |
| // 0 % y = 0 |
| return zeroOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a.urem(b); }); |
| } |
| |
| OpFoldResult RemI32UOp::fold(ArrayRef<Attribute> operands) { |
| return foldRemUOp(*this, operands); |
| } |
| |
| OpFoldResult RemI64UOp::fold(ArrayRef<Attribute> operands) { |
| return foldRemUOp(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldNotOp(T op, ArrayRef<Attribute> operands) { |
| return constFoldUnaryOp<IntegerAttr>(operands, [](APInt a) { |
| a.flipAllBits(); |
| return a; |
| }); |
| } |
| |
| OpFoldResult NotI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldNotOp(*this, operands); |
| } |
| |
| OpFoldResult NotI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldNotOp(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldAndOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| // x & 0 = 0 or 0 & y = 0 (commutative) |
| return zeroOfType(op.getType()); |
| } else if (op.lhs() == op.rhs()) { |
| // x & x = x |
| return op.lhs(); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a & b; }); |
| } |
| |
| OpFoldResult AndI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldAndOp(*this, operands); |
| } |
| |
| OpFoldResult AndI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldAndOp(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldOrOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| // x | 0 = x or 0 | y = y (commutative) |
| return op.lhs(); |
| } else if (op.lhs() == op.rhs()) { |
| // x | x = x |
| return op.lhs(); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a | b; }); |
| } |
| |
| OpFoldResult OrI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldOrOp(*this, operands); |
| } |
| |
| OpFoldResult OrI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldOrOp(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldXorOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| // x ^ 0 = x or 0 ^ y = y (commutative) |
| return op.lhs(); |
| } else if (op.lhs() == op.rhs()) { |
| // x ^ x = 0 |
| return zeroOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [](const APInt &a, const APInt &b) { return a ^ b; }); |
| } |
| |
| OpFoldResult XorI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldXorOp(*this, operands); |
| } |
| |
| OpFoldResult XorI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldXorOp(*this, operands); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Native bitwise shifts and rotates |
| //===----------------------------------------------------------------------===// |
| |
| template <typename T> |
| static OpFoldResult foldShlOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.operand(), m_Zero())) { |
| // 0 << y = 0 |
| return zeroOfType(op.getType()); |
| } else if (op.amount() == 0) { |
| // x << 0 = x |
| return op.operand(); |
| } |
| return constFoldUnaryOp<IntegerAttr>( |
| operands, [&](const APInt &a) { return a.shl(op.amount()); }); |
| } |
| |
| OpFoldResult ShlI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldShlOp(*this, operands); |
| } |
| |
| OpFoldResult ShlI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldShlOp(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldShrSOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.operand(), m_Zero())) { |
| // 0 >> y = 0 |
| return zeroOfType(op.getType()); |
| } else if (op.amount() == 0) { |
| // x >> 0 = x |
| return op.operand(); |
| } |
| return constFoldUnaryOp<IntegerAttr>( |
| operands, [&](const APInt &a) { return a.ashr(op.amount()); }); |
| } |
| |
| OpFoldResult ShrI32SOp::fold(ArrayRef<Attribute> operands) { |
| return foldShrSOp(*this, operands); |
| } |
| |
| OpFoldResult ShrI64SOp::fold(ArrayRef<Attribute> operands) { |
| return foldShrSOp(*this, operands); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldShrUOp(T op, ArrayRef<Attribute> operands) { |
| if (matchPattern(op.operand(), m_Zero())) { |
| // 0 >> y = 0 |
| return zeroOfType(op.getType()); |
| } else if (op.amount() == 0) { |
| // x >> 0 = x |
| return op.operand(); |
| } |
| return constFoldUnaryOp<IntegerAttr>( |
| operands, [&](const APInt &a) { return a.lshr(op.amount()); }); |
| } |
| |
| OpFoldResult ShrI32UOp::fold(ArrayRef<Attribute> operands) { |
| return foldShrUOp(*this, operands); |
| } |
| |
| OpFoldResult ShrI64UOp::fold(ArrayRef<Attribute> operands) { |
| return foldShrUOp(*this, operands); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Casting and type conversion/emulation |
| //===----------------------------------------------------------------------===// |
| |
| /// Performs const folding `calculate` with element-wise behavior on the given |
| /// attribute in `operands` and returns the result if possible. |
| template <class AttrElementT, |
| class ElementValueT = typename AttrElementT::ValueType, |
| class CalculationT = std::function<ElementValueT(ElementValueT)>> |
| static Attribute constFoldConversionOp(Type resultType, |
| ArrayRef<Attribute> operands, |
| const CalculationT &calculate) { |
| assert(operands.size() == 1 && "unary op takes one operand"); |
| if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) { |
| return AttrElementT::get(resultType, calculate(operand.getValue())); |
| } |
| return {}; |
| } |
| |
| OpFoldResult TruncI32I8Op::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 32), operands, |
| [&](const APInt &a) { return a.trunc(8).zext(32); }); |
| } |
| |
| OpFoldResult TruncI32I16Op::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 32), operands, |
| [&](const APInt &a) { return a.trunc(16).zext(32); }); |
| } |
| |
| OpFoldResult TruncI64I8Op::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 32), operands, |
| [&](const APInt &a) { return a.trunc(8).zext(32); }); |
| } |
| |
| OpFoldResult TruncI64I16Op::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 32), operands, |
| [&](const APInt &a) { return a.trunc(16).zext(32); }); |
| } |
| |
| OpFoldResult TruncI64I32Op::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 32), operands, |
| [&](const APInt &a) { return a.trunc(32); }); |
| } |
| |
| OpFoldResult ExtI8I32SOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 32), operands, |
| [&](const APInt &a) { return a.trunc(8).sext(32); }); |
| } |
| |
| OpFoldResult ExtI8I32UOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 32), operands, |
| [&](const APInt &a) { return a.trunc(8).zext(32); }); |
| } |
| |
| OpFoldResult ExtI16I32SOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 32), operands, |
| [&](const APInt &a) { return a.trunc(16).sext(32); }); |
| } |
| |
| OpFoldResult ExtI16I32UOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 32), operands, |
| [&](const APInt &a) { return a.trunc(16).zext(32); }); |
| } |
| |
| OpFoldResult ExtI8I64SOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 64), operands, |
| [&](const APInt &a) { return a.trunc(8).sext(64); }); |
| } |
| |
| OpFoldResult ExtI8I64UOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 64), operands, |
| [&](const APInt &a) { return a.trunc(8).zext(64); }); |
| } |
| |
| OpFoldResult ExtI16I64SOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 64), operands, |
| [&](const APInt &a) { return a.trunc(16).sext(64); }); |
| } |
| |
| OpFoldResult ExtI16I64UOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 64), operands, |
| [&](const APInt &a) { return a.trunc(16).zext(64); }); |
| } |
| |
| OpFoldResult ExtI32I64SOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 64), operands, |
| [&](const APInt &a) { return a.sext(64); }); |
| } |
| |
| OpFoldResult ExtI32I64UOp::fold(ArrayRef<Attribute> operands) { |
| return constFoldConversionOp<IntegerAttr>( |
| IntegerType::get(getContext(), 64), operands, |
| [&](const APInt &a) { return a.zext(64); }); |
| } |
| |
| namespace { |
| |
| template <typename SRC_OP, typename OP_A, int SZ_T, typename OP_B> |
| struct PseudoIntegerConversionToSplitConversionOp |
| : public OpRewritePattern<SRC_OP> { |
| using OpRewritePattern<SRC_OP>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(SRC_OP op, |
| PatternRewriter &rewriter) const override { |
| auto tmp = rewriter.createOrFold<OP_A>( |
| op.getLoc(), rewriter.getIntegerType(SZ_T), op.operand()); |
| rewriter.replaceOpWithNewOp<OP_B>(op, op.result().getType(), tmp); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void TruncI64I8Op::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<PseudoIntegerConversionToSplitConversionOp< |
| TruncI64I8Op, TruncI64I32Op, 32, TruncI32I8Op>>(context); |
| } |
| |
| void TruncI64I16Op::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<PseudoIntegerConversionToSplitConversionOp< |
| TruncI64I16Op, TruncI64I32Op, 32, TruncI32I16Op>>(context); |
| } |
| |
| void ExtI8I64SOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<PseudoIntegerConversionToSplitConversionOp< |
| ExtI8I64SOp, ExtI8I32SOp, 32, ExtI32I64SOp>>(context); |
| } |
| |
| void ExtI8I64UOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<PseudoIntegerConversionToSplitConversionOp< |
| ExtI8I64UOp, ExtI8I32UOp, 32, ExtI32I64UOp>>(context); |
| } |
| |
| void ExtI16I64SOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<PseudoIntegerConversionToSplitConversionOp< |
| ExtI16I64SOp, ExtI16I32SOp, 32, ExtI32I64SOp>>(context); |
| } |
| |
| void ExtI16I64UOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<PseudoIntegerConversionToSplitConversionOp< |
| ExtI16I64UOp, ExtI16I32UOp, 32, ExtI32I64UOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Native reduction (horizontal) arithmetic |
| //===----------------------------------------------------------------------===// |
| |
| //===----------------------------------------------------------------------===// |
| // Comparison ops |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// Swaps the cmp op with its inverse if the result is inverted. |
| template <typename OP, typename INV> |
| struct SwapInvertedCmpOps : public OpRewritePattern<OP> { |
| using OpRewritePattern<OP>::OpRewritePattern; |
| LogicalResult matchAndRewrite(OP op, |
| PatternRewriter &rewriter) const override { |
| // We generate xor(cmp(...), 1) to flip conditions, so look for that pattern |
| // so that we can do the swap here and remove the xor. |
| if (!op.result().hasOneUse()) { |
| // Can't change if there are multiple users. |
| return failure(); |
| } |
| if (auto xorOp = dyn_cast_or_null<XorI32Op>(*op.result().user_begin())) { |
| Attribute rhs; |
| if (xorOp.lhs() == op.result() && |
| matchPattern(xorOp.rhs(), m_Constant(&rhs)) && |
| rhs.cast<IntegerAttr>().getInt() == 1) { |
| auto invValue = rewriter.createOrFold<INV>( |
| op.getLoc(), op.result().getType(), op.lhs(), op.rhs()); |
| rewriter.replaceOp(op, {invValue}); |
| rewriter.replaceOp(xorOp, {invValue}); |
| return success(); |
| } |
| } |
| return failure(); |
| } |
| }; |
| |
| } // namespace |
| |
| template <typename T> |
| static OpFoldResult foldCmpEQOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x == x = true |
| return oneOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.eq(b); }); |
| } |
| |
| OpFoldResult CmpEQI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldCmpEQOp(*this, operands); |
| } |
| |
| void CmpEQI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpEQI32Op, CmpNEI32Op>>(context); |
| } |
| |
| OpFoldResult CmpEQI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldCmpEQOp(*this, operands); |
| } |
| |
| void CmpEQI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpEQI64Op, CmpNEI64Op>>(context); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldCmpNEOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x != x = false |
| return zeroOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.ne(b); }); |
| } |
| |
| OpFoldResult CmpNEI32Op::fold(ArrayRef<Attribute> operands) { |
| return foldCmpNEOp(*this, operands); |
| } |
| |
| OpFoldResult CmpNEI64Op::fold(ArrayRef<Attribute> operands) { |
| return foldCmpNEOp(*this, operands); |
| } |
| |
| namespace { |
| |
| /// Changes a cmp.ne.i32 check against 0 to a cmp.nz.i32. |
| template <typename NE_OP, typename NZ_OP> |
| struct CmpNEZeroToCmpNZ : public OpRewritePattern<NE_OP> { |
| using OpRewritePattern<NE_OP>::OpRewritePattern; |
| LogicalResult matchAndRewrite(NE_OP op, |
| PatternRewriter &rewriter) const override { |
| if (matchPattern(op.rhs(), m_Zero())) { |
| rewriter.replaceOpWithNewOp<NZ_OP>(op, op.getType(), op.lhs()); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| |
| } // namespace |
| |
| void CmpNEI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpNEI32Op, CmpEQI32Op>, |
| CmpNEZeroToCmpNZ<CmpNEI32Op, CmpNZI32Op>>(context); |
| } |
| |
| void CmpNEI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpNEI64Op, CmpEQI64Op>, |
| CmpNEZeroToCmpNZ<CmpNEI64Op, CmpNZI64Op>>(context); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldCmpLTSOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x < x = false |
| return zeroOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.slt(b); }); |
| } |
| |
| OpFoldResult CmpLTI32SOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpLTSOp(*this, operands); |
| } |
| |
| OpFoldResult CmpLTI64SOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpLTSOp(*this, operands); |
| } |
| |
| void CmpLTI32SOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) {} |
| |
| void CmpLTI64SOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) {} |
| |
| template <typename T> |
| static OpFoldResult foldCmpLTUOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x < x = false |
| return zeroOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.ult(b); }); |
| } |
| |
| OpFoldResult CmpLTI32UOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpLTUOp(*this, operands); |
| } |
| |
| OpFoldResult CmpLTI64UOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpLTUOp(*this, operands); |
| } |
| |
| void CmpLTI32UOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) {} |
| |
| void CmpLTI64UOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) {} |
| |
| namespace { |
| |
| /// Rewrites a vm.cmp.lte.* pseudo op to a vm.cmp.lt.* op. |
| template <typename T, typename U> |
| struct RewritePseudoCmpLTEToLT : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| LogicalResult matchAndRewrite(T op, |
| PatternRewriter &rewriter) const override { |
| // !(lhs > rhs) |
| auto condValue = |
| rewriter.createOrFold<U>(op.getLoc(), op.getType(), op.rhs(), op.lhs()); |
| rewriter.replaceOpWithNewOp<XorI32Op>( |
| op, op.getType(), condValue, |
| rewriter.createOrFold<IREE::VM::ConstI32Op>(op.getLoc(), 1)); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| template <typename T> |
| static OpFoldResult foldCmpLTESOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x <= x = true |
| return oneOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.sle(b); }); |
| } |
| |
| OpFoldResult CmpLTEI32SOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpLTESOp(*this, operands); |
| } |
| |
| OpFoldResult CmpLTEI64SOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpLTESOp(*this, operands); |
| } |
| |
| void CmpLTEI32SOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpLTEI32SOp, CmpGTI32SOp>>(context); |
| results.insert<RewritePseudoCmpLTEToLT<CmpLTEI32SOp, CmpLTI32SOp>>(context); |
| } |
| |
| void CmpLTEI64SOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpLTEI64SOp, CmpGTI64SOp>>(context); |
| results.insert<RewritePseudoCmpLTEToLT<CmpLTEI64SOp, CmpLTI64SOp>>(context); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldCmpLTEUOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x <= x = true |
| return oneOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.ule(b); }); |
| } |
| |
| OpFoldResult CmpLTEI32UOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpLTEUOp(*this, operands); |
| } |
| |
| OpFoldResult CmpLTEI64UOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpLTEUOp(*this, operands); |
| } |
| |
| void CmpLTEI32UOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpLTEI32UOp, CmpGTI32UOp>>(context); |
| results.insert<RewritePseudoCmpLTEToLT<CmpLTEI32UOp, CmpLTI32UOp>>(context); |
| } |
| |
| void CmpLTEI64UOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpLTEI64UOp, CmpGTI64UOp>>(context); |
| results.insert<RewritePseudoCmpLTEToLT<CmpLTEI64UOp, CmpLTI64UOp>>(context); |
| } |
| |
| namespace { |
| |
| /// Rewrites a vm.cmp.gt.* pseudo op to a vm.cmp.lt.* op. |
| template <typename T, typename U> |
| struct RewritePseudoCmpGTToLT : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| LogicalResult matchAndRewrite(T op, |
| PatternRewriter &rewriter) const override { |
| // rhs < lhs |
| rewriter.replaceOpWithNewOp<U>(op, op.getType(), op.rhs(), op.lhs()); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| template <typename T> |
| static OpFoldResult foldCmpGTSOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x > x = false |
| return zeroOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.sgt(b); }); |
| } |
| |
| OpFoldResult CmpGTI32SOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpGTSOp(*this, operands); |
| } |
| |
| OpFoldResult CmpGTI64SOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpGTSOp(*this, operands); |
| } |
| |
| void CmpGTI32SOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpGTI32SOp, CmpLTEI32SOp>>(context); |
| results.insert<RewritePseudoCmpGTToLT<CmpGTI32SOp, CmpLTI32SOp>>(context); |
| } |
| |
| void CmpGTI64SOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpGTI64SOp, CmpLTEI64SOp>>(context); |
| results.insert<RewritePseudoCmpGTToLT<CmpGTI64SOp, CmpLTI64SOp>>(context); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldCmpGTUOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x > x = false |
| return zeroOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.ugt(b); }); |
| } |
| |
| OpFoldResult CmpGTI32UOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpGTUOp(*this, operands); |
| } |
| |
| OpFoldResult CmpGTI64UOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpGTUOp(*this, operands); |
| } |
| |
| void CmpGTI32UOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpGTI32UOp, CmpLTEI32UOp>>(context); |
| results.insert<RewritePseudoCmpGTToLT<CmpGTI32UOp, CmpLTI32UOp>>(context); |
| } |
| |
| void CmpGTI64UOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpGTI64UOp, CmpLTEI64UOp>>(context); |
| results.insert<RewritePseudoCmpGTToLT<CmpGTI64UOp, CmpLTI64UOp>>(context); |
| } |
| |
| namespace { |
| |
| /// Rewrites a vm.cmp.gte.* pseudo op to a vm.cmp.lt.* op. |
| template <typename T, typename U> |
| struct RewritePseudoCmpGTEToLT : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| LogicalResult matchAndRewrite(T op, |
| PatternRewriter &rewriter) const override { |
| // !(lhs < rhs) |
| auto condValue = |
| rewriter.createOrFold<U>(op.getLoc(), op.getType(), op.lhs(), op.rhs()); |
| rewriter.replaceOpWithNewOp<XorI32Op>( |
| op, op.getType(), condValue, |
| rewriter.createOrFold<IREE::VM::ConstI32Op>(op.getLoc(), 1)); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| template <typename T> |
| static OpFoldResult foldCmpGTESOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x >= x = true |
| return oneOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.sge(b); }); |
| } |
| |
| OpFoldResult CmpGTEI32SOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpGTESOp(*this, operands); |
| } |
| |
| OpFoldResult CmpGTEI64SOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpGTESOp(*this, operands); |
| } |
| |
| void CmpGTEI32SOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpGTEI32SOp, CmpLTI32SOp>>(context); |
| results.insert<RewritePseudoCmpGTEToLT<CmpGTEI32SOp, CmpLTI32SOp>>(context); |
| } |
| |
| void CmpGTEI64SOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpGTEI64SOp, CmpLTI64SOp>>(context); |
| results.insert<RewritePseudoCmpGTEToLT<CmpGTEI64SOp, CmpLTI64SOp>>(context); |
| } |
| |
| template <typename T> |
| static OpFoldResult foldCmpGTEUOp(T op, ArrayRef<Attribute> operands) { |
| if (op.lhs() == op.rhs()) { |
| // x >= x = true |
| return oneOfType(op.getType()); |
| } |
| return constFoldBinaryOp<IntegerAttr>( |
| operands, [&](const APInt &a, const APInt &b) { return a.uge(b); }); |
| } |
| |
| OpFoldResult CmpGTEI32UOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpGTEUOp(*this, operands); |
| } |
| |
| OpFoldResult CmpGTEI64UOp::fold(ArrayRef<Attribute> operands) { |
| return foldCmpGTEUOp(*this, operands); |
| } |
| |
| void CmpGTEI32UOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpGTEI32UOp, CmpLTI32UOp>>(context); |
| results.insert<RewritePseudoCmpGTEToLT<CmpGTEI32UOp, CmpLTI32UOp>>(context); |
| } |
| |
| void CmpGTEI64UOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<SwapInvertedCmpOps<CmpGTEI64UOp, CmpLTI64UOp>>(context); |
| results.insert<RewritePseudoCmpGTEToLT<CmpGTEI64UOp, CmpLTI64UOp>>(context); |
| } |
| |
| OpFoldResult CmpNZI32Op::fold(ArrayRef<Attribute> operands) { |
| return constFoldUnaryOp<IntegerAttr>( |
| operands, [&](const APInt &a) { return APInt(32, a.getBoolValue()); }); |
| } |
| |
| OpFoldResult CmpNZI64Op::fold(ArrayRef<Attribute> operands) { |
| return constFoldUnaryOp<IntegerAttr>( |
| operands, [&](const APInt &a) { return APInt(64, a.getBoolValue()); }); |
| } |
| |
| OpFoldResult CmpEQRefOp::fold(ArrayRef<Attribute> operands) { |
| if (lhs() == rhs()) { |
| // x == x = true |
| return oneOfType(getType()); |
| } else if (operands[0] && operands[1]) { |
| // Constant null == null = true |
| return oneOfType(getType()); |
| } |
| return {}; |
| } |
| |
| namespace { |
| |
| /// Changes a cmp.eq.ref check against null to a cmp.nz.ref and inverted cond. |
| struct NullCheckCmpEQRefToCmpNZRef : public OpRewritePattern<CmpEQRefOp> { |
| using OpRewritePattern<CmpEQRefOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(CmpEQRefOp op, |
| PatternRewriter &rewriter) const override { |
| Attribute rhs; |
| if (matchPattern(op.rhs(), m_Constant(&rhs))) { |
| auto cmpNz = |
| rewriter.create<CmpNZRefOp>(op.getLoc(), op.getType(), op.lhs()); |
| rewriter.replaceOpWithNewOp<XorI32Op>( |
| op, op.getType(), cmpNz, |
| rewriter.createOrFold<IREE::VM::ConstI32Op>(op.getLoc(), 1)); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| |
| } // namespace |
| |
| void CmpEQRefOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<NullCheckCmpEQRefToCmpNZRef>(context); |
| } |
| |
| OpFoldResult CmpNERefOp::fold(ArrayRef<Attribute> operands) { |
| if (lhs() == rhs()) { |
| // x != x = false |
| return zeroOfType(getType()); |
| } |
| return {}; |
| } |
| |
| namespace { |
| |
| /// Changes a cmp.ne.ref check against null to a cmp.nz.ref. |
| struct NullCheckCmpNERefToCmpNZRef : public OpRewritePattern<CmpNERefOp> { |
| using OpRewritePattern<CmpNERefOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(CmpNERefOp op, |
| PatternRewriter &rewriter) const override { |
| Attribute rhs; |
| if (matchPattern(op.rhs(), m_Constant(&rhs))) { |
| rewriter.replaceOpWithNewOp<CmpNZRefOp>(op, op.getType(), op.lhs()); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| |
| } // namespace |
| |
| void CmpNERefOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<NullCheckCmpNERefToCmpNZRef>(context); |
| } |
| |
| OpFoldResult CmpNZRefOp::fold(ArrayRef<Attribute> operands) { |
| Attribute operandValue; |
| if (matchPattern(operand(), m_Constant(&operandValue))) { |
| // x == null |
| return zeroOfType(getType()); |
| } |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Control flow |
| //===----------------------------------------------------------------------===// |
| |
| /// Given a successor, try to collapse it to a new destination if it only |
| /// contains a passthrough unconditional branch. If the successor is |
| /// collapsable, `successor` and `successorOperands` are updated to reference |
| /// the new destination and values. `argStorage` is an optional storage to use |
| /// if operands to the collapsed successor need to be remapped. |
| static LogicalResult collapseBranch(Block *&successor, |
| ValueRange &successorOperands, |
| SmallVectorImpl<Value> &argStorage) { |
| // Check that the successor only contains a unconditional branch. |
| if (std::next(successor->begin()) != successor->end()) return failure(); |
| // Check that the terminator is an unconditional branch. |
| BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); |
| if (!successorBranch) return failure(); |
| // Check that the arguments are only used within the terminator. |
| for (BlockArgument arg : successor->getArguments()) { |
| for (Operation *user : arg.getUsers()) |
| if (user != successorBranch) return failure(); |
| } |
| // Don't try to collapse branches to infinite loops. |
| Block *successorDest = successorBranch.getDest(); |
| if (successorDest == successor) return failure(); |
| |
| // Update the operands to the successor. If the branch parent has no |
| // arguments, we can use the branch operands directly. |
| OperandRange operands = successorBranch.getOperands(); |
| if (successor->args_empty()) { |
| successor = successorDest; |
| successorOperands = operands; |
| return success(); |
| } |
| |
| // Otherwise, we need to remap any argument operands. |
| for (Value operand : operands) { |
| BlockArgument argOperand = operand.dyn_cast<BlockArgument>(); |
| if (argOperand && argOperand.getOwner() == successor) |
| argStorage.push_back(successorOperands[argOperand.getArgNumber()]); |
| else |
| argStorage.push_back(operand); |
| } |
| successor = successorDest; |
| successorOperands = argStorage; |
| return success(); |
| } |
| |
| namespace { |
| |
| /// Simplify a branch to a block that has a single predecessor. This effectively |
| /// merges the two blocks. |
| /// |
| /// (same logic as for std.br) |
| struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> { |
| using OpRewritePattern<BranchOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(BranchOp op, |
| PatternRewriter &rewriter) const override { |
| // Check that the successor block has a single predecessor. |
| Block *succ = op.getDest(); |
| Block *opParent = op.getOperation()->getBlock(); |
| if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) { |
| return failure(); |
| } |
| |
| // Merge the successor into the current block and erase the branch. |
| rewriter.mergeBlocks(succ, opParent, op.getOperands()); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| /// br ^bb1 |
| /// ^bb1 |
| /// br ^bbN(...) |
| /// |
| /// -> br ^bbN(...) |
| /// |
| /// (same logic as for std.br) |
| struct SimplifyPassThroughBr : public OpRewritePattern<BranchOp> { |
| using OpRewritePattern<BranchOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(BranchOp op, |
| PatternRewriter &rewriter) const override { |
| Block *dest = op.getDest(); |
| ValueRange destOperands = op.getOperands(); |
| SmallVector<Value, 4> destOperandStorage; |
| |
| // Try to collapse the successor if it points somewhere other than this |
| // block. |
| if (dest == op.getOperation()->getBlock() || |
| failed(collapseBranch(dest, destOperands, destOperandStorage))) { |
| return failure(); |
| } |
| |
| // Create a new branch with the collapsed successor. |
| rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<SimplifyBrToBlockWithSinglePred, SimplifyPassThroughBr>( |
| context); |
| } |
| |
| namespace { |
| |
| /// Simplifies a cond_br with a constant condition to an unconditional branch. |
| struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { |
| using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(CondBranchOp op, |
| PatternRewriter &rewriter) const override { |
| if (matchPattern(op.condition(), m_NonZero())) { |
| // True branch taken. |
| rewriter.replaceOpWithNewOp<BranchOp>(op, op.getTrueDest(), |
| op.getTrueOperands()); |
| return success(); |
| } else if (matchPattern(op.condition(), m_Zero())) { |
| // False branch taken. |
| rewriter.replaceOpWithNewOp<BranchOp>(op, op.getFalseDest(), |
| op.getFalseOperands()); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| |
| /// Simplifies a cond_br with both targets (including operands) being equal to |
| /// an unconditional branch. |
| struct SimplifySameTargetCondBranchOp : public OpRewritePattern<CondBranchOp> { |
| using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(CondBranchOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.getTrueDest() != op.getFalseDest()) { |
| // Targets differ so we need to be a cond branch. |
| return failure(); |
| } |
| |
| // If all operands match between the targets then we can become a normal |
| // branch to the shared target. |
| auto trueOperands = llvm::to_vector<4>(op.getTrueOperands()); |
| auto falseOperands = llvm::to_vector<4>(op.getFalseOperands()); |
| if (trueOperands == falseOperands) { |
| rewriter.replaceOpWithNewOp<BranchOp>(op, op.getTrueDest(), trueOperands); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| /// Swaps the cond_br true and false targets if the condition is inverted. |
| struct SwapInvertedCondBranchOpTargets : public OpRewritePattern<CondBranchOp> { |
| using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(CondBranchOp op, |
| PatternRewriter &rewriter) const override { |
| // TODO(benvanik): figure out something more reliable when the xor may be |
| // used on a non-binary value. |
| // We generate xor(cmp(...), 1) to flip conditions, so look for that pattern |
| // so that we can do the swap here and remove the xor. |
| // auto condValue = op.getCondition(); |
| // if (auto xorOp = dyn_cast_or_null<XorI32Op>(condValue.getDefiningOp())) { |
| // Attribute rhs; |
| // if (matchPattern(xorOp.rhs(), m_Constant(&rhs)) && |
| // rhs.cast<IntegerAttr>().getInt() == 1) { |
| // rewriter.replaceOpWithNewOp<CondBranchOp>( |
| // op, xorOp.lhs(), op.getFalseDest(), op.getFalseOperands(), |
| // op.getTrueDest(), op.getTrueOperands()); |
| // return success(); |
| // } |
| // } |
| return failure(); |
| } |
| }; |
| |
| } // namespace |
| |
| void CondBranchOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<SimplifyConstCondBranchPred, SimplifySameTargetCondBranchOp, |
| SwapInvertedCondBranchOpTargets>(context); |
| } |
| |
| namespace { |
| |
| /// Removes vm.call ops to functions that are marked as having no side-effects |
| /// if the results are unused. |
| template <typename T> |
| struct EraseUnusedCallOp : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| LogicalResult matchAndRewrite(T op, |
| PatternRewriter &rewriter) const override { |
| // First check if the call is unused - this ensures we only do the symbol |
| // lookup if we are actually going to use it. |
| for (auto result : op.getResults()) { |
| if (!result.use_empty()) { |
| return failure(); |
| } |
| } |
| |
| auto *calleeOp = SymbolTable::lookupSymbolIn( |
| op->template getParentOfType<ModuleOp>(), op.callee()); |
| |
| bool hasNoSideEffects = false; |
| if (calleeOp->getAttr("nosideeffects")) { |
| hasNoSideEffects = true; |
| } else if (auto import = dyn_cast<ImportInterface>(calleeOp)) { |
| hasNoSideEffects = !import.hasSideEffects(); |
| } |
| if (!hasNoSideEffects) { |
| // Op has side-effects (or may have them); can't remove. |
| return failure(); |
| } |
| |
| // Erase op as it is unused. |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void CallOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<EraseUnusedCallOp<CallOp>>(context); |
| } |
| |
| namespace { |
| |
| /// Converts a vm.call.variadic to a non-variadic function to a normal vm.call. |
| struct ConvertNonVariadicToCallOp : public OpRewritePattern<CallVariadicOp> { |
| using OpRewritePattern<CallVariadicOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(CallVariadicOp op, |
| PatternRewriter &rewriter) const override { |
| // If any segment size is != -1 (which indicates variadic) we bail. |
| for (const auto &segmentSize : op.segment_sizes()) { |
| if (segmentSize.getSExtValue() != -1) { |
| return failure(); |
| } |
| } |
| rewriter.replaceOpWithNewOp<CallOp>(op, op.callee(), |
| llvm::to_vector<4>(op.getResultTypes()), |
| op.getOperands()); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void CallVariadicOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<EraseUnusedCallOp<CallVariadicOp>, ConvertNonVariadicToCallOp>( |
| context); |
| } |
| |
| namespace { |
| |
| /// Rewrites a cond_fail op to a cond_branch to a fail op. |
| struct RewriteCondFailToBranchFail : public OpRewritePattern<CondFailOp> { |
| using OpRewritePattern<CondFailOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(CondFailOp op, |
| PatternRewriter &rewriter) const override { |
| auto *block = rewriter.getInsertionBlock(); |
| |
| // Create the block with the vm.fail in it. |
| // This is what we will branch to if the condition is true at runtime. |
| auto *failBlock = rewriter.createBlock(block, {op.status().getType()}); |
| block->moveBefore(failBlock); |
| rewriter.setInsertionPointToStart(failBlock); |
| rewriter.create<FailOp>( |
| op.getLoc(), failBlock->getArgument(0), |
| op.message().hasValue() ? op.message().getValue() : ""); |
| |
| // Replace the original cond_fail with our cond_branch, splitting the block |
| // and continuing if the condition is not taken. |
| auto *continueBlock = rewriter.splitBlock( |
| block, op.getOperation()->getNextNode()->getIterator()); |
| rewriter.setInsertionPointToEnd(block); |
| rewriter.replaceOpWithNewOp<CondBranchOp>(op, op.condition(), failBlock, |
| ValueRange{op.status()}, |
| continueBlock, ValueRange{}); |
| |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void CondFailOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<RewriteCondFailToBranchFail>(context); |
| } |
| |
| namespace { |
| |
| /// Rewrites a check op to a cmp and a cond_fail. |
| template <typename CheckOp, typename CmpI32Op, typename CmpI64Op, |
| typename CmpRefOp> |
| struct RewriteCheckToCondFail : public OpRewritePattern<CheckOp> { |
| using OpRewritePattern<CheckOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(CheckOp op, |
| PatternRewriter &rewriter) const override { |
| Type condType = rewriter.getI32Type(); |
| Value condValue; |
| Type operandType = op.getOperation()->getOperand(0).getType(); |
| if (operandType.template isa<RefType>()) { |
| condValue = rewriter.template createOrFold<CmpRefOp>( |
| op.getLoc(), ArrayRef<Type>{condType}, |
| op.getOperation()->getOperands()); |
| } else if (operandType.isInteger(64)) { |
| condValue = rewriter.template createOrFold<CmpI64Op>( |
| op.getLoc(), ArrayRef<Type>{condType}, |
| op.getOperation()->getOperands()); |
| } else if (operandType.isInteger(32)) { |
| condValue = rewriter.template createOrFold<CmpI32Op>( |
| op.getLoc(), ArrayRef<Type>{condType}, |
| op.getOperation()->getOperands()); |
| } else { |
| return failure(); |
| } |
| condValue = rewriter.createOrFold<XorI32Op>( |
| op.getLoc(), condType, condValue, |
| rewriter.createOrFold<IREE::VM::ConstI32Op>(op.getLoc(), 1)); |
| auto statusCode = rewriter.createOrFold<ConstI32Op>( |
| op.getLoc(), /*IREE_STATUS_FAILED_PRECONDITION=*/9); |
| rewriter.replaceOpWithNewOp<IREE::VM::CondFailOp>(op, condValue, statusCode, |
| op.messageAttr()); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void CheckEQOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert< |
| RewriteCheckToCondFail<CheckEQOp, CmpEQI32Op, CmpEQI64Op, CmpEQRefOp>>( |
| context); |
| } |
| |
| void CheckNEOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert< |
| RewriteCheckToCondFail<CheckNEOp, CmpNEI32Op, CmpNEI64Op, CmpNERefOp>>( |
| context); |
| } |
| |
| void CheckNZOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert< |
| RewriteCheckToCondFail<CheckNZOp, CmpNZI32Op, CmpNZI64Op, CmpNZRefOp>>( |
| context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Async/fiber ops |
| //===----------------------------------------------------------------------===// |
| |
| //===----------------------------------------------------------------------===// |
| // Debugging |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| template <typename T> |
| struct RemoveDisabledDebugOp : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| LogicalResult matchAndRewrite(T op, |
| PatternRewriter &rewriter) const override { |
| // TODO(benvanik): if debug disabled then replace inputs -> outputs. |
| return failure(); |
| } |
| }; |
| |
| template <typename T> |
| struct RemoveDisabledDebugAsyncOp : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| LogicalResult matchAndRewrite(T op, |
| PatternRewriter &rewriter) const override { |
| // TODO(benvanik): if debug disabled then replace with a branch to dest. |
| return failure(); |
| } |
| }; |
| |
| struct SimplifyConstCondBreakPred : public OpRewritePattern<CondBreakOp> { |
| using OpRewritePattern<CondBreakOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(CondBreakOp op, |
| PatternRewriter &rewriter) const override { |
| IntegerAttr condValue; |
| if (!matchPattern(op.condition(), m_Constant(&condValue))) { |
| return failure(); |
| } |
| |
| if (condValue.getValue() != 0) { |
| // True - always break (to the same destination). |
| rewriter.replaceOpWithNewOp<BreakOp>(op, op.getDest(), op.destOperands()); |
| } else { |
| // False - skip the break. |
| rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDest(), |
| op.destOperands()); |
| } |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void TraceOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<RemoveDisabledDebugOp<TraceOp>>(context); |
| } |
| |
| void PrintOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<RemoveDisabledDebugOp<PrintOp>>(context); |
| } |
| |
| void BreakOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<RemoveDisabledDebugAsyncOp<BreakOp>>(context); |
| } |
| |
| void CondBreakOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<RemoveDisabledDebugAsyncOp<CondBreakOp>, |
| SimplifyConstCondBreakPred>(context); |
| } |
| |
| } // namespace VM |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |