blob: 63c8f2da491420865f62c93e33e8bd54bd3d800a [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "iree/base/api.h"
#include "iree/base/internal/flags.h"
#include "iree/base/internal/math.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/device_util.h"
#include "iree/vm/api.h"
#include "iree/vm/native_module_cc.h"
#include "tools/testing/e2e/test_utils.h"
//===----------------------------------------------------------------------===//
// Reference conv2d (NCHW-FCHW)
//===----------------------------------------------------------------------===//
// Conversion from 4D indices in row major order to 1D index.
static int convert_to_1d_index(iree_hal_dim_t channels, iree_hal_dim_t height,
iree_hal_dim_t width, iree_hal_dim_t n,
iree_hal_dim_t c, iree_hal_dim_t h,
iree_hal_dim_t w) {
return n * (channels * height * width) + c * (height * width) + h * width + w;
}
// [f16 <= f16 * f16 + f16]
static void reference_conv2d_f16_f16_f16_f16(
iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size,
iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size,
iree_hal_dim_t kw_size, iree_hal_dim_t sh_size, iree_hal_dim_t sw_size,
iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, iree_hal_dim_t oh_size,
iree_hal_dim_t ow_size, const uint16_t* input_data,
const uint16_t* kernel_data, const uint16_t* acc_data,
uint16_t* result_data, iree_hal_dim_t n, iree_hal_dim_t oc,
iree_hal_dim_t oh, iree_hal_dim_t ow) {
iree_hal_dim_t out_idx =
convert_to_1d_index(f_size, oh_size, ow_size, n, oc, oh, ow);
float acc = acc_data ? iree_math_f16_to_f32(acc_data[out_idx]) : 0.f;
for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) {
for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) {
for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) {
iree_hal_dim_t inp_idx = convert_to_1d_index(
c_size, h_size, w_size, n, ic, (oh * sh_size + kh * dh_size),
(ow * sw_size + kw * dw_size));
iree_hal_dim_t krnl_idx =
convert_to_1d_index(c_size, kh_size, kw_size, oc, ic, kh, kw);
acc += iree_math_f16_to_f32(input_data[inp_idx]) *
iree_math_f16_to_f32(kernel_data[krnl_idx]);
}
}
result_data[out_idx] = iree_math_f32_to_f16(acc);
}
}
static void reference_conv2d_f32_f32_f32_f32(
iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size,
iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size,
iree_hal_dim_t kw_size, iree_hal_dim_t sh_size, iree_hal_dim_t sw_size,
iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, iree_hal_dim_t oh_size,
iree_hal_dim_t ow_size, const float* input_data, const float* kernel_data,
const float* acc_data, float* result_data, iree_hal_dim_t n,
iree_hal_dim_t oc, iree_hal_dim_t oh, iree_hal_dim_t ow) {
iree_hal_dim_t out_idx =
convert_to_1d_index(f_size, oh_size, ow_size, n, oc, oh, ow);
float acc = acc_data ? acc_data[out_idx] : 0;
for (iree_hal_dim_t ic = 0; ic < c_size; ++ic) {
for (iree_hal_dim_t kh = 0; kh < kh_size; ++kh) {
for (iree_hal_dim_t kw = 0; kw < kw_size; ++kw) {
iree_hal_dim_t inp_idx = convert_to_1d_index(
c_size, h_size, w_size, n, ic, (oh * sh_size + kh * dh_size),
(ow * sw_size + kw * dw_size));
iree_hal_dim_t krnl_idx =
convert_to_1d_index(c_size, kh_size, kw_size, oc, ic, kh, kw);
acc += input_data[inp_idx] * kernel_data[krnl_idx];
}
}
result_data[out_idx] = acc;
}
}
// Helper for reference_conv2d.
static iree_status_t reference_conv2d_element(
iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size,
iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size,
iree_hal_dim_t kw_size, iree_hal_dim_t sh_size, iree_hal_dim_t sw_size,
iree_hal_dim_t dh_size, iree_hal_dim_t dw_size, iree_hal_dim_t oh_size,
iree_hal_dim_t ow_size, iree_hal_element_type_t input_type,
iree_hal_element_type_t kernel_type, iree_hal_element_type_t acc_type,
void* input_data, void* kernel_data, void* acc_data, void* result_data,
iree_hal_dim_t n, iree_hal_dim_t oc, iree_hal_dim_t oh, iree_hal_dim_t ow) {
if (input_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
kernel_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
reference_conv2d_f32_f32_f32_f32(
n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, sh_size,
sw_size, dh_size, dw_size, oh_size, ow_size, (const float*)input_data,
(const float*)kernel_data, (const float*)acc_data, (float*)result_data,
n, oc, oh, ow);
} else if (input_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
kernel_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) {
reference_conv2d_f16_f16_f16_f16(
n_size, c_size, h_size, w_size, f_size, kh_size, kw_size, sh_size,
sw_size, dh_size, dw_size, oh_size, ow_size,
(const uint16_t*)input_data, (const uint16_t*)kernel_data,
(const uint16_t*)acc_data, (uint16_t*)result_data, n, oc, oh, ow);
} else {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled combination of element types in conv2d");
}
return iree_ok_status();
}
// Calculate the output shape given the dilation and strides.
static iree_hal_dim_t out_shape_calc(iree_hal_dim_t i_shape,
iree_hal_dim_t k_shape,
iree_hal_dim_t stride,
iree_hal_dim_t dilation) {
iree_hal_dim_t x = (k_shape - 1) * (dilation - 1);
x = i_shape - k_shape - x;
return floor(x / stride) + 1;
}
// Reference conv2d-NCHW-FCHW implementation, used to compare conv2d results
// against.
static iree_status_t reference_conv2d(
iree_hal_dim_t n_size, iree_hal_dim_t c_size, iree_hal_dim_t h_size,
iree_hal_dim_t w_size, iree_hal_dim_t f_size, iree_hal_dim_t kh_size,
iree_hal_dim_t kw_size, iree_hal_dim_t sh_size, iree_hal_dim_t sw_size,
iree_hal_dim_t dh_size, iree_hal_dim_t dw_size,
iree_hal_element_type_t input_type, iree_hal_element_type_t kernel_type,
iree_hal_element_type_t acc_type, iree_byte_span_t input_contents,
iree_byte_span_t kernel_contents, iree_byte_span_t acc_contents,
iree_byte_span_t result_contents, int compute_every) {
IREE_TRACE_ZONE_BEGIN(z0);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, n_size);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, c_size);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, h_size);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, w_size);
iree_hal_dim_t oh_size = out_shape_calc(h_size, kh_size, sh_size, dh_size);
iree_hal_dim_t ow_size = out_shape_calc(w_size, kw_size, sw_size, dw_size);
for (iree_hal_dim_t n = 0; n < n_size; ++n) {
for (iree_hal_dim_t oc = 0; oc < f_size; ++oc) {
for (iree_hal_dim_t oh = 0; oh < oh_size; ++oh) {
for (iree_hal_dim_t ow = 0; ow < ow_size; ++ow) {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, reference_conv2d_element(
n_size, c_size, h_size, w_size, f_size, kh_size, kw_size,
sh_size, sw_size, dh_size, dw_size, oh_size, ow_size,
input_type, kernel_type, acc_type, input_contents.data,
kernel_contents.data, acc_contents.data,
result_contents.data, n, oc, oh, ow));
}
}
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// Conv2d comparison/logging
//===----------------------------------------------------------------------===//
typedef struct {
iree_allocator_t host_allocator;
iree_hal_dim_t n; // batch dim
iree_hal_dim_t c; // input channels
iree_hal_dim_t h; // input height
iree_hal_dim_t w; // input width
iree_hal_dim_t f; // output channels
iree_hal_dim_t kh; // kernel height
iree_hal_dim_t kw; // kernel width
iree_hal_dim_t sh; // stride along height dim
iree_hal_dim_t sw; // stride along width dim
iree_hal_dim_t dh; // dilation along height dim
iree_hal_dim_t dw; // dilation along width dim
iree_hal_element_type_t input_type;
iree_hal_element_type_t kernel_type;
iree_hal_element_type_t acc_type;
iree_hal_element_type_t result_type;
iree_byte_span_t input_contents;
iree_byte_span_t kernel_contents;
iree_byte_span_t acc_contents;
iree_byte_span_t actual_contents;
iree_byte_span_t expected_contents;
} conv2d_results_t;
static void conv2d_results_deinitialize(conv2d_results_t* results);
static iree_status_t conv2d_results_initialize(
iree_hal_device_t* device, iree_hal_dim_t n_size, iree_hal_dim_t c_size,
iree_hal_dim_t h_size, iree_hal_dim_t w_size, iree_hal_dim_t f_size,
iree_hal_dim_t kh_size, iree_hal_dim_t kw_size, iree_hal_dim_t sh_size,
iree_hal_dim_t sw_size, iree_hal_dim_t dh_size, iree_hal_dim_t dw_size,
iree_hal_buffer_view_t* input, iree_hal_buffer_view_t* kernel,
iree_hal_buffer_view_t* acc, iree_hal_buffer_view_t* result,
iree_allocator_t host_allocator, conv2d_results_t* out_results) {
IREE_TRACE_ZONE_BEGIN(z0);
memset(out_results, 0, sizeof(*out_results));
out_results->host_allocator = host_allocator;
out_results->n = n_size;
out_results->c = c_size;
out_results->h = h_size;
out_results->w = w_size;
out_results->f = f_size;
out_results->kh = kh_size;
out_results->kw = kw_size;
out_results->sh = sh_size;
out_results->sw = sw_size;
out_results->dh = dh_size;
out_results->dw = dw_size;
out_results->input_type = iree_hal_buffer_view_element_type(input);
out_results->kernel_type = iree_hal_buffer_view_element_type(kernel);
out_results->acc_type = iree_hal_buffer_view_element_type(acc);
out_results->result_type = iree_hal_buffer_view_element_type(result);
iree_hal_buffer_t* input_buffer = iree_hal_buffer_view_buffer(input);
iree_hal_buffer_t* kernel_buffer = iree_hal_buffer_view_buffer(kernel);
iree_hal_buffer_t* acc_buffer = acc ? iree_hal_buffer_view_buffer(acc) : NULL;
iree_hal_buffer_t* result_buffer = iree_hal_buffer_view_buffer(result);
iree_status_t status = iree_ok_status();
if (iree_status_is_ok(status)) {
out_results->input_contents.data_length =
iree_hal_buffer_byte_length(input_buffer);
status = iree_allocator_malloc(host_allocator,
out_results->input_contents.data_length,
(void**)&out_results->input_contents.data);
}
if (iree_status_is_ok(status)) {
status = iree_hal_device_transfer_d2h(
device, input_buffer, 0, out_results->input_contents.data,
out_results->input_contents.data_length,
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
}
if (iree_status_is_ok(status)) {
out_results->kernel_contents.data_length =
iree_hal_buffer_byte_length(kernel_buffer);
status = iree_allocator_malloc(host_allocator,
out_results->kernel_contents.data_length,
(void**)&out_results->kernel_contents.data);
}
if (iree_status_is_ok(status)) {
status = iree_hal_device_transfer_d2h(
device, kernel_buffer, 0, out_results->kernel_contents.data,
out_results->kernel_contents.data_length,
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
}
if (acc_buffer) {
if (iree_status_is_ok(status)) {
out_results->acc_contents.data_length =
iree_hal_buffer_byte_length(acc_buffer);
status = iree_allocator_malloc(host_allocator,
out_results->acc_contents.data_length,
(void**)&out_results->acc_contents.data);
}
if (iree_status_is_ok(status)) {
status = iree_hal_device_transfer_d2h(
device, acc_buffer, 0, out_results->acc_contents.data,
out_results->acc_contents.data_length,
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
}
}
if (iree_status_is_ok(status)) {
out_results->actual_contents.data_length =
iree_hal_buffer_byte_length(result_buffer);
status = iree_allocator_malloc(host_allocator,
out_results->actual_contents.data_length,
(void**)&out_results->actual_contents.data);
}
if (iree_status_is_ok(status)) {
status = iree_hal_device_transfer_d2h(
device, result_buffer, 0, out_results->actual_contents.data,
out_results->actual_contents.data_length,
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
}
if (iree_status_is_ok(status)) {
out_results->expected_contents.data_length =
iree_hal_buffer_byte_length(result_buffer);
status = iree_allocator_malloc(
host_allocator, out_results->expected_contents.data_length,
(void**)&out_results->expected_contents.data);
}
if (!iree_status_is_ok(status)) {
conv2d_results_deinitialize(out_results);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
static void conv2d_results_deinitialize(conv2d_results_t* results) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_allocator_free(results->host_allocator, results->input_contents.data);
iree_allocator_free(results->host_allocator, results->kernel_contents.data);
if (!iree_byte_span_is_empty(results->acc_contents)) {
iree_allocator_free(results->host_allocator, results->acc_contents.data);
}
iree_allocator_free(results->host_allocator, results->actual_contents.data);
iree_allocator_free(results->host_allocator, results->expected_contents.data);
IREE_TRACE_ZONE_END(z0);
}
// Helper for check_conv2d: the actual interesting part once we've
// obtained and validated the {n, f, oh, ow}_size values. On error, the first
// index is returned where the actual and expected value doesn't match. TODO:
// Add detailed logging to |file|.
static iree_status_t check_conv2d_results_impl(FILE* file,
const conv2d_results_t* results,
int check_every) {
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, reference_conv2d(results->n, results->c, results->h, results->w,
results->f, results->kh, results->kw, results->sh,
results->sw, results->dh, results->dw,
results->input_type, results->acc_type,
results->kernel_type, results->input_contents,
results->kernel_contents, results->acc_contents,
results->expected_contents, check_every));
int count = 0;
iree_hal_dim_t oh_size =
out_shape_calc(results->h, results->kh, results->sh, results->dh);
iree_hal_dim_t ow_size =
out_shape_calc(results->w, results->kw, results->sw, results->dw);
for (iree_hal_dim_t n = 0; n < results->n; ++n) {
for (iree_hal_dim_t oc = 0; oc < results->f; ++oc) {
for (iree_hal_dim_t oh = 0; oh < oh_size; ++oh) {
for (iree_hal_dim_t ow = 0; ow < ow_size; ++ow) {
if (++count < check_every) continue;
count = 0;
iree_hal_dim_t idx =
convert_to_1d_index(results->f, oh_size, ow_size, n, oc, oh, ow);
iree_test_utils_e2e_value_t actual_value =
iree_test_utils_read_buffer_element(
idx, results->result_type, results->actual_contents.data);
iree_test_utils_e2e_value_t expected_value =
iree_test_utils_read_buffer_element(
idx, results->result_type, results->expected_contents.data);
if (!iree_test_utils_result_elements_agree(actual_value,
expected_value)) {
fprintf(
file,
"\n\nerror: the actual and expected result tensors disagree "
"at n %" PRIdim ", oc %" PRIdim ", oh %" PRIdim ", ow %" PRIdim
".\n\n",
n, oc, oh, ow);
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_ABORTED);
}
}
}
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
// Given an actual conv2d's inputs and output (all host-local), uses a
// reference conv2d implementation on the same inputs to check if the output
// is correct. On error, the first index is returned where the actual and
// expected value doesn't match. TODO: Add detailed logging to |file|.
static iree_status_t check_conv2d_results(FILE* file,
const conv2d_results_t* results) {
IREE_TRACE_ZONE_BEGIN(z0);
// TODO: Increase the check every param to reduce the number of comparisons.
int check_every = 1;
iree_status_t status = check_conv2d_results_impl(file, results, check_every);
if (!iree_status_is_ok(status) && check_every > 1) {
// If we got a failure with check_every>1, that didn't log a useful
// numerical summary, as most of the reference tensor entries hadn't been
// computed. Rerun now with check_every=1 to get that numerical logging.
iree_status_ignore(status);
status = check_conv2d_results_impl(file, results, 1);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
//===----------------------------------------------------------------------===//
// `conv2d_test` custom module
//===----------------------------------------------------------------------===//
// This uses the C++ wrapper to keep things simple. Though easier to use it's
// got additional overhead/code-size bloat that doesn't matter in a test like
// this. Making a C module builder API that removes the boilerplate there is
// TBD so this file is written in C besides this module so that we can swap it
// back to being pure C in the future.
namespace iree {
class Conv2dTestModuleState final {
public:
explicit Conv2dTestModuleState(iree_allocator_t host_allocator)
: host_allocator_(host_allocator) {}
~Conv2dTestModuleState() = default;
// Fills the destination span with pseudorandom values of the given
// |element_type|. The given |seed| is passed to the pseudorandom generator.
// The pseudorandom values are reproducible both across runs and across
// machines.
StatusOr<vm::ref<iree_hal_buffer_view_t>> GenerateRandom4dTensor(
const vm::ref<iree_hal_device_t> device, int64_t dim0, int64_t dim1,
int64_t dim2, int64_t dim3, iree_hal_element_type_t element_type,
int32_t seed) {
iree_hal_dim_t dims[4] = {
(iree_hal_dim_t)dim0,
(iree_hal_dim_t)dim1,
(iree_hal_dim_t)dim2,
(iree_hal_dim_t)dim3,
};
iree_hal_buffer_params_t buffer_params = {0};
buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
buffer_params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
vm::ref<iree_hal_buffer_view_t> result_view;
struct callback_state_t {
iree_hal_element_type_t element_type;
int32_t seed;
} callback_state = {
element_type,
seed,
};
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_generate_buffer(
device.get(), iree_hal_device_allocator(device.get()),
IREE_ARRAYSIZE(dims), dims, element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
+[](iree_hal_buffer_mapping_t* mapping, void* user_data) {
callback_state_t callback_state = *(callback_state_t*)user_data;
iree_byte_span_t span = mapping->contents;
// Generate "uniform" integer-valued numbers in the range [min,
// max].
int32_t min = 0;
int32_t max = 0;
iree_test_utils_get_min_max_for_element_type(
callback_state.element_type, &min, &max);
uint32_t range = (max - min + 1);
iree_host_size_t element_byte_count =
iree_hal_element_dense_byte_count(callback_state.element_type);
uint8_t* data_end = span.data + span.data_length;
uint32_t state = callback_state.seed;
for (uint8_t* data = span.data; data < data_end;
data += element_byte_count) {
int32_t value =
(int32_t)iree_test_utils_pseudorandom_range(&state, range) +
min;
iree_test_utils_write_element(callback_state.element_type, value,
data);
}
return iree_ok_status();
},
&callback_state, &result_view));
return std::move(result_view);
}
Status CheckConv2dResults(
const vm::ref<iree_hal_device_t> device, int64_t n, int64_t c, int64_t h,
int64_t w, int64_t f, int64_t kh, int64_t kw, int64_t sh, int64_t sw,
int64_t dh, int64_t dw, const vm::ref<iree_hal_buffer_view_t> input,
const vm::ref<iree_hal_buffer_view_t> kernel,
const vm::ref<iree_hal_buffer_view_t> acc,
const vm::ref<iree_hal_buffer_view_t> actual_result) {
conv2d_results_t results = {};
IREE_RETURN_IF_ERROR(conv2d_results_initialize(
device.get(), (iree_hal_dim_t)n, (iree_hal_dim_t)c, (iree_hal_dim_t)h,
(iree_hal_dim_t)w, (iree_hal_dim_t)f, (iree_hal_dim_t)kh,
(iree_hal_dim_t)kw, (iree_hal_dim_t)sh, (iree_hal_dim_t)sw,
(iree_hal_dim_t)dh, (iree_hal_dim_t)dw, input.get(), kernel.get(),
acc.get(), actual_result.get(), host_allocator_, &results));
iree_status_t status = check_conv2d_results(stderr, &results);
conv2d_results_deinitialize(&results);
return status;
}
private:
iree_allocator_t host_allocator_;
};
static const vm::NativeFunction<Conv2dTestModuleState>
kConv2dTestModuleFunctions[] = {
vm::MakeNativeFunction("generate_random_tensor",
&Conv2dTestModuleState::GenerateRandom4dTensor),
vm::MakeNativeFunction("check_conv2d_results",
&Conv2dTestModuleState::CheckConv2dResults),
};
struct Conv2dTestModule final : public vm::NativeModule<Conv2dTestModuleState> {
using vm::NativeModule<Conv2dTestModuleState>::NativeModule;
StatusOr<std::unique_ptr<Conv2dTestModuleState>> CreateState(
iree_allocator_t host_allocator) override {
return std::make_unique<Conv2dTestModuleState>(host_allocator);
}
StatusOr<std::unique_ptr<Conv2dTestModuleState>> ForkState(
Conv2dTestModuleState* parent_state,
iree_allocator_t host_allocator) override {
return CreateState(host_allocator);
}
};
} // namespace iree
static iree_status_t conv2d_test_module_create(iree_vm_instance_t* instance,
iree_allocator_t host_allocator,
iree_vm_module_t** out_module) {
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
auto module = std::make_unique<iree::Conv2dTestModule>(
"conv2d_test", /*version=*/0, instance, host_allocator,
iree::span<const iree::vm::NativeFunction<iree::Conv2dTestModuleState>>(
iree::kConv2dTestModuleFunctions));
*out_module = module.release()->interface();
return iree_ok_status();
}
int main(int argc, char** argv) {
IREE_TRACE_APP_ENTER();
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
if (argc != 1) {
fprintf(stderr, "use --module= flags to specify the modules to run\n");
IREE_TRACE_APP_EXIT(EXIT_FAILURE);
return EXIT_FAILURE;
}
// Run the tests. Note that some modules may be compiled for other platforms
// and not have the required architectures for execution within them - to keep
// the test runner dumber we gracefully fail those cases by returning success.
iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
iree_allocator_system(), conv2d_test_module_create);
int exit_code = EXIT_SUCCESS;
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
bool is_device_unavailable = iree_status_is_not_found(status);
iree_status_free(status);
exit_code = is_device_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
}
IREE_TRACE_APP_EXIT(exit_code);
return exit_code;
}