Extending imports to have a context pointer. (#10580)

This will allow for imports to carry state that doesn't require TLS or
globals. Also sprinkling in a reserved pointer in case we need more
stuff in the future.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 8816913..3fb5be6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -101,7 +101,8 @@
   enum class EnvironmentField {
     constants,
     import_thunk,
-    imports,
+    import_funcs,
+    import_contexts,
     processor,
   };
 
@@ -124,13 +125,16 @@
     fieldTypes.push_back(uint32PtrType);
 
     // iree_hal_executable_import_thunk_v0_t import_thunk;
-    // const iree_hal_executable_import_v0_t* imports;
-    auto importType = LLVM::LLVMFunctionType::get(uint32Type, int8PtrType);
+    // const iree_hal_executable_import_v0_t* import_funcs;
+    // const void** import_contexts;
+    auto importType = LLVM::LLVMFunctionType::get(
+        uint32Type, {int8PtrType, int8PtrType, int8PtrType});
     auto importPtrType = LLVM::LLVMPointerType::get(importType);
-    auto importThunkType =
-        LLVM::LLVMFunctionType::get(uint32Type, {importPtrType, int8PtrType});
+    auto importThunkType = LLVM::LLVMFunctionType::get(
+        uint32Type, {importPtrType, int8PtrType, int8PtrType, int8PtrType});
     fieldTypes.push_back(LLVM::LLVMPointerType::get(importThunkType));
     fieldTypes.push_back(LLVM::LLVMPointerType::get(importPtrType));
+    fieldTypes.push_back(LLVM::LLVMPointerType::get(int8PtrType));
 
     // iree_hal_processor_v0_t processor;
     fieldTypes.push_back(processorType);
@@ -515,27 +519,35 @@
 
   // Loads the import function pointer of the import |ordinal|.
   // Equivalent to:
-  //   iree_hal_executable_import_v0_t func_ptr = state->imports[ordinal];
-  Value loadImportFuncPtr(Location loc, int64_t ordinal, OpBuilder &builder) {
-    auto importsPtrValue =
-        loadFieldValue(loc, EnvironmentField::imports, builder);
+  //   iree_hal_executable_import_v0_t fn_ptr = state->import_funcs[ordinal];
+  //   void* context = state->import_contexts[ordinal];
+  std::pair<Value, Value> loadImportFunc(Location loc, int64_t ordinal,
+                                         OpBuilder &builder) {
     auto ordinalValue = getIndexValue(loc, ordinal, builder);
-    auto elementPtrValue = builder.createOrFold<LLVM::GEPOp>(
-        loc, importsPtrValue.getType(), importsPtrValue, ordinalValue);
-    return builder.createOrFold<LLVM::LoadOp>(loc, elementPtrValue);
+    auto funcPtrsValue =
+        loadFieldValue(loc, EnvironmentField::import_funcs, builder);
+    auto funcPtrValue = builder.createOrFold<LLVM::GEPOp>(
+        loc, funcPtrsValue.getType(), funcPtrsValue, ordinalValue);
+    auto contextPtrsValue =
+        loadFieldValue(loc, EnvironmentField::import_contexts, builder);
+    auto contextPtrValue = builder.createOrFold<LLVM::GEPOp>(
+        loc, contextPtrsValue.getType(), contextPtrsValue, ordinalValue);
+    return std::make_pair(
+        builder.createOrFold<LLVM::LoadOp>(loc, funcPtrValue),
+        builder.createOrFold<LLVM::LoadOp>(loc, contextPtrValue));
   }
 
   // Returns an i1 indicating whether the optional import with |ordinal| is
   // defined. Equivalent to:
-  //   state->imports[ordinal] != NULL
+  //   state->import_funcs[ordinal] != NULL
   Value isImportFuncAvailable(Location loc, int64_t ordinal,
                               OpBuilder &builder) {
-    auto importPtrValue = loadImportFuncPtr(loc, ordinal, builder);
-    auto nullPtrValue =
-        builder.create<LLVM::NullOp>(loc, importPtrValue.getType()).getResult();
+    auto importFunc = loadImportFunc(loc, ordinal, builder);
+    Value nullPtrValue =
+        builder.create<LLVM::NullOp>(loc, importFunc.first.getType());
     return builder.create<LLVM::ICmpOp>(loc, builder.getI1Type(),
-                                        LLVM::ICmpPredicate::ne, importPtrValue,
-                                        nullPtrValue);
+                                        LLVM::ICmpPredicate::ne,
+                                        importFunc.first, nullPtrValue);
   }
 
   // Emits a call to the import with the given |importOrdinal|.
@@ -546,13 +558,17 @@
                    OpBuilder &builder) {
     auto thunkPtrValue =
         loadFieldValue(loc, EnvironmentField::import_thunk, builder);
-    auto importPtrValue = loadImportFuncPtr(loc, importOrdinal, builder);
+    auto importFunc = loadImportFunc(loc, importOrdinal, builder);
+    Value nullPtrValue = builder.create<LLVM::NullOp>(
+        loc, LLVM::LLVMPointerType::get(builder.getI8Type()));
     auto callOp =
         builder.create<LLVM::CallOp>(loc, TypeRange{builder.getI32Type()},
                                      ValueRange{
                                          /*thunk_func_ptr=*/thunkPtrValue,
-                                         /*import_func_ptr=*/importPtrValue,
-                                         /*import_params=*/params,
+                                         /*import_func_ptr=*/importFunc.first,
+                                         /*context=*/importFunc.second,
+                                         /*params=*/params,
+                                         /*reserved=*/nullPtrValue,
                                      });
     return callOp.getResult();
   }
