blob: 4a234e4670717c207e25f5a4abbb8ccb75cdec73 [file] [log] [blame]
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Heavily based on the `runtime/src/iree/hal/local/elf/elf_module_test_main.c`
// sample from IREE.
#include "iree/base/api.h"
#include "iree/base/internal/cpu.h"
#include "iree/hal/local/elf/elf_module.h"
#include "iree/hal/local/executable_environment.h"
#include "iree/hal/local/executable_library.h"
#include "samples/branch_mul/branch_mul_arg0.h"
#include "samples/branch_mul/branch_mul_arg1.h"
#include "samples/branch_mul/branch_mul_arg2.h"
#include "samples/branch_mul/branch_mul_c.h"
#include "samples/branch_mul/branch_mul_expected.h"
static int8_t ret0[3][2][256] = { 0 };
static iree_status_t run_test() {
iree_hal_executable_environment_v0_t environment;
iree_hal_executable_environment_initialize(iree_allocator_system(),
&environment);
void* query_fn_ptr = &llvm_module_linked_llvm_cpu_library_query;
union {
const iree_hal_executable_library_header_t** header;
const iree_hal_executable_library_v0_t* v0;
} library;
library.header =
(const iree_hal_executable_library_header_t**)iree_elf_call_p_ip(
query_fn_ptr, IREE_HAL_EXECUTABLE_LIBRARY_VERSION_LATEST,
&environment);
if (library.header == NULL) {
return iree_make_status(IREE_STATUS_NOT_FOUND,
"library header is empty (version mismatch?)");
}
const iree_hal_executable_library_header_t* header = *library.header;
if (header->version != IREE_HAL_EXECUTABLE_LIBRARY_VERSION_LATEST) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"library version error");
}
if (strncmp(header->name, "llvm_module_linked_llvm_cpu",
strlen(header->name)) != 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"library name mismatches");
}
if (library.v0->exports.count != 6) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"entry point count mismatches");
}
size_t binding_lengths[4] = {
sizeof(arg0),
sizeof(arg1),
sizeof(arg2),
sizeof(ret0),
};
void* binding_ptrs[4] = {
arg0,
arg1,
arg2,
ret0,
};
const iree_hal_executable_dispatch_state_v0_t dispatch_state = {
.workgroup_size_x = 1,
.workgroup_size_y = 1,
.workgroup_size_z = 1,
.workgroup_count_x = 1,
.workgroup_count_y = 1,
.workgroup_count_z = 1,
.max_concurrency = 1,
.binding_count = 1,
.binding_lengths = binding_lengths,
.binding_ptrs = binding_ptrs,
};
const iree_hal_executable_workgroup_state_v0_t workgroup_state = {
.workgroup_id_x = 0,
.workgroup_id_y = 0,
.workgroup_id_z = 0,
.processor_id = iree_cpu_query_processor_id(),
};
int ret = iree_elf_call_i_ppp((const void*)library.v0->exports.ptrs[5],
(void*)&environment, (void*)&dispatch_state,
(void*)&workgroup_state);
if (ret != 0) {
return iree_make_status(IREE_STATUS_INTERNAL,
"dispatch function returned failure: %d", ret);
}
int mismatch_count = 0;
iree_status_t status = iree_ok_status();
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 256; ++k) {
int8_t out = ret0[i][j][k];
int8_t expect = expected[i][j][k];
if (out != expect) {
mismatch_count++;
fprintf(stdout, "mismatch at [%d][%d][%d] %x != %x\n", i, j, k, out,
expect);
fprintf(stdout, "ins: %x %x %x\n", arg0[i][j][k], arg1[k],
arg2[i][j][k]);
}
}
}
}
if (mismatch_count) {
status = iree_make_status(
IREE_STATUS_INTERNAL,
"%d mismatches between actual and expected output", mismatch_count);
}
return status;
}
int main() {
const iree_status_t result = run_test();
int ret = (int)iree_status_code(result);
if (!iree_status_is_ok(result)) {
iree_status_fprint(stderr, result);
iree_status_free(result);
}
return ret;
}