| // Copyright 2022 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/ExternalInterfaces/UtilExternalModels.h" |
| |
| #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" |
| #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" |
| #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" |
| #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/MLProgram/IR/MLProgram.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| |
| namespace mlir::iree_compiler { |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // InferIntDivisibilityOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| static IREE::Util::ConstantIntDivisibility |
| getDivisibilityOfOperand(Value v, |
| IREE::Util::IntegerDivisibility divisibility) { |
| if (!divisibility.isUninitialized()) { |
| return divisibility.getValue(); |
| } |
| APInt intVal; |
| if (matchPattern(v, m_ConstantInt(&intVal))) { |
| uint64_t udiv = intVal.getZExtValue(); |
| uint64_t sdiv = std::abs(intVal.getSExtValue()); |
| return IREE::Util::ConstantIntDivisibility(udiv, sdiv); |
| } |
| return IREE::Util::ConstantIntDivisibility(1, 1); |
| } |
| |
| struct ArithConstantInferIntDivisibilityOpInterface |
| : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< |
| ArithConstantInferIntDivisibilityOpInterface, arith::ConstantOp> { |
| |
| void inferResultDivisibility( |
| Operation *op, ArrayRef<IREE::Util::IntegerDivisibility> argDivs, |
| IREE::Util::SetIntDivisibilityFn setResultDivs) const { |
| auto constOp = cast<arith::ConstantOp>(op); |
| auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(constOp.getValue()); |
| if (constAttr) { |
| const APInt &value = constAttr.getValue(); |
| uint64_t udiv = value.getZExtValue(); |
| uint64_t sdiv = std::abs(value.getSExtValue()); |
| setResultDivs(constOp.getResult(), |
| IREE::Util::ConstantIntDivisibility(udiv, sdiv)); |
| } |
| } |
| }; |
| |
| struct ArithMulIInferIntDivisibilityOpInterface |
| : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< |
| ArithMulIInferIntDivisibilityOpInterface, arith::MulIOp> { |
| |
| void inferResultDivisibility( |
| Operation *op, ArrayRef<IREE::Util::IntegerDivisibility> argDivs, |
| IREE::Util::SetIntDivisibilityFn setResultDivs) const { |
| auto mulOp = cast<arith::MulIOp>(op); |
| |
| auto lhsDivisibility = getDivisibilityOfOperand(mulOp.getLhs(), argDivs[0]); |
| auto rhsDivisibility = getDivisibilityOfOperand(mulOp.getRhs(), argDivs[1]); |
| |
| uint64_t mulUDiv = lhsDivisibility.udiv() * rhsDivisibility.udiv(); |
| uint64_t mulSDiv = lhsDivisibility.sdiv() * rhsDivisibility.sdiv(); |
| |
| setResultDivs(mulOp.getResult(), |
| IREE::Util::ConstantIntDivisibility(mulUDiv, mulSDiv)); |
| } |
| }; |
| |
| struct ArithDivUIInferIntDivisibilityOpInterface |
| : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< |
| ArithDivUIInferIntDivisibilityOpInterface, arith::DivUIOp> { |
| |
| void inferResultDivisibility( |
| Operation *op, ArrayRef<IREE::Util::IntegerDivisibility> argDivs, |
| IREE::Util::SetIntDivisibilityFn setResultDivs) const { |
| auto divOp = cast<arith::DivUIOp>(op); |
| |
| APInt intVal; |
| if (!matchPattern(divOp.getRhs(), m_ConstantInt(&intVal))) { |
| return; |
| } |
| |
| auto lhsDivisibility = getDivisibilityOfOperand(divOp.getLhs(), argDivs[0]); |
| |
| uint64_t divUDiv = lhsDivisibility.udiv() / intVal.getZExtValue(); |
| uint64_t divSDiv = lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue()); |
| |
| setResultDivs(divOp, IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv)); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // ValueBoundsOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| /// For some reason, this interface has to be done as an external model. |
| struct UtilAssumeIntValueBoundsOpInterface |
| : public ValueBoundsOpInterface::ExternalModel< |
| UtilAssumeIntValueBoundsOpInterface, IREE::Util::AssumeIntOp> { |
| void populateBoundsForIndexValue(Operation *op, Value value, |
| ValueBoundsConstraintSet &cstr) const { |
| auto assumeOp = cast<IREE::Util::AssumeIntOp>(op); |
| auto result = cast<OpResult>(value); |
| assert(result.getOwner() == op && "value is a result of this index op"); |
| auto [min, max] = |
| assumeOp.getUnionedUnsignedRange(result.getResultNumber()); |
| |
| if (min) { |
| cstr.bound(result) >= *min; |
| } |
| if (max) { |
| cstr.bound(result) <= *max; |
| } |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // GlobalOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| struct GlobalOpInterfaceExternalModel |
| : public IREE::Util::GlobalOpInterface::ExternalModel< |
| GlobalOpInterfaceExternalModel, ml_program::GlobalOp> { |
| Attribute getGlobalInitialValue(Operation *op) const { |
| return cast<ml_program::GlobalOp>(op).getValueAttr(); |
| } |
| void setGlobalInitialValue(Operation *op, Attribute value) const { |
| if (value) { |
| cast<ml_program::GlobalOp>(op).setValueAttr(value); |
| } else { |
| cast<ml_program::GlobalOp>(op).removeValueAttr(); |
| } |
| } |
| |
| IREE::Util::InliningPolicyAttrInterface |
| getGlobalInliningPolicy(Operation *op) const { |
| if (op->hasAttr("noinline")) |
| return IREE::Util::InlineNeverAttr::get(op->getContext()); |
| return {}; |
| } |
| void |
| setGlobalInliningPolicy(Operation *op, |
| IREE::Util::InliningPolicyAttrInterface value) const { |
| if (isa_and_nonnull<IREE::Util::InlineNeverAttr>(value)) { |
| op->setAttr("noinline", UnitAttr::get(op->getContext())); |
| } else { |
| op->removeAttr("noinline"); |
| } |
| } |
| |
| IREE::Util::GlobalLoadOpInterface createLoadOp(Operation *op, Location loc, |
| OpBuilder &builder) const { |
| auto globalOp = cast<ml_program::GlobalOp>(op); |
| if (globalOp.getIsMutable()) { |
| return cast<IREE::Util::GlobalLoadOpInterface>( |
| builder |
| .create<ml_program::GlobalLoadOp>( |
| loc, globalOp.getType(), FlatSymbolRefAttr::get(globalOp)) |
| .getOperation()); |
| } else { |
| return cast<IREE::Util::GlobalLoadOpInterface>( |
| builder |
| .create<ml_program::GlobalLoadConstOp>( |
| loc, globalOp.getType(), FlatSymbolRefAttr::get(globalOp)) |
| .getOperation()); |
| } |
| } |
| |
| IREE::Util::GlobalStoreOpInterface createStoreOp(Operation *op, Location loc, |
| Value value, |
| OpBuilder &builder) const { |
| auto globalOp = cast<ml_program::GlobalOp>(op); |
| return cast<IREE::Util::GlobalStoreOpInterface>( |
| builder |
| .create<ml_program::GlobalStoreOp>( |
| loc, FlatSymbolRefAttr::get(globalOp), value) |
| .getOperation()); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // NumericCastOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| // Since all details of the interface are provided via default implementations, |
| // we can just have one templated external model to apply per op, vs one |
| // explicit model per op. |
| struct GenericNumericCastExternalModel { |
| template <typename OpTy> |
| struct ExternalModel |
| : public IREE::Util::NumericCastOpInterface::ExternalModel< |
| ExternalModel<OpTy>, OpTy> {}; |
| |
| template <typename OpTy> |
| static void add(MLIRContext *context) { |
| OpTy::template attachInterface<ExternalModel<OpTy>>(*context); |
| } |
| |
| template <typename OpTy1, typename OpTy2, typename... More> |
| static void add(MLIRContext *context) { |
| add<OpTy1>(context); |
| add<OpTy2, More...>(context); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // TiedOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| struct InsertSliceOpTiedOpInterface |
| : public IREE::Util::TiedOpInterface::ExternalModel< |
| InsertSliceOpTiedOpInterface, tensor::InsertSliceOp> { |
| Value getTiedResult(Operation *op, unsigned resultIndex) const { |
| auto insertSliceOp = cast<tensor::InsertSliceOp>(op); |
| return IREE::Util::TiedOpInterface::findTiedBaseValue( |
| insertSliceOp.getDest()); |
| } |
| |
| ::std::optional<unsigned> |
| getTiedResultOperandIndex(Operation *op, unsigned resultIndex) const { |
| return {1}; // dest |
| } |
| |
| SmallVector<int64_t> getTiedResultOperandIndices(Operation *op) const { |
| return {1}; // dest |
| } |
| }; |
| |
| template <typename OpTy> |
| struct LinalgOpTiedOpInterface |
| : public IREE::Util::TiedOpInterface::ExternalModel< |
| LinalgOpTiedOpInterface<OpTy>, OpTy> { |
| Value getTiedResult(Operation *op, unsigned resultIndex) const { |
| auto linalgOp = cast<OpTy>(op); |
| return IREE::Util::TiedOpInterface::findTiedBaseValue( |
| linalgOp.getDpsInits()[resultIndex]); |
| } |
| |
| ::std::optional<unsigned> |
| getTiedResultOperandIndex(Operation *op, unsigned resultIndex) const { |
| auto linalgOp = cast<OpTy>(op); |
| return {linalgOp.getDpsInitsMutable()[resultIndex].getOperandNumber()}; |
| } |
| |
| SmallVector<int64_t> getTiedResultOperandIndices(Operation *op) const { |
| SmallVector<int64_t> result; |
| for (unsigned i = 0; i < op->getNumResults(); ++i) |
| result.push_back(*getTiedResultOperandIndex(op, i)); |
| return result; |
| } |
| }; |
| |
| /// Helper structure that iterates over all LinalgOps in `OpTys` and registers |
| /// the `TiedOpInterface` with each of them. |
| template <typename... Ops> |
| struct LinalgOpTiedOpInterfaceHelper { |
| static void registerOpInterface(MLIRContext *context) { |
| (void)std::initializer_list<int>{ |
| 0, |
| (Ops::template attachInterface<LinalgOpTiedOpInterface<Ops>>(*context), |
| 0)...}; |
| } |
| }; |
| |
| // TODO(Max191): Remove this interface once GPU data tiling stops using early |
| // materialization. This only exists for handling multi_mma ops before dispatch |
| // workgroups are created, which only happens with early materialization. |
| struct MultiMmaOpTiedOpInterface |
| : public IREE::Util::TiedOpInterface::ExternalModel< |
| MultiMmaOpTiedOpInterface, IREE::GPU::MultiMmaOp> { |
| Value getTiedResult(Operation *op, unsigned resultIndex) const { |
| auto linalgOp = cast<IREE::GPU::MultiMmaOp>(op); |
| return IREE::Util::TiedOpInterface::findTiedBaseValue(linalgOp.getAcc()); |
| } |
| |
| ::std::optional<unsigned> |
| getTiedResultOperandIndex(Operation *op, unsigned resultIndex) const { |
| return {2}; // acc |
| } |
| |
| SmallVector<int64_t> getTiedResultOperandIndices(Operation *op) const { |
| return {2}; // acc |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // HoistableOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| template <typename OpTy> |
| struct UnhoistableOpInterface |
| : public IREE::Util::HoistableOpInterface::ExternalModel< |
| UnhoistableOpInterface<OpTy>, OpTy> { |
| bool isHoistableOp(Operation *) const { return false; } |
| bool isHoistableLeafOp(Operation *) const { return false; } |
| }; |
| |
| template <typename OpTy> |
| struct HoistableNonLeafOpInterface |
| : public IREE::Util::HoistableOpInterface::ExternalModel< |
| HoistableNonLeafOpInterface<OpTy>, OpTy> { |
| bool isHoistableLeafOp(Operation *) const { return false; } |
| }; |
| |
| // The default interface is always hoistable. This acts as an override |
| // for other default hoistability checks as the interface is checked |
| // first. |
| template <typename OpTy> |
| struct AlwaysHoistableOpInterface |
| : public IREE::Util::HoistableOpInterface::ExternalModel< |
| AlwaysHoistableOpInterface<OpTy>, OpTy> {}; |
| |
| template <typename OpTy> |
| struct HoistableLinalgOpInterface |
| : public IREE::Util::HoistableOpInterface::ExternalModel< |
| HoistableLinalgOpInterface<OpTy>, OpTy> { |
| bool isHoistableOp(Operation *) const { return true; } |
| |
| /// FillOp and broadcasting ops are not hoistableLeaf ops, since it is |
| /// typically better to fuse them with their consumers. |
| bool isHoistableLeafOp(Operation *op) const { |
| auto genericOp = llvm::dyn_cast<linalg::GenericOp>(op); |
| if (!genericOp) |
| return !isa<linalg::FillOp>(op); |
| // Generally, we prefer to not hoist broadcasts. |
| // Detect op that only broadcast input as fusing them makes the new |
| // op cheaper. |
| if (genericOp.getNumParallelLoops() == genericOp.getNumLoops() && |
| isa<linalg::YieldOp>(genericOp.getBody()->front())) { |
| for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { |
| AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); |
| if (indexingMap.isProjectedPermutation() && |
| indexingMap.getNumDims() != indexingMap.getNumResults()) { |
| return false; |
| } |
| } |
| } |
| return !linalg::isaFillOpInterface(genericOp).has_value(); |
| } |
| bool isAtomicallyHoistableOp(Operation *) const { return true; } |
| bool isOperandHoistable(Operation *, OpOperand *) const { return true; } |
| }; |
| |
| /// Helper structures that iterates over all Op types in `OpTys` and registers |
| /// the associated Hoistable___OpInterface. |
| template <typename... Ops> |
| struct UnhoistableOpInterfaceHelper { |
| static void registerOpInterface(MLIRContext *context) { |
| (Ops::template attachInterface<UnhoistableOpInterface<Ops>>(*context), ...); |
| } |
| }; |
| |
| template <typename... Ops> |
| struct HoistableNonLeafOpInterfaceHelper { |
| static void registerOpInterface(MLIRContext *context) { |
| (Ops::template attachInterface<HoistableNonLeafOpInterface<Ops>>(*context), |
| ...); |
| } |
| }; |
| |
| template <typename... Ops> |
| struct AlwaysHoistableOpInterfaceHelper { |
| static void registerOpInterface(MLIRContext *context) { |
| (Ops::template attachInterface<AlwaysHoistableOpInterface<Ops>>(*context), |
| ...); |
| } |
| }; |
| |
| template <typename... Ops> |
| struct HoistableLinalgOpInterfaceHelper { |
| static void registerOpInterface(MLIRContext *context) { |
| (Ops::template attachInterface<HoistableLinalgOpInterface<Ops>>(*context), |
| ...); |
| } |
| }; |
| |
| } // namespace |
| |
| void registerUtilExternalModels(DialectRegistry ®istry) { |
| // Must ensure that any dependent dialects are registered. |
| registry.insert<arith::ArithDialect>(); |
| registry.insert<linalg::LinalgDialect>(); |
| registry.insert<ml_program::MLProgramDialect>(); |
| registry.insert<tensor::TensorDialect>(); |
| |
| registry.addExtension( |
| +[](MLIRContext *context, ml_program::MLProgramDialect *dialect) { |
| ml_program::GlobalOp::attachInterface<GlobalOpInterfaceExternalModel>( |
| *context); |
| }); |
| |
| registry.addExtension(+[](MLIRContext *context, |
| arith::ArithDialect *dialect) { |
| GenericNumericCastExternalModel::add< |
| arith::BitcastOp, arith::ExtFOp, arith::ExtUIOp, arith::ExtSIOp, |
| arith::FPToSIOp, arith::FPToUIOp, arith::IndexCastOp, arith::TruncFOp, |
| arith::TruncIOp, arith::SIToFPOp, arith::UIToFPOp>(context); |
| arith::ConstantOp::attachInterface< |
| ArithConstantInferIntDivisibilityOpInterface>(*context); |
| arith::MulIOp::attachInterface<ArithMulIInferIntDivisibilityOpInterface>( |
| *context); |
| arith::DivUIOp::attachInterface<ArithDivUIInferIntDivisibilityOpInterface>( |
| *context); |
| }); |
| |
| registry.addExtension( |
| +[](MLIRContext *context, tensor::TensorDialect *dialect) { |
| tensor::InsertSliceOp::attachInterface<InsertSliceOpTiedOpInterface>( |
| *context); |
| }); |
| |
| registry.addExtension(+[](MLIRContext *context, |
| IREE::GPU::IREEGPUDialect *dialect) { |
| IREE::GPU::MultiMmaOp::attachInterface<MultiMmaOpTiedOpInterface>(*context); |
| }); |
| |
| registry.addExtension( |
| +[](MLIRContext *context, linalg::LinalgDialect *dialect) { |
| // Register all Linalg structured ops. `LinalgOp` is an interface and it |
| // is not possible to attach an external interface to an existing |
| // interface. Therefore, attach the `TiedOpInterface` to all ops |
| // one-by-one. |
| LinalgOpTiedOpInterfaceHelper< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
| >::registerOpInterface(context); |
| }); |
| |
| // TODO(matthias-springer): Use a helper instead of listing all ops. This is |
| // tricky because LinalgExtOps.td includes YieldOp. |
| registry.addExtension(+[](MLIRContext *context, |
| IREE::LinalgExt::IREELinalgExtDialect *dialect) { |
| IREE::LinalgExt::ScatterOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::ScatterOp>>(*context); |
| IREE::LinalgExt::SortOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::SortOp>>(*context); |
| IREE::LinalgExt::FftOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::FftOp>>(*context); |
| IREE::LinalgExt::ScanOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::ScanOp>>(*context); |
| IREE::LinalgExt::TopkOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::TopkOp>>(*context); |
| IREE::LinalgExt::WinogradInputTransformOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::WinogradInputTransformOp>>( |
| *context); |
| IREE::LinalgExt::WinogradFilterTransformOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::WinogradFilterTransformOp>>( |
| *context); |
| IREE::LinalgExt::WinogradOutputTransformOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::WinogradOutputTransformOp>>( |
| *context); |
| IREE::LinalgExt::Im2colOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::Im2colOp>>(*context); |
| IREE::LinalgExt::AttentionOp::attachInterface< |
| LinalgOpTiedOpInterface<IREE::LinalgExt::AttentionOp>>(*context); |
| }); |
| |
| // Hoistable Op Interface registration. |
| |
| // Register hoistable type interfaces for LinalgExt ops. |
| registry.addExtension( |
| +[](MLIRContext *context, IREE::Encoding::IREEEncodingDialect *dialect) { |
| UnhoistableOpInterfaceHelper< |
| IREE::Encoding::SetEncodingOp>::registerOpInterface(context); |
| }); |
| // Register hoistable type interfaces for linalg ops. |
| // We have a specific allow-list for Linalg ops because we want to consider |
| // new additions carefully. |
| registry.addExtension( |
| +[](MLIRContext *context, linalg::LinalgDialect *dialect) { |
| // Structured op implementations and a handful of pure ops are included. |
| // Notably: IndexOp is not included because it establishes a hidden |
| // dependency to the iterator and is non-const. |
| |
| // Register all LinalgOps ops. `LinalgOp` is an interface and it is |
| // not possible to attach an external interface to an existing |
| // interface. Therefore, attach the `HoistableLinalgOpInterface` to all |
| // ops one-by-one. |
| HoistableLinalgOpInterfaceHelper< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
| >::registerOpInterface(context); |
| UnhoistableOpInterfaceHelper< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" |
| >::registerOpInterface(context); |
| }); |
| // Register hoistable type interfaces for tensor ops. |
| registry.addExtension( |
| +[](MLIRContext *context, tensor::TensorDialect *dialect) { |
| // Never hoist empty and other pure metadata ops as a leaf. It's fine to |
| // hoist them as a part of a larger constant tree that does actual work. |
| HoistableNonLeafOpInterfaceHelper< |
| tensor::EmptyOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, |
| tensor::ExtractSliceOp>::registerOpInterface(context); |
| // Cases of trivial pack/unpack should be handled as canonicalizations |
| // before we get here, thus we're safe to always hoist. |
| AlwaysHoistableOpInterfaceHelper< |
| tensor::PadOp, tensor::PackOp, |
| tensor::UnPackOp>::registerOpInterface(context); |
| }); |
| registry.addExtension( |
| +[](MLIRContext *context, IREE::Util::UtilDialect *dialect) { |
| IREE::Util::AssumeIntOp::attachInterface< |
| UtilAssumeIntValueBoundsOpInterface>(*context); |
| }); |
| } |
| |
| } // namespace mlir::iree_compiler |