diff --git a/runtime/src/iree/hal/local/executable_library.h b/runtime/src/iree/hal/local/executable_library.h
index 2757726..a9fc3b6 100644
--- a/runtime/src/iree/hal/local/executable_library.h
+++ b/runtime/src/iree/hal/local/executable_library.h
@@ -154,14 +154,16 @@
 // a useful failure though the HAL does not mandate that all overflows are
 // caught and only that they are not harmful - clamping byte ranges and never
 // returning a failure is sufficient.
-typedef int (*iree_hal_executable_import_v0_t)(void* import_params);
+typedef int (*iree_hal_executable_import_v0_t)(void* context, void* params,
+                                               void* reserved);
 
 // A thunk function used to call an import.
 // All imports must be called through this function by passing the import
 // function pointer as the first argument followed by the arguments of the
 // import function itself.
 typedef int (*iree_hal_executable_import_thunk_v0_t)(
-    iree_hal_executable_import_v0_t fn_ptr, void* import_params);
+    iree_hal_executable_import_v0_t fn_ptr, void* context, void* params,
+    void* reserved);
 
 // Declares imports available to the executable library at runtime.
 // To enable linker isolation, ABI shimming, and import multi-versioning we use
@@ -251,7 +253,8 @@
   // Optional imported functions available for use within the executable.
   // Contains one entry per imported function. If an import was marked as weak
   // then the corresponding entry may be NULL.
-  const iree_hal_executable_import_v0_t* imports;
+  const iree_hal_executable_import_v0_t* import_funcs;
+  const void** import_contexts;
 
   // Optional architecture-specific CPU information.
   // In heterogenous processors this may represent any of the subarchitecture
diff --git a/runtime/src/iree/hal/local/executable_loader.c b/runtime/src/iree/hal/local/executable_loader.c
index f0b8c78..d0a8a31 100644
--- a/runtime/src/iree/hal/local/executable_loader.c
+++ b/runtime/src/iree/hal/local/executable_loader.c
@@ -26,9 +26,11 @@
 
 iree_status_t iree_hal_executable_import_provider_resolve(
     const iree_hal_executable_import_provider_t import_provider,
-    iree_string_view_t symbol_name, void** out_fn_ptr) {
+    iree_string_view_t symbol_name, void** out_fn_ptr, void** out_fn_context) {
   IREE_ASSERT_ARGUMENT(out_fn_ptr);
+  IREE_ASSERT_ARGUMENT(out_fn_context);
   *out_fn_ptr = NULL;
+  *out_fn_context = NULL;
 
   // A `?` suffix indicates the symbol is weakly linked and can be NULL.
   bool is_weak = false;
@@ -47,8 +49,8 @@
                             (int)symbol_name.size, symbol_name.data);
   }
 
-  iree_status_t status =
-      import_provider.resolve(import_provider.self, symbol_name, out_fn_ptr);
+  iree_status_t status = import_provider.resolve(
+      import_provider.self, symbol_name, out_fn_ptr, out_fn_context);
   if (!iree_status_is_ok(status) && is_weak) {
     status = iree_status_ignore(status);  // ok to fail on weak symbols
   }
diff --git a/runtime/src/iree/hal/local/executable_loader.h b/runtime/src/iree/hal/local/executable_loader.h
index c7a5967..5a095d9 100644
--- a/runtime/src/iree/hal/local/executable_loader.h
+++ b/runtime/src/iree/hal/local/executable_loader.h
@@ -37,7 +37,8 @@
   // to the function (or its context) in |out_fn_ptr|.
   iree_status_t(IREE_API_PTR* resolve)(void* self,
                                        iree_string_view_t symbol_name,
-                                       void** out_fn_ptr);
+                                       void** out_fn_ptr,
+                                       void** out_fn_context);
 } iree_hal_executable_import_provider_t;
 
 static inline iree_hal_executable_import_provider_t
@@ -63,7 +64,7 @@
 // allowed to be resolved to NULL. Such cases will always return OK.
 iree_status_t iree_hal_executable_import_provider_resolve(
     const iree_hal_executable_import_provider_t import_provider,
-    iree_string_view_t symbol_name, void** out_fn_ptr);
+    iree_string_view_t symbol_name, void** out_fn_ptr, void** out_fn_context);
 
 //===----------------------------------------------------------------------===//
 // iree_hal_executable_loader_t
