| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" |
| #include "iree/compiler/Dialect/HAL/Target/TargetBackend.h" |
| #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h" |
| #include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h" |
| #include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h" |
| #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" |
| #include "iree/compiler/Dialect/Shape/IR/ShapeOps.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Function.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| #define DEBUG_TYPE "iree-hal" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace { |
| |
| struct BufferRange { |
| BufferRange() = default; |
| explicit BufferRange(Value buffer) : buffer(buffer) {} |
| |
| Value buffer = nullptr; |
| }; |
| |
| // Allocated buffers used within the stream. |
| struct BufferSet { |
| explicit BufferSet(Value allocator) : allocator(allocator) {} |
| |
| // Allocator instance the buffers come from. |
| Value allocator = nullptr; |
| |
| // All output buffers in the same order as the original results. |
| SmallVector<Value, 4> outputBuffers; |
| |
| // Maps tensor values within the stream to a buffer range that stores them. |
| DenseMap<Value, BufferRange> rangeMap; |
| }; |
| |
| // If the op does no work and has no operands/results that impact buffer |
| // assignment, the it is a no-op. |
| static bool isNoOp(Operation *op) { return isa<Shape::MakeRankedShapeOp>(op); } |
| |
| // If the op's result is an identity of its first operand, it is an |
| // identity. |
| static bool isIdentityOp(Operation *op) { return isa<Shape::TieShapeOp>(op); } |
| |
| // Allocates a buffer for the given stream output value. |
| // |streamValue| is the Value used within the stream region and |
| // |externalValue| is the returned value from the stream region in the parent |
| // block. |
| static Value allocateOutputBuffer(Value streamValue, Value externalValue, |
| Value allocator, |
| ConversionPatternRewriter &rewriter) { |
| Location loc = externalValue.getLoc(); |
| // TODO(benvanik): compute from SSA use-def chain uses. |
| IREE::HAL::MemoryTypeBitfield memoryTypes = |
| IREE::HAL::MemoryTypeBitfield::DeviceLocal | |
| IREE::HAL::MemoryTypeBitfield::HostVisible; |
| IREE::HAL::BufferUsageBitfield bufferUsage = |
| IREE::HAL::BufferUsageBitfield::All; |
| |
| // Compute the allocation size for the value. |
| auto elementType = IREE::HAL::getElementTypeValue( |
| streamValue.getType().cast<ShapedType>().getElementType()); |
| if (!elementType) { |
| return {}; |
| } |
| auto shape = IREE::HAL::getShapeDims(loc, streamValue, rewriter); |
| if (!shape) { |
| return {}; |
| } |
| auto allocationSize = rewriter |
| .create<IREE::HAL::AllocatorComputeSizeOp>( |
| loc, allocator, *shape, elementType.getValue()) |
| .getResult(); |
| |
| auto buffer = |
| rewriter |
| .create<IREE::HAL::AllocatorAllocateOp>(loc, allocator, memoryTypes, |
| bufferUsage, allocationSize) |
| .getResult(); |
| |
| return buffer; |
| } |
| |
| // Allocates all output buffers for the stream and populates the |bufferSet| |
| // with the new mappings. |
| static void allocateOutputBuffers(IREE::Flow::ExStreamFragmentOp streamOp, |
| BufferSet &bufferSet, |
| ConversionPatternRewriter &rewriter) { |
| // Allocate output buffers and replace the original uses with the buffers. |
| auto returnOp = cast<IREE::Flow::ReturnOp>(streamOp.body().front().back()); |
| for (auto result : llvm::enumerate(streamOp.getResults())) { |
| auto streamValue = returnOp.getOperand(result.index()); |
| auto externalValue = result.value(); |
| auto buffer = allocateOutputBuffer(streamValue, externalValue, |
| bufferSet.allocator, rewriter); |
| auto bufferRange = BufferRange{buffer}; |
| bufferSet.rangeMap[externalValue] = bufferRange; |
| bufferSet.rangeMap[streamValue] = bufferRange; |
| bufferSet.outputBuffers.push_back(buffer); |
| } |
| } |
| |
| // Allocates a transient buffer for use entirely within the command buffer. |
| static Value allocateTransientBuffer(Value streamValue, Value allocator, |
| ConversionPatternRewriter &rewriter) { |
| Location loc = streamValue.getLoc(); |
| // TODO(benvanik): compute from SSA use-def chain uses. |
| IREE::HAL::MemoryTypeBitfield memoryTypes = |
| IREE::HAL::MemoryTypeBitfield::DeviceLocal; |
| IREE::HAL::BufferUsageBitfield bufferUsage = |
| IREE::HAL::BufferUsageBitfield::Dispatch | |
| IREE::HAL::BufferUsageBitfield::Transfer; |
| |
| // Compute the allocation size for the value. |
| auto elementType = IREE::HAL::getElementTypeValue( |
| streamValue.getType().cast<ShapedType>().getElementType()); |
| if (!elementType) { |
| return {}; |
| } |
| auto shape = IREE::HAL::getShapeDims(loc, streamValue, rewriter); |
| if (!shape) { |
| return {}; |
| } |
| auto allocationSize = rewriter |
| .create<IREE::HAL::AllocatorComputeSizeOp>( |
| loc, allocator, *shape, elementType.getValue()) |
| .getResult(); |
| |
| auto buffer = |
| rewriter |
| .create<IREE::HAL::AllocatorAllocateOp>(loc, allocator, memoryTypes, |
| bufferUsage, allocationSize) |
| .getResult(); |
| |
| return buffer; |
| } |
| |
| // Allocates transient buffers to store the intra-stream results and populates |
| // the |bufferSet| with the new mappings. |
| static void allocateTransientBuffers(IREE::Flow::ExStreamFragmentOp streamOp, |
| BufferSet &bufferSet, |
| ConversionPatternRewriter &rewriter) { |
| LLVM_DEBUG(llvm::dbgs() << ": HAL allocateTransientBuffers: " |
| << *streamOp.getOperation() << "\n"); |
| |
| auto propagateIdentityBuffers = [&]() { |
| bool madeChange = false; |
| // Pull outputs that terminate on identities to operands. |
| for (auto &op : llvm::reverse(streamOp.body().front())) { |
| if (isIdentityOp(&op)) { |
| auto result = op.getResult(0); |
| auto operand = op.getOperand(0); |
| if (bufferSet.rangeMap[result].buffer && |
| !bufferSet.rangeMap[operand].buffer) { |
| LLVM_DEBUG(llvm::dbgs() << " + PROPAGATE IDENTITY RESULT->OPERAND: " |
| << op << "\n"); |
| madeChange = true; |
| bufferSet.rangeMap[operand].buffer = |
| bufferSet.rangeMap[result].buffer; |
| } |
| } |
| } |
| |
| // Push inputs that originate on identities to results. |
| for (auto &op : streamOp.body().front()) { |
| if (isIdentityOp(&op)) { |
| auto operand = op.getOperand(0); |
| auto result = op.getResult(0); |
| if (bufferSet.rangeMap[operand].buffer && |
| !bufferSet.rangeMap[result].buffer) { |
| LLVM_DEBUG(llvm::dbgs() << " + PROPAGATE IDENTITY OPERAND->RESULT: " |
| << op << "\n"); |
| madeChange = true; |
| bufferSet.rangeMap[result].buffer = |
| bufferSet.rangeMap[operand].buffer; |
| } |
| } |
| } |
| return madeChange; |
| }; |
| |
| // Allocate any needed transient buffers. |
| // The idea here is that every non-identity-op result needs to be assigned a |
| // buffer; however, input and output buffers are already assigned to outer |
| // operands and results (which may be on identity ops). To handle this, |
| // we first propagate all buffers across identity ops, then allocate any |
| // transient buffers on non-identity ops that are still needed. Finally, |
| // propagate across identity ops again (to account for identity ops on |
| // the interior). |
| // Because there may be runs of identity ops, propagation loops until no |
| // changes are made. |
| while (propagateIdentityBuffers()) { |
| } |
| for (auto &op : streamOp.body().front()) { |
| if (isNoOp(&op) || isIdentityOp(&op)) continue; |
| for (auto it : llvm::enumerate(op.getResults())) { |
| auto result = it.value(); |
| // If the result is an output buffer we can just use that directly. |
| if (bufferSet.rangeMap[result].buffer) { |
| LLVM_DEBUG(llvm::dbgs() << " -- SKIP ALREADY SET BUFFER RESULT(" |
| << it.index() << "): " << op << "\n"); |
| continue; |
| } |
| LLVM_DEBUG(llvm::dbgs() << " -- ALLOCATE BUFFER FOR RESULT(" |
| << it.index() << "): " << op << "\n"); |
| auto buffer = |
| allocateTransientBuffer(result, bufferSet.allocator, rewriter); |
| bufferSet.rangeMap[result] = BufferRange{buffer}; |
| } |
| } |
| while (propagateIdentityBuffers()) { |
| } |
| } |
| |
| // Records a full execution barrier that forces visibility of all buffers. |
| static void recordFullExecutionBarrier(Value commandBuffer, Location loc, |
| ConversionPatternRewriter &rewriter) { |
| auto memoryBarrier = |
| rewriter |
| .create<IREE::HAL::MakeMemoryBarrierOp>( |
| loc, IREE::HAL::AccessScopeBitfield::DispatchWrite, |
| IREE::HAL::AccessScopeBitfield::DispatchRead) |
| .getResult(); |
| rewriter.create<IREE::HAL::CommandBufferExecutionBarrierOp>( |
| loc, commandBuffer, |
| IREE::HAL::ExecutionStageBitfield::CommandRetire | |
| IREE::HAL::ExecutionStageBitfield::Dispatch, |
| IREE::HAL::ExecutionStageBitfield::CommandIssue | |
| IREE::HAL::ExecutionStageBitfield::Dispatch, |
| ArrayRef<Value>{memoryBarrier}, ArrayRef<Value>{}); |
| } |
| |
| static void recordPushConstants(Value device, Value commandBuffer, |
| IREE::Flow::DispatchOp &dispatchOp, |
| IREE::HAL::InterfaceOp &interfaceOp, |
| Value executableLayout, |
| ConversionPatternRewriter &rewriter) { |
| SmallVector<Value, 4> pushConstantValues; |
| for (auto inputValue : dispatchOp.operands()) { |
| if (inputValue.getType().isa<IndexType>() || |
| inputValue.getType().isa<IntegerType>()) { |
| auto pushConstantValue = rewriter.getRemappedValue(inputValue); |
| // Need an explicit index cast to i32 since the |
| // CommandBufferPushConstantsOp is intrinsically i32 based. |
| if (inputValue.getType().isa<IndexType>()) { |
| pushConstantValue = rewriter.create<IndexCastOp>( |
| dispatchOp.getLoc(), rewriter.getIntegerType(32), |
| pushConstantValue); |
| } |
| pushConstantValues.push_back(pushConstantValue); |
| } |
| } |
| if (pushConstantValues.empty()) { |
| return; |
| } |
| |
| uint64_t maxPushConstants = interfaceOp.push_constants().getValueOr(0); |
| (void)maxPushConstants; |
| assert(pushConstantValues.size() <= maxPushConstants && |
| "uniform buffer spilling not yet implemented"); |
| |
| rewriter.create<IREE::HAL::CommandBufferPushConstantsOp>( |
| dispatchOp.getLoc(), commandBuffer, executableLayout, |
| rewriter.getI32IntegerAttr(0), pushConstantValues); |
| } |
| |
| static LogicalResult recordPushBindings(Value device, Value commandBuffer, |
| IREE::Flow::DispatchOp &dispatchOp, |
| Value executableLayout, |
| BufferSet &bufferSet, |
| ConversionPatternRewriter &rewriter) { |
| LLVM_DEBUG(llvm::dbgs() << "HAL recordPushBindings: " |
| << *dispatchOp.getOperation() << "\n"); |
| uint32_t setOrdinal = 0; |
| uint32_t bindingOrdinal = 0; |
| SmallVector<IREE::HAL::DescriptorSetBindingValue, 4> bindings; |
| auto zeroOffset = |
| rewriter.createOrFold<mlir::ConstantIndexOp>(dispatchOp.getLoc(), 0); |
| auto pushBinding = [&](Value tensorValue) -> LogicalResult { |
| auto &bufferRange = bufferSet.rangeMap[tensorValue]; |
| assert(bufferRange.buffer && "buffer not preallocated"); |
| auto value = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| dispatchOp.getLoc(), tensorValue, bufferRange.buffer, rewriter); |
| if (!value.hasValue()) { |
| return dispatchOp.emitOpError() << "cannot create adaptor for tensor"; |
| } |
| auto byteLength = value->getByteLength(); |
| if (!byteLength) return failure(); |
| |
| bindings.push_back(std::make_tuple(bindingOrdinal++, value->getBuffer(), |
| zeroOffset, byteLength)); |
| return success(); |
| }; |
| for (auto it : llvm::enumerate(dispatchOp.operands())) { |
| LLVM_DEBUG(llvm::dbgs() |
| << " + OPERAND(" << it.index() << "): " << it.value() << "\n"); |
| if (it.value().getType().isa<TensorType>()) { |
| if (failed(pushBinding(it.value()))) { |
| return failure(); |
| } |
| } |
| } |
| for (auto it : llvm::enumerate(dispatchOp.results())) { |
| LLVM_DEBUG(llvm::dbgs() |
| << " + RESULT(" << it.index() << "): " << it.value() << "\n"); |
| if (failed(pushBinding(it.value()))) { |
| return failure(); |
| } |
| } |
| rewriter.create<IREE::HAL::CommandBufferPushDescriptorSetOp>( |
| dispatchOp.getLoc(), commandBuffer, executableLayout, setOrdinal, |
| bindings); |
| return success(); |
| } |
| |
| // Records a dispatch operation. |
| static LogicalResult recordDispatch(Value device, Value commandBuffer, |
| IREE::Flow::DispatchOp &dispatchOp, |
| BufferSet &bufferSet, |
| ConversionPatternRewriter &rewriter) { |
| // Get the handle to the executable that is compatible with our device. |
| auto executableOp = |
| cast<IREE::HAL::ExecutableOp>(SymbolTable::lookupNearestSymbolFrom( |
| dispatchOp, dispatchOp.executable())); |
| |
| // TODO(benvanik): support multiple interfaces. We'd probably want to |
| // store each executable+interface as a variable. |
| auto interfaceOp = executableOp.getInterfaceOp(); |
| auto executableLayout = |
| rewriter.createOrFold<IREE::HAL::ExecutableLayoutLookupOp>( |
| dispatchOp.getLoc(), |
| IREE::HAL::ExecutableLayoutType::get(device.getContext()), device, |
| interfaceOp.getExecutableSetLayoutsAttr(), |
| interfaceOp.push_constantsAttr()); |
| |
| // Setup push constants for any dynamic values we need to pass across at |
| // runtime. |
| recordPushConstants(device, commandBuffer, dispatchOp, interfaceOp, |
| executableLayout, rewriter); |
| |
| // Setup bindings, right now pushed immediately but soon to be replaced |
| // with descriptor sets (or something better, anyway). |
| if (failed(recordPushBindings(device, commandBuffer, dispatchOp, |
| executableLayout, bufferSet, rewriter))) { |
| return failure(); |
| } |
| // Marshal tensor operands/results in to the state so that backends can |
| // read/write them as they need. |
| SmallVector<Optional<IREE::HAL::TensorRewriteAdaptor>, 4> operandAdaptors; |
| for (int i = 0; i < dispatchOp.getNumOperands(); ++i) { |
| auto value = dispatchOp.getOperand(i); |
| if (!value.getType().isa<TensorType>()) continue; |
| auto &bufferRange = bufferSet.rangeMap[value]; |
| assert(bufferRange.buffer && "operand buffer not allocated"); |
| auto adaptor = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| dispatchOp.getLoc(), value, bufferRange.buffer, rewriter); |
| if (!adaptor.hasValue()) { |
| return dispatchOp.emitOpError() |
| << "cannot create adaptor for tensor operand"; |
| } |
| operandAdaptors.emplace_back(adaptor.getValue()); |
| } |
| |
| SmallVector<Optional<IREE::HAL::TensorRewriteAdaptor>, 4> resultAdaptors; |
| for (int i = 0; i < dispatchOp.getNumResults(); ++i) { |
| auto value = dispatchOp.getResult(i); |
| if (!value.getType().isa<TensorType>()) continue; |
| auto &bufferRange = bufferSet.rangeMap[value]; |
| assert(bufferRange.buffer && "result buffer not preallocated"); |
| auto adaptor = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| dispatchOp.getLoc(), value, bufferRange.buffer, rewriter); |
| if (!adaptor.hasValue()) { |
| return dispatchOp.emitOpError() |
| << "cannot create adaptor for tensor result"; |
| ; |
| } |
| resultAdaptors.emplace_back(adaptor.getValue()); |
| } |
| |
| IREE::HAL::TargetBackend::DispatchState dispatchState; |
| dispatchState.dispatchOp = dispatchOp; |
| dispatchState.executableOp = executableOp; |
| dispatchState.device = device; |
| dispatchState.commandBuffer = commandBuffer; |
| dispatchState.executableLayout = executableLayout; |
| dispatchState.workload = rewriter.getRemappedValue(dispatchOp.workload()); |
| // TODO(benvanik): support extended push constants. |
| dispatchState.basePushConstantOffset = 0; |
| dispatchState.operands = operandAdaptors; |
| dispatchState.results = resultAdaptors; |
| |
| // Ask each target backend to record their dispatch logic. |
| IREE::HAL::DeviceSwitchBuilder switchBuilder(dispatchOp.getLoc(), |
| /*resultTypes=*/TypeRange{}, |
| device, rewriter); |
| for (auto targetOp : |
| executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) { |
| for (auto &targetBackend : IREE::HAL::matchTargetBackends( |
| {targetOp.target_backend_filter().str()})) { |
| auto entryPointOps = |
| targetOp.getBlock().getOps<IREE::HAL::ExecutableEntryPointOp>(); |
| if (entryPointOps.empty()) { |
| return dispatchOp.emitOpError() << "need at least one entry point"; |
| } |
| // Use the first (possibly only) entry point op. If the target split the |
| // original entry point into multiple entry points then it should |
| // sequence them together during the call to |recordDispatch| below. |
| dispatchState.entryPointOp = *entryPointOps.begin(); |
| |
| if (failed(targetBackend->recordDispatch(dispatchOp.getLoc(), |
| dispatchState, switchBuilder))) { |
| return dispatchOp.emitError() |
| << "unable to record dispatch for target backend " |
| << targetBackend->name(); |
| } |
| } |
| } |
| switchBuilder.build(); |
| |
| // Full barriers for now as we aren't scheduling things. |
| // TODO(benvanik): don't add at the end of the command buffer (we could |
| // also do a canonicalization step that removed trailing barriers). |
| recordFullExecutionBarrier(commandBuffer, dispatchOp.getLoc(), rewriter); |
| return success(); |
| } |
| |
| static LogicalResult recordTensorUpdate(Value device, Value commandBuffer, |
| IREE::Flow::TensorUpdateOp &updateOp, |
| BufferSet &bufferSet, |
| ConversionPatternRewriter &rewriter) { |
| auto &updateBuffer = bufferSet.rangeMap[updateOp.update()]; |
| auto &targetBuffer = bufferSet.rangeMap[updateOp.target()]; |
| auto &resultBuffer = bufferSet.rangeMap[updateOp.result()]; |
| |
| // TODO(benvanik): use something other than the BufferRange::buffer? |
| // This may require us to subview the buffer first. |
| auto update = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| updateOp.getLoc(), updateOp.update(), updateBuffer.buffer, rewriter); |
| auto target = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| updateOp.getLoc(), updateOp.target(), targetBuffer.buffer, rewriter); |
| auto result = IREE::HAL::TensorRewriteAdaptor::getChecked( |
| updateOp.getLoc(), updateOp.result(), resultBuffer.buffer, rewriter); |
| if (!update.hasValue() || !target.hasValue() || !result.hasValue()) { |
| return updateOp.emitOpError() |
| << "cannot create adaptors for tensor update operands/results"; |
| } |
| |
| auto zeroOffset = |
| rewriter.createOrFold<mlir::ConstantIndexOp>(updateOp.getLoc(), 0); |
| |
| // Compute the size of the update range. |
| auto startIndices = llvm::to_vector<4>(llvm::map_range( |
| updateOp.start_indices(), |
| [&](Value value) { return rewriter.getRemappedValue(value); })); |
| auto shapeDims = update->getShapeDims(); |
| if (!shapeDims) return failure(); |
| auto targetRange = |
| target->computeRange(startIndices, *update->getShapeDims()); |
| if (!targetRange) return failure(); |
| |
| // TODO(benvanik): actual buffer allocation so we aren't doing this copy. |
| auto targetByteLength = target->getByteLength(); |
| if (!targetByteLength) return failure(); |
| |
| rewriter.create<IREE::HAL::CommandBufferCopyBufferOp>( |
| updateOp.getLoc(), commandBuffer, target->getBuffer(), zeroOffset, |
| result->getBuffer(), zeroOffset, targetByteLength); |
| // TODO(benvanik): slice left/mid/right, but really just don't do this. |
| recordFullExecutionBarrier(commandBuffer, updateOp.getLoc(), rewriter); |
| rewriter.create<IREE::HAL::CommandBufferCopyBufferOp>( |
| updateOp.getLoc(), commandBuffer, update->getBuffer(), zeroOffset, |
| result->getBuffer(), targetRange->offset, targetRange->length); |
| |
| // Full barriers for now as we aren't scheduling things. |
| // TODO(benvanik): don't add at the end of the command buffer (we could |
| // also do a canonicalization step that removed trailing barriers). |
| recordFullExecutionBarrier(commandBuffer, updateOp.getLoc(), rewriter); |
| return success(); |
| } |
| |
| static LogicalResult recordStreamCommands(Value device, Value commandBuffer, |
| Block &streamBlock, |
| BufferSet &bufferSet, |
| ConversionPatternRewriter &rewriter) { |
| for (auto &op : streamBlock) { |
| if (auto dispatchOp = dyn_cast<IREE::Flow::DispatchOp>(op)) { |
| if (failed(recordDispatch(device, commandBuffer, dispatchOp, bufferSet, |
| rewriter))) { |
| return failure(); |
| } |
| } else if (auto updateOp = dyn_cast<IREE::Flow::TensorUpdateOp>(op)) { |
| if (failed(recordTensorUpdate(device, commandBuffer, updateOp, bufferSet, |
| rewriter))) { |
| return failure(); |
| } |
| } else if (auto returnOp = dyn_cast<IREE::Flow::ReturnOp>(op)) { |
| // No-op; handled by the buffer allocation. |
| } else if (isNoOp(&op) || isIdentityOp(&op)) { |
| // No work to perform. For identity ops, all buffers have been pushed |
| // to "real" ops. |
| } else { |
| return op.emitOpError() << "unexpected in stream"; |
| } |
| } |
| return success(); |
| } |
| |
| class ExStreamFragmentOpConversion |
| : public OpConversionPattern<IREE::Flow::ExStreamFragmentOp> { |
| public: |
| using OpConversionPattern< |
| IREE::Flow::ExStreamFragmentOp>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| IREE::Flow::ExStreamFragmentOp streamOp, llvm::ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| // TODO(benvanik): choose buffer mode/category based on stream commands. |
| auto mode = IREE::HAL::CommandBufferModeBitfield::OneShot; |
| auto category = IREE::HAL::CommandCategoryBitfield::Dispatch | |
| IREE::HAL::CommandCategoryBitfield::Transfer; |
| |
| // We'll use this buffer set to track the original and converted tensors |
| // and buffers during conversion. Ideally we'd run some fancy allocation |
| // analysis first to produce it. |
| auto device = |
| rewriter.createOrFold<IREE::HAL::ExSharedDeviceOp>(streamOp.getLoc()); |
| auto allocator = |
| rewriter.create<IREE::HAL::DeviceAllocatorOp>(streamOp.getLoc(), device) |
| .getResult(); |
| BufferSet bufferSet{allocator}; |
| |
| // Remap non-tensor operands (such as workloads). |
| auto &entryBlock = streamOp.body().front(); |
| for (int i = 0; i < operands.size(); ++i) { |
| if (streamOp.getOperand(i).getType().isa<TensorType>()) { |
| bufferSet.rangeMap[entryBlock.getArgument(i)] = |
| BufferRange{operands[i]}; |
| } else { |
| rewriter.replaceUsesOfBlockArgument(entryBlock.getArgument(i), |
| operands[i]); |
| } |
| } |
| |
| // Allocate buffers for outputs and transient buffers. |
| allocateOutputBuffers(streamOp, bufferSet, rewriter); |
| allocateTransientBuffers(streamOp, bufferSet, rewriter); |
| |
| // Allocate and begin the command buffer. |
| // In a real version we would want to pick the device based on the placement |
| // information attached to the stream. |
| auto commandBuffer = |
| rewriter.createOrFold<IREE::HAL::CommandBufferCreateOp>( |
| streamOp.getLoc(), device, mode, category); |
| rewriter.create<IREE::HAL::CommandBufferBeginOp>(streamOp.getLoc(), |
| commandBuffer); |
| |
| // Record all of the commands into the command buffer. |
| if (failed(recordStreamCommands(device, commandBuffer, entryBlock, |
| bufferSet, rewriter))) { |
| return failure(); |
| } |
| |
| // End and submit the command buffer. |
| // In a real version we'd want to setup a semaphore chain instead of |
| // submitting and waiting. |
| rewriter.create<IREE::HAL::CommandBufferEndOp>(streamOp.getLoc(), |
| commandBuffer); |
| rewriter.create<IREE::HAL::ExSubmitAndWaitOp>(streamOp.getLoc(), device, |
| commandBuffer); |
| |
| // It's annoying, but we need to do this replacement at the very end as |
| // otherwise we lose access to the original values (which we need for |
| // shape information). |
| for (int i = 0; i < operands.size(); ++i) { |
| if (operands[i].getType().isa<IREE::HAL::BufferType>()) { |
| rewriter.replaceUsesOfBlockArgument(entryBlock.getArgument(i), |
| operands[i]); |
| } |
| } |
| |
| rewriter.replaceOp(streamOp, bufferSet.outputBuffers); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void populateFlowStreamToHALPatterns(MLIRContext *context, |
| OwningRewritePatternList &patterns, |
| TypeConverter &converter) { |
| patterns.insert<ExStreamFragmentOpConversion>(context); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |