| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" |
| #include "iree/compiler/Dialect/HAL/Target/TargetBackend.h" |
| #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" |
| #include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h" |
| #include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h" |
| #include "iree/compiler/Dialect/Shape/IR/Builders.h" |
| #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir/Analysis/Liveness.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/AsmState.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| #define DEBUG_TYPE "iree-hal" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Alias analysis |
| //===----------------------------------------------------------------------===// |
| |
| using ValueAliasingMap = llvm::MapVector<Value, SmallPtrSet<Value, 16>>; |
| |
| // Builds a map of value aliases from aliasee to a set of aliasers. |
| // Only values that alias will be present in the map. |
| static ValueAliasingMap computeValueAliases( |
| IREE::Flow::ExStreamFragmentOp streamOp) { |
| auto *streamBlock = &streamOp.body().front(); |
| ValueAliasingMap valueAliases; |
| |
| std::function<void(Value streamValue, Value aliasedValue)> propagateAlias; |
| propagateAlias = [&](Value streamValue, Value aliasedValue) { |
| auto &baseSet = valueAliases[streamValue]; |
| baseSet.insert(aliasedValue); |
| auto &aliasedSet = valueAliases[aliasedValue]; |
| baseSet.insert(aliasedSet.begin(), aliasedSet.end()); |
| aliasedSet.insert(streamValue); |
| }; |
| |
| // Start with outputs so that we handle tied values that may lead all the way |
| // back up the chain to the stream inputs. |
| auto tiedStreamOp = |
| cast<IREE::Util::TiedOpInterface>(streamOp.getOperation()); |
| auto returnOp = cast<IREE::Flow::ReturnOp>(streamBlock->back()); |
| for (auto result : llvm::enumerate(streamOp.getResults())) { |
| auto streamValue = returnOp.getOperand(result.index()); |
| |
| // Tied stream results reuse their stream operand buffer. |
| auto tiedOperandIndex = |
| tiedStreamOp.getTiedResultOperandIndex(result.index()); |
| if (tiedOperandIndex.hasValue()) { |
| auto operand = streamBlock->getArgument(tiedOperandIndex.getValue()); |
| propagateAlias(streamValue, operand); |
| } |
| } |
| |
| for (auto &op : *streamBlock) { |
| auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op); |
| for (auto it : llvm::enumerate(op.getResults())) { |
| auto result = it.value(); |
| if (!result.getType().isa<ShapedType>()) continue; |
| |
| // Tied results reuse their operand buffer. |
| if (tiedOp) { |
| auto tiedOperandIndex = tiedOp.getTiedResultOperandIndex(it.index()); |
| if (tiedOperandIndex.hasValue()) { |
| auto operand = op.getOperand(tiedOperandIndex.getValue()); |
| propagateAlias(result, operand); |
| } |
| } |
| } |
| } |
| |
| // Inverse the value aliaser->aliasee map so that we have for any particular |
| // value the list of all other values that alias it. |
| for (auto it : valueAliases) { |
| for (auto aliasee : it.second) { |
| for (auto aliaser : it.second) { |
| if (aliaser != aliasee) { |
| valueAliases[aliasee].insert(aliaser); |
| } |
| } |
| } |
| } |
| |
| return valueAliases; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Liveness interval analysis |
| //===----------------------------------------------------------------------===// |
| |
| static constexpr int LIVE_IN = INT_MIN; |
| static constexpr int LIVE_OUT = INT_MAX; |
| struct LivenessInterval { |
| int start = 0; |
| int end = 0; |
| int ordinal = -1; // unique per value |
| Value value; |
| bool operator<(const LivenessInterval &rhs) const { |
| return ordinal < rhs.ordinal; |
| } |
| }; |
| using LivenessIntervalMap = DenseMap<Value, LivenessInterval>; |
| using LivenessIntervalList = SmallVector<LivenessInterval>; |
| |
| // Computes the liveness intervals for each value in the stream. |
| // Returns a closed range over an arbitrary operation ordering. The LIVE_IN and |
| // LIVE_OUT sentinels will be used to indicate values that are live-in and |
| // live-out to the stream (captured input arguments and escaping output |
| // results). |
| // |
| // All values will have a range with aliased values sharing the union of their |
| // constituent ranges - including block arguments. Note that not all values will |
| // have buffers allocated to them - we are just tracking transitive SSA value |
| // lifetime. |
| static LivenessIntervalList computeLivenessIntervals( |
| IREE::Flow::ExStreamFragmentOp streamOp, |
| const ValueAliasingMap &valueAliases) { |
| // Perform a liveness analysis on the stream fragment. |
| // Fragments have a single block and as such the live-in/live-out block |
| // information derived here applies to the entire stream region. |
| assert(streamOp.body().getBlocks().size() == 1); |
| auto *streamBlock = &streamOp.body().front(); |
| Liveness streamLiveness(streamOp); |
| auto *livenessInfo = streamLiveness.getLiveness(streamBlock); |
| |
| // Operations don't allow us to get their already computed order so we make up |
| // our own. We have a single block and thus the ordering is complete. |
| DenseMap<Operation *, int> opOrdering; |
| for (auto &op : *streamBlock) { |
| opOrdering[&op] = opOrdering.size(); |
| } |
| |
| // Liveness doesn't track return values as live-outs so we do that here. |
| SmallPtrSet<Value, 16> liveOuts; |
| auto returnOp = cast<IREE::Flow::ReturnOp>(streamBlock->back()); |
| for (auto returnValue : returnOp.operands()) { |
| if (!returnValue.getType().isa<ShapedType>()) continue; |
| liveOuts.insert(returnValue); |
| } |
| |
| // Compute live-in intervals special as we won't catch them in the op walk |
| // below as they are block arguments. |
| LivenessIntervalMap valueIntervals; |
| int ordinal = 0; |
| for (Value value : streamBlock->getArguments()) { |
| if (!value.getType().isa<ShapedType>()) continue; |
| LivenessInterval interval; |
| interval.start = LIVE_IN; |
| if (liveOuts.contains(value)) { |
| interval.end = LIVE_OUT; |
| } else { |
| auto *endOp = livenessInfo->getEndOperation(value, &streamBlock->front()); |
| interval.end = opOrdering[endOp]; |
| } |
| interval.value = value; |
| interval.ordinal = ++ordinal; |
| valueIntervals[value] = interval; |
| } |
| |
| // Compute ranges for all values independently (ignoring aliasing). |
| for (auto &op : *streamBlock) { |
| int start = opOrdering[&op]; |
| for (auto value : op.getResults()) { |
| if (!value.getType().isa<ShapedType>()) continue; |
| LivenessInterval interval; |
| interval.start = start; |
| if (liveOuts.contains(value)) { |
| interval.end = LIVE_OUT; |
| } else { |
| interval.end = start; |
| for (auto &use : value.getUses()) { |
| interval.end = std::max(interval.end, opOrdering[use.getOwner()]); |
| } |
| } |
| interval.value = value; |
| interval.ordinal = ++ordinal; |
| valueIntervals[value] = interval; |
| } |
| } |
| |
| // Walk the alias map and union intervals and propagate back. |
| for (auto it : valueAliases) { |
| auto &aliasee = it.first; |
| auto &aliasers = it.second; |
| auto &aliaseeInterval = valueIntervals[aliasee]; |
| int start = aliaseeInterval.start; |
| int end = aliaseeInterval.end; |
| for (auto aliaser : aliasers) { |
| auto &aliaserInterval = valueIntervals[aliaser]; |
| start = std::min(start, aliaserInterval.start); |
| end = std::max(end, aliaserInterval.end); |
| } |
| aliaseeInterval.start = start; |
| aliaseeInterval.end = end; |
| for (auto aliaser : aliasers) { |
| auto &aliaserInterval = valueIntervals[aliaser]; |
| aliaserInterval.start = start; |
| aliaserInterval.end = end; |
| } |
| } |
| |
| // Sort all intervals by lifetime start. This makes the intervals easier to |
| // read and deterministic across runs. |
| SmallVector<LivenessInterval> sortedIntervals; |
| for (auto it : valueIntervals) { |
| sortedIntervals.push_back(it.second); |
| } |
| std::sort(sortedIntervals.begin(), sortedIntervals.end()); |
| return sortedIntervals; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Stateful recording storage |
| //===----------------------------------------------------------------------===// |
| |
| struct BufferRange { |
| BufferRange() = default; |
| explicit BufferRange(Value buffer, Value length) |
| : buffer(buffer), length(length) {} |
| |
| Value buffer = nullptr; |
| Value length = nullptr; |
| }; |
| |
| // State cache used during stream scheduling. |
| // |
| // This contains caches used to memoize commonly occuring values such as |
| // variable loads, shapes, and computed sizes. These caches are just to lighten |
| // the load on cse/canonicalization and otherwise it would be fine to |
| // materialize IR for everything. |
| // |
| // Any allocations made are also tracked here so that the tensor->!hal.buffer |
| // mappings are available at any time a buffer may be required during the |
| // scheduling. |
| class StreamSchedulingState { |
| public: |
| explicit StreamSchedulingState(Location loc, Value device, Value allocator, |
| ValueAliasingMap &valueAliases) |
| : loc(loc), |
| device_(device), |
| allocator_(allocator), |
| valueAliases(valueAliases) {} |
| |
| Value device() { return device_; } |
| Value allocator() { return allocator_; } |
| |
| // Returns a arith::ConstantIndexOp of |value|. |
| Value lookupOrCreateIndex(int64_t value, OpBuilder &builder) { |
| auto it = indexConstantMap.find(value); |
| if (it != indexConstantMap.end()) return it->second; |
| auto constantValue = |
| builder.createOrFold<arith::ConstantIndexOp>(loc, value); |
| indexConstantMap.insert(std::make_pair(value, constantValue)); |
| return constantValue; |
| } |
| |
| // Loads a variable with the given |symName|. |
| Value loadGlobal(Type resultType, StringRef symName, OpBuilder &builder) { |
| auto it = loadedGlobalMap.find(symName); |
| if (it != loadedGlobalMap.end()) { |
| assert(it->second.getType() == resultType && "variable type mismatch"); |
| return it->second; |
| } |
| auto value = builder.createOrFold<IREE::Util::GlobalLoadOp>(loc, resultType, |
| symName); |
| loadedGlobalMap.insert(std::make_pair(symName, value)); |
| return value; |
| } |
| |
| // Returns an executable layout with the given attributes. |
| Value lookupExecutableLayout(Type resultType, IntegerAttr pushConstantsAttr, |
| ArrayAttr layoutsAttr, OpBuilder &builder) { |
| auto keyAttr = builder.getArrayAttr({pushConstantsAttr, layoutsAttr}); |
| auto it = executableLayoutMap.find(keyAttr); |
| if (it != executableLayoutMap.end()) { |
| assert(it->second.getType() == resultType && "variable type mismatch"); |
| return it->second; |
| } |
| auto value = builder.createOrFold<IREE::HAL::ExecutableLayoutLookupOp>( |
| loc, IREE::HAL::ExecutableLayoutType::get(device().getContext()), |
| device(), pushConstantsAttr, layoutsAttr); |
| executableLayoutMap.insert(std::make_pair(keyAttr, value)); |
| return value; |
| } |
| |
| // Returns a computed shape value inserted at |builder| based on the shape of |
| // the given |streamValue|. Returns an existing size if one matching the |
| // parameters has already been inserted. |
| Value lookupOrComputeSize(Value streamValue, OpBuilder &builder) { |
| return lookupOrComputeSize(streamValue.getType().cast<ShapedType>(), |
| Shape::buildOrFindDynamicDimsForValue( |
| streamValue.getLoc(), streamValue, builder), |
| builder); |
| } |
| |
| // Returns a computed shape value inserted at |builder| based on the given |
| // shaped type and its dynamic dimensions. Returns an existing size if one |
| // matching the parameters has already been inserted. |
| Value lookupOrComputeSize(ShapedType shapedType, ValueRange dynamicDims, |
| OpBuilder &builder) { |
| if (shapedType.hasStaticShape()) { |
| auto it = staticShapeToSizeMap.find(shapedType); |
| if (it != staticShapeToSizeMap.end()) return it->second; |
| } else { |
| auto typeIt = dynamicShapeToSizeMap.find(shapedType); |
| if (typeIt != dynamicShapeToSizeMap.end()) { |
| for (auto dimsIt : typeIt->second) { |
| if (std::equal(dimsIt.first.begin(), dimsIt.first.end(), |
| dynamicDims.begin())) { |
| return dimsIt.second; |
| } |
| } |
| } |
| } |
| |
| auto elementType = getElementType(shapedType.getElementType(), builder); |
| assert(elementType && "unhandled element type for allocation"); |
| auto encodingType = getEncodingType({}, builder); |
| assert(encodingType && "unhandled encoding type for allocation"); |
| |
| SmallVector<Value> shapeDims(shapedType.getRank()); |
| int64_t dynamicDimIndex = 0; |
| for (int64_t i = 0; i < shapedType.getRank(); ++i) { |
| if (shapedType.isDynamicDim(i)) { |
| shapeDims[i] = dynamicDims[dynamicDimIndex++]; |
| } else { |
| shapeDims[i] = lookupOrCreateIndex(shapedType.getDimSize(i), builder); |
| } |
| } |
| |
| auto size = builder.createOrFold<IREE::HAL::AllocatorComputeSizeOp>( |
| loc, allocator(), shapeDims, elementType, encodingType); |
| if (shapedType.hasStaticShape()) { |
| staticShapeToSizeMap[shapedType] = size; |
| } else { |
| dynamicShapeToSizeMap[shapedType].push_back( |
| std::make_pair(dynamicDims, size)); |
| } |
| return size; |
| } |
| |
| // Maps |tensorValue| to the backing storage buffer defined by |bufferRange|. |
| LogicalResult mapTensorToBufferRange(Value tensorValue, |
| BufferRange bufferRange) { |
| if (bufferRangeMap.count(tensorValue)) { |
| return failure(); |
| } |
| bufferRangeMap.insert(std::make_pair(tensorValue, bufferRange)); |
| |
| // TODO(#5410): make alias propagation map through an indexing map for |
| // slices/updates. Right now we assume all aliases are 1:1 full maps. |
| for (auto alias : valueAliases[tensorValue]) { |
| bufferRangeMap.insert(std::make_pair(alias, bufferRange)); |
| } |
| return success(); |
| } |
| |
| // Returns a buffer range backing the given stream |tensorValue|. |
| BufferRange lookupTensorBufferRange(Value tensorValue) { |
| auto it = bufferRangeMap.find(tensorValue); |
| assert(it != bufferRangeMap.end() && "buffer not pre-allocated for tensor"); |
| return it->second; |
| } |
| |
| // Returns true if the given |tensorValue| has a buffer range mapped to it. |
| bool hasTensorBufferRange(Value tensorValue) { |
| return bufferRangeMap.count(tensorValue) != 0; |
| } |
| |
| // Calls |callback| for |tensorValue| and each value aliasing it. |
| void forEachEquivalentTensorValue(Value tensorValue, |
| std::function<void(Value)> callback) { |
| callback(tensorValue); |
| for (auto alias : valueAliases[tensorValue]) { |
| callback(alias); |
| } |
| } |
| |
| private: |
| Value getElementType(Type elementType, OpBuilder &builder) { |
| auto it = memoizedElementTypesConstants.find(elementType); |
| if (it != memoizedElementTypesConstants.end()) return it->second; |
| auto i32Value = IREE::HAL::getElementTypeValue(elementType); |
| assert(i32Value.hasValue() && "unhandled element type for allocation"); |
| auto constantValue = builder.createOrFold<arith::ConstantIntOp>( |
| loc, i32Value.getValue(), 32); |
| memoizedElementTypesConstants[elementType] = constantValue; |
| return constantValue; |
| } |
| |
| Value getEncodingType(Attribute encodingType, OpBuilder &builder) { |
| auto it = memoizedEncodingTypesConstants.find(encodingType); |
| if (it != memoizedEncodingTypesConstants.end()) return it->second; |
| auto i32Value = IREE::HAL::getEncodingTypeValue(encodingType); |
| assert(i32Value.hasValue() && "unhandled encoding type for allocation"); |
| auto constantValue = builder.createOrFold<arith::ConstantIntOp>( |
| loc, i32Value.getValue(), 32); |
| memoizedEncodingTypesConstants[encodingType] = constantValue; |
| return constantValue; |
| } |
| |
| Location loc; |
| |
| // !hal.device used throughout the stream. |
| Value device_; |
| // !hal.allocator used throughout the stream. |
| Value allocator_; |
| |
| // All values that have aliases mapped to a set of all of the values they |
| // alias with. That two things alias does not imply the values can be treated |
| // as equivalent: some values may be subranges of others. |
| ValueAliasingMap valueAliases; |
| |
| // Index value -> std.constant index value. |
| DenseMap<int64_t, Value> indexConstantMap; |
| |
| // Global sym name -> loaded value. |
| DenseMap<StringRef, Value> loadedGlobalMap; |
| |
| // Key of [push constants, set layouts] -> loaded value. |
| DenseMap<Attribute, Value> executableLayoutMap; |
| |
| // Small cache of constants used for element types. |
| DenseMap<Type, Value> memoizedElementTypesConstants; |
| |
| // Small cache of constants used for encoding types. |
| DenseMap<Attribute, Value> memoizedEncodingTypesConstants; |
| |
| // Map of static shaped types to computed size values. |
| DenseMap<Type, Value> staticShapeToSizeMap; |
| |
| // Map of dynamic shaped types to a set of dynamic dimension SSA values and |
| // the corresponding computed size. A single shaped type such as `?x4` may |
| // have multiple unique sizes if differing dimension values are used (such as |
| // `{%dimA}x4` and `{%dimB}x4`). |
| DenseMap<Type, SmallVector<std::pair<SmallVector<Value>, Value>>> |
| dynamicShapeToSizeMap; |
| |
| // Maps tensor values inside the stream to a buffer range that stores them. |
| DenseMap<Value, BufferRange> bufferRangeMap; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Buffer allocation |
| //===----------------------------------------------------------------------===// |
| |
| // Allocates a buffer for the given stream output value. |
| // |streamValue| is the Value used within the stream region and |
| // |externalValue| is the returned value from the stream region in the parent |
| // block. |
| static BufferRange allocateOutputBuffer(Value streamValue, Value externalValue, |
| StreamSchedulingState &schedulingState, |
| ConversionPatternRewriter &rewriter) { |
| Location loc = externalValue.getLoc(); |
| |
| // TODO(benvanik): compute from SSA use-def chain uses. |
| // HACK: we have no idea right now whether the buffer escaping the stream will |
| // be used on the host or the device and have to allocate it as HOST_VISIBLE. |
| // Improvements to the flow dialect to track tensor lifetime and hal.stream |
| // for tracking usage will reduce this to just what is required. |
| IREE::HAL::MemoryTypeBitfield memoryTypes = |
| IREE::HAL::MemoryTypeBitfield::DeviceLocal | |
| IREE::HAL::MemoryTypeBitfield::HostVisible; |
| IREE::HAL::BufferUsageBitfield bufferUsage = |
| IREE::HAL::BufferUsageBitfield::All; |
| |
| // Compute the allocation size for the value. |
| auto allocationSize = schedulingState.lookupOrComputeSize( |
| streamValue.getType().cast<ShapedType>(), |
| Shape::buildOrFindDynamicDimsForValue(streamValue.getLoc(), streamValue, |
| rewriter), |
| rewriter); |
| |
| auto buffer = rewriter |
| .create<IREE::HAL::AllocatorAllocateOp>( |
| loc, IREE::HAL::BufferType::get(rewriter.getContext()), |
| schedulingState.allocator(), memoryTypes, bufferUsage, |
| allocationSize) |
| .getResult(); |
| |
| return BufferRange{buffer, allocationSize}; |
| } |
| |
| // Allocates all output buffers for the stream and populates the |
| // |schedulingState| with the new mappings. Returns the set of output buffers |
| // mapping 1:1 with the |streamOp| results. |
| static LogicalResult allocateOutputBuffers( |
| IREE::Flow::ExStreamFragmentOp streamOp, |
| StreamSchedulingState &schedulingState, ConversionPatternRewriter &rewriter, |
| SmallVectorImpl<Value> &output) { |
| auto tiedStreamOp = |
| cast<IREE::Util::TiedOpInterface>(streamOp.getOperation()); |
| auto &entryBlock = streamOp.body().front(); |
| |
| SmallVector<Value> outputBuffers; |
| |
| // Allocate output buffers and replace the original uses with the buffers. |
| auto returnOp = cast<IREE::Flow::ReturnOp>(streamOp.body().front().back()); |
| for (auto result : llvm::enumerate(streamOp.getResults())) { |
| auto streamValue = returnOp.getOperand(result.index()); |
| auto externalValue = result.value(); |
| |
| // Ignore already allocated buffers. |
| if (schedulingState.hasTensorBufferRange(streamValue)) { |
| outputBuffers.push_back( |
| schedulingState.lookupTensorBufferRange(streamValue).buffer); |
| continue; |
| } |
| |
| // Tied stream results reuse their operand buffer. |
| BufferRange bufferRange; |
| auto tiedOperandIndex = |
| tiedStreamOp.getTiedResultOperandIndex(result.index()); |
| if (tiedOperandIndex.hasValue()) { |
| LLVM_DEBUG(llvm::dbgs() |
| << " -- REUSING TIED OPERAND(" |
| << tiedOperandIndex.getValue() << ") BUFFER FOR STREAM RESULT(" |
| << result.index() << "): " << streamOp << "\n"); |
| auto operand = entryBlock.getArgument(tiedOperandIndex.getValue()); |
| bufferRange = schedulingState.lookupTensorBufferRange(operand); |
| } else { |
| LLVM_DEBUG(llvm::dbgs() |
| << " -- ALLOCATE BUFFER FOR STREAM ESCAPE RESULT(" |
| << result.index() << ")\n"); |
| bufferRange = allocateOutputBuffer(streamValue, externalValue, |
| schedulingState, rewriter); |
| } |
| if (!bufferRange.buffer) { |
| return streamOp.emitOpError() << "buffer range has no buffer"; |
| } |
| outputBuffers.push_back(bufferRange.buffer); |
| if (failed( |
| schedulingState.mapTensorToBufferRange(streamValue, bufferRange))) { |
| return streamOp.emitOpError() << "tensor was mapped to multiple buffer " |
| "ranges while allocating output buffers"; |
| } |
| } |
| |
| output = outputBuffers; |
| return success(); |
| } |
| |
| // Allocates transient buffers to store the intra-stream results and populates |
| // the |schedulingState| with the new mappings. |
| static LogicalResult allocateTransientBuffers( |
| IREE::Flow::ExStreamFragmentOp streamOp, |
| LivenessIntervalList &livenessIntervals, |
| StreamSchedulingState &schedulingState, |
| ConversionPatternRewriter &rewriter) { |
| // TODO(#5410): unify with slice/update handling below. We should have a |
| // more generic way of handling these special ops and need to be able to hook |
| // into ones that directly control aliasing behavior like slice/update. |
| SmallPtrSet<Value, 16> coveredValues; |
| auto walkResult = streamOp.walk([&](IREE::HAL::ConstantSubspanOp subspanOp) { |
| auto tensorValue = subspanOp.result(); |
| auto bufferValue = schedulingState.loadGlobal( |
| IREE::HAL::BufferType::get(rewriter.getContext()), |
| subspanOp.runtime_buffer().getLeafReference().getValue(), rewriter); |
| auto runtimeRange = subspanOp.runtime_range(); |
| auto offsetValue = |
| schedulingState.lookupOrCreateIndex(runtimeRange.getOffset(), rewriter); |
| auto lengthValue = |
| schedulingState.lookupOrCreateIndex(runtimeRange.getLength(), rewriter); |
| auto subspanValue = rewriter.createOrFold<IREE::HAL::BufferSubspanOp>( |
| subspanOp.getLoc(), bufferValue.getType(), bufferValue, offsetValue, |
| lengthValue); |
| auto bufferRange = BufferRange{subspanValue, lengthValue}; |
| if (failed( |
| schedulingState.mapTensorToBufferRange(tensorValue, bufferRange))) { |
| return WalkResult::interrupt(); |
| } |
| schedulingState.forEachEquivalentTensorValue( |
| tensorValue, [&](Value alias) { coveredValues.insert(alias); }); |
| return WalkResult::advance(); |
| }); |
| |
| if (walkResult.wasInterrupted()) { |
| return streamOp.emitOpError() << "constant subspan op was mapped to " |
| "multiple buffer ranges while allocating " |
| "transient buffers"; |
| } |
| |
| // Gather all of the transient values we need to allocate buffers for. |
| SmallVector<Value> transientValues; |
| SmallVector<int64_t> lifetimeIntervals; |
| SmallVector<Value> dynamicSliceSizes; |
| AsmState state(streamOp); |
| for (auto valueInterval : livenessIntervals) { |
| auto value = valueInterval.value; |
| auto valueType = value.getType().dyn_cast<ShapedType>(); |
| if (!valueType) continue; |
| |
| // Only handle transient buffers (created/used/dropped within the stream). |
| if (valueInterval.start == LIVE_IN || valueInterval.end == LIVE_OUT) { |
| continue; |
| } |
| |
| // Ignore covered values. |
| if (schedulingState.hasTensorBufferRange(value) || |
| coveredValues.contains(value)) { |
| continue; |
| } |
| |
| transientValues.push_back(value); |
| lifetimeIntervals.push_back(valueInterval.start); |
| lifetimeIntervals.push_back(valueInterval.end); |
| |
| // Compute the allocation size for the value. |
| auto allocationSize = schedulingState.lookupOrComputeSize( |
| valueType, |
| Shape::buildOrFindDynamicDimsForValue(value.getLoc(), value, rewriter), |
| rewriter); |
| dynamicSliceSizes.push_back(allocationSize); |
| |
| // Mark as covered so we don't allocate it again. |
| schedulingState.forEachEquivalentTensorValue( |
| value, [&](Value alias) { coveredValues.insert(alias); }); |
| } |
| if (transientValues.empty()) { |
| // No transients required. |
| return success(); |
| } |
| |
| // Insert the hal.allocator.pack op to compute the packed offsets and total |
| // buffer size required for all transients. |
| auto indexType = rewriter.getIndexType(); |
| SmallVector<Type> packedOffsetTypes(dynamicSliceSizes.size(), indexType); |
| auto packOp = rewriter.create<IREE::HAL::AllocatorPackOp>( |
| streamOp.getLoc(), indexType, packedOffsetTypes, |
| schedulingState.allocator(), |
| /*offset=*/nullptr, rewriter.getIndexArrayAttr(lifetimeIntervals), |
| dynamicSliceSizes); |
| |
| // Allocate the transient storage buffer. |
| // TODO(benvanik): compute from SSA use-def chain uses. |
| IREE::HAL::MemoryTypeBitfield memoryTypes = |
| IREE::HAL::MemoryTypeBitfield::DeviceLocal; |
| IREE::HAL::BufferUsageBitfield bufferUsage = |
| IREE::HAL::BufferUsageBitfield::Dispatch | |
| IREE::HAL::BufferUsageBitfield::Transfer; |
| auto allocateOp = rewriter.create<IREE::HAL::AllocatorAllocateOp>( |
| streamOp.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), |
| schedulingState.allocator(), memoryTypes, bufferUsage, |
| packOp.total_length()); |
| |
| // Add a buffer set map entry for each transient buffer that references into |
| // a subspan of the transient storage buffer. |
| for (size_t i = 0; i < transientValues.size(); ++i) { |
| auto value = transientValues[i]; |
| auto offset = packOp.packed_offsets()[i]; |
| auto subspanValue = rewriter.createOrFold<IREE::HAL::BufferSubspanOp>( |
| value.getLoc(), allocateOp.result().getType(), allocateOp.result(), |
| offset, dynamicSliceSizes[i]); |
| auto bufferRange = BufferRange{subspanValue, dynamicSliceSizes[i]}; |
| if (failed(schedulingState.mapTensorToBufferRange(value, bufferRange))) { |
| return streamOp.emitOpError() |
| << "tensor for buffer subspan was mapped to multiple buffer " |
| "ranges while allocating transient buffers"; |
| } |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion |
| //===----------------------------------------------------------------------===// |
| |
| // Records a full execution barrier that forces visibility of all buffers. |
| static void recordFullExecutionBarrier(Value commandBuffer, Location loc, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.create<IREE::HAL::CommandBufferExecutionBarrierOp>( |
| loc, commandBuffer, |
| IREE::HAL::ExecutionStageBitfield::CommandRetire | |
| IREE::HAL::ExecutionStageBitfield::Dispatch, |
| IREE::HAL::ExecutionStageBitfield::CommandIssue | |
| IREE::HAL::ExecutionStageBitfield::Dispatch, |
| IREE::HAL::ExecutionBarrierFlagBitfield::None); |
| } |
| |
| // Records a dispatch using the given bindings attribute set populated by |
| // the -iree-hal-materialize-interfaces pass. |
| static void recordInterfaceBindings( |
| Value device, Value commandBuffer, IREE::Flow::DispatchOp &dispatchOp, |
| IREE::HAL::InterfaceOp &interfaceOp, Value executableLayout, |
| ArrayAttr bindingsAttr, StreamSchedulingState &schedulingState, |
| ConversionPatternRewriter &rewriter, OpBuilder &builder) { |
| // Accumulate a potentially sparse set of push constants. |
| // If we had canonicalizers for hal.command_buffer.push_constants then we |
| // would instead just emit each constant individually and let that collapse |
| // things later on. |
| int pushConstantBase = 0; // always 0 today |
| SmallVector<Value> pushConstantValues; |
| pushConstantValues.resize( |
| interfaceOp.push_constants().getValueOr(APInt(64, 0)).getSExtValue()); |
| |
| // Accumulate a potentially sparse set of bindings. |
| int setOrdinal = 0; // always 0 today |
| SmallVector<IREE::HAL::DescriptorSetBindingValue, 4> bindings; |
| |
| auto zeroOffset = schedulingState.lookupOrCreateIndex(0, rewriter); |
| auto push_buffer_binding = [&](StringRef bindingName, Value tensorValue) { |
| auto bindingOp = |
| interfaceOp.lookupSymbol<IREE::HAL::InterfaceBindingOp>(bindingName); |
| assert(bindingOp); |
| assert(bindingOp.set().getSExtValue() == 0); |
| |
| auto bufferRange = schedulingState.lookupTensorBufferRange(tensorValue); |
| assert(bufferRange.buffer && "buffer not preallocated"); |
| assert(bufferRange.length && "buffer has no precomputed size"); |
| bindings.push_back( |
| std::make_tuple(schedulingState.lookupOrCreateIndex( |
| bindingOp.binding().getSExtValue(), rewriter), |
| bufferRange.buffer, zeroOffset, bufferRange.length)); |
| }; |
| |
| for (auto bindingAttr : bindingsAttr) { |
| if (auto constantStorageAttr = |
| bindingAttr.dyn_cast<IREE::HAL::ExConstantStorageAttr>()) { |
| auto bindingOp = interfaceOp.lookupSymbol<IREE::HAL::InterfaceBindingOp>( |
| constantStorageAttr.binding()); |
| assert(bindingOp); |
| assert(bindingOp.set().getSExtValue() == setOrdinal); |
| auto storageBuffer = schedulingState.loadGlobal( |
| IREE::HAL::BufferType::get(builder.getContext()), |
| constantStorageAttr.storage(), rewriter); |
| bindings.push_back(std::make_tuple( |
| schedulingState.lookupOrCreateIndex( |
| bindingOp.binding().getSExtValue(), rewriter), |
| storageBuffer, |
| schedulingState.lookupOrCreateIndex( |
| constantStorageAttr.offset().getSExtValue(), rewriter), |
| schedulingState.lookupOrCreateIndex( |
| constantStorageAttr.length().getSExtValue(), rewriter))); |
| } else if (auto pushConstantAttr = |
| bindingAttr.dyn_cast<IREE::HAL::ExPushConstantAttr>()) { |
| auto inputValue = |
| dispatchOp.operands()[pushConstantAttr.operand().getSExtValue()]; |
| auto pushConstantValue = rewriter.getRemappedValue(inputValue); |
| // Need an explicit index cast to i32 since the |
| // CommandBufferPushConstantsOp is intrinsically i32 based. |
| if (inputValue.getType().isa<IndexType>()) { |
| pushConstantValue = rewriter.create<mlir::arith::IndexCastOp>( |
| dispatchOp.getLoc(), rewriter.getIntegerType(32), |
| pushConstantValue); |
| } |
| pushConstantValues[pushConstantAttr.ordinal().getSExtValue()] = |
| pushConstantValue; |
| } else if (auto operandBufferAttr = |
| bindingAttr.dyn_cast<IREE::HAL::ExOperandBufferAttr>()) { |
| auto tensorValue = |
| dispatchOp.operands()[operandBufferAttr.operand().getSExtValue()]; |
| push_buffer_binding(operandBufferAttr.binding(), tensorValue); |
| } else if (auto resultBufferAttr = |
| bindingAttr.dyn_cast<IREE::HAL::ExResultBufferAttr>()) { |
| auto tensorValue = |
| dispatchOp.results()[resultBufferAttr.result().getSExtValue()]; |
| push_buffer_binding(resultBufferAttr.binding(), tensorValue); |
| } |
| } |
| |
| builder.create<IREE::HAL::CommandBufferPushDescriptorSetOp>( |
| dispatchOp.getLoc(), commandBuffer, executableLayout, |
| schedulingState.lookupOrCreateIndex(setOrdinal, rewriter), bindings); |
| |
| if (!pushConstantValues.empty()) { |
| builder.create<IREE::HAL::CommandBufferPushConstantsOp>( |
| dispatchOp.getLoc(), commandBuffer, executableLayout, |
| rewriter.getIndexAttr(pushConstantBase), pushConstantValues); |
| } |
| } |
| |
| // Calculates the workgroup size (x, y, z). These are the dimension numbers |
| // for a single workgroup. |
| static std::array<Value, 3> calculateDispatchWorkgroupSize( |
| Location loc, IREE::HAL::ExecutableOp executableOp, |
| IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload, |
| OpBuilder &builder) { |
| // When no workgroup size is specified we just assume [1,1,1]. |
| // This yields a workgroup count that models the extents of the workload. |
| return { |
| builder.createOrFold<mlir::arith::ConstantIndexOp>(loc, 1), |
| builder.createOrFold<mlir::arith::ConstantIndexOp>(loc, 1), |
| builder.createOrFold<mlir::arith::ConstantIndexOp>(loc, 1), |
| }; |
| } |
| |
| static std::array<Value, 3> calculateDispatchWorkgroupCountFromRegion( |
| Location loc, IREE::HAL::ExecutableEntryPointOp entryPointOp, |
| ValueRange workload, OpBuilder &builder) { |
| Block *body = entryPointOp.getBlock(); |
| BlockAndValueMapping bvm; |
| for (auto args : llvm::enumerate(workload)) { |
| bvm.map(body->getArgument(args.index()), args.value()); |
| } |
| for (Operation &op : body->without_terminator()) { |
| builder.clone(op, bvm); |
| } |
| auto returnOp = cast<IREE::HAL::ReturnOp>(body->getTerminator()); |
| // Verifier of EntryPointOp checks that the return has 3 values. |
| SmallVector<Value, 4> count = llvm::to_vector<4>(llvm::map_range( |
| returnOp.operands(), [&bvm](Value v) { return bvm.lookup(v); })); |
| return {count[0], count[1], count[2]}; |
| } |
| |
| // Calculates the workgroup count (x, y, z) given the total N-dimensional |
| // |workload| and specific |workgroupSize|. |
| static std::array<Value, 3> calculateWorkloadWorkgroupCount( |
| Location loc, ValueRange workload, |
| const std::array<Value, 3> &workgroupSize, OpBuilder &builder) { |
| std::array<Value, 3> result; |
| |
| auto constantOne = builder.createOrFold<mlir::arith::ConstantIndexOp>(loc, 1); |
| if (workload.size() <= 3) { |
| // 1-D to 3-D are easy (pad 2 to 0 dimensions) and divide by workgroup size. |
| for (int i = 0; i < 3; ++i) { |
| // Round up: (workload[i] + workgroup_size - 1) / workgroup_size; |
| Value workloadI = i < workload.size() ? workload[i] : constantOne; |
| workloadI = builder.createOrFold<mlir::arith::SubIOp>( |
| loc, |
| builder.createOrFold<mlir::arith::AddIOp>(loc, workloadI, |
| workgroupSize[i]), |
| constantOne); |
| result[i] = builder.createOrFold<arith::DivUIOp>(loc, workloadI, |
| workgroupSize[i]); |
| } |
| } else { |
| // TODO(#4140): remapping of N-D to 3-D: this is not how you do this! |
| Value flatWorkload = constantOne; |
| for (auto workloadI : workload) { |
| flatWorkload = |
| builder.createOrFold<arith::MulIOp>(loc, flatWorkload, workloadI); |
| } |
| for (int i = 0; i < 3; ++i) { |
| // Round up: (workload[i] + workgroup_size - 1) / workgroup_size; |
| auto rounded = builder.createOrFold<mlir::arith::SubIOp>( |
| loc, |
| builder.createOrFold<mlir::arith::AddIOp>(loc, flatWorkload, |
| workgroupSize[i]), |
| constantOne); |
| auto workgroupCountI = builder.createOrFold<mlir::arith::DivUIOp>( |
| loc, rounded, workgroupSize[i]); |
| result[i] = workgroupCountI; |
| |
| // Multiply back out and subtract from invocations. |
| flatWorkload = builder.createOrFold<arith::SubIOp>( |
| loc, flatWorkload, |
| builder.createOrFold<arith::MulIOp>(loc, workgroupCountI, rounded)); |
| } |
| } |
| |
| return result; |
| } |
| |
| // Calculates the workgroup count (x, y, z) for dispatching to the given |
| // |entryPointOp|. The provided N-dimensional |workload| is the total number |
| // of invocations required as calculated by the generic workload logic |
| // (basically, number of output elements in tensors). |
| static std::array<Value, 3> calculateDispatchWorkgroupCount( |
| Location loc, IREE::HAL::ExecutableOp executableOp, |
| IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload, |
| OpBuilder &builder) { |
| Region *region = entryPointOp.getBody(); |
| if (region) { |
| return calculateDispatchWorkgroupCountFromRegion(loc, entryPointOp, |
| workload, builder); |
| } |
| auto workgroupSize = calculateDispatchWorkgroupSize( |
| loc, executableOp, entryPointOp, workload, builder); |
| return calculateWorkloadWorkgroupCount(loc, workload, workgroupSize, builder); |
| } |
| |
| // Records a dispatch operation. |
| static LogicalResult recordDispatch(Value device, Value commandBuffer, |
| IREE::Flow::DispatchOp &dispatchOp, |
| StreamSchedulingState &schedulingState, |
| ConversionPatternRewriter &rewriter) { |
| auto loc = dispatchOp.getLoc(); |
| |
| // Get the handle to the executable that is compatible with our device. |
| auto executableOp = |
| cast<IREE::HAL::ExecutableOp>(SymbolTable::lookupNearestSymbolFrom( |
| dispatchOp, dispatchOp.executable())); |
| |
| SmallVector<Value> workgroupCount; |
| for (auto dim : dispatchOp.workgroup_count()) { |
| workgroupCount.push_back(rewriter.getRemappedValue(dim)); |
| } |
| |
| // Ask each target backend to record their dispatch logic. |
| IREE::HAL::DeviceSwitchRewriter switchRewriter(loc, |
| /*resultTypes=*/TypeRange{}, |
| device, rewriter); |
| for (auto variantOp : |
| executableOp.getBlock().getOps<IREE::HAL::ExecutableVariantOp>()) { |
| auto entryPointOps = |
| variantOp.getBlock().getOps<IREE::HAL::ExecutableEntryPointOp>(); |
| auto entryPointIt = |
| llvm::find_if(entryPointOps, [&](IREE::HAL::ExecutableEntryPointOp op) { |
| return op.getNameAttr() == |
| dispatchOp.entry_point().getLeafReference(); |
| }); |
| if (entryPointIt == entryPointOps.end()) { |
| return variantOp.emitError() |
| << "hal.executable.variant is missing the flow entry point for " |
| << dispatchOp.entry_point(); |
| } |
| auto entryPointOp = *entryPointIt; |
| auto interfaceOp = |
| dyn_cast<IREE::HAL::InterfaceOp>(SymbolTable::lookupSymbolIn( |
| executableOp, entryPointOp.interfaceAttr())); |
| auto executableLayout = schedulingState.lookupExecutableLayout( |
| IREE::HAL::ExecutableLayoutType::get(interfaceOp.getContext()), |
| interfaceOp.push_constantsAttr(), |
| interfaceOp.getExecutableSetLayoutsAttr(), rewriter); |
| |
| auto *region = switchRewriter.addConditionRegion( |
| variantOp.target().getMatchExpression()); |
| auto &entryBlock = region->front(); |
| auto caseBuilder = OpBuilder::atBlockBegin(&entryBlock); |
| |
| auto bindingsAttr = dispatchOp->getAttrOfType<ArrayAttr>("hal.bindings"); |
| assert(bindingsAttr); |
| recordInterfaceBindings(device, commandBuffer, dispatchOp, interfaceOp, |
| executableLayout, bindingsAttr, schedulingState, |
| rewriter, caseBuilder); |
| |
| auto entryPointSymRef = |
| SymbolRefAttr::get(caseBuilder.getContext(), executableOp.getName(), |
| {SymbolRefAttr::get(entryPointOp->getParentOp()), |
| SymbolRefAttr::get(entryPointOp)}); |
| auto caseWorkgroupCount = calculateDispatchWorkgroupCount( |
| loc, executableOp, entryPointOp, workgroupCount, caseBuilder); |
| caseBuilder.create<IREE::HAL::CommandBufferDispatchSymbolOp>( |
| loc, commandBuffer, entryPointSymRef, caseWorkgroupCount[0], |
| caseWorkgroupCount[1], caseWorkgroupCount[2]); |
| |
| caseBuilder.create<IREE::HAL::ReturnOp>(loc); |
| } |
| switchRewriter.build(); |
| |
| // Full barriers for now as we aren't scheduling things in waves. |
| recordFullExecutionBarrier(commandBuffer, dispatchOp.getLoc(), rewriter); |
| return success(); |
| } |
| |
| // Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte integer value. |
| // The bit representation of |baseValue| will be repeated as many times as |
| // needed in the returned value to use 4 bytes of storage. For example, |
| // a 16-bit value (int or float) will have its native bit representation |
| // repeated twice. |
| static Value splatFillPattern(Location loc, Value baseValue, |
| OpBuilder &builder) { |
| // Bitcast to an integer, then use integer math for the rest of the pattern. |
| auto baseBitWidth = baseValue.getType().getIntOrFloatBitWidth(); |
| baseValue = builder.createOrFold<arith::BitcastOp>( |
| loc, builder.getIntegerType(baseBitWidth), baseValue); |
| |
| switch (baseBitWidth) { |
| case 8: { |
| // (v << 24) | (v << 16) | (v << 8) | v |
| auto b0 = builder.createOrFold<arith::ExtUIOp>( |
| loc, baseValue, builder.getIntegerType(32)); |
| auto c8 = builder.create<arith::ConstantIntOp>(loc, 8, 32); |
| auto b1 = builder.createOrFold<arith::ShLIOp>(loc, b0, c8); |
| auto c16 = builder.create<arith::ConstantIntOp>(loc, 16, 32); |
| auto b2 = builder.createOrFold<arith::ShLIOp>(loc, b0, c16); |
| auto c24 = builder.create<arith::ConstantIntOp>(loc, 24, 32); |
| auto b3 = builder.createOrFold<arith::ShLIOp>(loc, b0, c24); |
| return builder.createOrFold<arith::OrIOp>( |
| loc, b0, |
| builder.createOrFold<arith::OrIOp>( |
| loc, b1, builder.createOrFold<arith::OrIOp>(loc, b2, b3))); |
| } |
| case 16: { |
| // (v << 16) | v |
| auto c16 = builder.create<arith::ConstantIntOp>(loc, 16, 32); |
| auto b0 = builder.createOrFold<arith::ExtUIOp>( |
| loc, baseValue, builder.getIntegerType(32)); |
| auto b1 = builder.createOrFold<arith::ShLIOp>(loc, b0, c16); |
| return builder.createOrFold<arith::OrIOp>(loc, b0, b1); |
| } |
| case 32: |
| return baseValue; |
| default: |
| return {}; // Unsupported (so far) |
| } |
| } |
| |
| static LogicalResult recordTensorSplat(Value device, Value commandBuffer, |
| IREE::Flow::TensorSplatOp &splatOp, |
| StreamSchedulingState &schedulingState, |
| ConversionPatternRewriter &rewriter) { |
| auto resultBuffer = schedulingState.lookupTensorBufferRange(splatOp.result()); |
| |
| auto pattern = splatFillPattern(splatOp.getLoc(), splatOp.value(), rewriter); |
| if (!pattern) { |
| return splatOp.emitError() << ">4 byte/non-byte-aligned fills are not yet " |
| "implemented (require special emulation)"; |
| } |
| |
| auto zeroOffset = schedulingState.lookupOrCreateIndex(0, rewriter); |
| rewriter.create<IREE::HAL::CommandBufferFillBufferOp>( |
| splatOp.getLoc(), commandBuffer, resultBuffer.buffer, zeroOffset, |
| resultBuffer.length, pattern); |
| |
| // Full barriers for now as we aren't scheduling things. |
| recordFullExecutionBarrier(commandBuffer, splatOp.getLoc(), rewriter); |
| return success(); |
| } |
| |
| static LogicalResult recordTensorClone(Value device, Value commandBuffer, |
| IREE::Flow::TensorCloneOp &cloneOp, |
| StreamSchedulingState &schedulingState, |
| ConversionPatternRewriter &rewriter) { |
| auto operandBuffer = |
| schedulingState.lookupTensorBufferRange(cloneOp.operand()); |
| auto resultBuffer = schedulingState.lookupTensorBufferRange(cloneOp.result()); |
| |
| auto zeroOffset = schedulingState.lookupOrCreateIndex(0, rewriter); |
| rewriter.create<IREE::HAL::CommandBufferCopyBufferOp>( |
| cloneOp.getLoc(), commandBuffer, operandBuffer.buffer, zeroOffset, |
| // Note: we use the result buffer's length here deliberately to handle the |
| // case where the source buffer can be the constant pool buffer. |
| resultBuffer.buffer, zeroOffset, resultBuffer.length); |
| |
| // Full barriers for now as we aren't scheduling things. |
| recordFullExecutionBarrier(commandBuffer, cloneOp.getLoc(), rewriter); |
| return success(); |
| } |
| |
| // TODO(#5410): make this an aliasing operation in allocateTransientBuffers. |
| static LogicalResult recordTensorSlice(Value device, Value commandBuffer, |
| IREE::Flow::TensorSliceOp &sliceOp, |
| StreamSchedulingState &schedulingState, |
| ConversionPatternRewriter &rewriter) { |
| auto sourceBuffer = schedulingState.lookupTensorBufferRange(sliceOp.source()); |
| auto resultBuffer = schedulingState.lookupTensorBufferRange(sliceOp.result()); |
| |
| // TODO(benvanik): use something other than the BufferRange::buffer? |
| // This may require us to subview the buffer first. |
| auto source = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| sliceOp.getLoc(), sliceOp.source(), sourceBuffer.buffer, rewriter); |
| auto result = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| sliceOp.getLoc(), sliceOp.result(), resultBuffer.buffer, rewriter); |
| if (!source.hasValue() || !result.hasValue()) { |
| return sliceOp.emitOpError() |
| << "cannot create adaptors for tensor slice operands/results"; |
| } |
| |
| // Compute the size of the update range. |
| auto startIndices = llvm::to_vector<4>(llvm::map_range( |
| sliceOp.start_indices(), |
| [&](Value value) { return rewriter.getRemappedValue(value); })); |
| auto shapeDims = result->getShapeDims(); |
| if (!shapeDims) return failure(); |
| auto sourceRange = source->computeRange(startIndices, *shapeDims); |
| if (!sourceRange) return failure(); |
| |
| // TODO(benvanik): slice left/mid/right, but really just don't do this. |
| auto zeroOffset = schedulingState.lookupOrCreateIndex(0, rewriter); |
| rewriter.create<IREE::HAL::CommandBufferCopyBufferOp>( |
| sliceOp.getLoc(), commandBuffer, sourceBuffer.buffer, sourceRange->offset, |
| resultBuffer.buffer, zeroOffset, sourceRange->length); |
| |
| // Full barriers for now as we aren't scheduling things. |
| // TODO(benvanik): don't add at the end of the command buffer (we could |
| // also do a canonicalization step that removed trailing barriers). |
| recordFullExecutionBarrier(commandBuffer, sliceOp.getLoc(), rewriter); |
| return success(); |
| } |
| |
| // TODO(#5410): make this an aliasing operation in allocateTransientBuffers. |
| static LogicalResult recordTensorUpdate(Value device, Value commandBuffer, |
| IREE::Flow::TensorUpdateOp &updateOp, |
| StreamSchedulingState &schedulingState, |
| ConversionPatternRewriter &rewriter) { |
| auto updateBuffer = |
| schedulingState.lookupTensorBufferRange(updateOp.update()); |
| auto targetBuffer = |
| schedulingState.lookupTensorBufferRange(updateOp.target()); |
| |
| // TODO(benvanik): use something other than the BufferRange::buffer? |
| // This may require us to subview the buffer first. |
| auto update = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| updateOp.getLoc(), updateOp.update(), updateBuffer.buffer, rewriter); |
| auto target = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| updateOp.getLoc(), updateOp.target(), targetBuffer.buffer, rewriter); |
| if (!update.hasValue() || !target.hasValue()) { |
| return updateOp.emitOpError() |
| << "cannot create adaptors for tensor update operands/results"; |
| } |
| |
| // Compute the size of the update range. |
| auto startIndices = llvm::to_vector<4>(llvm::map_range( |
| updateOp.start_indices(), |
| [&](Value value) { return rewriter.getRemappedValue(value); })); |
| auto shapeDims = update->getShapeDims(); |
| if (!shapeDims) return failure(); |
| auto targetRange = |
| target->computeRange(startIndices, *update->getShapeDims()); |
| if (!targetRange) return failure(); |
| |
| // TODO(benvanik): slice left/mid/right, but really just don't do this. |
| auto zeroOffset = schedulingState.lookupOrCreateIndex(0, rewriter); |
| rewriter.create<IREE::HAL::CommandBufferCopyBufferOp>( |
| updateOp.getLoc(), commandBuffer, updateBuffer.buffer, zeroOffset, |
| targetBuffer.buffer, targetRange->offset, targetRange->length); |
| |
| // Full barriers for now as we aren't scheduling things. |
| // TODO(benvanik): don't add at the end of the command buffer (we could |
| // also do a canonicalization step that removed trailing barriers). |
| recordFullExecutionBarrier(commandBuffer, updateOp.getLoc(), rewriter); |
| return success(); |
| } |
| |
| static void hoistConstants(Block &streamBlock, |
| ConversionPatternRewriter &rewriter) { |
| for (auto &op : streamBlock) { |
| if (isa<arith::ConstantOp>(op)) { |
| auto newOp = rewriter.clone(op); |
| op.replaceAllUsesWith(newOp); |
| } |
| } |
| } |
| |
| static LogicalResult recordStreamCommands( |
| Value device, Value commandBuffer, Block &streamBlock, |
| StreamSchedulingState &schedulingState, |
| ConversionPatternRewriter &rewriter) { |
| for (auto &op : streamBlock) { |
| if (auto dispatchOp = dyn_cast<IREE::Flow::DispatchOp>(op)) { |
| if (failed(recordDispatch(device, commandBuffer, dispatchOp, |
| schedulingState, rewriter))) { |
| return failure(); |
| } |
| } else if (auto splatOp = dyn_cast<IREE::Flow::TensorSplatOp>(op)) { |
| if (failed(recordTensorSplat(device, commandBuffer, splatOp, |
| schedulingState, rewriter))) { |
| return failure(); |
| } |
| } else if (auto cloneOp = dyn_cast<IREE::Flow::TensorCloneOp>(op)) { |
| if (failed(recordTensorClone(device, commandBuffer, cloneOp, |
| schedulingState, rewriter))) { |
| return failure(); |
| } |
| } else if (auto sliceOp = dyn_cast<IREE::Flow::TensorSliceOp>(op)) { |
| if (failed(recordTensorSlice(device, commandBuffer, sliceOp, |
| schedulingState, rewriter))) { |
| return failure(); |
| } |
| } else if (auto updateOp = dyn_cast<IREE::Flow::TensorUpdateOp>(op)) { |
| if (failed(recordTensorUpdate(device, commandBuffer, updateOp, |
| schedulingState, rewriter))) { |
| return failure(); |
| } |
| } else if (auto returnOp = dyn_cast<IREE::Flow::ReturnOp>(op)) { |
| // No-op; handled by the buffer allocation. |
| } else if (isa<arith::ConstantOp>(op)) { |
| // Note that even though constants were hoisted early, they can be |
| // materialized as part of various conversions so do it again to get |
| // any new ones. |
| auto newOp = rewriter.clone(op); |
| op.replaceAllUsesWith(newOp); |
| } else if (isa<IREE::HAL::ConstantSubspanOp>(op) || |
| isa<IREE::Flow::TensorReshapeOp>(op)) { |
| // No work to perform. |
| } else { |
| return op.emitOpError() << "unexpected in stream"; |
| } |
| } |
| return success(); |
| } |
| |
| class ExStreamFragmentOpConversion |
| : public OpConversionPattern<IREE::Flow::ExStreamFragmentOp> { |
| public: |
| using OpConversionPattern< |
| IREE::Flow::ExStreamFragmentOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| IREE::Flow::ExStreamFragmentOp streamOp, ArrayRef<Value> newOperands, |
| ConversionPatternRewriter &rewriter) const override { |
| IREE::Flow::ExStreamFragmentOp::Adaptor adaptor( |
| newOperands, streamOp->getAttrDictionary()); |
| |
| auto valueAliases = computeValueAliases(streamOp); |
| auto livenessIntervals = computeLivenessIntervals(streamOp, valueAliases); |
| |
| auto device = |
| rewriter.createOrFold<IREE::HAL::ExSharedDeviceOp>(streamOp.getLoc()); |
| auto allocator = |
| rewriter.create<IREE::HAL::DeviceAllocatorOp>(streamOp.getLoc(), device) |
| .getResult(); |
| StreamSchedulingState schedulingState(streamOp.getLoc(), device, allocator, |
| valueAliases); |
| |
| // Map stream captures to their external buffers or SSA values. |
| // This covers all of the live-in stream values. |
| auto &entryBlock = streamOp.body().front(); |
| |
| // Since constants can be tied to shapes, which are used in the size |
| // computations below, and since they are just simple RAUW transforms |
| // if recording, just hoist them out first to make dominance work out. |
| hoistConstants(entryBlock, rewriter); |
| |
| for (int i = 0; i < adaptor.operands().size(); ++i) { |
| auto streamValue = entryBlock.getArgument(i); |
| auto bufferValue = adaptor.operands()[i]; |
| if (auto shapedType = streamValue.getType().dyn_cast<TensorType>()) { |
| BufferRange bufferRange; |
| if (bufferValue.getType().isa<IREE::HAL::BufferViewType>()) { |
| bufferRange = BufferRange{ |
| rewriter.createOrFold<IREE::HAL::BufferViewBufferOp>( |
| streamOp.getLoc(), |
| IREE::HAL::BufferType::get(rewriter.getContext()), |
| bufferValue), |
| schedulingState.lookupOrComputeSize(streamValue, rewriter)}; |
| } else { |
| bufferRange = BufferRange{ |
| bufferValue, |
| schedulingState.lookupOrComputeSize(streamValue, rewriter)}; |
| } |
| if (failed(schedulingState.mapTensorToBufferRange(streamValue, |
| bufferRange))) { |
| return streamOp.emitOpError() |
| << "tensor was mapped to multiple buffer ranges"; |
| } |
| |
| } else { |
| rewriter.replaceUsesOfBlockArgument(streamValue, bufferValue); |
| } |
| } |
| |
| // Allocate buffers for values that escape the stream via return. |
| // These may alias input buffers above such as when an input is returned or |
| // a return value is tied. |
| SmallVector<Value> outputBuffers; |
| if (failed(allocateOutputBuffers(streamOp, schedulingState, rewriter, |
| outputBuffers))) { |
| return failure(); |
| } |
| |
| // Allocate all of the transient buffers used entirely within the stream. |
| // These all end up aliased from a single slab allocation and use the |
| // computed liveness information to know the lifetime intervals. Note that |
| // after we perform this allocation we can no longer safely rearrange the |
| // ops as buffers will start to alias. All reordering must have happened |
| // prior to this conversion. |
| if (failed(allocateTransientBuffers(streamOp, livenessIntervals, |
| schedulingState, rewriter))) { |
| return failure(); |
| } |
| |
| // Allocate and begin the command buffer. |
| // In a real version we would want to pick the device based on the placement |
| // information attached to the stream. |
| // TODO(benvanik): choose buffer mode/category based on stream commands. |
| // NOTE: we are not doing any overlapping work today and can always allow |
| // inline execution. |
| auto mode = IREE::HAL::CommandBufferModeBitfield::OneShot | |
| IREE::HAL::CommandBufferModeBitfield::AllowInlineExecution; |
| auto category = IREE::HAL::CommandCategoryBitfield::Dispatch | |
| IREE::HAL::CommandCategoryBitfield::Transfer; |
| auto commandBuffer = |
| rewriter.createOrFold<IREE::HAL::CommandBufferCreateOp>( |
| streamOp.getLoc(), |
| IREE::HAL::CommandBufferType::get(rewriter.getContext()), device, |
| mode, category); |
| rewriter.create<IREE::HAL::CommandBufferBeginOp>(streamOp.getLoc(), |
| commandBuffer); |
| |
| // Record all of the commands into the command buffer. |
| if (failed(recordStreamCommands(device, commandBuffer, entryBlock, |
| schedulingState, rewriter))) { |
| return failure(); |
| } |
| |
| // End and submit the command buffer. |
| // In a real version we'd want to setup a semaphore chain instead of |
| // submitting and waiting. |
| rewriter.create<IREE::HAL::CommandBufferEndOp>(streamOp.getLoc(), |
| commandBuffer); |
| rewriter.create<IREE::HAL::ExSubmitAndWaitOp>(streamOp.getLoc(), device, |
| commandBuffer); |
| |
| // It's annoying but we need to do this replacement at the very end as |
| // otherwise we lose access to the original values (which we need for |
| // shape information). |
| for (int i = 0; i < adaptor.operands().size(); ++i) { |
| if (adaptor.operands()[i].getType().isa<IREE::HAL::BufferType>()) { |
| rewriter.replaceUsesOfBlockArgument(entryBlock.getArgument(i), |
| adaptor.operands()[i]); |
| } |
| } |
| |
| rewriter.replaceOp(streamOp, outputBuffers); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void populateFlowStreamToHALPatterns(MLIRContext *context, |
| OwningRewritePatternList &patterns, |
| TypeConverter &converter) { |
| patterns.insert<ExStreamFragmentOpConversion>(context); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |