| // Copyright 2023 Google, LLC. |
| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Heavily based on the `runtime/src/iree/hal/local/elf/elf_module_test_main.c` |
| // sample from IREE. |
| |
| #include "iree/base/api.h" |
| #include "iree/base/internal/cpu.h" |
| #include "iree/hal/local/elf/elf_module.h" |
| #include "iree/hal/local/executable_environment.h" |
| #include "iree/hal/local/executable_library.h" |
| #include "samples/branch_mul/branch_mul_arg0.h" |
| #include "samples/branch_mul/branch_mul_arg1.h" |
| #include "samples/branch_mul/branch_mul_arg2.h" |
| #include "samples/branch_mul/branch_mul_c.h" |
| #include "samples/branch_mul/branch_mul_expected.h" |
| |
| static int8_t ret0[3][2][256] = { 0 }; |
| static iree_status_t run_test() { |
| iree_hal_executable_environment_v0_t environment; |
| iree_hal_executable_environment_initialize(iree_allocator_system(), |
| &environment); |
| |
| void* query_fn_ptr = &llvm_module_linked_llvm_cpu_library_query; |
| |
| union { |
| const iree_hal_executable_library_header_t** header; |
| const iree_hal_executable_library_v0_t* v0; |
| } library; |
| library.header = |
| (const iree_hal_executable_library_header_t**)iree_elf_call_p_ip( |
| query_fn_ptr, IREE_HAL_EXECUTABLE_LIBRARY_VERSION_LATEST, |
| &environment); |
| if (library.header == NULL) { |
| return iree_make_status(IREE_STATUS_NOT_FOUND, |
| "library header is empty (version mismatch?)"); |
| } |
| |
| const iree_hal_executable_library_header_t* header = *library.header; |
| if (header->version != IREE_HAL_EXECUTABLE_LIBRARY_VERSION_LATEST) { |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "library version error"); |
| } |
| |
| if (strncmp(header->name, "llvm_module_linked_llvm_cpu", |
| strlen(header->name)) != 0) { |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "library name mismatches"); |
| } |
| |
| if (library.v0->exports.count != 6) { |
| return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, |
| "entry point count mismatches"); |
| } |
| |
| size_t binding_lengths[4] = { |
| sizeof(arg0), |
| sizeof(arg1), |
| sizeof(arg2), |
| sizeof(ret0), |
| }; |
| void* binding_ptrs[4] = { |
| arg0, |
| arg1, |
| arg2, |
| ret0, |
| }; |
| const iree_hal_executable_dispatch_state_v0_t dispatch_state = { |
| .workgroup_size_x = 1, |
| .workgroup_size_y = 1, |
| .workgroup_size_z = 1, |
| .workgroup_count_x = 1, |
| .workgroup_count_y = 1, |
| .workgroup_count_z = 1, |
| .max_concurrency = 1, |
| .binding_count = 1, |
| .binding_lengths = binding_lengths, |
| .binding_ptrs = binding_ptrs, |
| }; |
| const iree_hal_executable_workgroup_state_v0_t workgroup_state = { |
| .workgroup_id_x = 0, |
| .workgroup_id_y = 0, |
| .workgroup_id_z = 0, |
| .processor_id = iree_cpu_query_processor_id(), |
| }; |
| int ret = iree_elf_call_i_ppp((const void*)library.v0->exports.ptrs[5], |
| (void*)&environment, (void*)&dispatch_state, |
| (void*)&workgroup_state); |
| if (ret != 0) { |
| return iree_make_status(IREE_STATUS_INTERNAL, |
| "dispatch function returned failure: %d", ret); |
| } |
| |
| int mismatch_count = 0; |
| iree_status_t status = iree_ok_status(); |
| for (int i = 0; i < 3; ++i) { |
| for (int j = 0; j < 2; ++j) { |
| for (int k = 0; k < 256; ++k) { |
| int8_t out = ret0[i][j][k]; |
| int8_t expect = expected[i][j][k]; |
| if (out != expect) { |
| mismatch_count++; |
| fprintf(stdout, "mismatch at [%d][%d][%d] %x != %x\n", i, j, k, out, |
| expect); |
| fprintf(stdout, "ins: %x %x %x\n", arg0[i][j][k], arg1[k], |
| arg2[i][j][k]); |
| } |
| } |
| } |
| } |
| if (mismatch_count) { |
| status = iree_make_status( |
| IREE_STATUS_INTERNAL, |
| "%d mismatches between actual and expected output", mismatch_count); |
| } |
| |
| return status; |
| } |
| |
| int main() { |
| const iree_status_t result = run_test(); |
| int ret = (int)iree_status_code(result); |
| if (!iree_status_is_ok(result)) { |
| iree_status_fprint(stderr, result); |
| iree_status_free(result); |
| } |
| return ret; |
| } |