blob: c54c7190cdb6d1adac5b61558fb83e56bf9048f5 [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "tools/testing/e2e/test_utils.h"
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "iree/base/api.h"
#include "iree/base/internal/cpu.h"
#include "iree/base/internal/flags.h"
#include "iree/base/internal/math.h"
#include "iree/base/internal/path.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
#include "iree/vm/api.h"
IREE_FLAG(bool, require_exact_results, true,
"Requires floating point result elements to match exactly.");
bool iree_test_utils_require_exact_results(void) {
return FLAG_require_exact_results;
}
IREE_FLAG(
float, acceptable_fp_delta, 1e-5f,
"Maximum absolute difference allowed with inexact floating point results.");
float iree_test_utils_acceptable_fb_delta(void) {
return FLAG_acceptable_fp_delta;
}
IREE_FLAG(
int32_t, max_elements_to_check, 10000,
"Maximum number of tensor elements to check for the given test. For larger "
"buffers, only every n-th element will be checked for some n chosed to "
"stay just under that threshold and to avoid being a divisor of the inner "
"dimension size to avoid special patterns. As the check uses a slow "
"reference implementation, this is a trade-off between test latency and "
"coverage. The value 0 means check all elements.");
int32_t iree_test_utils_max_elements_to_check(void) {
return FLAG_max_elements_to_check;
}
const char* iree_test_utils_emoji(bool good) { return good ? "🦄" : "🐞"; }
int iree_test_utils_calculate_check_every(iree_hal_dim_t tot_elements,
iree_hal_dim_t no_div_of) {
int check_every = 1;
if (iree_test_utils_max_elements_to_check()) {
check_every =
((tot_elements) + iree_test_utils_max_elements_to_check() - 1) /
iree_test_utils_max_elements_to_check();
if (check_every < 1) check_every = 1;
if (check_every > 1)
while ((no_div_of % check_every) == 0) ++check_every;
}
return check_every;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_none() {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_NONE;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_i8(int8_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_I8;
result.i8 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_i16(int16_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_I16;
result.i16 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_i32(int32_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_I32;
result.i32 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2(uint8_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2;
result.f8_u8 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3(uint8_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3;
result.f8_u8 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E5M2FNUZ(
uint16_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ;
result.f8_u8 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_f8E4M3FNUZ(
uint16_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ;
result.f8_u8 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_f16(uint16_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_F16;
result.f16_u16 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_bf16(uint16_t value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_BF16;
result.bf16_u16 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_value_make_f32(float value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_F32;
result.f32 = value;
return result;
}
iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
iree_hal_dim_t index, iree_hal_element_type_t result_type,
const void* data) {
if (iree_hal_element_type_is_integer(result_type, 8)) {
return iree_test_utils_value_make_i8(((int8_t*)data)[index]);
} else if (iree_hal_element_type_is_integer(result_type, 16)) {
return iree_test_utils_value_make_i16(((int16_t*)data)[index]);
} else if (iree_hal_element_type_is_integer(result_type, 32)) {
return iree_test_utils_value_make_i32(((int32_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2) {
return iree_test_utils_value_make_f8E5M2(((uint8_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3) {
return iree_test_utils_value_make_f8E4M3(((uint8_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ) {
return iree_test_utils_value_make_f8E5M2FNUZ(((uint8_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ) {
return iree_test_utils_value_make_f8E4M3FNUZ(((uint8_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
return iree_test_utils_value_make_f16(((uint16_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) {
return iree_test_utils_value_make_bf16(((uint16_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
return iree_test_utils_value_make_f32(((float*)data)[index]);
}
iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled matmul result type"));
return iree_test_utils_value_make_none();
}
int iree_test_utils_snprintf_value(char* buf, size_t bufsize,
iree_test_utils_e2e_value_t value,
precision_t precision) {
switch (value.type) {
case IREE_TEST_UTILS_VALUE_TYPE_I8:
return snprintf(buf, bufsize, "%" PRIi8, value.i8);
case IREE_TEST_UTILS_VALUE_TYPE_I16:
return snprintf(buf, bufsize, "%" PRIi16, value.i16);
case IREE_TEST_UTILS_VALUE_TYPE_I32:
return snprintf(buf, bufsize, "%" PRIi32, value.i32);
case IREE_TEST_UTILS_VALUE_TYPE_I64:
return snprintf(buf, bufsize, "%" PRIi64, value.i64);
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
iree_math_f8e5m2_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
iree_math_f8e4m3_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F8E5M2FNUZ:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
iree_math_f8e5m2fnuz_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F8E4M3FNUZ:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.3g" : "%.2g",
iree_math_f8e4m3fnuz_to_f32(value.f8_u8));
case IREE_TEST_UTILS_VALUE_TYPE_F16:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
iree_math_f16_to_f32(value.f16_u16));
case IREE_TEST_UTILS_VALUE_TYPE_BF16:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.5g" : "%.4g",
iree_math_bf16_to_f32(value.bf16_u16));
case IREE_TEST_UTILS_VALUE_TYPE_F32:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.8g" : "%.4g", value.f32);
case IREE_TEST_UTILS_VALUE_TYPE_F64:
return snprintf(buf, bufsize,
precision == PRECISION_HIGH ? "%.16g" : "%.4g",
value.f64);
default:
iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled value type"));
return 0;
}
}
bool iree_test_utils_result_elements_agree(iree_test_utils_e2e_value_t expected,
iree_test_utils_e2e_value_t actual) {
float acceptable_fp_delta = iree_test_utils_acceptable_fb_delta();
if (expected.type != actual.type) {
iree_status_abort(
iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "mismatched types"));
return false;
}
if (acceptable_fp_delta < 0.0f) {
iree_status_abort(iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"negative tolerance (acceptable_fp_delta=%.8g)", acceptable_fp_delta));
return false;
}
switch (expected.type) {
case IREE_TEST_UTILS_VALUE_TYPE_I32:
return actual.i32 == expected.i32;
// Since we fill buffers with small integers for floating point GEMMs
// functional testing, we can test for bit-exactness on the actual and
// expected values. Inexact results are only permitted when the
// `require_exact_results` flag is set to `false`.
case IREE_TEST_UTILS_VALUE_TYPE_F16:
if (actual.f16_u16 == expected.f16_u16) return true;
if (iree_test_utils_max_elements_to_check()) return false;
return fabsf(iree_math_f16_to_f32(actual.f16_u16) -
iree_math_f16_to_f32(expected.f16_u16)) <
acceptable_fp_delta;
case IREE_TEST_UTILS_VALUE_TYPE_BF16: {
if (actual.bf16_u16 == expected.bf16_u16) return true;
// This is the rare case where the accumulator itself (not just LHS/RHS)
// is bf16. This doesn't really happen in practice and is mostly just in
// some CPU tests for completeness. Accumulators grow outside of the
// narrow range of bf16 exact representation of integers, forcing fuzzy
// compares.
float actual_f32 = iree_math_bf16_to_f32(actual.bf16_u16);
float expected_f32 = iree_math_bf16_to_f32(expected.bf16_u16);
if (fabsf(actual_f32) > 127.0f || fabsf(expected_f32) > 127.0f) {
if (fabsf(actual_f32 - expected_f32) < 10.0f) {
return true;
}
}
if (iree_test_utils_require_exact_results()) return false;
return fabsf(actual_f32 - expected_f32) < acceptable_fp_delta;
}
case IREE_TEST_UTILS_VALUE_TYPE_F32:
if (actual.f32 == expected.f32) return true;
if (iree_test_utils_require_exact_results()) return false;
return fabsf(actual.f32 - expected.f32) < acceptable_fp_delta;
default:
iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled value type"));
return false;
}
}
//===----------------------------------------------------------------------===//
// RNG utilities
//===----------------------------------------------------------------------===//
void iree_test_utils_write_element(iree_hal_element_type_t element_type,
int32_t value, void* dst) {
#define WRITE_ELEMENT_CASE(ETYPE, CTYPE) \
case IREE_HAL_ELEMENT_TYPE_##ETYPE: \
*(CTYPE*)dst = (CTYPE)value; \
break;
switch (element_type) {
WRITE_ELEMENT_CASE(INT_8, int8_t)
WRITE_ELEMENT_CASE(INT_16, int16_t)
WRITE_ELEMENT_CASE(INT_32, int32_t)
WRITE_ELEMENT_CASE(INT_64, int64_t)
WRITE_ELEMENT_CASE(SINT_8, int8_t)
WRITE_ELEMENT_CASE(SINT_16, int16_t)
WRITE_ELEMENT_CASE(SINT_32, int32_t)
WRITE_ELEMENT_CASE(SINT_64, int64_t)
WRITE_ELEMENT_CASE(UINT_8, uint8_t)
WRITE_ELEMENT_CASE(UINT_16, uint16_t)
WRITE_ELEMENT_CASE(UINT_32, uint32_t)
WRITE_ELEMENT_CASE(UINT_64, uint64_t)
// clang-format off
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
*(uint16_t*)dst = iree_math_f32_to_f16((float)value);
break;
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
*(uint16_t*)dst = iree_math_f32_to_bf16((float)value);
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2:
*(uint8_t*)dst = iree_math_f32_to_f8e5m2((float)value);
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3:
*(uint8_t*)dst = iree_math_f32_to_f8e4m3((float)value);
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ:
*(uint8_t*)dst = iree_math_f32_to_f8e5m2fnuz((float)value);
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ:
*(uint8_t*)dst = iree_math_f32_to_f8e4m3fnuz((float)value);
break;
WRITE_ELEMENT_CASE(FLOAT_32, float)
WRITE_ELEMENT_CASE(FLOAT_64, double)
// clang-format on
default:
IREE_ASSERT(false, "unhandled element type");
break;
}
#undef WRITE_ELEMENT_CASE
}
uint32_t iree_test_utils_pseudorandom_uint32(uint32_t* state) {
*state = (*state * IREE_PRNG_MULTIPLIER) % IREE_PRNG_MODULUS;
return *state;
}
uint32_t iree_test_utils_pseudorandom_range(uint32_t* state, uint32_t range) {
return iree_test_utils_pseudorandom_uint32(state) % range;
}
void iree_test_utils_get_min_max_for_element_type(
iree_hal_element_type_t element_type, int32_t* min, int32_t* max) {
switch (element_type) {
case IREE_HAL_ELEMENT_TYPE_INT_8:
case IREE_HAL_ELEMENT_TYPE_SINT_8:
*min = -2;
*max = +2;
break;
case IREE_HAL_ELEMENT_TYPE_UINT_8:
*min = 0;
*max = +2;
break;
case IREE_HAL_ELEMENT_TYPE_INT_16:
case IREE_HAL_ELEMENT_TYPE_SINT_16:
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
*min = -4;
*max = +4;
break;
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2:
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3:
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2_FNUZ:
case IREE_HAL_ELEMENT_TYPE_FLOAT_8_E4M3_FNUZ:
*min = -2;
*max = +2;
break;
case IREE_HAL_ELEMENT_TYPE_UINT_16:
*min = 0;
*max = +4;
break;
case IREE_HAL_ELEMENT_TYPE_INT_32:
case IREE_HAL_ELEMENT_TYPE_SINT_32:
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
*min = -8;
*max = +8;
break;
case IREE_HAL_ELEMENT_TYPE_UINT_32:
*min = 0;
*max = +8;
break;
case IREE_HAL_ELEMENT_TYPE_INT_64:
case IREE_HAL_ELEMENT_TYPE_SINT_64:
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
*min = -16;
*min = +16;
break;
case IREE_HAL_ELEMENT_TYPE_UINT_64:
*min = 0;
*max = +16;
break;
default:
IREE_ASSERT(false, "unhandled element type");
break;
}
}
//===----------------------------------------------------------------------===//
// Test runner
//===----------------------------------------------------------------------===//
iree_status_t iree_test_utils_check_test_function(iree_vm_function_t function,
bool* out_is_valid) {
*out_is_valid = true;
iree_string_view_t function_name = iree_vm_function_name(&function);
if (iree_string_view_starts_with(function_name,
iree_make_cstring_view("__"))) {
// Internal compiler/runtime support function.
*out_is_valid = false;
}
iree_vm_function_signature_t function_signature =
iree_vm_function_signature(&function);
iree_host_size_t argument_count = 0;
iree_host_size_t result_count = 0;
IREE_RETURN_IF_ERROR(iree_vm_function_call_count_arguments_and_results(
&function_signature, &argument_count, &result_count));
if (argument_count || result_count) {
// Takes args or has results we don't expect.
*out_is_valid = false;
}
return iree_ok_status();
}
iree_status_t iree_test_utils_run_test_function(
iree_vm_context_t* context, iree_vm_function_t function,
iree_allocator_t host_allocator) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_string_view_t function_name = iree_vm_function_name(&function);
IREE_TRACE_ZONE_APPEND_TEXT(z0, function_name.data, function_name.size);
fprintf(stderr, "--- TEST[%.*s] ---\n", (int)function_name.size,
function_name.data);
iree_string_view_t function_desc =
iree_vm_function_lookup_attr_by_name(&function, IREE_SV("description"));
if (!iree_string_view_is_empty(function_desc)) {
fprintf(stderr, "%.*s\n", (int)function_desc.size, function_desc.data);
}
iree_status_t status = iree_vm_invoke(
context, function, IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL,
/*inputs=*/NULL, /*outputs=*/NULL, host_allocator);
IREE_TRACE_ZONE_END(z0);
return status;
}
iree_status_t iree_test_utils_run_all_test_functions(
iree_vm_context_t* context, iree_vm_module_t* test_module,
iree_allocator_t host_allocator) {
IREE_TRACE_ZONE_BEGIN(z0);
// Walk all functions and find the ones we can run (no args, non-internal).
const iree_vm_module_signature_t module_signature =
iree_vm_module_signature(test_module);
for (iree_host_size_t i = 0; i < module_signature.export_function_count;
++i) {
// Get the function and filter to just the public user exports.
iree_vm_function_t function;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_vm_module_lookup_function_by_ordinal(
test_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
bool is_valid = false;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_test_utils_check_test_function(function, &is_valid));
if (is_valid) {
// Try to run the function and fail on mismatch.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_test_utils_run_test_function(context, function, host_allocator));
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
iree_status_t iree_test_utils_check_module_requirements(
iree_vm_module_t* module) {
iree_string_view_t target_features =
iree_vm_module_lookup_attr_by_name(module, IREE_SV("target_features"));
while (!iree_string_view_is_empty(target_features)) {
iree_string_view_t required_feature;
iree_string_view_split(target_features, ',', &required_feature,
&target_features);
if (iree_string_view_is_empty(required_feature)) continue;
int64_t feature_is_supported = 0;
IREE_RETURN_IF_ERROR(
iree_cpu_lookup_data_by_key(required_feature, &feature_is_supported));
if (!feature_is_supported) {
return iree_make_status(
// The error status matters. We distinguish "feature not supported"
// which is a normal thing to happen from actual errors.
IREE_STATUS_NOT_FOUND,
"target device does not have the required feature '%.*s'",
(int)required_feature.size, required_feature.data);
}
}
return iree_ok_status();
}
iree_status_t iree_test_utils_load_and_run_e2e_tests(
iree_allocator_t host_allocator,
iree_status_t (*test_module_create)(iree_vm_instance_t*, iree_allocator_t,
iree_vm_module_t**)) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_cpu_initialize(host_allocator);
iree_vm_instance_t* instance = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_tooling_create_instance(host_allocator, &instance));
iree_tooling_module_list_t module_list;
iree_tooling_module_list_initialize(&module_list);
// Create the test module providing helper functions used by test programs.
iree_vm_module_t* custom_test_module = NULL;
iree_status_t status =
test_module_create(instance, host_allocator, &custom_test_module);
if (iree_status_is_ok(status)) {
status =
iree_tooling_module_list_push_back(&module_list, custom_test_module);
}
iree_vm_module_release(custom_test_module);
// Load all modules specified by --module= flags.
if (iree_status_is_ok(status)) {
status = iree_tooling_load_modules_from_flags(instance, host_allocator,
&module_list);
}
iree_vm_module_t* test_module = iree_tooling_module_list_back(&module_list);
// Create the context with our support module and all --module= flags.
iree_vm_context_t* context = NULL;
iree_hal_device_t* device = NULL;
if (iree_status_is_ok(status)) {
status = iree_tooling_create_context_from_flags(
instance, module_list.count, module_list.values,
/*default_device_uri=*/iree_string_view_empty(), host_allocator,
&context, &device, /*out_device_allocator=*/NULL);
}
// Ensure the test module is possible to run.
if (iree_status_is_ok(status)) {
status = iree_test_utils_check_module_requirements(test_module);
}
iree_tooling_module_list_reset(&module_list);
// Begin profiling (if enabled).
if (iree_status_is_ok(status)) {
status = iree_hal_begin_profiling_from_flags(device);
}
// Run all of the tests in the test module.
if (iree_status_is_ok(status)) {
status = iree_test_utils_run_all_test_functions(context, test_module,
host_allocator);
}
// End profiling (if enabled).
if (iree_status_is_ok(status)) {
status = iree_hal_end_profiling_from_flags(device);
}
iree_hal_device_release(device);
iree_vm_context_release(context);
iree_vm_instance_release(instance);
IREE_TRACE_ZONE_END(z0);
return status;
}