blob: fe756c89031f4e3369af57407f70cc46269c0b3f [file] [log] [blame] [edit]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Modules/HAL/Inline/IR/HALInlineDialect.h"
#include "iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir::iree_compiler {
namespace {
static Value getResourceSize(Location loc, Value resource, OpBuilder &builder) {
if (isa<IREE::HAL::BufferType>(resource.getType())) {
return builder.createOrFold<IREE::HAL::Inline::BufferLengthOp>(
loc, builder.getIndexType(), resource);
}
return builder.createOrFold<IREE::Util::BufferSizeOp>(
loc, builder.getIndexType(), resource);
}
struct Storage {
// Underlying storage buffer.
Value buffer;
// Total size of the storage buffer in bytes.
Value bufferSize;
};
static Storage getResourceStorage(Location loc, Value resource,
Value resourceSize, OpBuilder &builder) {
if (isa<IREE::HAL::BufferType>(resource.getType())) {
// Get the storage of the buffer; the returned buffer is already a subspan.
auto storageBuffer =
builder.createOrFold<IREE::HAL::Inline::BufferStorageOp>(loc, resource);
auto storageSize = getResourceSize(loc, resource, builder);
return {
storageBuffer,
storageSize,
};
}
return {
resource,
resourceSize,
};
}
struct ResourceAllocOpPattern
: public OpConversionPattern<IREE::Stream::ResourceAllocOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceAllocOp allocOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto deviceBufferType = rewriter.getType<IREE::HAL::BufferType>();
auto hostBufferType = rewriter.getType<IREE::Util::BufferType>();
// For now we don't have this information and assume something conservative.
Value minAlignment =
arith::ConstantIndexOp::create(rewriter, allocOp.getLoc(), 64);
auto allocateOp = IREE::HAL::Inline::BufferAllocateOp::create(
rewriter, allocOp.getLoc(), deviceBufferType, hostBufferType,
minAlignment, adaptor.getStorageSize());
rewriter.replaceOp(allocOp, allocateOp.getResult());
return success();
}
};
struct ResourceAllocaOpPattern
: public OpConversionPattern<IREE::Stream::ResourceAllocaOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceAllocaOp allocaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto deviceBufferType = rewriter.getType<IREE::HAL::BufferType>();
auto hostBufferType = rewriter.getType<IREE::Util::BufferType>();
// For now we don't have this information and assume something conservative.
Value minAlignment =
arith::ConstantIndexOp::create(rewriter, allocaOp.getLoc(), 64);
auto allocateOp = IREE::HAL::Inline::BufferAllocateOp::create(
rewriter, allocaOp.getLoc(), deviceBufferType, hostBufferType,
minAlignment, adaptor.getStorageSize());
auto resolvedTimepoint =
arith::ConstantIntOp::create(rewriter, allocaOp.getLoc(), 0, 64)
.getResult();
rewriter.replaceOp(allocaOp, {allocateOp.getResult(), resolvedTimepoint});
return success();
}
};
struct ResourceDeallocaOpPattern
: public OpConversionPattern<IREE::Stream::ResourceDeallocaOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceDeallocaOp deallocaOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): discard op?
auto resolvedTimepoint =
arith::ConstantIntOp::create(rewriter, deallocaOp.getLoc(), 0, 64)
.getResult();
rewriter.replaceOp(deallocaOp, {resolvedTimepoint});
return success();
}
};
struct ResourceRetainOpPattern
: public OpConversionPattern<IREE::Stream::ResourceRetainOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceRetainOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Allocation tracking not supported in the inline HAL.
rewriter.eraseOp(op);
return success();
}
};
struct ResourceReleaseOpPattern
: public OpConversionPattern<IREE::Stream::ResourceReleaseOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceReleaseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Allocation tracking not supported in the inline HAL.
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(op, rewriter.getI1Type(),
0);
return success();
}
};
struct ResourceIsTerminalOpPattern
: public OpConversionPattern<IREE::Stream::ResourceIsTerminalOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceIsTerminalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Allocation tracking not supported in the inline HAL.
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(op, rewriter.getI1Type(),
0);
return success();
}
};
struct ResourceSizeOpPattern
: public OpConversionPattern<IREE::Stream::ResourceSizeOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceSizeOp sizeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(sizeOp, getResourceSize(sizeOp.getLoc(),
adaptor.getOperand(), rewriter));
return success();
}
};
// The constant buffer returned from this is always a !util.buffer.
// We can thus directly pass along the input buffer that's being mapped
// (after taking a subspan for the defined range).
struct ResourceTryMapOpPattern
: public OpConversionPattern<IREE::Stream::ResourceTryMapOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceTryMapOp tryMapOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value subspan = IREE::Util::BufferSubspanOp::create(
rewriter, tryMapOp.getLoc(), adaptor.getSource(),
getResourceSize(tryMapOp.getLoc(), adaptor.getSource(), rewriter),
adaptor.getSourceOffset(), adaptor.getResultSize());
Value didMap =
arith::ConstantIntOp::create(rewriter, tryMapOp.getLoc(), 1, 1);
rewriter.replaceOp(tryMapOp, {didMap, subspan});
return success();
}
};
struct ResourceLoadOpPattern
: public OpConversionPattern<IREE::Stream::ResourceLoadOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = loadOp.getLoc();
auto storage = getResourceStorage(loc, adaptor.getSource(),
adaptor.getSourceSize(), rewriter);
auto loadType =
getTypeConverter()->convertType(loadOp.getResult().getType());
auto elementSize =
rewriter.createOrFold<IREE::Util::SizeOfOp>(loc, loadType);
rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>(
loadOp, loadType, storage.buffer, storage.bufferSize,
adaptor.getSourceOffset(), elementSize);
return success();
}
};
struct ResourceStoreOpPattern
: public OpConversionPattern<IREE::Stream::ResourceStoreOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = storeOp.getLoc();
auto storage = getResourceStorage(loc, adaptor.getTarget(),
adaptor.getTargetSize(), rewriter);
auto elementSize = rewriter.createOrFold<IREE::Util::SizeOfOp>(
loc, adaptor.getValue().getType());
rewriter.replaceOpWithNewOp<IREE::Util::BufferStoreOp>(
storeOp, adaptor.getValue(), storage.buffer, storage.bufferSize,
adaptor.getTargetOffset(), elementSize);
return success();
}
};
struct ResourceSubviewOpPattern
: public OpConversionPattern<IREE::Stream::ResourceSubviewOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::ResourceSubviewOp subviewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<IREE::HAL::BufferType>(adaptor.getSource().getType())) {
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();
// NOTE: this aliases! We assume at this point all useful alias analysis
// has been performed and it's fine to lose the tie information here.
rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferSubspanOp>(
subviewOp, bufferType, adaptor.getSource(), adaptor.getSourceOffset(),
adaptor.getResultSize());
} else {
rewriter.replaceOpWithNewOp<IREE::Util::BufferSubspanOp>(
subviewOp, adaptor.getSource(), adaptor.getSourceSize(),
adaptor.getSourceOffset(), adaptor.getResultSize());
}
return success();
}
};
struct FileConstantOpPattern
: public OpConversionPattern<IREE::Stream::FileConstantOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::FileConstantOp constantOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::Util::BufferSubspanOp>(
constantOp, constantOp.getSource(), constantOp.getSourceSize(),
constantOp.getSourceOffset(), constantOp.getSourceLength());
return success();
}
};
struct FileReadOpPattern
: public OpConversionPattern<IREE::Stream::FileReadOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::FileReadOp readOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value sourceSize = IREE::Util::BufferSizeOp::create(
rewriter, readOp.getLoc(), adaptor.getSource());
IREE::Util::BufferCopyOp::create(
rewriter, readOp.getLoc(), adaptor.getSource(), sourceSize,
rewriter.createOrFold<arith::IndexCastOp>(readOp.getLoc(),
rewriter.getIndexType(),
adaptor.getSourceOffset()),
adaptor.getTarget(), adaptor.getTargetSize(), adaptor.getTargetOffset(),
adaptor.getLength());
auto resolvedTimepoint =
arith::ConstantIntOp::create(rewriter, readOp.getLoc(), 0, 64)
.getResult();
rewriter.replaceOp(readOp, resolvedTimepoint);
return success();
}
};
struct FileWriteOpPattern
: public OpConversionPattern<IREE::Stream::FileWriteOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::FileWriteOp writeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value targetSize = IREE::Util::BufferSizeOp::create(
rewriter, writeOp.getLoc(), adaptor.getTarget());
IREE::Util::BufferCopyOp::create(
rewriter, writeOp.getLoc(), adaptor.getSource(),
adaptor.getSourceSize(), adaptor.getSourceOffset(), adaptor.getTarget(),
targetSize,
rewriter.createOrFold<arith::IndexCastOp>(writeOp.getLoc(),
rewriter.getIndexType(),
adaptor.getTargetOffset()),
adaptor.getLength());
auto resolvedTimepoint =
arith::ConstantIntOp::create(rewriter, writeOp.getLoc(), 0, 64)
.getResult();
rewriter.replaceOp(writeOp, resolvedTimepoint);
return success();
}
};
struct TensorImportBufferOpPattern
: public OpConversionPattern<IREE::Stream::TensorImportOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TensorImportOp importOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isa<IREE::HAL::BufferType>(importOp.getSource().getType())) {
return failure();
}
// Directly use the buffer.
auto buffer = adaptor.getSource();
rewriter.replaceOp(importOp, buffer);
return success();
}
};
struct TensorImportBufferViewOpPattern
: public OpConversionPattern<IREE::Stream::TensorImportOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TensorImportOp importOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto sourceType = importOp.getSource().getType();
if (!isa<IREE::HAL::BufferViewType>(sourceType) &&
!isa<TensorType>(sourceType)) {
return failure();
}
auto bufferView = adaptor.getSource();
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();
rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewBufferOp>(
importOp, bufferType, bufferView);
return success();
}
};
struct TensorExportBufferOpPattern
: public OpConversionPattern<IREE::Stream::TensorExportOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TensorExportOp exportOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isa<IREE::HAL::BufferType>(exportOp.getResult().getType())) {
return failure();
}
rewriter.replaceOp(exportOp, adaptor.getSource());
return success();
}
};
struct TensorExportBufferViewOpPattern
: public OpConversionPattern<IREE::Stream::TensorExportOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TensorExportOp exportOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto targetType = exportOp.getResult().getType();
if (!isa<IREE::HAL::BufferViewType>(targetType) &&
!isa<TensorType>(targetType)) {
return failure();
}
auto loc = exportOp.getLoc();
auto tensorType = cast<RankedTensorType>(adaptor.getSourceEncoding());
auto dynamicDims = adaptor.getSourceEncodingDims();
// NOTE: we should have verified supported encodings/types at entry into the
// HAL pipeline.
auto encodingType = IREE::HAL::EncodingTypeOp::create(
rewriter, loc, tensorType.getEncoding());
auto elementType = IREE::HAL::ElementTypeOp::create(
rewriter, loc, tensorType.getElementType());
// Flatten static + dynamic shape dimensions.
SmallVector<Value> dims;
unsigned dynamicIdx = 0;
for (int64_t idx = 0; idx < tensorType.getRank(); ++idx) {
if (tensorType.isDynamicDim(idx)) {
dims.push_back(dynamicDims[dynamicIdx++]);
} else {
dims.push_back(arith::ConstantIndexOp::create(
rewriter, loc, tensorType.getDimSize(idx)));
}
}
rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewCreateOp>(
exportOp, adaptor.getSource(),
arith::ConstantIndexOp::create(rewriter, loc, 0),
adaptor.getSourceSize(), elementType, encodingType, dims);
return success();
}
};
struct TensorTraceOpPattern
: public OpConversionPattern<IREE::Stream::TensorTraceOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TensorTraceOp traceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto bufferType = rewriter.getType<IREE::HAL::BufferType>();
auto bufferViewType = rewriter.getType<IREE::HAL::BufferViewType>();
auto zero = arith::ConstantIndexOp::create(rewriter, traceOp.getLoc(), 0);
auto resourceEncodingDims = adaptor.getResourceEncodingDims();
SmallVector<Value> bufferViews;
for (auto [resource, resourceSize, resourceEncoding] : llvm::zip_equal(
adaptor.getResources(), adaptor.getResourceSizes(),
adaptor.getResourceEncodings().getAsRange<TypeAttr>())) {
Value resourceBuffer = IREE::HAL::Inline::BufferWrapOp::create(
rewriter, traceOp.getLoc(), bufferType, resource,
/*offset=*/
zero,
/*length=*/resourceSize);
int64_t dynamicDimCount =
cast<ShapedType>(resourceEncoding.getValue()).getNumDynamicDims();
bufferViews.push_back(IREE::Stream::TensorExportOp::create(
rewriter, traceOp.getLoc(), bufferViewType, resourceBuffer,
resourceEncoding, resourceEncodingDims.take_front(dynamicDimCount),
resourceSize,
/*affinity=*/IREE::Stream::AffinityAttr{}));
resourceEncodingDims = resourceEncodingDims.drop_front(dynamicDimCount);
}
rewriter.replaceOpWithNewOp<IREE::HAL::Inline::BufferViewTraceOp>(
traceOp, traceOp.getKeyAttr(), bufferViews);
return success();
}
};
struct CmdFlushOpPattern
: public OpConversionPattern<IREE::Stream::CmdFlushOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdFlushOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
struct CmdInvalidateOpPattern
: public OpConversionPattern<IREE::Stream::CmdInvalidateOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdInvalidateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
struct CmdDiscardOpPattern
: public OpConversionPattern<IREE::Stream::CmdDiscardOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdDiscardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
struct CmdFillOpPattern : public OpConversionPattern<IREE::Stream::CmdFillOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdFillOp fillOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = fillOp.getLoc();
auto storage = getResourceStorage(loc, adaptor.getTarget(),
adaptor.getTargetSize(), rewriter);
rewriter.replaceOpWithNewOp<IREE::Util::BufferFillOp>(
fillOp, adaptor.getValue(), storage.buffer, storage.bufferSize,
adaptor.getTargetOffset(), adaptor.getTargetLength());
return success();
}
};
struct CmdCopyOpPattern : public OpConversionPattern<IREE::Stream::CmdCopyOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdCopyOp copyOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = copyOp.getLoc();
auto sourceStorage = getResourceStorage(loc, adaptor.getSource(),
adaptor.getSourceSize(), rewriter);
auto targetStorage = getResourceStorage(loc, adaptor.getTarget(),
adaptor.getTargetSize(), rewriter);
rewriter.replaceOpWithNewOp<IREE::Util::BufferCopyOp>(
copyOp, sourceStorage.buffer, sourceStorage.bufferSize,
adaptor.getSourceOffset(), targetStorage.buffer,
targetStorage.bufferSize, adaptor.getTargetOffset(),
adaptor.getLength());
return success();
}
};
struct CmdDispatchOpPattern
: public OpConversionPattern<IREE::Stream::CmdDispatchOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdDispatchOp dispatchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = dispatchOp.getLoc();
auto callee = dispatchOp->getAttrOfType<SymbolRefAttr>("hal_inline.target");
if (!callee) {
return rewriter.notifyMatchFailure(
dispatchOp, "missing hal_inline.target annotation from the "
"--iree-hal-inline-executables pass");
}
// The InlineExecutables pass has already done the hard work here; we just
// need to make a function call to the annotated target function with all
// operands/bindings.
SmallVector<Value> callArgs;
llvm::append_range(callArgs, adaptor.getWorkload());
llvm::append_range(callArgs, adaptor.getUniformOperands());
SmallVector<Value> bindingBuffers;
SmallVector<Value> bindingOffsets;
for (auto [resource, resourceSize, resourceOffset] :
llvm::zip_equal(adaptor.getResources(), adaptor.getResourceSizes(),
adaptor.getResourceOffsets())) {
auto storage = getResourceStorage(loc, resource, resourceSize, rewriter);
bindingBuffers.push_back(storage.buffer);
bindingOffsets.push_back(resourceOffset);
}
llvm::append_range(callArgs, bindingBuffers);
llvm::append_range(callArgs, bindingOffsets);
llvm::append_range(callArgs, adaptor.getResourceLengths());
rewriter.replaceOpWithNewOp<IREE::Util::CallOp>(
dispatchOp, TypeRange{}, callee.getLeafReference(), callArgs,
/*tied_operands=*/ArrayAttr{},
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
return success();
}
};
struct CmdFuncOpPattern : public OpConversionPattern<IREE::Stream::CmdFuncOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdFuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> newArgTypes;
SmallVector<Type> newResultTypes;
if (failed(getTypeConverter()->convertTypes(funcOp.getArgumentTypes(),
newArgTypes)) ||
failed(getTypeConverter()->convertTypes(funcOp.getResultTypes(),
newResultTypes))) {
return rewriter.notifyMatchFailure(funcOp, "failed to convert types");
}
auto newOp = rewriter.replaceOpWithNewOp<IREE::Util::FuncOp>(
funcOp, funcOp.getName(),
rewriter.getFunctionType(newArgTypes, newResultTypes),
/*tied_operands=*/ArrayAttr{}, funcOp.getSymVisibilityAttr(),
funcOp.getAllArgAttrs(), funcOp.getAllResultAttrs(),
IREE::Util::InliningPolicyAttrInterface{});
newOp->setDialectAttrs(funcOp->getDialectAttrs());
return success();
}
};
struct CmdCallOpPattern : public OpConversionPattern<IREE::Stream::CmdCallOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdCallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> operands;
size_t resourceIndex = 0;
for (auto [originalOperand, convertedOperand] : llvm::zip_equal(
callOp.getResourceOperands(), adaptor.getResourceOperands())) {
if (isa<IREE::Stream::ResourceType>(originalOperand.getType())) {
// Resource type, add offset/length.
auto resourceSize = adaptor.getResourceOperandSizes()[resourceIndex];
auto storage = getResourceStorage(callOp.getLoc(), convertedOperand,
resourceSize, rewriter);
operands.push_back(storage.buffer);
operands.push_back(adaptor.getResourceOperandOffsets()[resourceIndex]);
operands.push_back(adaptor.getResourceOperandLengths()[resourceIndex]);
++resourceIndex;
} else {
// Primitive/custom type.
operands.push_back(convertedOperand);
}
}
SmallVector<Type> resultTypes;
for (auto result : callOp.getResults()) {
SmallVector<Type> convertedTypes;
if (failed(getTypeConverter()->convertType(result.getType(),
convertedTypes))) {
return rewriter.notifyMatchFailure(callOp.getLoc(),
"unconvertable result type");
}
llvm::append_range(resultTypes, convertedTypes);
}
rewriter.replaceOpWithNewOp<IREE::Util::CallOp>(
callOp, resultTypes, callOp.getCallee(), operands,
/*tied_operands=*/ArrayAttr{},
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
return success();
}
};
struct CmdExecuteOpPattern
: public OpConversionPattern<IREE::Stream::CmdExecuteOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdExecuteOp executeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Inline the serial execution region.
rewriter.inlineBlockBefore(&executeOp.getBody().front(), executeOp,
adaptor.getResourceOperands());
// Immediately resolve the timepoint.
auto resolvedTimepoint =
arith::ConstantIntOp::create(rewriter, executeOp.getLoc(), 0, 64)
.getResult();
rewriter.replaceOp(executeOp, resolvedTimepoint);
return success();
}
};
struct CmdSerialOpPattern
: public OpConversionPattern<IREE::Stream::CmdSerialOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdSerialOp serialOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Inline the serial execution region.
rewriter.inlineBlockBefore(&serialOp.getBody().front(), serialOp);
rewriter.eraseOp(serialOp);
return success();
}
};
struct CmdConcurrentOpPattern
: public OpConversionPattern<IREE::Stream::CmdConcurrentOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::CmdConcurrentOp concurrentOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Inline the concurrent execution region.
rewriter.inlineBlockBefore(&concurrentOp.getBody().front(), concurrentOp);
rewriter.eraseOp(concurrentOp);
return success();
}
};
// Annoying we have to have this here, but there's no attribute converter
// equivalent we have access to so that we could do it in a generic way.
struct GlobalTimepointConversionPattern
: public OpConversionPattern<IREE::Util::GlobalOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Util::GlobalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto initialValue = op.getInitialValue();
if (!initialValue.has_value())
return failure();
if (!isa<IREE::Stream::TimepointAttr>(*initialValue))
return failure();
rewriter.modifyOpInPlace(
op, [&]() { op.setInitialValueAttr(rewriter.getI64IntegerAttr(0)); });
return success();
}
};
struct TimepointImmediateOpPattern
: public OpConversionPattern<IREE::Stream::TimepointImmediateOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TimepointImmediateOp immediateOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(immediateOp, 0, 64);
return success();
}
};
struct TimepointImportOpPattern
: public OpConversionPattern<IREE::Stream::TimepointImportOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TimepointImportOp importOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(
importOp,
"timepoints are not supported across the ABI with inline execution");
}
};
struct TimepointExportOpPattern
: public OpConversionPattern<IREE::Stream::TimepointExportOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TimepointExportOp exportOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(
exportOp,
"timepoints are not supported across the ABI with inline execution");
}
};
struct TimepointChainExternalOpPattern
: public OpConversionPattern<IREE::Stream::TimepointChainExternalOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TimepointChainExternalOp exportOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(
exportOp,
"timepoints are not supported across the ABI with inline execution");
}
};
struct TimepointJoinOpPattern
: public OpConversionPattern<IREE::Stream::TimepointJoinOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TimepointJoinOp joinOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(joinOp, 0, 64);
return success();
}
};
struct TimepointBarrierOpPattern
: public OpConversionPattern<IREE::Stream::TimepointBarrierOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TimepointBarrierOp barrierOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(barrierOp, {
adaptor.getResource(),
arith::ConstantIntOp::create(
rewriter, barrierOp.getLoc(), 0, 64),
});
return success();
}
};
struct TimepointAwaitOpPattern
: public OpConversionPattern<IREE::Stream::TimepointAwaitOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::TimepointAwaitOp awaitOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(awaitOp, adaptor.getResourceOperands());
return success();
}
};
struct ElideYieldOpPattern : public OpConversionPattern<IREE::Stream::YieldOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Stream::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(yieldOp);
return success();
}
};
} // namespace
void populateStreamToHALInlinePatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
// Resources are just buffers (no shape/encoding/etc).
// We use !hal.buffer when going across the external ABI boundary but
// otherwise use our host buffer type.
typeConverter.addConversion(
[=](IREE::Stream::ResourceType type, SmallVectorImpl<Type> &results) {
if (type.getLifetime() == IREE::Stream::Lifetime::External) {
results.push_back(IREE::HAL::BufferType::get(context));
} else {
results.push_back(IREE::Util::BufferType::get(context));
}
return success();
});
// Today files all originate from host buffers and we just treat them the
// same. Note that file initialization from buffers may require subviews.
typeConverter.addConversion(
[=](IREE::Stream::FileType type, SmallVectorImpl<Type> &results) {
results.push_back(IREE::Util::BufferType::get(context));
return success();
});
// Timepoints and files are both no-oped in the inline HAL.
typeConverter.addConversion(
[=](IREE::Stream::TimepointType type, SmallVectorImpl<Type> &results) {
results.push_back(IntegerType::get(context, 64));
return success();
});
patterns.insert<ResourceAllocOpPattern, ResourceAllocaOpPattern,
ResourceDeallocaOpPattern, ResourceRetainOpPattern,
ResourceReleaseOpPattern, ResourceIsTerminalOpPattern,
ResourceSizeOpPattern, ResourceTryMapOpPattern,
ResourceLoadOpPattern, ResourceStoreOpPattern,
ResourceSubviewOpPattern>(typeConverter, context);
patterns.insert<FileConstantOpPattern, FileReadOpPattern, FileWriteOpPattern>(
typeConverter, context);
patterns.insert<TensorImportBufferOpPattern, TensorImportBufferViewOpPattern,
TensorExportBufferOpPattern, TensorExportBufferViewOpPattern,
TensorTraceOpPattern>(typeConverter, context);
patterns
.insert<CmdFlushOpPattern, CmdInvalidateOpPattern, CmdDiscardOpPattern,
CmdFillOpPattern, CmdCopyOpPattern, CmdDispatchOpPattern,
CmdFuncOpPattern, CmdCallOpPattern, CmdExecuteOpPattern,
CmdSerialOpPattern, CmdConcurrentOpPattern>(typeConverter,
context);
patterns.insert<GlobalTimepointConversionPattern>(typeConverter, context);
patterns.insert<TimepointImmediateOpPattern, TimepointImportOpPattern,
TimepointExportOpPattern, TimepointChainExternalOpPattern,
TimepointJoinOpPattern, TimepointBarrierOpPattern,
TimepointAwaitOpPattern>(typeConverter, context);
patterns.insert<ElideYieldOpPattern>(typeConverter, context);
}
} // namespace mlir::iree_compiler