| // Copyright 2025 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/ExternalInterfaces/EncodingExternalModels.h" |
| #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" |
| #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" |
| #include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" |
| #include "iree/compiler/Dialect/Encoding/Utils/Utils.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Builders.h" |
| |
| #define DEBUG_TYPE "iree-encoding-external-models" |
| |
| namespace mlir::iree_compiler { |
| namespace { |
| |
| /// Propagate an encoding through an "encoding castable" op. Encoding castable |
| /// means that the op can be encoded by casting its types to the encoded types. |
| /// This transform adds iree_encoding.set_encoding ops to the operands of the |
| /// `op`, and clones the `op` with the new encoded operands and encoded result |
| /// types. If the `propagationSource` comes froman iree_encoding.unset_encoding |
| /// op, and it is consumed by the `op`, then take the source of the unset |
| /// encoding instead of re-setting the encoding. If the `propagationSource` is |
| /// produced by the `op`, then do not unset the encoding after cloning the op, |
| /// because the encoded result will be used for propagation. |
| /// |
| /// Use this function for ops that: |
| /// 1. Are encoded by casting their types to the encoded types. |
| /// 2. Are able to directly use the source of any producer unset_encoding ops |
| /// for propagation, and do not need to re-set the encoding. |
| static FailureOr<IREE::Encoding::PropagationResult> |
| propagateThroughEncodingCastableOp( |
| RewriterBase &builder, Operation *op, |
| IREE::Encoding::PropagationEncoding encodings, |
| OpOperand *propagationSource) { |
| SmallVector<Value> encodedOperands; |
| IREE::Encoding::PropagationResult result; |
| auto maybeUnsetEncodingProducer = |
| propagationSource->get().getDefiningOp<IREE::Encoding::UnsetEncodingOp>(); |
| OpBuilder::InsertionGuard guard(builder); |
| builder.setInsertionPoint(op); |
| |
| // Try to rematerialize encoding dims at the new insertion point. |
| // For bubble-up propagation through encoding castable ops (tensor.cast, |
| // tensor.collapse_shape, etc.), tensor.dim ops on the result can be |
| // recreated from the source operand. These ops have a single tensor |
| // input as their first operand. |
| Value sourceOperand = op->getOperand(0); |
| FailureOr<SmallVector<Value>> rematerializedDims = |
| IREE::Encoding::rematerializeEncodingDims( |
| builder, op, encodings.encodingDims, propagationSource->get(), |
| sourceOperand); |
| if (failed(rematerializedDims)) { |
| return failure(); |
| } |
| encodings.encodingDims = *rematerializedDims; |
| |
| for (auto [operand, encoding] : |
| llvm::zip_equal(op->getOperands(), encodings.operandEncodings)) { |
| // Scalar operands do not need encodings. |
| auto operandType = dyn_cast<RankedTensorType>(operand.getType()); |
| if (!operandType) { |
| encodedOperands.push_back(operand); |
| continue; |
| } |
| // If the operand comes from the provided producer unset_encoding op, then |
| // we don't need to set the encoding again, because this operand is the |
| // source of propagation. |
| if (operand == propagationSource->get() && maybeUnsetEncodingProducer) { |
| encodedOperands.push_back(maybeUnsetEncodingProducer.getSource()); |
| continue; |
| } |
| auto encodedOperandType = operandType.cloneWithEncoding(encoding); |
| // Special case for tensor.empty ops, which can simply be cloned with the |
| // encoding, instead of creating a new set_encoding op. |
| if (auto emptyOp = operand.getDefiningOp<tensor::EmptyOp>()) { |
| auto encodedEmptyOp = tensor::EmptyOp::create( |
| builder, op->getLoc(), encodedOperandType.getShape(), |
| encodedOperandType.getElementType(), emptyOp.getDynamicSizes(), |
| encoding); |
| encodedOperands.push_back(encodedEmptyOp.getResult()); |
| continue; |
| } |
| // Otherwise, we need to create a new set_encoding op. |
| auto setEncodingOp = IREE::Encoding::SetEncodingOp::create( |
| builder, op->getLoc(), encodedOperandType, operand, |
| encodings.encodingDims); |
| encodedOperands.push_back(setEncodingOp.getResult()); |
| result.generatedEncodingOps.push_back(setEncodingOp); |
| } |
| SmallVector<Type> encodedResultTypes; |
| for (auto [result, encoding] : |
| llvm::zip_equal(op->getResults(), encodings.resultEncodings)) { |
| auto resultType = cast<RankedTensorType>(result.getType()); |
| auto encodedResultType = resultType.cloneWithEncoding(encoding); |
| encodedResultTypes.push_back(encodedResultType); |
| } |
| Operation *encodedOp = |
| clone(builder, op, encodedResultTypes, encodedOperands); |
| for (OpResult encodedResult : encodedOp->getOpResults()) { |
| // If this encoded result is coming from the source of propagation, we want |
| // to return the encoded result. |
| OpResult originalResult = op->getOpResult(encodedResult.getResultNumber()); |
| if (originalResult == propagationSource->get()) { |
| result.replacements.push_back(encodedResult); |
| continue; |
| } |
| // Otherwise, we need to unset the encoding so the types are consistent with |
| // the other results' users. |
| SmallVector<OpFoldResult> mixedSizes = |
| tensor::getMixedSizes(builder, op->getLoc(), encodedResult); |
| SmallVector<Value> resultDynamicDims; |
| std::tie(std::ignore, resultDynamicDims) = decomposeMixedValues(mixedSizes); |
| auto unsetEncodingOp = IREE::Encoding::UnsetEncodingOp::create( |
| builder, op->getLoc(), originalResult.getType(), encodedResult, |
| resultDynamicDims, encodings.encodingDims); |
| result.generatedEncodingOps.push_back(unsetEncodingOp); |
| result.replacements.push_back(unsetEncodingOp.getResult()); |
| } |
| return result; |
| } |
| |
| struct EncodingAttrPropagationInterface final |
| : IREE::Encoding::EncodingPropagationAttrInterface::ExternalModel< |
| EncodingAttrPropagationInterface, IREE::Encoding::EncodingAttr> { |
| bool isPropagableDown(Attribute attr, OpOperand *target) const { |
| return TypeSwitch<Operation *, bool>(target->getOwner()) |
| .Case([&](linalg::GenericOp genericOp) { |
| // Only support parallel generic ops. |
| if (genericOp.getNumReductionLoops() != 0) { |
| return false; |
| } |
| // The unset encoding should not be on one of the inits. |
| if (genericOp.isDpsInit(target)) { |
| return false; |
| } |
| // Non-projected permutation indexing maps will unlikely get lowered |
| // correctly with the encoding. |
| if (llvm::any_of(genericOp->getOpOperands(), [&](OpOperand &operand) { |
| AffineMap map = genericOp.getMatchingIndexingMap(&operand); |
| return !map.isProjectedPermutation(); |
| })) { |
| return false; |
| } |
| // Only support permutation for now. Projected permutations mean that |
| // there are some broadcast dimensions, and it is unclear how to |
| // represent encodings for this case. Bail out for now. |
| if (!genericOp.getMatchingIndexingMap(target).isPermutation()) { |
| return false; |
| } |
| return true; |
| }) |
| .Default(false); |
| } |
| |
| FailureOr<IREE::Encoding::PropagationEncoding> |
| generateSinkingEncodings(Attribute attr, OpOperand *target) const { |
| auto encoding = cast<IREE::Encoding::EncodingAttr>(attr); |
| return TypeSwitch<Operation *, |
| FailureOr<IREE::Encoding::PropagationEncoding>>( |
| target->getOwner()) |
| .Case([&](linalg::GenericOp genericOp) { |
| IREE::Encoding::PropagationEncoding propEncoding; |
| propEncoding.operandEncodings.reserve(genericOp->getNumOperands()); |
| // Append the target and respective operand's indexing maps to the |
| // encoding's indexing maps to create the new encoding. |
| AffineMap invTargetIndexingMap = mlir::inversePermutation( |
| genericOp.getMatchingIndexingMap(target)); |
| auto createNewEncoding = |
| [&](AffineMap operandMap) -> IREE::Encoding::EncodingAttr { |
| IREE::Encoding::EncodingAttr newEncoding = encoding; |
| if (!invTargetIndexingMap.isIdentity()) { |
| newEncoding = newEncoding.cloneWithNewOperandIndexingMap( |
| invTargetIndexingMap); |
| } |
| if (!operandMap.isIdentity()) { |
| newEncoding = |
| newEncoding.cloneWithNewOperandIndexingMap(operandMap); |
| } |
| return newEncoding; |
| }; |
| for (OpOperand *operand : genericOp.getDpsInputOperands()) { |
| if (operand != target) { |
| AffineMap operandMap = genericOp.getMatchingIndexingMap(operand); |
| IREE::Encoding::EncodingAttr newEncoding = |
| createNewEncoding(operandMap); |
| propEncoding.operandEncodings.push_back(newEncoding); |
| } else { |
| propEncoding.operandEncodings.push_back(encoding); |
| } |
| } |
| for (OpOperand &operand : genericOp.getDpsInitsMutable()) { |
| AffineMap operandMap = genericOp.getMatchingIndexingMap(&operand); |
| IREE::Encoding::EncodingAttr newEncoding = |
| createNewEncoding(operandMap); |
| propEncoding.operandEncodings.push_back(newEncoding); |
| propEncoding.resultEncodings.push_back(newEncoding); |
| } |
| return propEncoding; |
| }) |
| .Default(failure()); |
| } |
| }; |
| |
| struct LayoutAttrPropagationInterface final |
| : IREE::Encoding::EncodingPropagationAttrInterface::ExternalModel< |
| LayoutAttrPropagationInterface, IREE::Encoding::LayoutAttr> { |
| bool isPropagableUp(Attribute attr, OpResult target) const { |
| auto layoutAttr = cast<IREE::Encoding::LayoutAttr>(attr); |
| return TypeSwitch<Operation *, bool>(target.getOwner()) |
| .Case([&](tensor::CastOp castOp) { |
| // CastOp is propagable if it is casting between compatible shapes, |
| // because the dimensions need to be consistent with the |
| // user_indexing_maps carried by the encoding. The tensor.cast op |
| // verifier already guarantees that the shapes are compatible. |
| return layoutAttr.isSerialized(); |
| }) |
| .Default(false); |
| } |
| |
| FailureOr<IREE::Encoding::PropagationEncoding> |
| generateBubblingEncodings(Attribute attr, OpResult target) const { |
| auto encoding = cast<IREE::Encoding::LayoutAttr>(attr); |
| return TypeSwitch<Operation *, |
| FailureOr<IREE::Encoding::PropagationEncoding>>( |
| target.getOwner()) |
| .Case([&](tensor::CastOp) { |
| IREE::Encoding::PropagationEncoding propEncoding; |
| propEncoding.resultEncodings.push_back(encoding); |
| propEncoding.operandEncodings.push_back(encoding); |
| return propEncoding; |
| }) |
| .Default(failure()); |
| } |
| }; |
| |
| template <typename OpTy> |
| struct EncodingCastableOpPropagationInterface final |
| : IREE::Encoding::EncodingPropagationOpInterface::ExternalModel< |
| EncodingCastableOpPropagationInterface<OpTy>, OpTy> { |
| FailureOr<IREE::Encoding::PropagationResult> |
| propagateEncoding(Operation *op, RewriterBase &rewriter, |
| IREE::Encoding::PropagationEncoding encodings, |
| OpOperand *propagationSource) const { |
| return propagateThroughEncodingCastableOp(rewriter, op, encodings, |
| propagationSource); |
| } |
| }; |
| |
| /// Helper structures that iterates over all Op types in `OpTys` and registers |
| /// the associated EncodingPropagationOpInterface. |
| template <typename... Ops> |
| struct EncodingCastableOpPropagationInterfaceHelper { |
| static void registerOpInterface(MLIRContext *context) { |
| (Ops::template attachInterface<EncodingCastableOpPropagationInterface<Ops>>( |
| *context), |
| ...); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // LinalgFusionOpInterface for Encoding Ops |
| //===----------------------------------------------------------------------===// |
| |
| /// External model for encoding ops implementing LinalgFusionOpInterface. |
| /// All dimensions are parallel with identity indexing maps. |
| template <typename ConcreteType> |
| struct EncodingFusionOpInterfaceAdapter |
| : IREE::LinalgExt::LinalgFusionOpInterface::ExternalModel< |
| EncodingFusionOpInterfaceAdapter<ConcreteType>, ConcreteType> { |
| SmallVector<AffineMap> getIndexingMapsForOperands(Operation *op) const { |
| int64_t rank = cast<ConcreteType>(op).getResultType().getRank(); |
| return {AffineMap::getMultiDimIdentityMap(rank, op->getContext())}; |
| } |
| |
| SmallVector<AffineMap> getIndexingMapsForResults(Operation *op) const { |
| int64_t rank = cast<ConcreteType>(op).getResultType().getRank(); |
| return {AffineMap::getMultiDimIdentityMap(rank, op->getContext())}; |
| } |
| |
| SmallVector<AffineMap> getIndexingMapsArray(Operation *op) const { |
| auto operandMaps = getIndexingMapsForOperands(op); |
| llvm::append_range(operandMaps, getIndexingMapsForResults(op)); |
| return operandMaps; |
| } |
| |
| unsigned getNumParallelLoops(Operation *op) const { |
| return cast<ConcreteType>(op).getResultType().getRank(); |
| } |
| |
| unsigned getNumLoops(Operation *op) const { |
| return cast<ConcreteType>(op).getResultType().getRank(); |
| } |
| |
| SmallVector<int64_t> getStaticLoopRanges(Operation *op) const { |
| auto type = cast<ConcreteType>(op).getResultType(); |
| return SmallVector<int64_t>(type.getShape()); |
| } |
| |
| AffineMap getIndexingMapMatchingResult(Operation *op, OpResult result) const { |
| assert(result.getOwner() == op); |
| return getIndexingMapsForResults(op)[result.getResultNumber()]; |
| } |
| |
| AffineMap getMatchingIndexingMap(Operation *op, OpOperand *opOperand) const { |
| assert(opOperand->getOwner() == op); |
| return getIndexingMapsArray(op)[opOperand->getOperandNumber()]; |
| } |
| }; |
| |
| } // namespace |
| |
| void registerEncodingExternalModels(DialectRegistry ®istry) { |
| registry.addExtension(+[](MLIRContext *ctx, |
| IREE::Encoding::IREEEncodingDialect *dialect) { |
| IREE::Encoding::EncodingAttr::attachInterface< |
| EncodingAttrPropagationInterface>(*ctx); |
| IREE::Encoding::LayoutAttr::attachInterface<LayoutAttrPropagationInterface>( |
| *ctx); |
| IREE::Encoding::SetEncodingOp::attachInterface< |
| EncodingFusionOpInterfaceAdapter<IREE::Encoding::SetEncodingOp>>(*ctx); |
| IREE::Encoding::UnsetEncodingOp::attachInterface< |
| EncodingFusionOpInterfaceAdapter<IREE::Encoding::UnsetEncodingOp>>( |
| *ctx); |
| }); |
| registry.addExtension( |
| +[](MLIRContext *ctx, mlir::tensor::TensorDialect *dialect) { |
| EncodingCastableOpPropagationInterfaceHelper< |
| tensor::CollapseShapeOp, tensor::CastOp>::registerOpInterface(ctx); |
| }); |
| registry.addExtension( |
| +[](MLIRContext *ctx, mlir::linalg::LinalgDialect *dialect) { |
| EncodingCastableOpPropagationInterfaceHelper< |
| linalg::GenericOp>::registerOpInterface(ctx); |
| }); |
| } |
| |
| } // namespace mlir::iree_compiler |