blob: 71f884e3070f3f186513cd3ee53ec848d8e5855d [file]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "absl/strings/match.h"
#include "absl/strings/str_replace.h"
#include "iree/base/logging.h"
#include "iree/base/status.h"
#include "iree/testing/gtest.h"
#include "iree/vm/api.h"
#include "iree/vm/test/emitc/arithmetic_ops.h"
#include "iree/vm/test/emitc/arithmetic_ops_f32.h"
#include "iree/vm/test/emitc/arithmetic_ops_i64.h"
#include "iree/vm/test/emitc/assignment_ops.h"
#include "iree/vm/test/emitc/assignment_ops_i64.h"
#include "iree/vm/test/emitc/comparison_ops.h"
#include "iree/vm/test/emitc/comparison_ops_f32.h"
#include "iree/vm/test/emitc/comparison_ops_i64.h"
#include "iree/vm/test/emitc/control_flow_ops.h"
#include "iree/vm/test/emitc/conversion_ops.h"
#include "iree/vm/test/emitc/conversion_ops_i64.h"
#include "iree/vm/test/emitc/global_ops.h"
#include "iree/vm/test/emitc/list_ops.h"
#include "iree/vm/test/emitc/shift_ops.h"
#include "iree/vm/test/emitc/shift_ops_i64.h"
namespace {
typedef iree_status_t (*create_function_t)(iree_allocator_t,
iree_vm_module_t**);
struct TestParams {
std::string module_name;
std::string local_name;
create_function_t create_function;
};
struct ModuleDescription {
iree_vm_native_module_descriptor_t descriptor;
create_function_t create_function;
};
std::ostream& operator<<(std::ostream& os, const TestParams& params) {
std::string qualified_name = params.module_name + "." + params.local_name;
return os << absl::StrReplaceAll(qualified_name, {{":", "_"}, {".", "_"}});
}
std::vector<TestParams> GetModuleTestParams() {
std::vector<TestParams> test_params;
// TODO(simon-camp): get these automatically
std::vector<ModuleDescription> modules = {
{arithmetic_ops_descriptor_, arithmetic_ops_create},
{arithmetic_ops_f32_descriptor_, arithmetic_ops_f32_create},
{arithmetic_ops_i64_descriptor_, arithmetic_ops_i64_create},
{assignment_ops_descriptor_, assignment_ops_create},
{assignment_ops_i64_descriptor_, assignment_ops_i64_create},
{comparison_ops_descriptor_, comparison_ops_create},
{comparison_ops_f32_descriptor_, comparison_ops_f32_create},
{comparison_ops_i64_descriptor_, comparison_ops_i64_create},
{control_flow_ops_descriptor_, control_flow_ops_create},
{conversion_ops_descriptor_, conversion_ops_create},
{conversion_ops_i64_descriptor_, conversion_ops_i64_create},
{global_ops_descriptor_, global_ops_create},
{list_ops_descriptor_, list_ops_create},
{shift_ops_descriptor_, shift_ops_create},
{shift_ops_i64_descriptor_, shift_ops_i64_create}};
for (size_t i = 0; i < modules.size(); i++) {
iree_vm_native_module_descriptor_t descriptor = modules[i].descriptor;
create_function_t function = modules[i].create_function;
std::string module_name =
std::string(descriptor.module_name.data, descriptor.module_name.size);
for (iree_host_size_t i = 0; i < descriptor.export_count; i++) {
iree_vm_native_export_descriptor_t export_descriptor =
descriptor.exports[i];
std::string local_name = std::string(export_descriptor.local_name.data,
export_descriptor.local_name.size);
test_params.push_back({module_name, local_name, function});
}
}
return test_params;
}
class VMCModuleTest : public ::testing::Test,
public ::testing::WithParamInterface<TestParams> {
protected:
virtual void SetUp() {
const auto& test_params = GetParam();
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
iree_vm_module_t* module_ = nullptr;
IREE_CHECK_OK(
test_params.create_function(iree_allocator_system(), &module_))
<< "Module failed to load";
std::vector<iree_vm_module_t*> modules = {module_};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
instance_, modules.data(), modules.size(), iree_allocator_system(),
&context_));
iree_vm_module_release(module_);
}
virtual void TearDown() {
iree_vm_context_release(context_);
iree_vm_instance_release(instance_);
}
iree_status_t RunFunction(std::string module_name, std::string local_name) {
std::string qualified_name = module_name + "." + local_name;
iree_vm_function_t function;
IREE_CHECK_OK(iree_vm_context_resolve_function(
context_,
iree_string_view_t{qualified_name.data(), qualified_name.size()},
&function))
<< "Exported function '" << local_name << "' not found";
return iree_vm_invoke(context_, function,
/*policy=*/nullptr, /*inputs=*/nullptr,
/*outputs=*/nullptr, iree_allocator_system());
}
iree_vm_instance_t* instance_ = nullptr;
iree_vm_context_t* context_ = nullptr;
};
TEST_P(VMCModuleTest, Check) {
const auto& test_params = GetParam();
bool expect_failure = absl::StartsWith(test_params.local_name, "fail_");
iree::Status result =
RunFunction(test_params.module_name, test_params.local_name);
if (result.ok()) {
if (expect_failure) {
GTEST_FAIL() << "Function expected failure but succeeded";
} else {
GTEST_SUCCEED();
}
} else {
if (expect_failure) {
GTEST_SUCCEED();
} else {
GTEST_FAIL() << "Function expected success but failed with error: "
<< result.ToString();
}
}
}
INSTANTIATE_TEST_SUITE_P(VMIRFunctions, VMCModuleTest,
::testing::ValuesIn(GetModuleTestParams()),
::testing::PrintToStringParamName());
} // namespace