blob: 46502f9f59e2339d78b7e3abec06cdb65bf5084e [file] [log] [blame] [edit]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/ConstEval/Passes.h"
#include "iree/compiler/ConstEval/Runtime.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetOptions.h"
#include "iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h"
#include "iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Pipelines/Pipelines.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Timer.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/SymbolTable.h"
#include <cstdlib>
#define DEBUG_TYPE "iree-const-eval"
namespace mlir::iree_compiler::ConstEval {
#define GEN_PASS_DEF_JITGLOBALSPASS
#include "iree/compiler/ConstEval/Passes.h.inc"
static llvm::cl::opt<std::string> clJitTargetDevice(
"iree-consteval-jit-target-device",
llvm::cl::desc("Overrides the target device used for JIT'ing."),
llvm::cl::init(""));
static llvm::cl::opt<bool> clEnableDebug(
"iree-consteval-jit-debug",
llvm::cl::desc(
"Prints debugging information to stderr (useful since when consteval "
"has issues, it is often in production on the largest models where we "
"don't want to run a debug compiler)."),
llvm::cl::init(false));
namespace {
static bool isDebugEnabled() {
if (clEnableDebug)
return true;
if (std::getenv("IREE_COMPILER_DEBUG_CONSTEVAL"))
return true;
return false;
}
static void
emitDebugWarning(Location loc,
llvm::function_ref<void(InFlightDiagnostic &)> emit) {
if (isDebugEnabled()) {
auto diagnostic = mlir::emitWarning(loc);
emit(diagnostic);
}
}
// These options structs are not copy-constructable so we have to allocate them
// shared.
// TODO: See if we can make them copyable?
struct CompileOptions {
GlobalPipelineOptions pipelineOptions;
BindingOptions bindingOptions;
InputDialectOptions inputOptions;
PreprocessingOptions preprocessingOptions;
GlobalOptimizationOptions globalOptimizationOptions;
DispatchCreationOptions dispatchCreationOptions;
SchedulingOptions schedulingOptions;
IREE::HAL::TargetOptions executableOptions;
IREE::VM::TargetOptions targetOptions;
IREEVMPipelineHooks hooks;
};
static inline bool isAttrParameterized(Attribute attr) {
if (!attr)
return false;
return !isa<IntegerAttr>(attr) && !isa<FloatAttr>(attr) &&
!isa<IREE::Util::SerializableAttrInterface>(attr);
}
template <typename AccessorTy>
static inline bool isAccessorParameterized(const SymbolTable &moduleSymbols,
AccessorTy op) {
auto global =
moduleSymbols.lookup<IREE::Util::GlobalOpInterface>(op.getGlobalName());
if (!global)
return true;
return isAttrParameterized(global.getGlobalInitialValue());
}
// Today the only way to interact with a global is with loads, stores, and
// addresses, and globals are the only way to reference parameters given where
// const-eval is run today. This is a workaround until we have proper dialect
// interfaces for detecting whether something is evaluatable at compile time.
static bool isParameterized(const SymbolTable &moduleSymbols,
IREE::Util::InitializerOpInterface initializerOp) {
WalkResult res = initializerOp->walk([&](Operation *op) {
const bool parameterized =
llvm::TypeSwitch<Operation *, bool>(op)
.Case([=](IREE::Util::GlobalLoadOpInterface accessor) {
return isAccessorParameterized(moduleSymbols, accessor);
})
.Case([=](IREE::Util::GlobalStoreOpInterface accessor) {
return isAccessorParameterized(moduleSymbols, accessor);
})
.Case([=](IREE::Flow::TensorConstantOp accessor) {
return isAttrParameterized(accessor.getValueAttr());
})
.Default([=](auto) { return false; });
if (parameterized)
return WalkResult::interrupt();
return WalkResult::advance();
});
return res.wasInterrupted();
}
// WIP specialized analysis for tracking initialization order in a module.
// This attempts to provide a "is this value initialized?" query with the
// differentiation of whether that initialization is possible within the
// compiler or if it relies on runtime information.
//
// This is currently fairly limited and bails on many common cases that we don't
// naturally generate in early phases of program compilation. More sophisticated
// analysis is required to use this elsewhere once calls, control flow, and
// more dynamic values are used.
class InitializationAnalysis {
public:
enum class Availability {
// Analysis failure, assume runtime.
Unknown = 0,
// Can only be evaluated fully at runtime. May depend on runtime-derived
// values from the HAL, custom modules, or parameters.
Runtime,
// Can be entirely evaluated at compile-time.
Compiler,
};
InitializationAnalysis(
Operation *rootOp, SymbolTable &symbolTable,
const IREE::Util::ConstExprAnalysis &constExprAnalysis) {
run(rootOp, symbolTable, constExprAnalysis);
}
// Returns the calculated availability of an initializer indicating when it is
// able to be evaluated.
Availability
getInitializerAvailability(IREE::Util::InitializerOpInterface initializerOp) {
auto it = initializerAvailability.find(initializerOp);
if (it == initializerAvailability.end())
return Availability::Unknown;
return it->second;
}
private:
void run(Operation *rootOp, SymbolTable &symbolTable,
const IREE::Util::ConstExprAnalysis &constExprAnalysis) {
unsigned nextOpOrdinal = 0;
for (auto &region : rootOp->getRegions()) {
for (auto &op : region.getOps()) {
if (auto globalOp = dyn_cast<IREE::Util::GlobalOpInterface>(op)) {
// Globals with initial values are initialized in order with where
// they are in the module.
auto &timeline = globalTimelines[globalOp.getGlobalName().getValue()];
assert(timeline.empty() && "out-of-order global store");
timeline.push_back(
std::make_pair(nextOpOrdinal++, Availability::Compiler));
} else if (auto initializerOp =
dyn_cast<IREE::Util::InitializerOpInterface>(op)) {
// Initializer availability depends on all dependent initialized
// values.
initializerAvailability[initializerOp] =
calculateInitializerAvailability(
initializerOp, symbolTable, constExprAnalysis, nextOpOrdinal);
}
}
}
}
// Returns the availability of |globalName| by the time |opOrdinal| is
// executed. Note that some globals may be initialized multiple times (yuck,
// but valid).
Availability queryGlobalInitializationStatus(StringRef globalName,
unsigned opOrdinal) {
auto &timeline = globalTimelines[globalName];
if (timeline.empty())
return Availability::Unknown;
for (auto &timepoint : timeline) {
if (timepoint.first > opOrdinal)
return timepoint.second;
}
return timeline.back().second;
}
// Returns true if the given |initializerOp| is a constant expression that is
// able to be evaluated by this pass.
Availability calculateInitializerAvailability(
IREE::Util::InitializerOpInterface initializerOp,
SymbolTable &symbolTable,
const IREE::Util::ConstExprAnalysis &constExprAnalysis,
unsigned &nextOpOrdinal) {
SmallVector<std::pair<IREE::Util::GlobalStoreOpInterface, unsigned>>
globalStoreOps;
// Assume compile-time availability unless we see anything that may prevent
// it. As we analyze the initializer we may "lower" the availability from
// the most available (compile-time) to least available (run-time/unknown).
auto availability = Availability::Compiler;
auto lowerAvailability = [&](Availability newAvailability,
StringRef reason) {
auto previousAvailability = availability;
availability = static_cast<Availability>(
std::min(static_cast<unsigned>(availability),
static_cast<unsigned>(newAvailability)));
if (previousAvailability != availability)
emitDebugWarning(
initializerOp.getLoc(),
[&](InFlightDiagnostic &diagnostic) { diagnostic << reason; });
};
if (initializerOp->getRegions().size() != 1 ||
!initializerOp->getRegion(0).hasOneBlock()) {
// Skip if multiple blocks. It would be possible to support these in
// theory but unclear if worth it in practice given the predominance of
// SCF at the levels we run things. What we'd require is adding a single
// exit block that stored to the globals unconditionally.
lowerAvailability(Availability::Unknown,
"skipping consteval initializer: initializers with >1 "
"block not yet supported");
} else if (isParameterized(symbolTable, initializerOp)) {
// We don't allow anything with parameters today. We could handle these by
// passing in the parameter file for use but would likely also want to
// bind a writeable parameter file to produce into.
lowerAvailability(Availability::Runtime,
"skipping consteval initializer: uses parameters or "
"other runtime-dependent values");
}
// Today we require that all values are constant expressions. We could slice
// out just the ones that are.
for (auto &op : initializerOp.getInitializerRegion().getOps()) {
if (op.hasTrait<OpTrait::ConstantLike>() ||
isa<IREE::Util::ReturnOp>(op)) {
continue;
} else if (isa<RegionBranchOpInterface>(op)) {
// Control flow currently isn't evaluated properly; we'd need much
// better analysis for things like conditional stores to globals. We
// could make this more permissive for cases where the globals are
// stored unconditionally/once but still allow control flow in other
// places.
lowerAvailability(
Availability::Unknown,
"skipping consteval initializer: has control flow ops");
} else if (isa<CallOpInterface>(op)) {
// Calls aren't currently analyzed - we need to rewrite this to use DFX
// and walk the call graph to do that.
lowerAvailability(Availability::Unknown,
"skipping consteval initializer: has call");
} else if (isa<IREE::Util::GlobalLoadIndirectOpInterface>(op) ||
isa<IREE::Util::GlobalStoreIndirectOpInterface>(op)) {
// Pessimistic case as we need analysis to know if the global
// being loaded may potentially be a parameter.
lowerAvailability(
Availability::Unknown,
"skipping consteval initializer: has indirect global accesses");
} else if (auto loadOp =
dyn_cast<IREE::Util::GlobalLoadOpInterface>(op)) {
// Globals must be initialized prior to this initializer and if they are
// initialized at runtime it means this initializer must be too.
auto globalStatus = queryGlobalInitializationStatus(
loadOp.getGlobalName(), nextOpOrdinal++);
if (globalStatus != Availability::Compiler) {
lowerAvailability(globalStatus, "skipping consteval initializer: has "
"runtime-dependent global load");
}
} else if (auto storeOp =
dyn_cast<IREE::Util::GlobalStoreOpInterface>(op)) {
// Only allow stores to immutable globals (ones we initialize).
auto globalOp = symbolTable.lookup<IREE::Util::GlobalOpInterface>(
storeOp.getGlobalAttr().getAttr());
if (!globalOp || globalOp.isGlobalMutable()) {
lowerAvailability(
Availability::Runtime,
"skipping consteval initializer: has mutable global store");
}
globalStoreOps.push_back(std::make_pair(storeOp, nextOpOrdinal++));
} else if (!constExprAnalysis.isConstExprOperation(&op)) {
lowerAvailability(
Availability::Runtime,
"skipping consteval initializer: has non-const-expr values");
}
}
// Record global availability produced by this initializer.
for (auto [storeOp, opOrdinal] : globalStoreOps) {
auto &timeline = globalTimelines[storeOp.getGlobalName()];
timeline.push_back(std::make_pair(opOrdinal, availability));
}
return availability;
}
// An initialization-ordered sequence denoting changes in availability.
// Example:
// * [-1, Compiler]: initialized with a constant primitive at startup
// * [2, Runtime]: reinitialized with a value computed at runtime
// * [4, Compiler]: reinitialized with a value available at compile time
// * [8, Unknown]: reinitialized with a value that failed analysis
// The timeline can be queried by walking in order looking for any ordinal
// under the requested point. A query at 3 would return Runtime as it is after
// the first initialization but prior to the subsequent reinitializations.
using AvailabilityTimeline = SmallVector<std::pair<unsigned, Availability>>;
DenseMap<StringRef, AvailabilityTimeline> globalTimelines;
DenseMap<Operation *, Availability> initializerAvailability;
};
// JIT functions take arguments, generally from the source program. We capture
// them here.
class ArgumentBinding {
public:
enum class Type {
// An ElementsAttr.
ElementsAttr,
// The value of a GlobalOp. It may not be set at the start of the run
// if there is a dependency that evaluates first.
GlobalOp,
};
ArgumentBinding(ElementsAttr attr)
: type(Type::ElementsAttr), elementsAttr(attr) {}
ArgumentBinding(IREE::Util::GlobalOpInterface globalOp)
: type(Type::GlobalOp), globalOp(globalOp) {}
Type getType() { return type; }
ElementsAttr getElementsAttr() {
assert(type == Type::ElementsAttr);
return elementsAttr;
}
IREE::Util::GlobalOpInterface getGlobalOp() {
assert(type == Type::GlobalOp);
return globalOp;
}
private:
Type type;
ElementsAttr elementsAttr;
IREE::Util::GlobalOpInterface globalOp;
};
// How to bind results to the original program.
class ResultBinding {
public:
enum class Type {
// Set the result on the global op.
GlobalOp,
};
ResultBinding(IREE::Util::GlobalOpInterface globalOp)
: type(Type::GlobalOp), globalOp(globalOp) {}
Type getType() { return type; }
IREE::Util::GlobalOpInterface getGlobalOp() {
assert(type == Type::GlobalOp);
return globalOp;
}
private:
Type type;
ElementsAttr elementsAttr;
IREE::Util::GlobalOpInterface globalOp;
};
// Description of a JIT function that we have created for doing some
// initialization work.
struct JitFunctionDesc {
JitFunctionDesc(Location loc, std::string name)
: loc(loc), name(std::move(name)) {}
Location loc;
std::string name;
llvm::SmallVector<ArgumentBinding> argumentBindings;
llvm::SmallVector<ResultBinding> resultBindings;
};
// Clones all object-like symbols used within the function.
// Objects are only cloned once if used by multiple functions.
// All object contents are cloned and symbol DCE is relied on to remove any
// unused nested symbols later on.
static LogicalResult cloneUsedObjects(FunctionOpInterface funcOp,
SymbolTable &sourceSymbolTable,
SymbolTable &targetSymbolTable,
OpBuilder &moduleBuilder) {
// Gather all symbol uses within the function.
auto uses = SymbolTable::getSymbolUses(funcOp);
if (!uses.has_value())
return success();
// Verify that all uses are to object-like types we can clone.
for (auto use : uses.value()) {
// Lookup the (maybe) object in the source module.
auto objectNameAttr = use.getSymbolRef().getRootReference();
auto *objectOp = sourceSymbolTable.lookup(objectNameAttr);
if (!objectOp) {
return use.getUser()->emitOpError()
<< "references undefined symbol " << use.getSymbolRef();
}
if (!objectOp->hasTrait<OpTrait::IREE::Util::ObjectLike>())
continue;
// Check if the object exists in the target yet. Since we create the
// target we know there should be no conflicts: the only symbols with the
// same name will be already cloned copies of the same source.
if (targetSymbolTable.lookup(objectNameAttr))
continue;
// Clone the object. It's isolated and safe to copy wholesale.
auto *clonedOp = moduleBuilder.clone(*objectOp);
targetSymbolTable.insert(clonedOp);
}
return success();
}
class ProgramBuilder {
public:
ProgramBuilder(ModuleOp sourceModuleOp,
IREE::HAL::DeviceTargetAttr deviceTargetAttr,
const IREE::HAL::TargetBackend::SupportedTypes &supportedTypes,
const IREE::Util::ConstExprAnalysis &constExprAnalysis)
: targetModuleOp(createInnerModule(sourceModuleOp)),
sourceSymbolTable(sourceModuleOp), targetSymbolTable(targetModuleOp),
supportedTypes(supportedTypes), constExprAnalysis(constExprAnalysis),
initializationAnalysis(sourceModuleOp, sourceSymbolTable,
constExprAnalysis) {
targetModuleOp->setAttr(
"hal.device.targets",
ArrayAttr::get(sourceModuleOp.getContext(),
{static_cast<Attribute>(deviceTargetAttr)}));
}
llvm::SmallVector<JitFunctionDesc> &getJitFunctions() { return jitFunctions; }
ModuleOp getTargetModule() { return targetModuleOp; }
LogicalResult importInitializer(IREE::Util::InitializerOp initializerOp) {
// We convert each initializer into a public FuncOp by converting each:
// - Tensor constant into an argument
// - util.global.load into an argument
// - util.global.store into a result
// It is considered an eval'able initializer if it contains stores
// into immutable global(s). In the future, we will also want to
// condition this on an attribute so as to not try to statically
// compile dynamic initializers.
auto availability =
initializationAnalysis.getInitializerAvailability(initializerOp);
if (availability != InitializationAnalysis::Availability::Compiler)
return failure();
OpBuilder moduleBuilder = OpBuilder::atBlockEnd(targetModuleOp.getBody());
// Find any object-like symbol references used by the initializer and
// clone them.
if (failed(cloneUsedObjects(initializerOp, sourceSymbolTable,
targetSymbolTable, moduleBuilder)))
return failure();
auto funcOp = IREE::Util::FuncOp::create(
moduleBuilder, initializerOp.getLoc(), "jit_eval",
moduleBuilder.getFunctionType({}, {}));
targetSymbolTable.insert(funcOp);
IRMapping unusedMapping;
initializerOp.getBody().cloneInto(&funcOp.getBody(), unusedMapping);
if (failed(transformToJitFunction(funcOp))) {
funcOp.erase();
return failure();
}
return success();
}
private:
static ModuleOp createInnerModule(ModuleOp sourceModuleOp) {
OpBuilder builder = OpBuilder::atBlockEnd(sourceModuleOp.getBody());
auto m = ModuleOp::create(builder, sourceModuleOp.getLoc());
m->setAttr("iree.consteval", builder.getUnitAttr());
return m;
}
LogicalResult transformToJitFunction(IREE::Util::FuncOp funcOp) {
JitFunctionDesc desc(funcOp.getLoc(), funcOp.getName().str());
llvm::SmallVector<Type> argumentTypes;
llvm::SmallVector<Type> returnTypes;
llvm::SmallVector<Value> returns;
llvm::SmallVector<Operation *> eraseOps;
Block *entryBlock = &funcOp.getBody().front();
// Find immutable loads.
for (auto loadOp : funcOp.getOps<IREE::Util::GlobalLoadOpInterface>()) {
auto globalOp = dyn_cast_if_present<IREE::Util::GlobalOpInterface>(
sourceSymbolTable.lookup(loadOp.getGlobalAttr().getAttr()));
if (!globalOp || globalOp.isGlobalMutable()) {
emitDebugWarning(loadOp.getLoc(), [&](InFlightDiagnostic &diagnostic) {
diagnostic << "skipping consteval initializer: load from mutable "
"globals not supported";
});
return failure();
}
Type t = loadOp.getLoadedGlobalValue().getType();
if (!supportedTypes.supportsType(t)) {
emitDebugWarning(funcOp.getLoc(), [&](InFlightDiagnostic &diagnostic) {
diagnostic << "skipping consteval initializer: unsupported type for "
"current jit configuration: "
<< t;
});
return failure();
}
argumentTypes.push_back(t);
BlockArgument entryArg = entryBlock->addArgument(t, loadOp.getLoc());
loadOp.getLoadedGlobalValue().replaceAllUsesWith(entryArg);
eraseOps.push_back(loadOp);
desc.argumentBindings.emplace_back(globalOp);
}
// And loose tensor constants.
for (auto constantOp : funcOp.getOps<arith::ConstantOp>()) {
auto tensorType = dyn_cast<TensorType>(constantOp.getResult().getType());
auto elementsAttr = dyn_cast<ElementsAttr>(constantOp.getValue());
if (!tensorType || !elementsAttr)
continue;
if (!supportedTypes.supportsType(tensorType)) {
emitDebugWarning(funcOp.getLoc(), [&](InFlightDiagnostic &diagnostic) {
diagnostic << "skipping consteval initializer: unsupported type for "
"current jit configuration: "
<< tensorType;
});
return failure();
}
argumentTypes.push_back(tensorType);
BlockArgument entryArg =
entryBlock->addArgument(tensorType, constantOp.getLoc());
constantOp.getResult().replaceAllUsesWith(entryArg);
eraseOps.push_back(constantOp);
desc.argumentBindings.emplace_back(elementsAttr);
}
// Find immutable stores, early exiting if not supported.
// The consumers must come after rewrites of the producers above.
for (auto storeOp : funcOp.getOps<IREE::Util::GlobalStoreOpInterface>()) {
auto globalOp = dyn_cast_if_present<IREE::Util::GlobalOpInterface>(
sourceSymbolTable.lookup(storeOp.getGlobalAttr().getAttr()));
assert(globalOp && "should have been checked in isConstExpr");
Type t = storeOp.getStoredGlobalValue().getType();
if (!supportedTypes.supportsType(t)) {
emitDebugWarning(funcOp.getLoc(), [&](InFlightDiagnostic &diagnostic) {
diagnostic << "skipping consteval initializer: unsupported type for "
"current jit configuration: "
<< t;
});
return failure();
}
returns.push_back(storeOp.getStoredGlobalValue());
returnTypes.push_back(t);
eraseOps.push_back(storeOp);
desc.resultBindings.emplace_back(globalOp);
}
// Cleanup.
for (auto *op : eraseOps) {
op->erase();
}
// Rewrite the terminator and the function type.
entryBlock->getTerminator()->erase();
OpBuilder termBuilder = OpBuilder::atBlockEnd(entryBlock);
IREE::Util::ReturnOp::create(termBuilder, funcOp.getLoc(), returns);
funcOp.setType(termBuilder.getFunctionType(argumentTypes, returnTypes));
jitFunctions.push_back(std::move(desc));
return success();
}
ModuleOp targetModuleOp;
SymbolTable sourceSymbolTable;
SymbolTable targetSymbolTable;
llvm::SmallVector<JitFunctionDesc> jitFunctions;
const IREE::HAL::TargetBackend::SupportedTypes supportedTypes;
const IREE::Util::ConstExprAnalysis &constExprAnalysis;
InitializationAnalysis initializationAnalysis;
};
class JitGlobalsPass final : public impl::JitGlobalsPassBase<JitGlobalsPass> {
public:
JitGlobalsPass() : JitGlobalsPass(JitGlobalsPassOptions{}) {}
JitGlobalsPass(const JitGlobalsPassOptions &options)
: compileOptions(std::make_shared<CompileOptions>()),
compilePipeline("builtin.module") {
targetRegistry = options.targetRegistry;
// Detect backend.
compileOptions->targetOptions.f32Extension = true;
compileOptions->targetOptions.f64Extension = true;
compileOptions->targetOptions.indexBits = 64;
compileOptions->targetOptions.truncateUnsupportedFloats = false;
compileOptions->inputOptions.demoteF64ToF32 = false;
requestedTargetDevice = resolveTargetDevice(*targetRegistry.value);
if (targetRegistry->getTargetDevice(requestedTargetDevice) == nullptr) {
targetDevice = targetRegistry->getTargetDevice("local");
} else {
targetDevice = targetRegistry->getTargetDevice(requestedTargetDevice);
}
// Disable constant evaluation for our Jit compilation pipeline.
// It would make no sense to recursively do constant evaluation, and since
// we omit the necessary hooks, it is unsupported anyway.
compileOptions->pipelineOptions.constExprHoisting = false;
compileOptions->globalOptimizationOptions.constEval = false;
buildIREEVMTransformPassPipeline(
*targetRegistry.value, compileOptions->pipelineOptions,
compileOptions->bindingOptions, compileOptions->inputOptions,
compileOptions->preprocessingOptions,
compileOptions->globalOptimizationOptions,
compileOptions->dispatchCreationOptions,
compileOptions->schedulingOptions, compileOptions->executableOptions,
compileOptions->targetOptions, compileOptions->hooks, compilePipeline);
}
void getDependentDialects(DialectRegistry &registry) const override {
compilePipeline.getDependentDialects(registry);
}
static std::string
resolveTargetDevice(const IREE::HAL::TargetRegistry &targetRegistry) {
if (clJitTargetDevice.empty()) {
return std::string("local");
}
return clJitTargetDevice;
}
LogicalResult
processFunctions(CompiledBinary &binary,
llvm::SmallVector<JitFunctionDesc> &jitFunctions,
ModuleOp module, llvm::TimerGroup &tg) {
// Process each function through the runtime.
for (JitFunctionDesc &jitFunction : jitFunctions) {
std::optional<llvm::Timer> invokeTimer;
if (debugEnabled) {
std::string timerName("Invoke ");
timerName.append(jitFunction.name);
invokeTimer.emplace(timerName, timerName, tg);
invokeTimer->startTimer();
llvm::dbgs() << "::: Invoking " << jitFunction.name << "\n";
}
FunctionCall call(binary, jitFunction.argumentBindings.size(),
jitFunction.resultBindings.size());
if (failed(call.initialize(jitFunction.loc)))
return failure();
// Convert arguments.
for (ArgumentBinding &arg : jitFunction.argumentBindings) {
switch (arg.getType()) {
case ArgumentBinding::Type::ElementsAttr: {
if (failed(call.addArgument(jitFunction.loc, arg.getElementsAttr())))
return failure();
break;
}
case ArgumentBinding::Type::GlobalOp: {
auto globalValue = arg.getGlobalOp().getGlobalInitialValue();
if (!globalValue) {
return emitError(jitFunction.loc)
<< "internal error: jit global source initialization order "
"invalid: global "
<< arg.getGlobalOp().getGlobalName() << " has no value";
}
if (failed(call.addArgument(arg.getGlobalOp().getLoc(), globalValue)))
return failure();
break;
}
}
}
if (failed(call.invoke(jitFunction.loc, jitFunction.name))) {
return failure();
}
// Process results.
for (auto it : llvm::enumerate(jitFunction.resultBindings)) {
ResultBinding &resultBinding = it.value();
switch (resultBinding.getType()) {
case ResultBinding::Type::GlobalOp: {
TypedAttr attr;
if (failed(call.getResultAsAttr(
resultBinding.getGlobalOp().getLoc(), it.index(),
resultBinding.getGlobalOp().getGlobalType(), attr)))
return failure();
resultBinding.getGlobalOp().setGlobalInitialValue(attr);
break;
}
}
}
if (debugEnabled) {
invokeTimer->stopTimer();
}
}
return success();
}
void runOnOperation() override {
llvm::TimerGroup tg("iree-consteval-jit", "Consteval Jit");
mlir::ModuleOp outerModuleOp = getOperation();
// Set the target.
if (!targetDevice) {
emitError(UnknownLoc::get(&getContext()))
<< "consteval jit could not find a usable backend (requested '"
<< requestedTargetDevice << "')";
signalPassFailure();
return;
}
auto deviceTargetAttr =
targetDevice->getHostDeviceTarget(&getContext(), *targetRegistry.value);
if (!deviceTargetAttr) {
emitError(UnknownLoc::get(&getContext()))
<< "consteval requested device " << requestedTargetDevice
<< " cannot target the host";
signalPassFailure();
return;
}
IREE::HAL::TargetBackend::SupportedTypes supportedTypes;
for (auto executableTargetAttr : deviceTargetAttr->getExecutableTargets()) {
auto targetBackend = targetRegistry->getTargetBackend(
executableTargetAttr.getBackend().getValue());
if (targetBackend) {
supportedTypes = targetBackend->getSupportedTypes(&getContext());
break;
} else {
emitError(UnknownLoc::get(&getContext()))
<< "consteval requested device " << requestedTargetDevice
<< " compilation backend " << executableTargetAttr.getBackend()
<< " not registered with the TargetRegistry";
signalPassFailure();
return;
}
}
// Build the program.
ProgramBuilder programBuilder(outerModuleOp, *deviceTargetAttr,
supportedTypes,
getAnalysis<IREE::Util::ConstExprAnalysis>());
// Iterate over initializers.
llvm::SmallVector<IREE::Util::InitializerOp> initializerOps;
llvm::SmallVector<IREE::Util::InitializerOp> deadInitOps;
for (auto childOp : outerModuleOp.getOps<IREE::Util::InitializerOp>()) {
initializerOps.push_back(childOp);
}
for (auto initializerOp : initializerOps) {
if (succeeded(programBuilder.importInitializer(initializerOp))) {
deadInitOps.push_back(initializerOp);
} else if (debugEnabled) {
llvm::dbgs() << "::: Rejected consteval initializer:\n"
<< initializerOp << "\n";
}
}
if (programBuilder.getJitFunctions().empty()) {
programBuilder.getTargetModule()->erase();
return;
}
std::optional<llvm::Timer> compileTimer;
if (debugEnabled) {
llvm::dbgs() << "::: COMPILING JIT (" << requestedTargetDevice
<< "): " << programBuilder.getTargetModule() << "\n";
compileTimer.emplace("iree-consteval-jit-compile", "Compiling", tg);
compileTimer->startTimer();
}
if (failed(
runPipeline(compilePipeline, programBuilder.getTargetModule()))) {
return signalPassFailure();
}
// Generate a binary.
InMemoryCompiledBinary binary;
if (failed(binary.translateFromModule(programBuilder.getTargetModule()))) {
return signalPassFailure();
}
if (debugEnabled) {
compileTimer->stopTimer();
}
// Kill the temporary program.
programBuilder.getTargetModule()->erase();
// Process the functions.
if (failed(processFunctions(binary, programBuilder.getJitFunctions(),
outerModuleOp, tg))) {
signalPassFailure();
return;
}
// Cleanup any initializers we replaced.
// We do this after running the JIT-ed functions because we have deep
// references into ops and attributes that need to be converted to
// arguments.
for (auto deadOp : deadInitOps) {
deadOp.erase();
}
}
private:
std::shared_ptr<CompileOptions> compileOptions;
OpPassManager compilePipeline;
std::string requestedTargetDevice;
std::shared_ptr<IREE::HAL::TargetDevice> targetDevice;
bool debugEnabled = isDebugEnabled();
};
} // namespace
} // namespace mlir::iree_compiler::ConstEval