blob: 01c7a6e7f9da27fb47551f0bafdcc8f2bb36f689 [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "compiler/plugins/input/Torch/InputConversion/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
namespace Torch = mlir::torch::Torch;
namespace TorchConversion = mlir::torch::TorchConversion;
namespace mlir::iree_compiler::TorchInput {
#define GEN_PASS_DEF_FUNCCONVERSIONPASS
#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc"
namespace {
//===----------------------------------------------------------------------===//
// Overall Approach
// ----------------
// This pass converts from the "torch" programming model to the "iree"
// programming model by rewriting all functions and calls to operate on native
// IREE types. In the process, synchronization is added as appropriate for any
// mutable and immutable torch-level arguments.
//
// Currently, the result of this pass is that every torch-func is augmented
// to be implemented in terms of IREE's "coarse fences" ABI. In this ABI,
// there is a (wait, signal) fence pair added to the end of every function.
// Since torch functions are single-exit, practically, this involves:
// * Adding preambles to convert function arguments to `tensor` (and native
// torch types via `torch_c`), adding coarse synchronization on import
// (presently for all buffer arguments but in the future could only be
// those which are not tied to fine grained fences).
// * Adding a postamble with a synchronization barrier on any produced
// or mutated tensors and appropriate exports/in-place tieing to buffers.
// * Generation of a synchronous wrapper function with the original name
// (the async function is named with an `$async` suffix) which internally
// sets up/waits on fences while delegating to the async function.
//
// Immutable tensor types
// ----------------------
//
// Immutable tensor types are mapped to a buffer_view and subject to
// `hal.tensor.import` on use. On return, they will be placed behind a
// synchronization barrier and exported.
//
// Mutable types
// -------------
// Here we rely on the characteristic that at the torch level, conversion to
// and from the value domain is only legal at certain well defined points in
// the program (currently at graph edges but potentially in the future at
// various control flow points). These conversions are modeled by:
// * `torch.copy.to_vtensor`: Copy from a mutable tensor (torch.tensor) to
// an immutable value (torch.vtensor).
// * `torch.copy.to_tensor`: Allocates a new mutable tensor and initializes it
// with the value of the given immutable tensor. Presently un-used.
// * `torch.overwrite.tensor.contents`: Updates the contents of a mutable
// tensor from a given immutable tensor.
//
// Note that when importing from Torch, these ops cannot just be added at will,
// and they are only created as a result of structural conversions. Therefore,
// we can rely on these invariants and assume that usage outside of this is an
// invalid program.
//===----------------------------------------------------------------------===//
std::optional<std::pair<Value, Value>>
getEnclosingWaitSignalFences(Operation *op) {
auto parentFuncOp = dyn_cast<IREE::Util::FuncOp>(op);
if (!parentFuncOp) {
parentFuncOp = parentFuncOp->getParentOfType<IREE::Util::FuncOp>();
if (!parentFuncOp)
return {};
}
Block *entryBlock = &parentFuncOp.front();
auto numArguments = entryBlock->getNumArguments();
Value coarseWaitFence = entryBlock->getArgument(numArguments - 2);
Value coarseSignalFence = entryBlock->getArgument(numArguments - 1);
return std::make_pair(coarseWaitFence, coarseSignalFence);
}
std::optional<std::pair<Value, Value>>
getEnclosingWaitSignalFences(Value value) {
return getEnclosingWaitSignalFences(value.getParentRegion()->getParentOp());
}
Value convertToBuiltinTensor(OpBuilder &builder, Value possibleTorchTensor) {
Type ty = possibleTorchTensor.getType();
if (isa<TensorType>(ty))
return possibleTorchTensor;
if (auto defining = dyn_cast_or_null<TorchConversion::FromBuiltinTensorOp>(
possibleTorchTensor.getDefiningOp())) {
return defining.getOperand();
}
Torch::ValueTensorType vtensorType = cast<Torch::ValueTensorType>(ty);
TensorType builtinTy = vtensorType.toBuiltinTensor();
if (auto intTy = dyn_cast<IntegerType>(builtinTy.getElementType())) {
builtinTy =
builtinTy.clone(builder.getIntegerType(intTy.getIntOrFloatBitWidth()));
}
return builder.create<TorchConversion::ToBuiltinTensorOp>(
possibleTorchTensor.getLoc(), builtinTy, possibleTorchTensor);
}
enum class TypeDisposition {
IMMUTABLE_TENSOR,
MUTABLE_TENSOR,
TORCH_PRIMITIVE,
PASSTHROUGH,
FENCE,
};
struct BarrierResult {
BlockArgument storage;
Type torchType;
int returnIndex = -1;
};
struct ConvertedAsyncFunctionInfo {
IREE::Util::FuncOp funcOp;
SmallVector<IREE::Util::ReturnOp> returnOps;
SmallVector<DictionaryAttr> torchArgAttrs;
SmallVector<DictionaryAttr> torchResultAttrs;
SmallVector<Type> torchInputTypes;
SmallVector<Type> torchResultTypes;
SmallVector<TypeDisposition> inputDispositions;
SmallVector<TypeDisposition> resultDispositions;
// Post processing state.
// Values that must be captured in the coarse barrier.
SmallVector<Value> barrierInputs;
// Meta data per barrier input: storage, torchType, returnIndex (or -1)
SmallVector<BarrierResult> barrierResultMeta;
LogicalResult postProcess();
LogicalResult convertImmutableTensorArg(BlockArgument argValue,
Type torchType, OpBuilder &builder);
LogicalResult convertMutableTensorArg(BlockArgument argValue, Type torchType,
OpBuilder &builder);
void addBarrierInput(Value inputTensor, BlockArgument storage, Type torchType,
int returnIndex) {
barrierInputs.push_back(inputTensor);
barrierResultMeta.emplace_back(BarrierResult{
storage,
torchType,
returnIndex,
});
}
Attribute getTorchArgAttr(BlockArgument argValue, StringRef attrName) {
return torchArgAttrs.empty()
? Attribute{}
: torchArgAttrs[argValue.getArgNumber()].get(attrName);
}
Attribute getTorchResultAttr(int returnIndex, StringRef attrName) {
return torchResultAttrs.empty()
? Attribute{}
: torchResultAttrs[returnIndex].get(attrName);
}
};
LogicalResult ConvertedAsyncFunctionInfo::postProcess() {
if (funcOp.isExternal())
return success();
if (returnOps.size() != 1) {
// Multi-exit/CFG could be supported but requires more complicated dominance
// analysis with respect to where the exit happens relative to mutated
// buffers.
return emitError(funcOp.getLoc())
<< "currently only single exit torch funcs are supported";
}
Block *entryBlock = &funcOp.getBlocks().front();
// Materialize argument conversions.
OpBuilder preambleBuilder = OpBuilder::atBlockBegin(entryBlock);
auto entryArgs = entryBlock->getArguments();
for (auto [disp, argValue, torchType] :
llvm::zip_equal(inputDispositions, entryArgs, torchInputTypes)) {
switch (disp) {
case TypeDisposition::IMMUTABLE_TENSOR: {
if (failed(
convertImmutableTensorArg(argValue, torchType, preambleBuilder)))
return failure();
break;
}
case TypeDisposition::MUTABLE_TENSOR: {
if (failed(convertMutableTensorArg(argValue, torchType, preambleBuilder)))
return failure();
break;
}
case TypeDisposition::TORCH_PRIMITIVE: {
Location loc = argValue.getLoc();
Operation *convertUser = nullptr;
Value convertResult;
if (isa<Torch::BoolType>(torchType)) {
convertUser =
preambleBuilder.create<TorchConversion::FromI1Op>(loc, argValue);
convertResult = convertUser->getResult(0);
} else if (isa<Torch::FloatType>(torchType)) {
convertUser =
preambleBuilder.create<TorchConversion::FromF64Op>(loc, argValue);
convertResult = convertUser->getResult(0);
} else if (isa<Torch::IntType>(torchType)) {
convertUser =
preambleBuilder.create<TorchConversion::FromI64Op>(loc, argValue);
convertResult = convertUser->getResult(0);
} else {
emitError(loc) << "unhandled torch primitive materialization: "
<< torchType;
return failure();
}
argValue.replaceAllUsesExcept(convertResult, convertUser);
break;
}
case TypeDisposition::PASSTHROUGH:
// Do nothing.
break;
case TypeDisposition::FENCE:
// Do nothing.
break;
}
}
// Materialize synchronization postamble and conversions.
IREE::Util::ReturnOp returnOp = returnOps.front();
SmallVector<Value> newReturnOperands;
OpBuilder postambleBuilder(returnOp);
for (auto [disp, returnValue, torchType] : llvm::zip_equal(
resultDispositions, returnOp.getOperands(), torchResultTypes)) {
size_t returnIndex = newReturnOperands.size();
newReturnOperands.emplace_back(returnValue);
switch (disp) {
case TypeDisposition::IMMUTABLE_TENSOR: {
bool needsBarrier = true;
if (auto blockArg = dyn_cast<BlockArgument>(returnValue)) {
// Trivial return of input. Just pass it through.
needsBarrier = blockArg.getOwner() != entryBlock;
}
if (needsBarrier) {
Value source = convertToBuiltinTensor(postambleBuilder, returnValue);
addBarrierInput(source, /*storage=*/BlockArgument{}, torchType,
returnIndex);
}
break;
}
case TypeDisposition::TORCH_PRIMITIVE: {
Location loc = returnValue.getLoc();
if (isa<Torch::BoolType>(torchType)) {
newReturnOperands.back() =
postambleBuilder.create<TorchConversion::ToI1Op>(loc, returnValue);
} else if (isa<Torch::FloatType>(torchType)) {
newReturnOperands.back() =
postambleBuilder.create<TorchConversion::ToF64Op>(loc, returnValue);
} else if (isa<Torch::IntType>(torchType)) {
newReturnOperands.back() =
postambleBuilder.create<TorchConversion::ToI64Op>(loc, returnValue);
} else if (isa<Torch::GeneratorType>(torchType)) {
newReturnOperands.back() =
postambleBuilder.create<TorchConversion::GeneratorToI64Op>(
loc, returnValue);
} else {
emitError(loc) << "unhandled torch primitive materialization: "
<< torchType;
return failure();
}
break;
}
default: {
// Non-tensor/converting. Just preserve.
}
}
}
// Emit the barrier and exports.
// If any of the exports are in-place we need to alias their storage to the
// provided buffers.
Value coarseSignalFence =
entryBlock->getArgument(entryBlock->getNumArguments() - 1);
if (barrierInputs.empty()) {
postambleBuilder.create<IREE::HAL::FenceSignalOp>(funcOp.getLoc(),
coarseSignalFence);
} else {
SmallVector<Value> aliasedResults;
for (auto [barrierInput, meta] :
llvm::zip_equal(barrierInputs, barrierResultMeta)) {
if (meta.storage) {
// Use the wait fence indicating when the storage is available for
// mutation. We need to ensure that no writes are made to the storage
// until it indicates it's safe to do so.
auto storageAffinityAttr =
getTorchArgAttr(meta.storage, "iree.abi.affinity");
auto waitSignalFences = getEnclosingWaitSignalFences(meta.storage);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
auto barrierInputDims = IREE::Util::buildDynamicDimsForValue(
barrierInput.getLoc(), barrierInput, postambleBuilder);
aliasedResults.push_back(
postambleBuilder.create<IREE::HAL::TensorAliasOp>(
barrierInput.getLoc(), barrierInput.getType(), barrierInput,
barrierInputDims, meta.storage, waitFence,
storageAffinityAttr));
} else {
aliasedResults.push_back(barrierInput);
}
}
auto barrierOp = postambleBuilder.create<IREE::HAL::TensorBarrierOp>(
funcOp.getLoc(), aliasedResults, coarseSignalFence);
for (auto [barrierResult, meta] :
llvm::zip_equal(barrierOp.getResults(), barrierResultMeta)) {
Attribute exportAffinityAttr;
if (meta.storage) {
exportAffinityAttr = getTorchArgAttr(meta.storage, "iree.abi.affinity");
} else if (meta.returnIndex >= 0) {
exportAffinityAttr =
getTorchResultAttr(meta.returnIndex, "iree.abi.affinity");
}
Value exportedValue = postambleBuilder.create<IREE::HAL::TensorExportOp>(
funcOp.getLoc(),
postambleBuilder.getType<IREE::HAL::BufferViewType>(), barrierResult,
TypeAttr::get(barrierResult.getType()), /*name=*/nullptr,
exportAffinityAttr);
if (meta.returnIndex >= 0) {
newReturnOperands[meta.returnIndex] = exportedValue;
}
}
}
// New return operands are all collected.
returnOp->setOperands(newReturnOperands);
return success();
}
class OriginalUses {
public:
OriginalUses(Value value) {
for (auto &use : value.getUses()) {
originalUses.push_back(&use);
}
}
void assign(Value newValue) {
for (OpOperand *originalUse : originalUses) {
originalUse->assign(newValue);
}
}
private:
SmallVector<OpOperand *> originalUses;
};
LogicalResult ConvertedAsyncFunctionInfo::convertImmutableTensorArg(
BlockArgument argValue, Type torchType, OpBuilder &builder) {
Location loc = argValue.getLoc();
// If the arg is just directly returned, then don't do anything special with
// it.
bool hasNonTrivialUse = false;
for (auto *userOp : argValue.getUsers()) {
if (isa<IREE::Util::ReturnOp>(userOp))
continue;
hasNonTrivialUse = true;
}
if (!hasNonTrivialUse)
return success();
// Remember original uses so we can redirect them.
OriginalUses originalUses(argValue);
// The type can either be a builtin TensorType or a Torch::ValueTensorType.
// OpBuilder
TensorType builtinTensorType;
if (auto tType = dyn_cast<TensorType>(torchType)) {
builtinTensorType = tType;
} else if (auto vtType = dyn_cast<Torch::ValueTensorType>(torchType)) {
builtinTensorType = vtType.toBuiltinTensor();
if (auto intTy =
dyn_cast<IntegerType>(builtinTensorType.getElementType())) {
builtinTensorType = builtinTensorType.clone(
builder.getIntegerType(intTy.getIntOrFloatBitWidth()));
}
} else {
return emitError(loc) << "unsupported immutable tensor argument: "
<< torchType;
}
// Propagate explicit affinities to the read.
auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");
auto waitSignalFences = getEnclosingWaitSignalFences(argValue);
assert(waitSignalFences && "async function missing fences");
Value waitFence = waitSignalFences->first;
Value importedTensor = builder.create<IREE::HAL::TensorImportOp>(
loc, builtinTensorType, argValue, TypeAttr::get(builtinTensorType),
waitFence,
/*name=*/nullptr, affinityAttr);
if (builtinTensorType != torchType) {
importedTensor = builder.create<TorchConversion::FromBuiltinTensorOp>(
loc, torchType, importedTensor);
}
originalUses.assign(importedTensor);
return success();
}
LogicalResult ConvertedAsyncFunctionInfo::convertMutableTensorArg(
BlockArgument argValue, Type torchType, OpBuilder &builder) {
Location loc = argValue.getLoc();
auto fences = getEnclosingWaitSignalFences(argValue);
assert(fences && "could not find async fences on func");
TensorType builtinTensorType;
if (auto t = dyn_cast<TensorType>(torchType)) {
builtinTensorType = t;
} else {
builtinTensorType = cast<Torch::NonValueTensorType>(torchType)
.getWithValueSemantics()
.toBuiltinTensor();
}
// Propagate explicit affinities to the read and write.
auto affinityAttr = getTorchArgAttr(argValue, "iree.abi.affinity");
// There are only a small set of possible users of a mutable tensor.
// Handle them by operation here.
SmallVector<Operation *> users(argValue.getUsers());
for (auto *userOp : users) {
IRRewriter rewriter(loc.getContext());
rewriter.setInsertionPoint(userOp);
if (auto copyToVtOp = dyn_cast<Torch::CopyToValueTensorOp>(userOp)) {
Value imported = rewriter.create<IREE::HAL::TensorImportOp>(
loc, builtinTensorType, argValue,
/*target_encoding=*/TypeAttr::get(builtinTensorType),
/*wait_fence*/ fences->first,
/*name=*/nullptr, affinityAttr);
rewriter.replaceOpWithNewOp<TorchConversion::FromBuiltinTensorOp>(
userOp, copyToVtOp.getResult().getType(), imported);
} else if (auto overwriteOp =
dyn_cast<Torch::OverwriteTensorContentsOp>(userOp)) {
Value overwriteValue =
convertToBuiltinTensor(rewriter, overwriteOp.getValue());
addBarrierInput(overwriteValue, /*storage=*/argValue, torchType,
/*returnIndex=*/-1);
rewriter.eraseOp(overwriteOp);
} else {
return emitError(userOp->getLoc())
<< "unsupported operation on coarse signaling mutable tensor: "
<< *userOp;
}
}
return success();
}
void retainFunctionAttributes(Operation *srcOp, IREE::Util::FuncOp destOp) {
// Allowlist of function attributes to retain when importing funcs.
constexpr const char *kRetainedAttributes[] = {
"iree.reflection",
};
auto retainedAttributes = ArrayRef<const char *>(
kRetainedAttributes,
sizeof(kRetainedAttributes) / sizeof(kRetainedAttributes[0]));
for (auto retainAttrName : retainedAttributes) {
StringRef attrName(retainAttrName);
Attribute attr = srcOp->getAttr(attrName);
if (attr)
destOp->setAttr(attrName, attr);
}
}
void createCoarseFencesSyncWrapper(StringRef syncFunctionName,
IREE::Util::FuncOp asyncFuncOp,
IRRewriter &rewriter) {
Location loc = asyncFuncOp.getLoc();
// The coarse fences wrapper has the same signature as the async variant
// but with the last two inputs (wait, signal fence) sliced off.
FunctionType asyncFuncType = asyncFuncOp.getFunctionType();
SmallVector<Type> inputTypes(asyncFuncType.getInputs().begin(),
asyncFuncType.getInputs().end() - 2);
// Create the function.
auto syncFuncType = rewriter.getType<mlir::FunctionType>(
inputTypes, asyncFuncType.getResults());
auto syncFuncOp =
rewriter.create<IREE::Util::FuncOp>(loc, syncFunctionName, syncFuncType,
/*tiedOperandsAttr=*/nullptr);
syncFuncOp.setSymVisibilityAttr(asyncFuncOp.getSymVisibilityAttr());
retainFunctionAttributes(asyncFuncOp, syncFuncOp);
syncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
if (auto affinityAttr = asyncFuncOp->getAttr("iree.abi.affinity")) {
syncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
}
Block *entryBlock = syncFuncOp.addEntryBlock();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(entryBlock);
// HACK: this is relying on the fact that there's only one HAL device.
// We should instead have a way of creating fences on the device that
// is used to produce the tensors we're wrapping.
//
// TODO(multi-device): emit get with derived ordinal or lookup with attr. We
// could always say device 0 for now but could instead look for an
// iree.abi.affinity/iree.abi.device/etc.
Value timeoutMillis = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
Value device = IREE::HAL::DeviceType::resolveAny(loc, rewriter);
Value waitFence = rewriter.create<IREE::Util::NullOp>(
loc, rewriter.getType<IREE::HAL::FenceType>());
Value signalFence = rewriter.create<IREE::HAL::FenceCreateOp>(
loc, rewriter.getType<IREE::HAL::FenceType>(), device,
IREE::HAL::FenceFlagBitfield::None);
SmallVector<Value> callOperands(entryBlock->getArguments());
callOperands.push_back(waitFence);
callOperands.push_back(signalFence);
std::optional<ArrayAttr> targetTiedOperands = asyncFuncOp.getTiedOperands();
auto callResults =
rewriter
.create<IREE::Util::CallOp>(loc, asyncFuncOp, callOperands,
targetTiedOperands ? *targetTiedOperands
: ArrayAttr{})
.getResults();
// Wait forever for signal.
rewriter.create<IREE::HAL::FenceAwaitOp>(loc, rewriter.getI32Type(),
timeoutMillis, signalFence);
rewriter.create<IREE::Util::ReturnOp>(loc, callResults);
}
} // namespace
class FuncConversionPass final
: public impl::FuncConversionPassBase<FuncConversionPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<IREE::HAL::HALDialect>();
registry.insert<IREE::Util::UtilDialect>();
registry.insert<TorchConversion::TorchConversionDialect>();
}
void runOnOperation() override {
auto moduleOp = getOperation();
// Convert all functions in the module to IREE funcs. In this stage,
// we convert contained return ops and argument/result types, but we have
// not yet converted anything "on the inside". Therefore, it is pretty
// likely the functions are still illegal.
SmallVector<Operation *> eraseFuncOps;
std::vector<ConvertedAsyncFunctionInfo> convertedFuncInfos;
for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
if (!shouldConvertFunc(funcOp))
continue;
ConvertedAsyncFunctionInfo &convertedFuncInfo =
convertedFuncInfos.emplace_back();
if (failed(convertFuncOp(funcOp, convertedFuncInfo))) {
signalPassFailure();
return;
}
eraseFuncOps.push_back(funcOp);
}
for (auto op : eraseFuncOps) {
op->erase();
}
// Now post-process async functions.
for (auto &info : convertedFuncInfos) {
if (failed(info.postProcess())) {
signalPassFailure();
return;
}
}
}
bool shouldConvertFunc(func::FuncOp torchFunc) {
// For now, we don't touch externals and assume they are in the proper
// calling convention. In the future, we may support "torch externals"
// which we convert to mate up with a torch module. We can remove/adapt
// this when that is elaborated.
if (torchFunc.isExternal())
return false;
// Something has already converted this and told us not to touch it.
if (torchFunc->hasAttr("iree.abi.stub"))
return false;
return true;
}
LogicalResult convertFuncOp(func::FuncOp torchFunc,
ConvertedAsyncFunctionInfo &convertedFuncInfo) {
IRRewriter rewriter(torchFunc.getContext());
rewriter.setInsertionPoint(torchFunc);
Location loc = torchFunc.getLoc();
// Determine whether to build pure async or async + sync wrapper.
bool generateSyncWrapper = true;
StringRef originalName = torchFunc.getName();
std::string asyncFunctionName = originalName.str();
if (generateSyncWrapper) {
asyncFunctionName.append("$async");
}
// Stash arg/result attrs so they can be referenced during conversion.
torchFunc.getAllArgAttrs(convertedFuncInfo.torchArgAttrs);
torchFunc.getAllResultAttrs(convertedFuncInfo.torchResultAttrs);
// Convert function signature.
Type fenceType = rewriter.getType<IREE::HAL::FenceType>();
FunctionType torchFuncType = torchFunc.getFunctionType();
convertedFuncInfo.torchInputTypes.append(torchFuncType.getInputs().begin(),
torchFuncType.getInputs().end());
convertedFuncInfo.torchResultTypes.append(
torchFuncType.getResults().begin(), torchFuncType.getResults().end());
// For the coarse-fences ABI, we add two fences to the end. Treat these as
// original types so that the lists line up.
convertedFuncInfo.torchInputTypes.push_back(fenceType);
convertedFuncInfo.torchInputTypes.push_back(fenceType);
SmallVector<Type> ireeInputTypes(convertedFuncInfo.torchInputTypes);
SmallVector<Type> ireeResultTypes(convertedFuncInfo.torchResultTypes);
convertedFuncInfo.inputDispositions.resize(ireeInputTypes.size());
convertedFuncInfo.resultDispositions.resize(ireeResultTypes.size());
for (size_t i = 0; i < convertedFuncInfo.torchInputTypes.size(); ++i) {
if (failed(convertType(loc, convertedFuncInfo.torchInputTypes[i],
ireeInputTypes[i],
convertedFuncInfo.inputDispositions[i])))
return failure();
}
for (size_t i = 0; i < convertedFuncInfo.torchResultTypes.size(); ++i) {
if (failed(convertType(loc, convertedFuncInfo.torchResultTypes[i],
ireeResultTypes[i],
convertedFuncInfo.resultDispositions[i])))
return failure();
}
// Build tied operands index mapping results back to operands.
SmallVector<int64_t> tiedOperands;
bool anyTiedOperands = false;
for (unsigned i = 0; i < torchFuncType.getNumResults(); ++i) {
auto tiedAttr =
torchFunc.getResultAttrOfType<IntegerAttr>(i, "iree.abi.tied");
if (tiedAttr) {
tiedOperands.push_back(tiedAttr.getInt());
anyTiedOperands = true;
} else {
tiedOperands.push_back(-1);
}
}
auto tiedOperandsAttr = anyTiedOperands
? rewriter.getIndexArrayAttr(tiedOperands)
: ArrayAttr{};
// Create new func.
FunctionType asyncFuncType =
FunctionType::get(loc.getContext(), ireeInputTypes, ireeResultTypes);
auto asyncFuncOp = rewriter.create<IREE::Util::FuncOp>(
torchFunc.getLoc(), asyncFunctionName, asyncFuncType, tiedOperandsAttr);
convertedFuncInfo.funcOp = asyncFuncOp;
asyncFuncOp.setSymVisibilityAttr(torchFunc.getSymVisibilityAttr());
// Handle defacto attrs to specialized ones.
asyncFuncOp.setInliningPolicyAttr(
rewriter.getAttr<IREE::Util::InlineNeverAttr>());
retainFunctionAttributes(torchFunc, asyncFuncOp);
asyncFuncOp->setAttr("iree.abi.stub", rewriter.getUnitAttr());
asyncFuncOp->setAttr("iree.abi.model",
rewriter.getStringAttr("coarse-fences"));
if (auto affinityAttr = torchFunc->getAttr("iree.abi.affinity")) {
asyncFuncOp->setAttr("iree.abi.affinity", affinityAttr);
}
rewriter.inlineRegionBefore(
torchFunc.getBody(), asyncFuncOp.getFunctionBody(), asyncFuncOp.end());
// Convert block arguments.
Block *entryBlock = &asyncFuncOp.getBlocks().front();
for (size_t i = 0; i < ireeInputTypes.size(); ++i) {
// Add if we have extended the list.
if (i >= entryBlock->getNumArguments()) {
entryBlock->addArgument(ireeInputTypes[i], loc);
continue;
}
// Convert.
entryBlock->getArgument(i).setType(ireeInputTypes[i]);
}
// Replace return ops.
asyncFuncOp->walk([&](func::ReturnOp returnOp) {
rewriter.setInsertionPoint(returnOp);
auto ireeReturnOp = rewriter.replaceOpWithNewOp<IREE::Util::ReturnOp>(
returnOp, returnOp.getOperands());
convertedFuncInfo.returnOps.push_back(ireeReturnOp);
});
// Create the sync variant.
rewriter.setInsertionPoint(torchFunc);
createCoarseFencesSyncWrapper(originalName, asyncFuncOp, rewriter);
return success();
}
LogicalResult convertType(Location loc, Type torchType, Type &ireeType,
TypeDisposition &disp) {
if (isa<TensorType, Torch::ValueTensorType>(torchType)) {
ireeType = IREE::HAL::BufferViewType::get(torchType.getContext());
disp = TypeDisposition::IMMUTABLE_TENSOR;
return success();
}
if (isa<Torch::NonValueTensorType>(torchType)) {
ireeType = IREE::HAL::BufferViewType::get(torchType.getContext());
disp = TypeDisposition::MUTABLE_TENSOR;
return success();
}
if (isa<IREE::HAL::FenceType>(torchType)) {
ireeType = torchType;
disp = TypeDisposition::FENCE;
return success();
}
if (isa<Torch::BoolType>(torchType)) {
ireeType = IntegerType::get(torchType.getContext(), 1);
disp = TypeDisposition::TORCH_PRIMITIVE;
return success();
}
if (isa<Torch::IntType, Torch::GeneratorType>(torchType)) {
ireeType = IntegerType::get(torchType.getContext(), 64);
disp = TypeDisposition::TORCH_PRIMITIVE;
return success();
}
if (isa<Torch::FloatType>(torchType)) {
ireeType = Float64Type::get(torchType.getContext());
disp = TypeDisposition::TORCH_PRIMITIVE;
return success();
}
if (isa<IntegerType, FloatType, IndexType>(torchType)) {
ireeType = torchType;
disp = TypeDisposition::PASSTHROUGH;
return success();
}
return emitError(loc) << "unhandled torch type: " << torchType;
}
};
} // namespace mlir::iree_compiler::TorchInput