blob: ce174bc1b052823b8cbe1aedb1d27c4b26add1dd [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace ABI {
// Wraps all entry points in a function that is compatible with the
// expected invocation semantics of bindings following the native IREE ABI.
class WrapEntryPointsPass
: public PassWrapper<WrapEntryPointsPass, OperationPass<ModuleOp>> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, mlir::arith::ArithmeticDialect,
mlir::tensor::TensorDialect, IREE::HAL::HALDialect>();
}
StringRef getArgument() const override {
return "iree-abi-wrap-entry-points";
}
StringRef getDescription() const override {
return "Wraps all entry points in a function that is compatible with the "
"expected invocation semantics of bindings following the native "
"IREE ABI.";
}
void runOnOperation() override {
auto moduleOp = getOperation();
SmallVector<FuncOp, 4> entryFuncOps;
for (auto funcOp : moduleOp.getOps<FuncOp>()) {
if (funcOp.isPublic() && !funcOp->hasAttr("iree.abi.stub")) {
entryFuncOps.push_back(funcOp);
}
}
SymbolTable symbolTable(moduleOp);
// Create a wrapper function for each entry point.
for (auto entryFuncOp : entryFuncOps) {
// Rename the original function so that our wrapper can use the original
// name in its public definition.
auto publicName = entryFuncOp.getName().str();
auto privateName = "_" + publicName;
auto privateNameAttr =
mlir::StringAttr::get(entryFuncOp.getContext(), privateName);
if (failed(symbolTable.replaceAllSymbolUses(entryFuncOp, privateNameAttr,
moduleOp))) {
entryFuncOp.emitError() << "unknown symbol table op encountered; "
"cannot fix up symbol names";
return signalPassFailure();
}
entryFuncOp.setName(privateNameAttr);
entryFuncOp.setPrivate();
// Create the wrapper function that conforms to the IREE native ABI and
// marshals arguments/results to the original function.
auto wrapperFuncOp = createWrapperFunc(entryFuncOp);
if (!wrapperFuncOp) return signalPassFailure();
wrapperFuncOp.setPublic();
wrapperFuncOp.setName(
mlir::StringAttr::get(entryFuncOp.getContext(), publicName));
moduleOp.insert(Block::iterator(entryFuncOp), wrapperFuncOp);
wrapperFuncOp.getOperation()->setAttr("iree.abi.stub",
UnitAttr::get(&getContext()));
}
}
private:
Type mapToABIType(Type type) {
if (type.isa<TensorType>()) {
return IREE::HAL::BufferViewType::get(type.getContext());
}
return type;
}
// Creates the corresponding wrapper function for the given entry point.
//
// We do this by creating a new function just for the bindings and calling the
// existing entry point. This allows us to support multiple binding schemes as
// transforms from other bindings can also perform their own equivalent
// wrapping.
//
// NOTE: today we only support a single entry point; with minor tweaks we
// could fix this up to support multiple if we wanted.
FuncOp createWrapperFunc(FuncOp entryFuncOp) {
// Convert argument types to those required by the binding ABI.
//
// NOTE: this is where we could change our signature to provide additional
// values from the runtime bindings as may be required - like semaphores for
// async behavior or cancellation.
auto entryFuncType = entryFuncOp.getType();
SmallVector<Type> inputTypes;
for (auto oldType : entryFuncType.getInputs()) {
inputTypes.push_back(mapToABIType(oldType));
}
SmallVector<Type> resultTypes;
for (auto oldType : entryFuncType.getResults()) {
resultTypes.push_back(mapToABIType(oldType));
}
auto wrapperFuncType =
FunctionType::get(entryFuncOp.getContext(), inputTypes, resultTypes);
auto wrapperFuncOp = FuncOp::create(entryFuncOp.getLoc(),
entryFuncOp.getName(), wrapperFuncType);
SmallVector<DictionaryAttr, 4> argAttrDict;
entryFuncOp.getAllArgAttrs(argAttrDict);
wrapperFuncOp.setAllArgAttrs(argAttrDict);
SmallVector<DictionaryAttr, 4> resultAttrDict;
entryFuncOp.getAllResultAttrs(resultAttrDict);
wrapperFuncOp.setAllResultAttrs(resultAttrDict);
populateReflectionAttrs(entryFuncOp, wrapperFuncOp);
auto *entryBlock = wrapperFuncOp.addEntryBlock();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
// Build a map of result value to the argument that has its backing storage.
SmallVector<Value> resultStorages;
resultStorages.resize(resultTypes.size());
for (unsigned i = 0; i < inputTypes.size(); ++i) {
auto outputAttr =
entryFuncOp.getArgAttrOfType<IntegerAttr>(i, "iree.abi.output");
if (!outputAttr) continue;
// Today all outputs need to be a !hal.buffer - we could change this
// in the future to be something more generalized.
auto storageArg = entryBlock->getArgument(i);
if (!storageArg.getType().isa<IREE::HAL::BufferType>()) {
entryFuncOp.emitError()
<< "storage argument " << i << " has an invalid type "
<< storageArg.getType() << "; must be a !hal.buffer";
return {};
}
resultStorages[outputAttr.getInt()] = storageArg;
}
// Marshal arguments.
SmallVector<Value> arguments;
for (auto arg : llvm::enumerate(entryBlock->getArguments())) {
auto oldType = entryFuncType.getInput(arg.index());
if (auto tensorType = oldType.dyn_cast<RankedTensorType>()) {
auto argLoc = arg.value().getLoc();
auto importOp = entryBuilder.create<IREE::HAL::TensorImportOp>(
argLoc, oldType, arg.value());
arguments.push_back(importOp.target());
} else {
arguments.push_back(arg.value());
}
}
// Make the call with the original types.
auto callOp = entryBuilder.create<func::CallOp>(entryFuncOp.getLoc(),
entryFuncOp, arguments);
// Marshal results.
SmallVector<Value> results;
for (auto result : llvm::enumerate(callOp.getResults())) {
auto oldType = entryFuncType.getResult(result.index());
auto newType = wrapperFuncType.getResult(result.index());
if (oldType.isa<TensorType>()) {
auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
entryFuncOp.getLoc(), result.value(), entryBuilder);
results.push_back(entryBuilder.create<IREE::HAL::TensorExportOp>(
entryFuncOp.getLoc(), newType, result.value(),
TypeAttr::get(result.value().getType()), dynamicDims,
resultStorages[result.index()]));
} else {
results.push_back(result.value());
}
}
entryBuilder.create<func::ReturnOp>(entryFuncOp.getLoc(), results);
return wrapperFuncOp;
}
// Populates attributes on |wrapperFuncOp| to support runtime reflection.
void populateReflectionAttrs(FuncOp entryFuncOp, FuncOp wrapperFuncOp) {
SmallVector<NamedAttribute, 4> attrs;
auto abiAttr = entryFuncOp->getAttr("iree.abi");
if (abiAttr) {
attrs.emplace_back(StringAttr::get(entryFuncOp.getContext(), "iree.abi"),
abiAttr);
}
if (!attrs.empty()) {
auto reflectionAttr = DictionaryAttr::get(&getContext(), attrs);
wrapperFuncOp->setAttr("iree.reflection", reflectionAttr);
}
}
};
std::unique_ptr<OperationPass<ModuleOp>> createWrapEntryPointsPass() {
return std::make_unique<WrapEntryPointsPass>();
}
static PassRegistration<WrapEntryPointsPass> pass;
} // namespace ABI
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir