Completing the wiring up of CPU imports. (#7503)
Confirmed working with both system and embedded libraries.
This is just the low level infra for emitting the dynamic calls at codegen time and passing through the calls at runtime.
There's another layer on top of this required in the compiler to make it generally usable, tracked in #7504.
diff --git a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 9cc77ba..d89fd17 100644
--- a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -76,6 +76,8 @@
binding_count = 4,
binding_ptrs = 5,
binding_lengths = 6,
+ import_thunk = 7,
+ imports = 8,
};
// Returns a Type representing iree_hal_executable_dispatch_state_v0_t.
@@ -88,6 +90,8 @@
auto indexType = typeConverter->convertType(IndexType::get(context));
auto int8Type = IntegerType::get(context, 8);
auto uint32Type = IntegerType::get(context, 32);
+ auto int8PtrType = LLVM::LLVMPointerType::get(int8Type);
+ auto uint32PtrType = LLVM::LLVMPointerType::get(uint32Type);
auto vec3Type = getVec3Type(context);
SmallVector<Type, 4> fieldTypes;
@@ -99,17 +103,23 @@
// size_t push_constant_count;
// const uint32_t * push_constants;
fieldTypes.push_back(indexType);
- fieldTypes.push_back(LLVM::LLVMPointerType::get(uint32Type));
+ fieldTypes.push_back(uint32PtrType);
// size_t binding_count;
// void *const * binding_ptrs;
// const size_t * binding_lengths;
fieldTypes.push_back(indexType);
- fieldTypes.push_back(
- LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(int8Type)));
+ fieldTypes.push_back(LLVM::LLVMPointerType::get(int8PtrType));
fieldTypes.push_back(LLVM::LLVMPointerType::get(indexType));
- // TODO(benvanik): import_thunk/import and a callImport() helper function.
+ // iree_hal_executable_import_thunk_v0_t import_thunk;
+ // const iree_hal_executable_import_v0_t* imports;
+ auto importType = LLVM::LLVMFunctionType::get(uint32Type, int8PtrType);
+ auto importPtrType = LLVM::LLVMPointerType::get(importType);
+ auto importThunkType =
+ LLVM::LLVMFunctionType::get(uint32Type, {importPtrType, int8PtrType});
+ fieldTypes.push_back(LLVM::LLVMPointerType::get(importThunkType));
+ fieldTypes.push_back(LLVM::LLVMPointerType::get(importPtrType));
LogicalResult bodySet = structType.setBody(fieldTypes, /*isPacked=*/false);
assert(succeeded(bodySet) &&
@@ -143,6 +153,8 @@
dispatchStateType(
getDispatchStateType(funcOp.getContext(), typeConverter)) {}
+ LLVM::LLVMFuncOp getFuncOp() { return funcOp; }
+
// Loads the workgroup_id[dim] value (XYZ) and casts it to |resultType|.
Value loadWorkgroupID(Location loc, int32_t dim, Type resultType,
OpBuilder &builder) {
@@ -299,6 +311,48 @@
}
}
+ // Loads the import function pointer of the import |ordinal|.
+ // Equivalent to:
+ // iree_hal_executable_import_v0_t func_ptr = state->imports[ordinal];
+ Value loadImportFuncPtr(Location loc, int64_t ordinal, OpBuilder &builder) {
+ auto importsPtrValue = loadFieldValue(loc, Field::imports, builder);
+ auto ordinalValue = getIndexValue(loc, ordinal, builder);
+ auto elementPtrValue = builder.createOrFold<LLVM::GEPOp>(
+ loc, importsPtrValue.getType(), importsPtrValue, ordinalValue);
+ return builder.createOrFold<LLVM::LoadOp>(loc, elementPtrValue);
+ }
+
+ // Returns an i1 indicating whether the weak import with |ordinal| is defined.
+ // Equivalent to:
+ // state->imports[ordinal] != NULL
+ Value isImportFuncAvailable(Location loc, int64_t ordinal,
+ OpBuilder &builder) {
+ auto importPtrValue = loadImportFuncPtr(loc, ordinal, builder);
+ auto nullPtrValue =
+ builder.create<LLVM::NullOp>(loc, importPtrValue.getType()).getResult();
+ return builder.create<LLVM::ICmpOp>(loc, builder.getI1Type(),
+ LLVM::ICmpPredicate::ne, importPtrValue,
+ nullPtrValue);
+ }
+
+ // Emits a call to the import with the given |importOrdinal|.
+ // The provided |params| struct containing the function-specific arguments
+ // is passed without modification.
+ // Returns 0 on success and non-zero otherwise.
+ Value callImport(Location loc, unsigned importOrdinal, Value params,
+ OpBuilder &builder) {
+ auto thunkPtrValue = loadFieldValue(loc, Field::import_thunk, builder);
+ auto importPtrValue = loadImportFuncPtr(loc, importOrdinal, builder);
+ auto callOp =
+ builder.create<LLVM::CallOp>(loc, TypeRange{builder.getI32Type()},
+ ValueRange{
+ /*thunk_func_ptr=*/thunkPtrValue,
+ /*import_func_ptr=*/importPtrValue,
+ /*import_params=*/params,
+ });
+ return callOp.getResult(0);
+ }
+
private:
Value loadFieldValue(Location loc, Field field, OpBuilder &builder) {
auto statePtrValue = funcOp.getArgument(0);
diff --git a/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir b/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir
index 210fd6d..ef3a8d9 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir
@@ -7,35 +7,35 @@
// CHECK-DAG: %[[C72:.+]] = llvm.mlir.constant(72 : index) : i64
%c72 = arith.constant 72 : index
- // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t", (array<3 x i32>, array<3 x i32>, i64, ptr<i32>, i64, ptr<ptr<i8>>, ptr<i64>)>>
- // CHECK: %[[PC:.+]] = llvm.extractvalue %[[STATE]][3] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (array<3 x i32>, array<3 x i32>, i64, ptr<i32>, i64, ptr<ptr<i8>>, ptr<i64>)>
+ // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<[[DISPATCH_STATE_TYPE:.+]]>>
+ // CHECK: %[[PC:.+]] = llvm.extractvalue %[[STATE]][3] : !llvm.struct<[[DISPATCH_STATE_TYPE]]>
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[DIM_PTR:.+]] = llvm.getelementptr %[[PC]][%[[C0]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
// CHECK: %[[DIM_I32:.+]] = llvm.load %[[DIM_PTR]] : !llvm.ptr<i32>
// CHECK: %[[DIM:.+]] = llvm.zext %[[DIM_I32]] : i32 to i64
%dim = hal.interface.load.constant offset = 0 : index
- // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t", (array<3 x i32>, array<3 x i32>, i64, ptr<i32>, i64, ptr<ptr<i8>>, ptr<i64>)>>
- // CHECK: %[[BINDING_PTRS:.+]] = llvm.extractvalue %[[STATE]][5] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (array<3 x i32>, array<3 x i32>, i64, ptr<i32>, i64, ptr<ptr<i8>>, ptr<i64>)>
+ // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<[[DISPATCH_STATE_TYPE]]>>
+ // CHECK: %[[BINDING_PTRS:.+]] = llvm.extractvalue %[[STATE]][5] : !llvm.struct<[[DISPATCH_STATE_TYPE]]>
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ARRAY_PTR:.+]] = llvm.getelementptr %[[BINDING_PTRS]][%[[C1]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[BASE_PTR_I8:.+]] = llvm.load %[[ARRAY_PTR]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[BUFFER_I8:.+]] = llvm.getelementptr %[[BASE_PTR_I8]][%[[C72]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
// CHECK: %[[BUFFER_F32:.+]] = llvm.bitcast %[[BUFFER_I8]] : !llvm.ptr<i8> to !llvm.ptr<f32>
// CHECK: %[[DESC_A:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[DESC_B:.+]] = llvm.insertvalue %[[BUFFER_F32]], %[[DESC_A]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[DESC_C:.+]] = llvm.insertvalue %[[BUFFER_F32]], %[[DESC_B]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESC_B:.+]] = llvm.insertvalue %[[BUFFER_F32]], %[[DESC_A]][0]
+ // CHECK: %[[DESC_C:.+]] = llvm.insertvalue %[[BUFFER_F32]], %[[DESC_B]][1]
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : index) : i64
- // CHECK: %[[DESC_D:.+]] = llvm.insertvalue %[[C0]], %[[DESC_C]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[DESC_E:.+]] = llvm.insertvalue %[[DIM]], %[[DESC_D]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESC_D:.+]] = llvm.insertvalue %[[C0]], %[[DESC_C]][2]
+ // CHECK: %[[DESC_E:.+]] = llvm.insertvalue %[[DIM]], %[[DESC_D]][3, 0]
// CHECK: %[[C2:.+]] = llvm.mlir.constant(2 : index) : i64
- // CHECK: %[[DESC_F:.+]] = llvm.insertvalue %[[C2]], %[[DESC_E]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESC_F:.+]] = llvm.insertvalue %[[C2]], %[[DESC_E]][3, 1]
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : index) : i64
- // CHECK: %[[DESC_G:.+]] = llvm.insertvalue %[[C1]], %[[DESC_F]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[STRIDE1:.+]] = llvm.extractvalue %[[DESC_G]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[DIM1:.+]] = llvm.extractvalue %[[DESC_G]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESC_G:.+]] = llvm.insertvalue %[[C1]], %[[DESC_F]][4, 1]
+ // CHECK: %[[STRIDE1:.+]] = llvm.extractvalue %[[DESC_G]][4, 1]
+ // CHECK: %[[DIM1:.+]] = llvm.extractvalue %[[DESC_G]][3, 1]
// CHECK: %[[STRIDE0:.+]] = llvm.mul %[[STRIDE1]], %[[DIM1]] : i64
- // CHECK: %[[DESC_H:.+]] = llvm.insertvalue %[[STRIDE0]], %[[DESC_G]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESC_H:.+]] = llvm.insertvalue %[[STRIDE0]], %[[DESC_G]][4, 0]
%memref = hal.interface.binding.subspan @io::@ret0[%c72] : memref<?x2xf32>{%dim}
// CHECK: %[[VAL:.+]] = llvm.load
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp
index 41be21d..8bcd8a4 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp
@@ -276,20 +276,39 @@
std::string libraryName) {
auto &context = module->getContext();
auto *importTableType = makeImportTableType(context);
- auto *i8PtrType = llvm::IntegerType::getInt8Ty(context);
+ auto *i8Type = llvm::IntegerType::getInt8Ty(context);
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
+ llvm::Constant *zero = llvm::ConstantInt::get(i32Type, 0);
- // Not yet implemented; we'd want to sort all the imports alphabetically first
- // before encoding and add the `?` suffix for weak symbols.
+ llvm::Constant *symbolNames =
+ llvm::Constant::getNullValue(i8Type->getPointerTo());
+ if (!imports.empty()) {
+ SmallVector<llvm::Constant *, 4> symbolNameValues;
+ for (auto &import : imports) {
+ auto symbolName = import.symbol_name;
+ if (import.weak) {
+ symbolName += "?";
+ }
+ symbolNameValues.push_back(getStringConstant(symbolName, module));
+ }
+ auto *symbolNamesType =
+ llvm::ArrayType::get(i8Type->getPointerTo(), symbolNameValues.size());
+ auto *global = new llvm::GlobalVariable(
+ *module, symbolNamesType, /*isConstant=*/true,
+ llvm::GlobalVariable::PrivateLinkage,
+ llvm::ConstantArray::get(symbolNamesType, symbolNameValues),
+ /*Name=*/libraryName + "_import_names");
+ symbolNames = llvm::ConstantExpr::getInBoundsGetElementPtr(
+ symbolNamesType, global, ArrayRef<llvm::Constant *>{zero, zero});
+ }
return llvm::ConstantStruct::get(
- importTableType,
- {
- // count=
- llvm::ConstantInt::get(i32Type, 0),
- // symbols=
- llvm::Constant::getNullValue(i8PtrType->getPointerTo()),
- });
+ importTableType, {
+ // count=
+ llvm::ConstantInt::get(i32Type, imports.size()),
+ // symbols=
+ symbolNames,
+ });
}
llvm::Constant *LibraryBuilder::buildLibraryV0ExportTable(
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.h b/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.h
index 7c6a0ce..aa2cb9c 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.h
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.h
@@ -96,6 +96,12 @@
this->sanitizerKind = sanitizerKind;
}
+ // Defines a new runtime import function and returns its ordinal.
+ unsigned addImport(StringRef name, bool weak) {
+ imports.push_back({name.str(), weak});
+ return imports.size() - 1;
+ }
+
// Defines a new entry point on the library implemented by |func|.
// |name| will be used as the library export and an optional |tag| will be
// attached.
@@ -126,13 +132,19 @@
Features features = Features::NONE;
SanitizerKind sanitizerKind = SanitizerKind::NONE;
+ struct Import {
+ std::string symbol_name;
+ bool weak = false;
+ };
+ SmallVector<Import> imports;
+
struct Dispatch {
std::string name;
std::string tag;
DispatchAttrs attrs;
llvm::Function *func;
};
- std::vector<Dispatch> exports;
+ SmallVector<Dispatch> exports;
};
} // namespace HAL
diff --git a/iree/hal/local/elf/arch.h b/iree/hal/local/elf/arch.h
index eb15851..3933c95 100644
--- a/iree/hal/local/elf/arch.h
+++ b/iree/hal/local/elf/arch.h
@@ -44,19 +44,22 @@
// TODO(benvanik): add thunk functions (iree_elf_thunk_*) to be used by imports
// for marshaling from linux ABI in the ELF to host ABI.
-// void(*)(void)
+// Host -> ELF: void(*)(void)
void iree_elf_call_v_v(const void* symbol_ptr);
-// void*(*)(int)
+// Host -> ELF: void*(*)(int)
void* iree_elf_call_p_i(const void* symbol_ptr, int a0);
-// void*(*)(int, void*)
+// Host -> ELF: void*(*)(int, void*)
void* iree_elf_call_p_ip(const void* symbol_ptr, int a0, void* a1);
-// int(*)(void*)
+// Host -> ELF: int(*)(void*)
int iree_elf_call_i_p(const void* symbol_ptr, void* a0);
-// int(*)(void*, void*, void*)
+// Host -> ELF: int(*)(void*, void*, void*)
int iree_elf_call_i_ppp(const void* symbol_ptr, void* a0, void* a1, void* a2);
+// ELF -> Host: int(*)(void*)
+int iree_elf_thunk_i_p(const void* symbol_ptr, void* a0);
+
#endif // IREE_HAL_LOCAL_ELF_ARCH_H_
diff --git a/iree/hal/local/elf/arch/arm_32.c b/iree/hal/local/elf/arch/arm_32.c
index 17f7d8d..4044fbf 100644
--- a/iree/hal/local/elf/arch/arm_32.c
+++ b/iree/hal/local/elf/arch/arm_32.c
@@ -144,4 +144,9 @@
return ((ptr_t)symbol_ptr)(a0, a1, a2);
}
+int iree_elf_thunk_i_p(const void* symbol_ptr, void* a0) {
+ typedef int (*ptr_t)(void*);
+ return ((ptr_t)symbol_ptr)(a0);
+}
+
#endif // IREE_ARCH_ARM_32
diff --git a/iree/hal/local/elf/arch/arm_64.c b/iree/hal/local/elf/arch/arm_64.c
index b677862..cc8398a 100644
--- a/iree/hal/local/elf/arch/arm_64.c
+++ b/iree/hal/local/elf/arch/arm_64.c
@@ -141,4 +141,9 @@
return ((ptr_t)symbol_ptr)(a0, a1, a2);
}
+int iree_elf_thunk_i_p(const void* symbol_ptr, void* a0) {
+ typedef int (*ptr_t)(void*);
+ return ((ptr_t)symbol_ptr)(a0);
+}
+
#endif // IREE_ARCH_ARM_64
diff --git a/iree/hal/local/elf/arch/riscv.c b/iree/hal/local/elf/arch/riscv.c
index 8d8f627..807b62d 100644
--- a/iree/hal/local/elf/arch/riscv.c
+++ b/iree/hal/local/elf/arch/riscv.c
@@ -184,4 +184,9 @@
return ((ptr_t)symbol_ptr)(a0, a1, a2);
}
+int iree_elf_thunk_i_p(const void* symbol_ptr, void* a0) {
+ typedef int (*ptr_t)(void*);
+ return ((ptr_t)symbol_ptr)(a0);
+}
+
#endif // IREE_ARCH_RISCV_*
diff --git a/iree/hal/local/elf/arch/x86_32.c b/iree/hal/local/elf/arch/x86_32.c
index 05d08d7..9d8d885 100644
--- a/iree/hal/local/elf/arch/x86_32.c
+++ b/iree/hal/local/elf/arch/x86_32.c
@@ -165,6 +165,11 @@
return ((ptr_t)symbol_ptr)(a0, a1, a2);
}
+int iree_elf_thunk_i_p(const void* symbol_ptr, void* a0) {
+ typedef int (*ptr_t)(void*);
+ return ((ptr_t)symbol_ptr)(a0);
+}
+
#endif // IREE_PLATFORM_WINDOWS
#endif // IREE_ARCH_X86_32
diff --git a/iree/hal/local/elf/arch/x86_64.c b/iree/hal/local/elf/arch/x86_64.c
index 0d67517..1e3adfc 100644
--- a/iree/hal/local/elf/arch/x86_64.c
+++ b/iree/hal/local/elf/arch/x86_64.c
@@ -206,6 +206,11 @@
return ((ptr_t)symbol_ptr)(a0, a1, a2);
}
+int iree_elf_thunk_i_p(const void* symbol_ptr, void* a0) {
+ typedef int (*ptr_t)(void*);
+ return ((ptr_t)symbol_ptr)(a0);
+}
+
#endif // IREE_PLATFORM_WINDOWS
#endif // IREE_ARCH_X86_64
diff --git a/iree/hal/local/elf/arch/x86_64_msvc.asm b/iree/hal/local/elf/arch/x86_64_msvc.asm
index 337924f..6e25c29 100644
--- a/iree/hal/local/elf/arch/x86_64_msvc.asm
+++ b/iree/hal/local/elf/arch/x86_64_msvc.asm
@@ -185,5 +185,18 @@
ret
iree_elf_call_i_ppp ENDP
+; int iree_elf_thunk_i_p(const void* symbol_ptr, void* a0)
+iree_elf_thunk_i_p PROC FRAME
+ _sysv_interop_prolog
+
+ ; RDI = symbol_ptr
+ ; RSI = a0
+ mov rcx, rsi
+ call rdi
+
+ _sysv_interop_epilog
+ ret
+iree_elf_thunk_i_p ENDP
+
_TEXT ENDS
END
diff --git a/iree/hal/local/executable_loader.c b/iree/hal/local/executable_loader.c
index 5c6d93f..1630b0c 100644
--- a/iree/hal/local/executable_loader.c
+++ b/iree/hal/local/executable_loader.c
@@ -34,7 +34,7 @@
iree_status_t status =
import_provider.resolve(import_provider.self, symbol_name, out_fn_ptr);
if (!iree_status_is_ok(status) && is_weak) {
- iree_status_ignore(status); // ok to fail on weak symbols
+ status = iree_status_ignore(status); // ok to fail on weak symbols
}
return status;
diff --git a/iree/hal/local/loaders/embedded_library_loader.c b/iree/hal/local/loaders/embedded_library_loader.c
index 74d48fe..e85988f 100644
--- a/iree/hal/local/loaders/embedded_library_loader.c
+++ b/iree/hal/local/loaders/embedded_library_loader.c
@@ -99,7 +99,7 @@
// All calls from the loaded ELF route through our thunk function so that we
// can adapt to ABI differences.
executable->base.import_thunk =
- (iree_hal_executable_import_thunk_v0_t)iree_elf_call_i_p;
+ (iree_hal_executable_import_thunk_v0_t)iree_elf_thunk_i_p;
// Allocate storage for the imports.
IREE_RETURN_AND_END_ZONE_IF_ERROR(