Extending imports to have a context pointer. (#10580)
This will allow for imports to carry state that doesn't require TLS or
globals. Also sprinkling in a reserved pointer in case we need more
stuff in the future.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 8816913..3fb5be6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -101,7 +101,8 @@
enum class EnvironmentField {
constants,
import_thunk,
- imports,
+ import_funcs,
+ import_contexts,
processor,
};
@@ -124,13 +125,16 @@
fieldTypes.push_back(uint32PtrType);
// iree_hal_executable_import_thunk_v0_t import_thunk;
- // const iree_hal_executable_import_v0_t* imports;
- auto importType = LLVM::LLVMFunctionType::get(uint32Type, int8PtrType);
+ // const iree_hal_executable_import_v0_t* import_funcs;
+ // const void** import_contexts;
+ auto importType = LLVM::LLVMFunctionType::get(
+ uint32Type, {int8PtrType, int8PtrType, int8PtrType});
auto importPtrType = LLVM::LLVMPointerType::get(importType);
- auto importThunkType =
- LLVM::LLVMFunctionType::get(uint32Type, {importPtrType, int8PtrType});
+ auto importThunkType = LLVM::LLVMFunctionType::get(
+ uint32Type, {importPtrType, int8PtrType, int8PtrType, int8PtrType});
fieldTypes.push_back(LLVM::LLVMPointerType::get(importThunkType));
fieldTypes.push_back(LLVM::LLVMPointerType::get(importPtrType));
+ fieldTypes.push_back(LLVM::LLVMPointerType::get(int8PtrType));
// iree_hal_processor_v0_t processor;
fieldTypes.push_back(processorType);
@@ -515,27 +519,35 @@
// Loads the import function pointer of the import |ordinal|.
// Equivalent to:
- // iree_hal_executable_import_v0_t func_ptr = state->imports[ordinal];
- Value loadImportFuncPtr(Location loc, int64_t ordinal, OpBuilder &builder) {
- auto importsPtrValue =
- loadFieldValue(loc, EnvironmentField::imports, builder);
+ // iree_hal_executable_import_v0_t fn_ptr = state->import_funcs[ordinal];
+ // void* context = state->import_contexts[ordinal];
+ std::pair<Value, Value> loadImportFunc(Location loc, int64_t ordinal,
+ OpBuilder &builder) {
auto ordinalValue = getIndexValue(loc, ordinal, builder);
- auto elementPtrValue = builder.createOrFold<LLVM::GEPOp>(
- loc, importsPtrValue.getType(), importsPtrValue, ordinalValue);
- return builder.createOrFold<LLVM::LoadOp>(loc, elementPtrValue);
+ auto funcPtrsValue =
+ loadFieldValue(loc, EnvironmentField::import_funcs, builder);
+ auto funcPtrValue = builder.createOrFold<LLVM::GEPOp>(
+ loc, funcPtrsValue.getType(), funcPtrsValue, ordinalValue);
+ auto contextPtrsValue =
+ loadFieldValue(loc, EnvironmentField::import_contexts, builder);
+ auto contextPtrValue = builder.createOrFold<LLVM::GEPOp>(
+ loc, contextPtrsValue.getType(), contextPtrsValue, ordinalValue);
+ return std::make_pair(
+ builder.createOrFold<LLVM::LoadOp>(loc, funcPtrValue),
+ builder.createOrFold<LLVM::LoadOp>(loc, contextPtrValue));
}
// Returns an i1 indicating whether the optional import with |ordinal| is
// defined. Equivalent to:
- // state->imports[ordinal] != NULL
+ // state->import_funcs[ordinal] != NULL
Value isImportFuncAvailable(Location loc, int64_t ordinal,
OpBuilder &builder) {
- auto importPtrValue = loadImportFuncPtr(loc, ordinal, builder);
- auto nullPtrValue =
- builder.create<LLVM::NullOp>(loc, importPtrValue.getType()).getResult();
+ auto importFunc = loadImportFunc(loc, ordinal, builder);
+ Value nullPtrValue =
+ builder.create<LLVM::NullOp>(loc, importFunc.first.getType());
return builder.create<LLVM::ICmpOp>(loc, builder.getI1Type(),
- LLVM::ICmpPredicate::ne, importPtrValue,
- nullPtrValue);
+ LLVM::ICmpPredicate::ne,
+ importFunc.first, nullPtrValue);
}
// Emits a call to the import with the given |importOrdinal|.
@@ -546,13 +558,17 @@
OpBuilder &builder) {
auto thunkPtrValue =
loadFieldValue(loc, EnvironmentField::import_thunk, builder);
- auto importPtrValue = loadImportFuncPtr(loc, importOrdinal, builder);
+ auto importFunc = loadImportFunc(loc, importOrdinal, builder);
+ Value nullPtrValue = builder.create<LLVM::NullOp>(
+ loc, LLVM::LLVMPointerType::get(builder.getI8Type()));
auto callOp =
builder.create<LLVM::CallOp>(loc, TypeRange{builder.getI32Type()},
ValueRange{
/*thunk_func_ptr=*/thunkPtrValue,
- /*import_func_ptr=*/importPtrValue,
- /*import_params=*/params,
+ /*import_func_ptr=*/importFunc.first,
+ /*context=*/importFunc.second,
+ /*params=*/params,
+ /*reserved=*/nullPtrValue,
});
return callOp.getResult();
}
diff --git a/runtime/src/iree/hal/local/executable_library.h b/runtime/src/iree/hal/local/executable_library.h
index 2757726..a9fc3b6 100644
--- a/runtime/src/iree/hal/local/executable_library.h
+++ b/runtime/src/iree/hal/local/executable_library.h
@@ -154,14 +154,16 @@
// a useful failure though the HAL does not mandate that all overflows are
// caught and only that they are not harmful - clamping byte ranges and never
// returning a failure is sufficient.
-typedef int (*iree_hal_executable_import_v0_t)(void* import_params);
+typedef int (*iree_hal_executable_import_v0_t)(void* context, void* params,
+ void* reserved);
// A thunk function used to call an import.
// All imports must be called through this function by passing the import
// function pointer as the first argument followed by the arguments of the
// import function itself.
typedef int (*iree_hal_executable_import_thunk_v0_t)(
- iree_hal_executable_import_v0_t fn_ptr, void* import_params);
+ iree_hal_executable_import_v0_t fn_ptr, void* context, void* params,
+ void* reserved);
// Declares imports available to the executable library at runtime.
// To enable linker isolation, ABI shimming, and import multi-versioning we use
@@ -251,7 +253,8 @@
// Optional imported functions available for use within the executable.
// Contains one entry per imported function. If an import was marked as weak
// then the corresponding entry may be NULL.
- const iree_hal_executable_import_v0_t* imports;
+ const iree_hal_executable_import_v0_t* import_funcs;
+ const void** import_contexts;
// Optional architecture-specific CPU information.
// In heterogenous processors this may represent any of the subarchitecture
diff --git a/runtime/src/iree/hal/local/executable_loader.c b/runtime/src/iree/hal/local/executable_loader.c
index f0b8c78..d0a8a31 100644
--- a/runtime/src/iree/hal/local/executable_loader.c
+++ b/runtime/src/iree/hal/local/executable_loader.c
@@ -26,9 +26,11 @@
iree_status_t iree_hal_executable_import_provider_resolve(
const iree_hal_executable_import_provider_t import_provider,
- iree_string_view_t symbol_name, void** out_fn_ptr) {
+ iree_string_view_t symbol_name, void** out_fn_ptr, void** out_fn_context) {
IREE_ASSERT_ARGUMENT(out_fn_ptr);
+ IREE_ASSERT_ARGUMENT(out_fn_context);
*out_fn_ptr = NULL;
+ *out_fn_context = NULL;
// A `?` suffix indicates the symbol is weakly linked and can be NULL.
bool is_weak = false;
@@ -47,8 +49,8 @@
(int)symbol_name.size, symbol_name.data);
}
- iree_status_t status =
- import_provider.resolve(import_provider.self, symbol_name, out_fn_ptr);
+ iree_status_t status = import_provider.resolve(
+ import_provider.self, symbol_name, out_fn_ptr, out_fn_context);
if (!iree_status_is_ok(status) && is_weak) {
status = iree_status_ignore(status); // ok to fail on weak symbols
}
diff --git a/runtime/src/iree/hal/local/executable_loader.h b/runtime/src/iree/hal/local/executable_loader.h
index c7a5967..5a095d9 100644
--- a/runtime/src/iree/hal/local/executable_loader.h
+++ b/runtime/src/iree/hal/local/executable_loader.h
@@ -37,7 +37,8 @@
// to the function (or its context) in |out_fn_ptr|.
iree_status_t(IREE_API_PTR* resolve)(void* self,
iree_string_view_t symbol_name,
- void** out_fn_ptr);
+ void** out_fn_ptr,
+ void** out_fn_context);
} iree_hal_executable_import_provider_t;
static inline iree_hal_executable_import_provider_t
@@ -63,7 +64,7 @@
// allowed to be resolved to NULL. Such cases will always return OK.
iree_status_t iree_hal_executable_import_provider_resolve(
const iree_hal_executable_import_provider_t import_provider,
- iree_string_view_t symbol_name, void** out_fn_ptr);
+ iree_string_view_t symbol_name, void** out_fn_ptr, void** out_fn_context);
//===----------------------------------------------------------------------===//
// iree_hal_executable_loader_t
diff --git a/runtime/src/iree/hal/local/loaders/embedded_elf_loader.c b/runtime/src/iree/hal/local/loaders/embedded_elf_loader.c
index ddcc1f4..2227137 100644
--- a/runtime/src/iree/hal/local/loaders/embedded_elf_loader.c
+++ b/runtime/src/iree/hal/local/loaders/embedded_elf_loader.c
@@ -98,15 +98,22 @@
// All calls from the loaded ELF route through our thunk function so that we
// can adapt to ABI differences.
executable->base.environment.import_thunk =
- (iree_hal_executable_import_thunk_v0_t)iree_elf_thunk_i_p;
+ (iree_hal_executable_import_thunk_v0_t)iree_elf_call_i_ppp;
// Allocate storage for the imports.
+ // TODO(benvanik): allocate both as one block.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0,
- iree_allocator_malloc(
- executable->base.host_allocator,
- import_table->count * sizeof(*executable->base.environment.imports),
- (void**)&executable->base.environment.imports));
+ z0, iree_allocator_malloc(
+ executable->base.host_allocator,
+ import_table->count *
+ sizeof(*executable->base.environment.import_funcs),
+ (void**)&executable->base.environment.import_funcs));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(
+ executable->base.host_allocator,
+ import_table->count *
+ sizeof(*executable->base.environment.import_contexts),
+ (void**)&executable->base.environment.import_contexts));
// Try to resolve each import.
// NOTE: imports are sorted alphabetically and if we cared we could use this
@@ -117,7 +124,8 @@
z0,
iree_hal_executable_import_provider_resolve(
import_provider, iree_make_cstring_view(import_table->symbols[i]),
- (void**)&executable->base.environment.imports[i]));
+ (void**)&executable->base.environment.import_funcs[i],
+ (void**)&executable->base.environment.import_contexts[i]));
}
IREE_TRACE_ZONE_END(z0);
@@ -231,9 +239,13 @@
iree_elf_module_deinitialize(&executable->module);
- if (executable->base.environment.imports != NULL) {
+ if (executable->base.environment.import_funcs != NULL) {
iree_allocator_free(host_allocator,
- (void*)executable->base.environment.imports);
+ (void*)executable->base.environment.import_funcs);
+ }
+ if (executable->base.environment.import_contexts != NULL) {
+ iree_allocator_free(host_allocator,
+ (void*)executable->base.environment.import_contexts);
}
iree_hal_local_executable_deinitialize(
diff --git a/runtime/src/iree/hal/local/loaders/system_library_loader.c b/runtime/src/iree/hal/local/loaders/system_library_loader.c
index 14d2c04..6fb73f8 100644
--- a/runtime/src/iree/hal/local/loaders/system_library_loader.c
+++ b/runtime/src/iree/hal/local/loaders/system_library_loader.c
@@ -206,8 +206,9 @@
}
static int iree_hal_system_executable_import_thunk_v0(
- iree_hal_executable_import_v0_t fn_ptr, void* import_params) {
- return fn_ptr(import_params);
+ iree_hal_executable_import_v0_t fn_ptr, void* context, void* params,
+ void* reserved) {
+ return fn_ptr(context, params, reserved);
}
// Resolves all of the imports declared by the executable using the given
@@ -225,12 +226,19 @@
iree_hal_system_executable_import_thunk_v0;
// Allocate storage for the imports.
+ // TODO(benvanik): allocate both as one block.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0,
- iree_allocator_malloc(
- executable->base.host_allocator,
- import_table->count * sizeof(*executable->base.environment.imports),
- (void**)&executable->base.environment.imports));
+ z0, iree_allocator_malloc(
+ executable->base.host_allocator,
+ import_table->count *
+ sizeof(*executable->base.environment.import_funcs),
+ (void**)&executable->base.environment.import_funcs));
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_allocator_malloc(
+ executable->base.host_allocator,
+ import_table->count *
+ sizeof(*executable->base.environment.import_contexts),
+ (void**)&executable->base.environment.import_contexts));
// Try to resolve each import.
// NOTE: imports are sorted alphabetically and if we cared we could use this
@@ -241,7 +249,8 @@
z0,
iree_hal_executable_import_provider_resolve(
import_provider, iree_make_cstring_view(import_table->symbols[i]),
- (void**)&executable->base.environment.imports[i]));
+ (void**)&executable->base.environment.import_funcs[i],
+ (void**)&executable->base.environment.import_contexts[i]));
}
IREE_TRACE_ZONE_END(z0);
@@ -351,9 +360,13 @@
iree_dynamic_library_release(executable->handle);
- if (executable->base.environment.imports != NULL) {
+ if (executable->base.environment.import_funcs != NULL) {
iree_allocator_free(host_allocator,
- (void*)executable->base.environment.imports);
+ (void*)executable->base.environment.import_funcs);
+ }
+ if (executable->base.environment.import_contexts != NULL) {
+ iree_allocator_free(host_allocator,
+ (void*)executable->base.environment.import_contexts);
}
iree_hal_local_executable_deinitialize(