| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| //===- FlattenMemRefSubspanPass.cpp - Flatten n-D MemRef subspan ----------===// |
| // |
| // This file implements a pass to flatten n-D MemRef subspan ops to 1-D MemRef |
| // ones and folds the byte offsets on subspan ops to the consumer load/store |
| // ops, in preparation for lowering to the final target. |
| // |
| // This pass is needed because of how MemRef is used by subspan ops: |
| // |
| // 1) Normally MemRef should capture the mapping to the underlying buffer with |
| // its internal strides and offsets. However, although subspan ops in IREE are |
| // subview-like constructs, they carry the offset directly on the ops themselves |
| // and return MemRefs with the identity layout map. This is due to that IREE can |
| // perform various optimizations over buffer allocation and decide, for example, |
| // to use the same underlying buffer for two MemRefs, which are converted form |
| // disjoint tensors initially. |
| // 2) The byte offset on subspan ops is an offset into the final planned 1-D |
| // byte buffer, while the MemRef can be n-D without considering a backing |
| // buffer and its data layout. |
| // |
| // So to bridge the gap, we need to linearize the MemRef dimensions to bring it |
| // onto the same view as IREE: buffers are just a bag of bytes. Then we need to |
| // fold the byte offset on subspan ops to the consumer load/store ops, so that |
| // we can rely on transformations in MLIR core, because they assume MemRefs map |
| // to the underlying buffers with its internal strides and offsets. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include <memory> |
| |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h" |
| #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Vector/VectorOps.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Dialect.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define DEBUG_TYPE "iree-flatten-memref-subspan" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Type Conversion |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns true if the given `type` is a MemRef of rank 1. |
| static bool isRankOneMemRef(Type type) { |
| if (auto memrefType = type.dyn_cast<MemRefType>()) { |
| return memrefType.hasRank() && memrefType.getRank() == 1; |
| } |
| return false; |
| } |
| |
| /// Flattens n-D MemRef to 1-D MemRef and allows other types. |
| struct FlattenMemRefTypeConverter final : public TypeConverter { |
| FlattenMemRefTypeConverter() { |
| // Allow all other types. |
| addConversion([](Type type) -> Optional<Type> { return type; }); |
| |
| // Convert n-D MemRef to 1-D MemRef. |
| addConversion([](MemRefType type) -> Optional<Type> { |
| // 1-D MemRef types are okay. |
| if (isRankOneMemRef(type)) return type; |
| |
| // Convert to a MemRef with unknown dimension. This is actually more akin |
| // to how IREE uses memref types: they are for representing a view from a |
| // byte buffer with potentially unknown total size, as transformation |
| // passes can concatenate buffers, etc. |
| return MemRefType::get(ShapedType::kDynamicSize, type.getElementType(), |
| ArrayRef<AffineMap>(), type.getMemorySpace()); |
| }); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Flattening Patterns |
| //===----------------------------------------------------------------------===// |
| |
| /// Creates a value for the total element count in `shape`, which may have |
| /// dynamic dimensions in `dynamicDims`. |
| static Value createTotalElementCountValue(ShapedType type, |
| ValueRange dynamicDims, Location loc, |
| OpBuilder &builder) { |
| MLIRContext *context = builder.getContext(); |
| |
| if (type.hasStaticShape()) { |
| assert(dynamicDims.empty()); |
| return builder.create<arith::ConstantIndexOp>(loc, type.getNumElements()); |
| } |
| |
| int dynamicDimIndex = 0; |
| SmallVector<Value, 4> dims; |
| auto shape = type.getShape(); |
| AffineExpr sizeExpr = getAffineConstantExpr(1, context); |
| for (int i = 0; i < shape.size(); ++i) { |
| sizeExpr = sizeExpr * getAffineSymbolExpr(i, context); |
| if (ShapedType::isDynamic(shape[i])) { |
| dims.push_back(dynamicDims[dynamicDimIndex++]); |
| } else { |
| dims.push_back(builder.create<arith::ConstantIndexOp>(loc, shape[i])); |
| } |
| } |
| return makeComposedAffineApply(builder, loc, sizeExpr, dims); |
| } |
| |
| // Flattens memref allocation ops with more than 1 dimensions to 1 dimension. |
| template <typename AllocOpTy> |
| struct FlattenAlloc final : public OpConversionPattern<AllocOpTy> { |
| using OpConversionPattern<AllocOpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| AllocOpTy allocOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto oldType = allocOp.getType().template dyn_cast<MemRefType>(); |
| if (!oldType || !oldType.getAffineMaps().empty()) return failure(); |
| |
| Value dynamicDim = createTotalElementCountValue( |
| oldType, allocOp.getDynamicSizes(), allocOp.getLoc(), rewriter); |
| Type newType = this->getTypeConverter()->convertType(oldType); |
| |
| rewriter.replaceOpWithNewOp<AllocOpTy>(allocOp, newType.cast<MemRefType>(), |
| ValueRange{dynamicDim}); |
| |
| return success(); |
| } |
| }; |
| |
| /// Flattens memref global ops with more than 1 dimensions to 1 dimension. |
| struct FlattenGlobal final : public OpConversionPattern<memref::GlobalOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| static Attribute flattenAttribute(Attribute value, ShapedType newType) { |
| if (!value) return value; |
| if (auto splatAttr = value.dyn_cast<SplatElementsAttr>()) { |
| return splatAttr.reshape(newType); |
| } else if (auto denseAttr = value.dyn_cast<DenseElementsAttr>()) { |
| return denseAttr.reshape(newType); |
| } |
| return {}; |
| } |
| |
| LogicalResult matchAndRewrite( |
| memref::GlobalOp globalOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto oldType = globalOp.type().dyn_cast<MemRefType>(); |
| if (!oldType || !oldType.getAffineMaps().empty()) return failure(); |
| |
| auto tensorType = RankedTensorType::get({oldType.getNumElements()}, |
| oldType.getElementType()); |
| auto memRefType = |
| MemRefType::get({oldType.getNumElements()}, oldType.getElementType(), |
| {}, oldType.getMemorySpace()); |
| auto newInitialValue = |
| flattenAttribute(globalOp.initial_valueAttr(), tensorType); |
| rewriter.replaceOpWithNewOp<memref::GlobalOp>( |
| globalOp, globalOp.sym_name(), globalOp.sym_visibilityAttr(), |
| memRefType, newInitialValue, globalOp.constant(), |
| /*alignment=*/IntegerAttr()); |
| return success(); |
| } |
| }; |
| |
| /// Flattens memref global load ops with more than 1 dimensions to 1 dimension. |
| struct FlattenGetGlobal final |
| : public OpConversionPattern<memref::GetGlobalOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| memref::GetGlobalOp getOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto oldType = getOp.getType().dyn_cast<MemRefType>(); |
| if (!oldType || !oldType.getAffineMaps().empty()) return failure(); |
| |
| auto globalOp = dyn_cast_or_null<memref::GlobalOp>( |
| SymbolTable::lookupNearestSymbolFrom(getOp, getOp.nameAttr())); |
| if (!globalOp) return failure(); |
| |
| auto loadedValue = rewriter.createOrFold<memref::GetGlobalOp>( |
| getOp.getLoc(), globalOp.type(), getOp.nameAttr()); |
| |
| auto newType = getTypeConverter()->convertType(oldType).cast<ShapedType>(); |
| rewriter.replaceOpWithNewOp<memref::CastOp>(getOp, newType, loadedValue); |
| return success(); |
| } |
| }; |
| |
| /// Flattens memref subspan ops with more than 1 dimensions to 1 dimension. |
| struct FlattenBindingSubspan final |
| : public OpConversionPattern<IREE::HAL::InterfaceBindingSubspanOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InterfaceBindingSubspanOp subspanOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto oldType = subspanOp.getType().dyn_cast<MemRefType>(); |
| // IREE subspan ops only use memref types with the default identity |
| // layout maps. |
| if (!oldType || !oldType.getAffineMaps().empty()) return failure(); |
| |
| Value dynamicDim = createTotalElementCountValue( |
| oldType, subspanOp.dynamic_dims(), subspanOp.getLoc(), rewriter); |
| Type newType = getTypeConverter()->convertType(oldType); |
| |
| rewriter.replaceOpWithNewOp<IREE::HAL::InterfaceBindingSubspanOp>( |
| subspanOp, newType, subspanOp.binding(), subspanOp.byte_offset(), |
| subspanOp.byte_length(), dynamicDim); |
| return success(); |
| } |
| }; |
| |
| /// Generates IR to perform index linearization with the given `indices` |
| /// indexing into the given memref `sourceValue`. |
| static Value linearizeIndices(Value sourceValue, ValueRange indices, |
| Location loc, OpBuilder &builder) { |
| MemRefType sourceType = sourceValue.getType().cast<MemRefType>(); |
| assert(sourceType.hasRank()); |
| |
| int64_t rank = sourceType.getRank(); |
| if (rank == 0) { |
| // For source 0-D MemRef, we convert them into 1-D MemRef with unknown |
| // dimension size. To convert its consumer load/store ops, we also need to |
| // create an index for it. |
| return builder.create<arith::ConstantIndexOp>(loc, 0); |
| } |
| |
| // First try to get the strides from the MemRef type itself. This applies to |
| // cases where we have static shapes and only the leading dimension is |
| // dynamic. |
| if (AffineMap linearLayoutMap = getStridedLinearLayoutMap(sourceType)) { |
| // Dynamic strides/offset will create symbols. There should be none for the |
| // static case. |
| if (linearLayoutMap.getNumSymbols() == 0) { |
| return makeComposedAffineApply(builder, loc, linearLayoutMap, indices); |
| } |
| } |
| |
| // Then try to see if the source op carries the dynamic dimensions itself. |
| // If so we can still get the strides for dimensions to linearize. |
| Operation *sourceOp = sourceValue.getDefiningOp(); |
| SmallVector<Value, 4> dims; |
| dims.reserve(rank); |
| if (auto shapeCarryOp = dyn_cast<ShapeCarryingInterface>(sourceOp)) { |
| Value shapeOp = |
| shapeCarryOp.buildResultValueRankedShape(sourceValue, builder); |
| for (int i = 0; i < rank; ++i) { |
| dims.push_back(builder.create<Shape::RankedDimOp>(loc, shapeOp, i)); |
| } |
| } else { |
| auto getDimValues = [&](MemRefType type, ValueRange dynamicDims) { |
| auto shape = type.getShape(); |
| int dynamicDimIndex = 0; |
| for (int i = 0; i < shape.size(); ++i) { |
| if (ShapedType::isDynamic(shape[i])) { |
| dims.push_back(dynamicDims[dynamicDimIndex++]); |
| } else { |
| dims.push_back(builder.create<arith::ConstantIndexOp>(loc, shape[i])); |
| } |
| } |
| }; |
| |
| if (auto allocOp = dyn_cast<memref::AllocOp>(sourceOp)) { |
| getDimValues(sourceType, allocOp.getDynamicSizes()); |
| } else if (auto allocaOp = dyn_cast<memref::AllocaOp>(sourceOp)) { |
| getDimValues(sourceType, allocaOp.getDynamicSizes()); |
| } else { |
| return nullptr; |
| } |
| } |
| |
| AffineExpr sym0, sym1, sym2; |
| bindSymbols(builder.getContext(), sym0, sym1, sym2); |
| MLIRContext *context = builder.getContext(); |
| auto mulAddMap = AffineMap::get(0, 3, {sym0 * sym1 + sym2}, context); |
| |
| Value linearIndex = indices.front(); |
| for (int i = 1; i < indices.size(); ++i) { |
| linearIndex = builder.create<AffineApplyOp>( |
| loc, mulAddMap, ValueRange{linearIndex, dims[i], indices[i]}); |
| } |
| return linearIndex; |
| } |
| |
| /// Linearizes indices in memref.load ops. |
| struct LinearizeLoadIndices final : public OpConversionPattern<memref::LoadOp> { |
| using OpConversionPattern<memref::LoadOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| memref::LoadOp loadOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!isRankOneMemRef(adaptor.memref().getType())) { |
| return rewriter.notifyMatchFailure( |
| loadOp, "expected converted memref of rank <= 1"); |
| } |
| |
| Value linearIndex = linearizeIndices(loadOp.memref(), loadOp.getIndices(), |
| loadOp.getLoc(), rewriter); |
| if (!linearIndex) { |
| return loadOp.emitOpError() << "failed to linearize index"; |
| } |
| |
| rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, adaptor.memref(), |
| linearIndex); |
| return success(); |
| } |
| }; |
| |
| /// Linearizes indices in memref.store ops. |
| struct LinearizeStoreIndices final |
| : public OpConversionPattern<memref::StoreOp> { |
| using OpConversionPattern<memref::StoreOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| memref::StoreOp storeOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!isRankOneMemRef(adaptor.memref().getType())) { |
| return rewriter.notifyMatchFailure( |
| storeOp, "expected converted memref of rank <= 1"); |
| } |
| |
| Value linearIndex = linearizeIndices(storeOp.memref(), storeOp.getIndices(), |
| storeOp.getLoc(), rewriter); |
| if (!linearIndex) { |
| return storeOp.emitOpError() << "failed to linearize index"; |
| } |
| |
| rewriter.replaceOpWithNewOp<memref::StoreOp>(storeOp, adaptor.value(), |
| adaptor.memref(), linearIndex); |
| return success(); |
| } |
| }; |
| |
| /// Linearizes indices in vector.transfer_read ops. |
| struct LinearizeTransferReadIndices final |
| : public OpConversionPattern<vector::TransferReadOp> { |
| using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| vector::TransferReadOp transferReadOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!transferReadOp.permutation_map().isMinorIdentity()) { |
| return rewriter.notifyMatchFailure( |
| transferReadOp, "cannot convert op with non-minor identity map"); |
| } |
| if (!isRankOneMemRef(adaptor.source().getType())) { |
| return rewriter.notifyMatchFailure( |
| transferReadOp, "expected converted memref of rank <= 1"); |
| } |
| Value linearIndex = |
| linearizeIndices(transferReadOp.source(), transferReadOp.indices(), |
| transferReadOp.getLoc(), rewriter); |
| if (!linearIndex) { |
| return transferReadOp.emitOpError() << "failed to linearize index"; |
| } |
| |
| rewriter.replaceOpWithNewOp<vector::TransferReadOp>( |
| transferReadOp, transferReadOp.getVectorType(), adaptor.source(), |
| linearIndex, rewriter.getDimIdentityMap(), transferReadOp.padding(), |
| transferReadOp.in_boundsAttr()); |
| return success(); |
| } |
| }; |
| |
| /// Linearizes indices in vector.transfer_write ops. |
| struct LinearizeTransferWriteIndices final |
| : public OpConversionPattern<vector::TransferWriteOp> { |
| using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| vector::TransferWriteOp transferWriteOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!transferWriteOp.permutation_map().isMinorIdentity()) { |
| return rewriter.notifyMatchFailure( |
| transferWriteOp, "cannot convert op with non-minor identity map"); |
| } |
| if (!isRankOneMemRef(adaptor.source().getType())) { |
| return rewriter.notifyMatchFailure( |
| transferWriteOp, "expected converted memref of rank <= 1"); |
| } |
| Value linearIndex = |
| linearizeIndices(transferWriteOp.source(), transferWriteOp.indices(), |
| transferWriteOp.getLoc(), rewriter); |
| if (!linearIndex) { |
| return transferWriteOp.emitOpError() << "failed to linearize index"; |
| } |
| |
| rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
| transferWriteOp, adaptor.vector(), adaptor.source(), linearIndex, |
| rewriter.getDimIdentityMap(), transferWriteOp.in_boundsAttr()); |
| return success(); |
| } |
| }; |
| |
| /// Adjusts unrealized_conversion_cast ops' inputs to flattened memref values. |
| struct AdjustConversionCast final |
| : public OpConversionPattern<UnrealizedConversionCastOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| UnrealizedConversionCastOp castOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| if (castOp->getNumOperands() != 1) return failure(); |
| |
| Value input = operands.front(); |
| // We only want to handle cases where the cast op handles memref types. |
| if (!input.getType().isa<BaseMemRefType>()) return failure(); |
| |
| if (!isRankOneMemRef(input.getType())) { |
| return rewriter.notifyMatchFailure( |
| castOp, "expected converted memref of rank <= 1"); |
| } |
| rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>( |
| castOp, castOp.getResultTypes(), input); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Folding Patterns |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns the number of bytes of the given `type`. Returns llvm::None if |
| /// cannot deduce. |
| /// |
| /// Note that this should be kept consistent with how the byte offset was |
| /// calculated in the subspan ops! |
| Optional<int64_t> getNumBytes(Type type) { |
| if (type.isIntOrFloat()) return (type.getIntOrFloatBitWidth() + 7) / 8; |
| if (auto vectorType = type.dyn_cast<VectorType>()) { |
| auto elementBytes = getNumBytes(vectorType.getElementType()); |
| if (!elementBytes) return llvm::None; |
| return elementBytes.getValue() * vectorType.getNumElements(); |
| } |
| return llvm::None; |
| } |
| |
| /// Folds the byte offset on subspan ops into the consumer load/store ops. |
| template <typename OpType> |
| struct FoldSubspanOffsetIntoLoadStore final : public OpRewritePattern<OpType> { |
| using OpRewritePattern<OpType>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(OpType op, |
| PatternRewriter &rewriter) const override { |
| auto memrefType = op.memref().getType().template cast<MemRefType>(); |
| if (!isRankOneMemRef(memrefType)) { |
| return rewriter.notifyMatchFailure(op, "expected 1-D memref"); |
| } |
| |
| auto subspanOp = |
| op.memref() |
| .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>(); |
| if (!subspanOp) return failure(); |
| |
| // If the subspan op has a zero byte offset then we are done. |
| if (matchPattern(subspanOp.byte_offset(), m_Zero())) return failure(); |
| // Byte length is unsupported for now. |
| if (subspanOp.byte_length()) { |
| return rewriter.notifyMatchFailure(op, "byte length unsupported"); |
| } |
| |
| // Calculate the offset we need to add to the load/store op, in terms of how |
| // many elements. |
| Optional<int64_t> numBytes = getNumBytes(memrefType.getElementType()); |
| if (!numBytes) { |
| return rewriter.notifyMatchFailure(op, |
| "cannot deduce element byte count"); |
| } |
| // Create a new subspan op with zero byte offset at the original location. |
| auto ip = rewriter.saveInsertionPoint(); |
| rewriter.setInsertionPointAfter(subspanOp); |
| Value zero = |
| rewriter.create<arith::ConstantIndexOp>(op.memref().getLoc(), 0); |
| Value newSubspan = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>( |
| op.memref().getLoc(), subspanOp.getType(), subspanOp.binding(), zero, |
| subspanOp.byte_length(), subspanOp.dynamic_dims()); |
| rewriter.restoreInsertionPoint(ip); |
| |
| MLIRContext *context = rewriter.getContext(); |
| AffineExpr sym0, sym1; |
| bindSymbols(context, sym0, sym1); |
| auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context); |
| auto divMap = AffineMap::get(0, 2, {sym0.floorDiv(sym1)}, context); |
| |
| Value byteValue = rewriter.create<arith::ConstantIndexOp>( |
| op.memref().getLoc(), numBytes.getValue()); |
| // We assume that upper layers guarantee the byte offset is perfectly |
| // divisible by the element byte count so the content is well aligned. |
| Value offset = rewriter.create<AffineApplyOp>( |
| op.getLoc(), divMap, ValueRange{subspanOp.byte_offset(), byteValue}); |
| |
| // Get the new index by adding the old index with the offset. |
| Value newIndex = rewriter.create<AffineApplyOp>( |
| op.getLoc(), addMap, ValueRange{op.indices().front(), offset}); |
| |
| if (std::is_same<OpType, memref::LoadOp>::value) { |
| rewriter.replaceOpWithNewOp<memref::LoadOp>( |
| op, memrefType.getElementType(), ValueRange{newSubspan, newIndex}); |
| } else { |
| rewriter.replaceOpWithNewOp<memref::StoreOp>( |
| op, TypeRange{}, ValueRange{op.getOperand(0), newSubspan, newIndex}); |
| } |
| |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Pass |
| //===----------------------------------------------------------------------===// |
| |
| struct FlattenMemRefSubspanPass |
| : public FlattenMemRefSubspanBase<FlattenMemRefSubspanPass> { |
| FlattenMemRefSubspanPass() {} |
| FlattenMemRefSubspanPass(const FlattenMemRefSubspanPass &pass) {} |
| |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<AffineDialect, memref::MemRefDialect, ShapeDialect>(); |
| } |
| |
| void runOnOperation() override { |
| // First flatten the dimensions of subspan op and their consumer load/store |
| // ops. This requires setting up conversion targets with type converter. |
| |
| MLIRContext &context = getContext(); |
| FlattenMemRefTypeConverter typeConverter; |
| RewritePatternSet flattenPatterns(&context); |
| flattenPatterns |
| .add<FlattenAlloc<memref::AllocaOp>, FlattenAlloc<memref::AllocOp>, |
| FlattenGlobal, FlattenGetGlobal, FlattenBindingSubspan, |
| LinearizeLoadIndices, LinearizeStoreIndices, |
| LinearizeTransferReadIndices, LinearizeTransferWriteIndices, |
| AdjustConversionCast>(typeConverter, &context); |
| |
| ConversionTarget target(context); |
| target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
| target.addDynamicallyLegalOp<IREE::HAL::InterfaceBindingSubspanOp, |
| memref::AllocaOp, memref::AllocOp, |
| memref::GetGlobalOp>([](Operation *op) { |
| return isRankOneMemRef(op->getResultTypes().front()); |
| }); |
| target.addDynamicallyLegalOp<memref::GlobalOp>( |
| [](memref::GlobalOp op) { return isRankOneMemRef(op.type()); }); |
| target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp loadOp) { |
| return isRankOneMemRef(loadOp.getMemRefType()); |
| }); |
| target.addDynamicallyLegalOp<memref::StoreOp>([](memref::StoreOp storeOp) { |
| return isRankOneMemRef(storeOp.getMemRefType()); |
| }); |
| target.addDynamicallyLegalOp<vector::TransferReadOp>( |
| [](vector::TransferReadOp readOp) { |
| return isRankOneMemRef(readOp.source().getType().cast<MemRefType>()); |
| }); |
| target.addDynamicallyLegalOp<vector::TransferWriteOp>( |
| [](vector::TransferWriteOp writeOp) { |
| return isRankOneMemRef(writeOp.source().getType().cast<MemRefType>()); |
| }); |
| target.addDynamicallyLegalOp<UnrealizedConversionCastOp>( |
| [](UnrealizedConversionCastOp castOp) { |
| if (castOp->getNumOperands() != 1) return false; |
| |
| Type inputType = castOp->getOperandTypes().front(); |
| return !inputType.isa<BaseMemRefType>() || isRankOneMemRef(inputType); |
| }); |
| |
| // Use partial conversion here so that we can ignore allocations created by |
| // promotion and their load/store ops. |
| if (failed(applyPartialConversion(getOperation(), target, |
| std::move(flattenPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| // Then fold byte offset on subspan ops into consumer load/store ops. |
| |
| RewritePatternSet foldPatterns(&context); |
| foldPatterns.add<FoldSubspanOffsetIntoLoadStore<memref::LoadOp>, |
| FoldSubspanOffsetIntoLoadStore<memref::StoreOp>>(&context); |
| |
| (void)applyPatternsAndFoldGreedily(getOperation(), std::move(foldPatterns)); |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<ModuleOp>> createFlattenMemRefSubspanPass() { |
| return std::make_unique<FlattenMemRefSubspanPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |