blob: a77a52f87383315e0714aa50229db88534b42a9b [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/ConvertStreamToHAL.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
static Value makeElementType(Location loc, Type elementType,
OpBuilder &builder) {
auto i32Value = IREE::HAL::getElementTypeValue(elementType);
assert(i32Value.hasValue() && "unhandled element type for allocation");
auto constantValue =
builder.createOrFold<arith::ConstantIntOp>(loc, i32Value.getValue(), 32);
return constantValue;
}
static Value makeEncodingType(Location loc, Attribute encodingType,
OpBuilder &builder) {
auto i32Value = IREE::HAL::getEncodingTypeValue(encodingType);
assert(i32Value.hasValue() && "unhandled encoding type for allocation");
auto constantValue =
builder.createOrFold<arith::ConstantIntOp>(loc, i32Value.getValue(), 32);
return constantValue;
}
static Value lookupDeviceFor(Operation *op, OpBuilder &builder) {
// TODO(benvanik): make this do multi-device lookup and other fancy things.
auto lookupOp = builder.create<IREE::HAL::ExSharedDeviceOp>(op->getLoc());
return lookupOp.result();
}
static Value lookupAllocatorFor(Operation *op, OpBuilder &builder) {
auto device = lookupDeviceFor(op, builder);
auto allocatorOp =
builder.create<IREE::HAL::DeviceAllocatorOp>(op->getLoc(), device);
return allocatorOp.result();
}
// Scans all of the stream.cmd.* ops in the region to derive a command category.
static IREE::HAL::CommandCategoryBitfield deriveCommandCategories(
Region &region) {
auto bits = IREE::HAL::CommandCategoryBitfield::None;
for (auto &block : region) {
for (auto &op : block) {
if (isa<IREE::Stream::CmdDispatchOp>(op)) {
bits = bits | IREE::HAL::CommandCategoryBitfield::Dispatch;
} else {
bits = bits | IREE::HAL::CommandCategoryBitfield::Transfer;
}
for (auto &nestedRegion : op.getRegions()) {
bits = bits | deriveCommandCategories(nestedRegion);
}
}
}
return bits;
}
// Maps a resource type to the corresponding HAL memory types and buffer usage.
// This will fail if the resource type is not directly mappable to HAL bits.
// The bits set here are those that must be set for the buffer to be used as the
// buffer within the program with its defined resource lifetime.
static LogicalResult deriveRequiredResourceBufferBits(
Location loc, IREE::Stream::ResourceType resourceType,
IREE::HAL::MemoryTypeBitfield &memoryTypes,
IREE::HAL::BufferUsageBitfield &bufferUsage) {
memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
bufferUsage = IREE::HAL::BufferUsageBitfield::None;
switch (resourceType.getLifetime()) {
default:
return mlir::emitError(loc)
<< "unsupported resource lifetime: "
<< IREE::Stream::stringifyLifetime(resourceType.getLifetime());
case IREE::Stream::Lifetime::Constant:
// Device local; copies required to get into external resources.
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Constant;
break;
case IREE::Stream::Lifetime::Variable:
// Device local; copies required to get into external resources.
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
break;
case IREE::Stream::Lifetime::External:
// We only require device-visible for external buffers (as we don't today
// do anything else with them on the host). They may be mappable for user
// convenience. Ideally they would have been placed in device-local memory
// but so long as they are device visible the program will execute
// correctly.
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceVisible;
break;
case IREE::Stream::Lifetime::Staging:
// Host local; copies required to get into device resources.
// We could vary this based on staging usage (upload/download) by
// making it device-local|host-visible, but host-local means we have
// a better chance of mapping it during uploads.
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::HostLocal |
IREE::HAL::MemoryTypeBitfield::DeviceVisible;
bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer |
IREE::HAL::BufferUsageBitfield::Mapping;
break;
case IREE::Stream::Lifetime::Transient:
// Device local; copies required to get into external resources.
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
IREE::HAL::MemoryTypeBitfield::Transient;
break;
}
// TODO(benvanik): refine usage based on analysis.
bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer |
IREE::HAL::BufferUsageBitfield::Dispatch;
return success();
}
// Maps a resource type to the corresponding HAL memory types and buffer usage.
// This will fail if the resource type is not directly mappable to HAL bits.
// The bits set here represent the superset of required and allowed bits and
// are useful for providing buffers back to users via the ABI that may need to
// be used for more than just what the internal program requires.
static LogicalResult deriveAllowedResourceBufferBits(
Location loc, IREE::Stream::ResourceType resourceType,
IREE::HAL::MemoryTypeBitfield &memoryTypes,
IREE::HAL::BufferUsageBitfield &bufferUsage) {
memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
bufferUsage = IREE::HAL::BufferUsageBitfield::None;
if (failed(deriveRequiredResourceBufferBits(loc, resourceType, memoryTypes,
bufferUsage))) {
return failure();
}
switch (resourceType.getLifetime()) {
default:
break;
case IREE::Stream::Lifetime::External:
// #yolo; these come from/go to outside the program.
// Today we assume they are device-local|host-visible just for
// practical purposes but that does not have to be true. We really
// want this to be something we analyze and handle on the edges
// (transfering devices/etc if needed).
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
IREE::HAL::MemoryTypeBitfield::HostVisible;
// NOTE: we may not map it but users may after they get them back.
// Another reason we should annotate this - having a buffer be
// mappable is potentially expensive (may get a 2nd copy in memory!).
bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping;
break;
}
return success();
}
class StreamConversionMapping {
public:
// Maps the stream dialect |executeOp| to the hal dialect |commandBuffer|
// value used during recording. Patterns can use this to find the SSA value
// they need to make hal.command_buffer.* ops.
void mapCommandBuffer(IREE::Stream::CmdExecuteOp executeOp,
Value commandBuffer) {
assert(commandBuffers.insert(std::make_pair(executeOp, commandBuffer))
.second &&
"multiple command buffers cannot be registered for the same op");
// Map all ops nested within the command buffer so we can query later.
executeOp.walk([&](Operation *op) {
commandBuffers.insert(std::make_pair(op, commandBuffer));
return WalkResult::advance();
});
}
// Looks up a mapped command buffer SSA value that can be used by the given
// stream.cmd.* op.
Value lookupCommandBufferFor(Operation *cmdOp) const {
auto it = commandBuffers.find(cmdOp);
assert(it != commandBuffers.end() &&
"command buffer must have been registered during conversion");
return it->second;
}
private:
// Ops within stream.cmd.execute ops -> !hal.command_buffer.
DenseMap<Operation *, Value> commandBuffers;
};
template <typename OpT>
struct StreamConversionPattern : public OpConversionPattern<OpT> {
StreamConversionPattern(std::shared_ptr<StreamConversionMapping> mapping,
TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern<OpT>(typeConverter, context, benefit),
mapping(std::move(mapping)) {}
std::shared_ptr<StreamConversionMapping> mapping;
};
struct ResourceAllocOpPattern
: public StreamConversionPattern<IREE::Stream::ResourceAllocOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::ResourceAllocOp allocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto allocator = lookupAllocatorFor(allocOp, rewriter);
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();
SmallVector<Value> results;
for (auto it : llvm::zip(allocOp.results(), allocOp.storage_sizes())) {
auto resourceResult = std::get<0>(it);
auto resourceType =
resourceResult.getType().cast<IREE::Stream::ResourceType>();
auto storageSize = std::get<1>(it);
auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
auto bufferUsage = IREE::HAL::BufferUsageBitfield::None;
if (failed(deriveAllowedResourceBufferBits(allocOp.getLoc(), resourceType,
memoryTypes, bufferUsage))) {
return failure();
}
auto allocateOp = rewriter.create<IREE::HAL::AllocatorAllocateOp>(
allocOp.getLoc(), bufferType, allocator, memoryTypes, bufferUsage,
storageSize);
results.push_back(allocateOp.result());
}
rewriter.replaceOp(allocOp, results);
return success();
}
};
struct ResourceAllocaOpPattern
: public StreamConversionPattern<IREE::Stream::ResourceAllocaOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::ResourceAllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto allocator = lookupAllocatorFor(allocaOp, rewriter);
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();
// Transient allocations are device-local. Copies are required to get their
// contents back on the host/another device.
auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal |
IREE::HAL::MemoryTypeBitfield::Transient;
// TODO(benvanik): refine usage.
// We know by construction that transient buffers are not host visible and
// as such can only be used for device commands. We should be able to more
// closely limit to just dispatch or transfer though.
auto bufferUsage = IREE::HAL::BufferUsageBitfield::Dispatch |
IREE::HAL::BufferUsageBitfield::Transfer;
auto allocateOp = rewriter.create<IREE::HAL::AllocatorAllocateOp>(
allocaOp.getLoc(), bufferType, allocator, memoryTypes, bufferUsage,
allocaOp.storage_size());
// TODO(benvanik): stream ordered allocations.
auto resolvedTimepoint =
rewriter.create<arith::ConstantIndexOp>(allocaOp.getLoc(), 0)
.getResult();
rewriter.replaceOp(allocaOp, {allocateOp.result(), resolvedTimepoint});
return success();
}
};
struct ResourceDeallocaOpPattern
: public StreamConversionPattern<IREE::Stream::ResourceDeallocaOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::ResourceDeallocaOp deallocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): stream ordered allocations.
auto resolvedTimepoint =
rewriter.create<arith::ConstantIndexOp>(deallocaOp.getLoc(), 0)
.getResult();
rewriter.replaceOp(deallocaOp, {resolvedTimepoint});
return success();
}
};
struct ResourceSizeOpPattern
: public StreamConversionPattern<IREE::Stream::ResourceSizeOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::ResourceSizeOp sizeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::HAL::BufferLengthOp>(
sizeOp, rewriter.getIndexType(), adaptor.operand());
return success();
}
};
struct ResourceMapOpPattern
: public StreamConversionPattern<IREE::Stream::ResourceMapOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::ResourceMapOp mapOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto allocator = lookupAllocatorFor(mapOp, rewriter);
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();
// We know this is a staging buffer. We could refine usage here by seeing
// whether this was upload or download.
auto memoryTypes = IREE::HAL::MemoryTypeBitfield::HostLocal |
IREE::HAL::MemoryTypeBitfield::DeviceVisible;
auto bufferUsage = IREE::HAL::BufferUsageBitfield::Mapping |
IREE::HAL::BufferUsageBitfield::Transfer;
rewriter.replaceOpWithNewOp<IREE::HAL::AllocatorMapOp>(
mapOp, bufferType, allocator, memoryTypes, bufferUsage,
adaptor.source(), adaptor.source_offset(), adaptor.result_size());
return success();
}
};
struct ResourceTryMapOpPattern
: public StreamConversionPattern<IREE::Stream::ResourceTryMapOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::ResourceTryMapOp tryMapOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto allocator = lookupAllocatorFor(tryMapOp, rewriter);
auto resourceType =
tryMapOp.result().getType().cast<IREE::Stream::ResourceType>();
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();
auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
auto bufferUsage = IREE::HAL::BufferUsageBitfield::None;
switch (resourceType.getLifetime()) {
default:
return tryMapOp.emitOpError()
<< "unsupported resource lifetime: "
<< IREE::Stream::stringifyLifetime(resourceType.getLifetime());
case IREE::Stream::Lifetime::Constant:
// Device local; copies required to get into external resources.
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Constant;
break;
case IREE::Stream::Lifetime::Staging:
// Host local; copies required to get into device resources.
// We could vary this based on staging usage (upload/download) by
// making it device-local|host-visible, but host-local means we have
// a better chance of mapping it during uploads.
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::HostLocal |
IREE::HAL::MemoryTypeBitfield::DeviceVisible;
bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer |
IREE::HAL::BufferUsageBitfield::Mapping;
break;
}
// TODO(benvanik): refine usage based on analysis.
bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Transfer |
IREE::HAL::BufferUsageBitfield::Dispatch;
rewriter.replaceOpWithNewOp<IREE::HAL::AllocatorTryMapOp>(
tryMapOp, rewriter.getI1Type(), bufferType, allocator, memoryTypes,
bufferUsage, adaptor.source(), adaptor.source_offset(),
adaptor.result_size());
return success();
}
};
struct ResourceLoadOpPattern
: public StreamConversionPattern<IREE::Stream::ResourceLoadOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::ResourceLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loadType = getTypeConverter()->convertType(loadOp.result().getType());
rewriter.replaceOpWithNewOp<IREE::HAL::BufferLoadOp>(
loadOp, loadType, adaptor.source(), adaptor.source_offset());
return success();
}
};
struct ResourceStoreOpPattern
: public StreamConversionPattern<IREE::Stream::ResourceStoreOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::ResourceStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.create<IREE::HAL::BufferStoreOp>(storeOp.getLoc(), adaptor.value(),
adaptor.target(),
adaptor.target_offset());
rewriter.replaceOp(storeOp, adaptor.target());
return success();
}
};
struct ResourceSubviewOpPattern
: public StreamConversionPattern<IREE::Stream::ResourceSubviewOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::ResourceSubviewOp subviewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();
// NOTE: this aliases! We assume at this point all useful alias analysis
// has been performed and it's fine to lose the tie information here.
rewriter.replaceOpWithNewOp<IREE::HAL::BufferSubspanOp>(
subviewOp, bufferType, adaptor.source(), adaptor.source_offset(),
adaptor.result_size());
return success();
}
};
// Inserts IR to assert that the underlying buffer storage is compatible with
// the intended usage in the program. The allocator used to allocate the
// buffer must have compatibility with our target device allocator and the
// buffer must have at least the minimum expected size (additional padding is
// ok).
static LogicalResult buildStorageAssertions(
Location loc, Value buffer, StringAttr message, Value allocator,
Value minimumLength, IREE::Stream::ResourceType resourceType,
OpBuilder &builder) {
auto memoryTypes = IREE::HAL::MemoryTypeBitfield::None;
auto bufferUsage = IREE::HAL::BufferUsageBitfield::None;
if (failed(deriveRequiredResourceBufferBits(loc, resourceType, memoryTypes,
bufferUsage))) {
return failure();
}
auto requiredTypes =
IREE::HAL::MemoryTypeBitfieldAttr::get(builder.getContext(), memoryTypes);
auto requiredUsage = IREE::HAL::BufferUsageBitfieldAttr::get(
builder.getContext(), bufferUsage);
builder.create<IREE::HAL::BufferAssertOp>(loc, buffer, message, allocator,
minimumLength, requiredTypes,
requiredUsage);
return success();
}
struct TensorImportBufferOpPattern
: public StreamConversionPattern<IREE::Stream::TensorImportOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TensorImportOp importOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!importOp.source().getType().isa<IREE::HAL::BufferType>()) {
return failure();
}
// TODO(benvanik): get a name for the tensor (argument name/etc).
auto message = rewriter.getStringAttr("tensor");
// Directly use the buffer.
auto buffer = adaptor.source();
rewriter.replaceOp(importOp, buffer);
// Assert the storage is compatible with our expected device and usage.
auto targetAllocator = lookupAllocatorFor(importOp, rewriter);
auto resourceType =
importOp.result().getType().cast<IREE::Stream::ResourceType>();
if (failed(buildStorageAssertions(
importOp.getLoc(), adaptor.source(), message, targetAllocator,
adaptor.result_size(), resourceType, rewriter))) {
return failure();
}
return success();
}
};
struct TensorImportBufferViewOpPattern
: public StreamConversionPattern<IREE::Stream::TensorImportOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TensorImportOp importOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto sourceType = importOp.source().getType();
if (!sourceType.isa<IREE::HAL::BufferViewType>() &&
!sourceType.isa<TensorType>()) {
return failure();
}
auto loc = importOp.getLoc();
// TODO(benvanik): get a name for the tensor (argument name/etc).
auto message = rewriter.getStringAttr("tensor");
auto bufferView = adaptor.source();
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();
auto bufferOp = rewriter.replaceOpWithNewOp<IREE::HAL::BufferViewBufferOp>(
importOp, bufferType, bufferView);
// Assert the storage is compatible with our expected device and usage.
auto targetAllocator = lookupAllocatorFor(importOp, rewriter);
auto resourceType =
importOp.result().getType().cast<IREE::Stream::ResourceType>();
if (failed(buildStorageAssertions(loc, bufferOp.result(), message,
targetAllocator, adaptor.result_size(),
resourceType, rewriter))) {
return failure();
}
return success();
}
};
struct TensorExportBufferOpPattern
: public StreamConversionPattern<IREE::Stream::TensorExportOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TensorExportOp exportOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!exportOp.result().getType().isa<IREE::HAL::BufferType>()) {
return failure();
}
rewriter.replaceOp(exportOp, adaptor.source());
return success();
}
};
struct TensorExportBufferViewOpPattern
: public StreamConversionPattern<IREE::Stream::TensorExportOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TensorExportOp exportOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto targetType = exportOp.result().getType();
if (!targetType.isa<IREE::HAL::BufferViewType>() &&
!targetType.isa<TensorType>()) {
return failure();
}
auto loc = exportOp.getLoc();
auto tensorType = adaptor.source_encoding().cast<RankedTensorType>();
auto dynamicDims = adaptor.source_encoding_dims();
// NOTE: we should have verified supported encodings/types at entry into the
// HAL pipeline.
auto encodingType =
IREE::HAL::getEncodingTypeValue(tensorType.getEncoding());
assert(encodingType.hasValue() && "invalid tensor encoding");
auto elementType =
IREE::HAL::getElementTypeValue(tensorType.getElementType());
assert(elementType.hasValue() && "invalid tensor element type");
// Flatten static + dynamic shape dimensions.
SmallVector<Value> dims;
unsigned dynamicIdx = 0;
for (int64_t idx = 0; idx < tensorType.getRank(); ++idx) {
if (tensorType.isDynamicDim(idx)) {
dims.push_back(dynamicDims[dynamicIdx++]);
} else {
dims.push_back(rewriter.create<arith::ConstantIndexOp>(
loc, tensorType.getDimSize(idx)));
}
}
rewriter.replaceOpWithNewOp<IREE::HAL::BufferViewCreateOp>(
exportOp, adaptor.source(), elementType.getValue(),
encodingType.getValue(), dims);
return success();
}
};
struct TensorTraceOpPattern
: public StreamConversionPattern<IREE::Stream::TensorTraceOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TensorTraceOp traceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::HAL::BufferViewTraceOp>(
traceOp, traceOp.keyAttr(), adaptor.operands());
return success();
}
};
struct CmdFlushOpPattern
: public StreamConversionPattern<IREE::Stream::CmdFlushOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::CmdFlushOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): HAL command buffer op for flush.
rewriter.eraseOp(op);
return success();
}
};
struct CmdInvalidateOpPattern
: public StreamConversionPattern<IREE::Stream::CmdInvalidateOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::CmdInvalidateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): HAL command buffer op for invalidate.
rewriter.eraseOp(op);
return success();
}
};
struct CmdDiscardOpPattern
: public StreamConversionPattern<IREE::Stream::CmdDiscardOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::CmdDiscardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): HAL command buffer op for discard.
rewriter.eraseOp(op);
return success();
}
};
struct CmdFillOpPattern
: public StreamConversionPattern<IREE::Stream::CmdFillOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::CmdFillOp fillOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto commandBuffer = mapping->lookupCommandBufferFor(fillOp);
rewriter.replaceOpWithNewOp<IREE::HAL::CommandBufferFillBufferOp>(
fillOp, commandBuffer, adaptor.target(), adaptor.target_offset(),
adaptor.target_length(), adaptor.value());
return success();
}
};
struct CmdCopyOpPattern
: public StreamConversionPattern<IREE::Stream::CmdCopyOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::CmdCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto commandBuffer = mapping->lookupCommandBufferFor(op);
rewriter.replaceOpWithNewOp<IREE::HAL::CommandBufferCopyBufferOp>(
op, commandBuffer, adaptor.source(), adaptor.source_offset(),
adaptor.target(), adaptor.target_offset(), adaptor.length());
return success();
}
};
struct CmdDispatchOpPattern
: public StreamConversionPattern<IREE::Stream::CmdDispatchOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::CmdDispatchOp dispatchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = dispatchOp.getLoc();
auto commandBuffer = mapping->lookupCommandBufferFor(dispatchOp);
// Get the device handle we're executing against in this execution region.
// Note that this is a dynamic value: we have to treat the device as unknown
// here.
auto device = rewriter.create<IREE::HAL::CommandBufferDeviceOp>(
loc, rewriter.getType<IREE::HAL::DeviceType>(), commandBuffer);
// Get the handle to the executable that is compatible with our device.
auto executableOp =
cast<IREE::HAL::ExecutableOp>(SymbolTable::lookupNearestSymbolFrom(
dispatchOp, dispatchOp.entry_point().getRootReference()));
assert(executableOp && "dispatch target executable op not found");
// Ask each target backend to record their dispatch logic.
IREE::HAL::DeviceSwitchRewriter switchRewriter(loc,
/*resultTypes=*/TypeRange{},
device, rewriter);
for (auto variantOp :
executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) {
auto entryPointOps =
variantOp.getOps<IREE::HAL::ExecutableEntryPointOp>();
auto entryPointIt = llvm::find_if(
entryPointOps, [&](IREE::HAL::ExecutableEntryPointOp op) {
return op.getNameAttr() ==
dispatchOp.entry_point().getLeafReference();
});
if (entryPointIt == entryPointOps.end()) {
return variantOp.emitError()
<< "hal.executable.variant is missing the flow entry point for "
<< dispatchOp.entry_point();
}
auto entryPointOp = *entryPointIt;
auto *region = switchRewriter.addConditionRegion(
variantOp.target().getMatchExpression());
auto &entryBlock = region->front();
auto caseBuilder = OpBuilder::atBlockBegin(&entryBlock);
// Record push constants and buffer bindings.
recordParameters(loc, device, commandBuffer, dispatchOp, adaptor,
entryPointOp.layout(), caseBuilder);
// Dispatch with a target-specific workgroup count.
auto entryPointSymRef =
SymbolRefAttr::get(caseBuilder.getContext(), executableOp.getName(),
{SymbolRefAttr::get(entryPointOp->getParentOp()),
SymbolRefAttr::get(entryPointOp)});
auto caseWorkgroupCount = entryPointOp.calculateWorkgroupCount(
loc, adaptor.workgroup_count(), caseBuilder);
caseBuilder.create<IREE::HAL::CommandBufferDispatchSymbolOp>(
loc, commandBuffer, entryPointSymRef, caseWorkgroupCount[0],
caseWorkgroupCount[1], caseWorkgroupCount[2]);
caseBuilder.create<IREE::HAL::ReturnOp>(loc);
}
switchRewriter.build();
rewriter.eraseOp(dispatchOp);
return success();
}
void recordParameters(Location loc, Value device, Value commandBuffer,
IREE::Stream::CmdDispatchOp dispatchOp,
OpAdaptor adaptor,
IREE::HAL::ExecutableLayoutAttr layoutAttr,
OpBuilder &builder) const {
auto executableLayout =
builder
.create<IREE::HAL::ExecutableLayoutLookupOp>(
loc, IREE::HAL::ExecutableLayoutType::get(loc.getContext()),
device, layoutAttr)
.result();
// Push constant values.
// TODO(#5322): symbolic push constant names on the hal.interface so we can
// sparsely pack these.
if (!adaptor.operands().empty()) {
int pushConstantBase = 0; // always 0 today
SmallVector<Value> pushConstants;
for (auto operand : adaptor.operands()) {
assert(
operand.getType().isInteger(32) &&
"expected only i32 values after iree-hal-pack-dispatch-operands");
pushConstants.push_back(operand);
}
builder.create<IREE::HAL::CommandBufferPushConstantsOp>(
loc, commandBuffer, executableLayout,
builder.getIndexAttr(pushConstantBase), pushConstants);
}
// TODO(benvanik): typed accessors for bindings.
auto bindingAttrs = dispatchOp->getAttr("hal.interface.bindings")
.dyn_cast_or_null<ArrayAttr>();
assert(bindingAttrs &&
"interface materialization must annotate dispatch sites");
// Push descriptor bindings.
int64_t currentSet = -1;
SmallVector<IREE::HAL::DescriptorSetBindingValue> bindings;
auto flushSet = [&]() {
builder.create<IREE::HAL::CommandBufferPushDescriptorSetOp>(
loc, commandBuffer, executableLayout, currentSet, bindings);
bindings.clear();
};
for (unsigned i = 0; i < adaptor.resources().size(); ++i) {
auto bindingAttr =
bindingAttrs[i].cast<IREE::HAL::InterfaceBindingAttr>();
int64_t set = bindingAttr.getSet();
if (currentSet != -1 && currentSet != set) flushSet();
currentSet = set;
IREE::HAL::DescriptorSetBindingValue binding;
binding.ordinal =
builder.create<arith::ConstantIndexOp>(loc, bindingAttr.getBinding());
binding.buffer = adaptor.resources()[i];
binding.byteOffset = adaptor.resource_offsets()[i];
binding.byteLength = adaptor.resource_lengths()[i];
bindings.push_back(binding);
}
if (currentSet != -1) flushSet();
}
};
static void insertSerializationBarriers(Location loc, Block &block,
Value commandBuffer,
OpBuilder builder) {
// TODO(benvanik): derive based on the type of operations that surround the
// barriers. Can use deriveCommandCategories on the ranges to see what kind
// of ops happen above and below, but really some analysis is required.
auto sourceStage = IREE::HAL::ExecutionStageBitfield::CommandRetire |
IREE::HAL::ExecutionStageBitfield::Transfer |
IREE::HAL::ExecutionStageBitfield::Dispatch;
auto targetStage = IREE::HAL::ExecutionStageBitfield::CommandIssue |
IREE::HAL::ExecutionStageBitfield::Transfer |
IREE::HAL::ExecutionStageBitfield::Dispatch;
auto flags = IREE::HAL::ExecutionBarrierFlagBitfield::None;
// Insert barriers after every op.
// Note that we can't mutate the block while iterating it so we first grab
// all the original ops.
SmallVector<Operation *> serialOps;
for (auto &op : block) serialOps.push_back(&op);
for (auto *op : serialOps) {
if (op->hasTrait<OpTrait::IsTerminator>()) continue;
builder.setInsertionPointAfter(op);
builder.create<IREE::HAL::CommandBufferExecutionBarrierOp>(
loc, commandBuffer, sourceStage, targetStage, flags);
}
}
struct CmdExecuteOpPattern
: public StreamConversionPattern<IREE::Stream::CmdExecuteOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::CmdExecuteOp executeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = executeOp.getLoc();
auto device = lookupDeviceFor(executeOp, rewriter);
// TODO(benvanik): disable inline execution once we have semaphores.
// We can look ahead to see if there's an await immediately to trigger the
// inline execution.
auto modes = IREE::HAL::CommandBufferModeBitfield::OneShot |
IREE::HAL::CommandBufferModeBitfield::AllowInlineExecution;
// Derive the command buffer type based on the kind of operations present.
// This can help the submission get routed to appropriate hardware queues
// (like dedicated DMA controllers).
auto commandCategories = deriveCommandCategories(executeOp.body());
// Create a new command buffer for recording. If we were
auto commandBuffer =
rewriter
.create<IREE::HAL::CommandBufferCreateOp>(
loc, rewriter.getType<IREE::HAL::CommandBufferType>(), device,
modes, commandCategories)
.result();
mapping->mapCommandBuffer(executeOp, commandBuffer);
// Run through the execution region and serialize execution by inserting
// barriers. Nested regions may elide barriers as needed.
auto &bodyBlock = executeOp.body().front();
insertSerializationBarriers(loc, bodyBlock, commandBuffer,
OpBuilder::atBlockBegin(&bodyBlock));
// Begin/end recording and inline the execution region between them.
rewriter.create<IREE::HAL::CommandBufferBeginOp>(loc, commandBuffer);
auto endOp =
rewriter.create<IREE::HAL::CommandBufferEndOp>(loc, commandBuffer);
rewriter.mergeBlockBefore(&executeOp.body().front(), endOp,
adaptor.operands());
// TODO(benvanik): we should queue a submit here with the semaphore instead.
rewriter.create<IREE::HAL::ExSubmitAndWaitOp>(loc, device, commandBuffer);
// TODO(benvanik): propagate semaphore information.
auto resolvedTimepoint =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
rewriter.replaceOp(executeOp, resolvedTimepoint);
return success();
}
};
struct CmdSerialOpPattern
: public StreamConversionPattern<IREE::Stream::CmdSerialOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::CmdSerialOp serialOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto commandBuffer = mapping->lookupCommandBufferFor(serialOp);
// Run through the execution region and serialize execution by inserting
// barriers. Nested regions may elide barriers as needed.
auto &bodyBlock = serialOp.body().front();
insertSerializationBarriers(serialOp.getLoc(), bodyBlock, commandBuffer,
OpBuilder::atBlockBegin(&bodyBlock));
// Inline the serial execution region.
rewriter.mergeBlockBefore(&serialOp.body().front(), serialOp);
rewriter.eraseOp(serialOp);
return success();
}
};
struct CmdConcurrentOpPattern
: public StreamConversionPattern<IREE::Stream::CmdConcurrentOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::CmdConcurrentOp concurrentOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Inline the concurrent execution region.
// TODO(benvanik): split barriers (event set/wait) when nesting.
rewriter.mergeBlockBefore(&concurrentOp.body().front(), concurrentOp);
rewriter.eraseOp(concurrentOp);
return success();
}
};
struct TimepointImmediateOpPattern
: public StreamConversionPattern<IREE::Stream::TimepointImmediateOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TimepointImmediateOp immediateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): model timepoints as semaphores.
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(immediateOp, 0);
return success();
}
};
struct TimepointImportOpPattern
: public StreamConversionPattern<IREE::Stream::TimepointImportOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TimepointImportOp importOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle imports from HAL semaphores.
auto operands = adaptor.operands();
if (operands.size() != 2 ||
!operands[0].getType().isa<IREE::HAL::SemaphoreType>() ||
!operands[1].getType().isIntOrIndex()) {
return rewriter.notifyMatchFailure(importOp,
"only imports from HAL semaphore + "
"sequence value tuples are supported");
}
// TODO(benvanik): model timepoints as semaphores.
// For now we just block on the semaphore.
auto awaitOp = rewriter.create<IREE::HAL::SemaphoreAwaitOp>(
importOp.getLoc(), rewriter.getI32Type(), operands[0], operands[1]);
rewriter.create<IREE::Util::StatusCheckOkOp>(
importOp.getLoc(), awaitOp.status(),
"failed to wait on imported semaphore");
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(importOp, 0);
return success();
}
};
struct TimepointExportOpPattern
: public StreamConversionPattern<IREE::Stream::TimepointExportOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TimepointExportOp exportOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle exports into HAL semaphores.
if (exportOp.getNumResults() != 2 ||
!exportOp.getResult(0).getType().isa<IREE::HAL::SemaphoreType>() ||
!exportOp.getResult(1).getType().isIntOrIndex()) {
return rewriter.notifyMatchFailure(exportOp,
"only exports to HAL semaphore + "
"sequence value tuples are supported");
}
auto loc = exportOp.getLoc();
auto device = lookupDeviceFor(exportOp, rewriter);
// TODO(benvanik): model timepoints as semaphores.
// For now we just create a signaled semaphore.
auto exportValue = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto exportSemaphore = rewriter.create<IREE::HAL::SemaphoreCreateOp>(
loc, rewriter.getType<IREE::HAL::SemaphoreType>(), device, exportValue);
rewriter.replaceOp(exportOp, {exportSemaphore, exportValue});
return success();
}
};
struct TimepointJoinOpPattern
: public StreamConversionPattern<IREE::Stream::TimepointJoinOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TimepointJoinOp joinOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): model timepoints as semaphores.
// This should be a max() of the operand timepoints. Could be done with
// affine expressions, but since everything is always 0 we just max(0,0)=0
// here :)
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(joinOp, 0);
return success();
}
};
struct TimepointAwaitOpPattern
: public StreamConversionPattern<IREE::Stream::TimepointAwaitOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::TimepointAwaitOp awaitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): model timepoints as semaphores.
rewriter.replaceOp(awaitOp, adaptor.operands());
return success();
}
};
struct ElideYieldOpPattern
: public StreamConversionPattern<IREE::Stream::YieldOp> {
using StreamConversionPattern::StreamConversionPattern;
LogicalResult matchAndRewrite(
IREE::Stream::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(yieldOp);
return success();
}
};
// Annoying we have to have this here, but there's no attribute converter
// equivalent we have access to so that we could do it in a generic way.
struct GlobalTimepointConversionPattern
: public OpConversionPattern<IREE::Util::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
IREE::Util::GlobalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto initialValue = op.initial_value();
if (!initialValue.hasValue()) return failure();
if (!initialValue->isa<IREE::Stream::TimepointAttr>()) return failure();
rewriter.updateRootInPlace(
op, [&]() { op.initial_valueAttr(rewriter.getIndexAttr(0)); });
return success();
}
};
} // namespace
void populateStreamToHALPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
conversionTarget.addIllegalDialect<IREE::Stream::StreamDialect>();
typeConverter.addConversion(
[=](IREE::Stream::ResourceType type, SmallVectorImpl<Type> &results) {
// Resources are just buffers (no shape/encoding/etc).
results.push_back(IREE::HAL::BufferType::get(context));
return success();
});
typeConverter.addConversion(
[=](IREE::Stream::TimepointType type, SmallVectorImpl<Type> &results) {
// TODO(benvanik): model timepoints as semaphores.
// This may become a !hal.semaphore + index, or some !hal.timepoint that
// we then do more analysis on once we know what devices are in use
// where.
results.push_back(IndexType::get(context));
return success();
});
// Spooky action at a distance:
patterns.insert<GlobalTimepointConversionPattern>(typeConverter, context);
auto mapping = std::make_shared<StreamConversionMapping>();
patterns.insert<ResourceAllocOpPattern, ResourceAllocaOpPattern,
ResourceDeallocaOpPattern, ResourceSizeOpPattern,
ResourceMapOpPattern, ResourceTryMapOpPattern,
ResourceLoadOpPattern, ResourceStoreOpPattern,
ResourceSubviewOpPattern>(mapping, typeConverter, context);
patterns.insert<TensorImportBufferOpPattern, TensorImportBufferViewOpPattern,
TensorExportBufferOpPattern, TensorExportBufferViewOpPattern,
TensorTraceOpPattern>(mapping, typeConverter, context);
patterns
.insert<CmdFlushOpPattern, CmdInvalidateOpPattern, CmdDiscardOpPattern,
CmdFillOpPattern, CmdCopyOpPattern, CmdDispatchOpPattern,
CmdExecuteOpPattern, CmdSerialOpPattern, CmdConcurrentOpPattern>(
mapping, typeConverter, context);
patterns.insert<TimepointImmediateOpPattern, TimepointImportOpPattern,
TimepointExportOpPattern, TimepointJoinOpPattern,
TimepointAwaitOpPattern>(mapping, typeConverter, context);
patterns.insert<ElideYieldOpPattern>(mapping, typeConverter, context);
}
} // namespace iree_compiler
} // namespace mlir