diff --git a/runtime/src/iree/hal/local/loaders/embedded_elf_loader.c b/runtime/src/iree/hal/local/loaders/embedded_elf_loader.c
index ddcc1f4..2227137 100644
--- a/runtime/src/iree/hal/local/loaders/embedded_elf_loader.c
+++ b/runtime/src/iree/hal/local/loaders/embedded_elf_loader.c
@@ -98,15 +98,22 @@
   // All calls from the loaded ELF route through our thunk function so that we
   // can adapt to ABI differences.
   executable->base.environment.import_thunk =
-      (iree_hal_executable_import_thunk_v0_t)iree_elf_thunk_i_p;
+      (iree_hal_executable_import_thunk_v0_t)iree_elf_call_i_ppp;
 
   // Allocate storage for the imports.
+  // TODO(benvanik): allocate both as one block.
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0,
-      iree_allocator_malloc(
-          executable->base.host_allocator,
-          import_table->count * sizeof(*executable->base.environment.imports),
-          (void**)&executable->base.environment.imports));
+      z0, iree_allocator_malloc(
+              executable->base.host_allocator,
+              import_table->count *
+                  sizeof(*executable->base.environment.import_funcs),
+              (void**)&executable->base.environment.import_funcs));
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(
+              executable->base.host_allocator,
+              import_table->count *
+                  sizeof(*executable->base.environment.import_contexts),
+              (void**)&executable->base.environment.import_contexts));
 
   // Try to resolve each import.
   // NOTE: imports are sorted alphabetically and if we cared we could use this
@@ -117,7 +124,8 @@
         z0,
         iree_hal_executable_import_provider_resolve(
             import_provider, iree_make_cstring_view(import_table->symbols[i]),
-            (void**)&executable->base.environment.imports[i]));
+            (void**)&executable->base.environment.import_funcs[i],
+            (void**)&executable->base.environment.import_contexts[i]));
   }
 
   IREE_TRACE_ZONE_END(z0);
@@ -231,9 +239,13 @@
 
   iree_elf_module_deinitialize(&executable->module);
 
-  if (executable->base.environment.imports != NULL) {
+  if (executable->base.environment.import_funcs != NULL) {
     iree_allocator_free(host_allocator,
-                        (void*)executable->base.environment.imports);
+                        (void*)executable->base.environment.import_funcs);
+  }
+  if (executable->base.environment.import_contexts != NULL) {
+    iree_allocator_free(host_allocator,
+                        (void*)executable->base.environment.import_contexts);
   }
 
   iree_hal_local_executable_deinitialize(
diff --git a/runtime/src/iree/hal/local/loaders/system_library_loader.c b/runtime/src/iree/hal/local/loaders/system_library_loader.c
index 14d2c04..6fb73f8 100644
--- a/runtime/src/iree/hal/local/loaders/system_library_loader.c
+++ b/runtime/src/iree/hal/local/loaders/system_library_loader.c
@@ -206,8 +206,9 @@
 }
 
 static int iree_hal_system_executable_import_thunk_v0(
-    iree_hal_executable_import_v0_t fn_ptr, void* import_params) {
-  return fn_ptr(import_params);
+    iree_hal_executable_import_v0_t fn_ptr, void* context, void* params,
+    void* reserved) {
+  return fn_ptr(context, params, reserved);
 }
 
 // Resolves all of the imports declared by the executable using the given
@@ -225,12 +226,19 @@
       iree_hal_system_executable_import_thunk_v0;
 
   // Allocate storage for the imports.
+  // TODO(benvanik): allocate both as one block.
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0,
-      iree_allocator_malloc(
-          executable->base.host_allocator,
-          import_table->count * sizeof(*executable->base.environment.imports),
-          (void**)&executable->base.environment.imports));
+      z0, iree_allocator_malloc(
+              executable->base.host_allocator,
+              import_table->count *
+                  sizeof(*executable->base.environment.import_funcs),
+              (void**)&executable->base.environment.import_funcs));
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0, iree_allocator_malloc(
+              executable->base.host_allocator,
+              import_table->count *
+                  sizeof(*executable->base.environment.import_contexts),
+              (void**)&executable->base.environment.import_contexts));
 
   // Try to resolve each import.
   // NOTE: imports are sorted alphabetically and if we cared we could use this
@@ -241,7 +249,8 @@
         z0,
         iree_hal_executable_import_provider_resolve(
             import_provider, iree_make_cstring_view(import_table->symbols[i]),
-            (void**)&executable->base.environment.imports[i]));
+            (void**)&executable->base.environment.import_funcs[i],
+            (void**)&executable->base.environment.import_contexts[i]));
   }
 
   IREE_TRACE_ZONE_END(z0);
@@ -351,9 +360,13 @@
 
   iree_dynamic_library_release(executable->handle);
 
-  if (executable->base.environment.imports != NULL) {
+  if (executable->base.environment.import_funcs != NULL) {
     iree_allocator_free(host_allocator,
-                        (void*)executable->base.environment.imports);
+                        (void*)executable->base.environment.import_funcs);
+  }
+  if (executable->base.environment.import_contexts != NULL) {
+    iree_allocator_free(host_allocator,
+                        (void*)executable->base.environment.import_contexts);
   }
 
   iree_hal_local_executable_deinitialize(