| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Tests for bytecode_module.cc implementations. |
| // This means mostly just FlatBuffer verification, module interface functions, |
| // etc. bytecode_dispatch_test.cc covers actual dispatch. |
| |
| #include "iree/vm/bytecode/module.h" |
| |
| #include <memory> |
| #include <vector> |
| |
| #include "iree/base/api.h" |
| #include "iree/testing/gtest.h" |
| #include "iree/testing/status_matchers.h" |
| #include "iree/vm/api.h" |
| #include "iree/vm/bytecode/module_test_module_c.h" |
| |
| static bool operator==(const iree_vm_value_t& lhs, |
| const iree_vm_value_t& rhs) noexcept { |
| if (lhs.type != rhs.type) return false; |
| switch (lhs.type) { |
| default: |
| case IREE_VM_VALUE_TYPE_NONE: |
| return true; // none == none |
| case IREE_VM_VALUE_TYPE_I8: |
| return lhs.i8 == rhs.i8; |
| case IREE_VM_VALUE_TYPE_I16: |
| return lhs.i16 == rhs.i16; |
| case IREE_VM_VALUE_TYPE_I32: |
| return lhs.i32 == rhs.i32; |
| case IREE_VM_VALUE_TYPE_I64: |
| return lhs.i64 == rhs.i64; |
| case IREE_VM_VALUE_TYPE_F32: |
| return lhs.f32 == rhs.f32; |
| case IREE_VM_VALUE_TYPE_F64: |
| return lhs.f64 == rhs.f64; |
| } |
| } |
| |
| static std::ostream& operator<<(std::ostream& os, |
| const iree_vm_value_t& value) { |
| switch (value.type) { |
| default: |
| case IREE_VM_VALUE_TYPE_NONE: |
| return os << "??"; |
| case IREE_VM_VALUE_TYPE_I8: |
| return os << value.i8; |
| case IREE_VM_VALUE_TYPE_I16: |
| return os << value.i16; |
| case IREE_VM_VALUE_TYPE_I32: |
| return os << value.i32; |
| case IREE_VM_VALUE_TYPE_I64: |
| return os << value.i64; |
| case IREE_VM_VALUE_TYPE_F32: |
| return os << value.f32; |
| case IREE_VM_VALUE_TYPE_F64: |
| return os << value.f64; |
| } |
| } |
| |
| template <size_t N> |
| static std::vector<iree_vm_value_t> MakeValuesList(const int32_t (&values)[N]) { |
| std::vector<iree_vm_value_t> result; |
| result.resize(N); |
| for (size_t i = 0; i < N; ++i) result[i] = iree_vm_value_make_i32(values[i]); |
| return result; |
| } |
| |
| static std::vector<iree_vm_value_t> MakeValueRangeList(int32_t start, |
| int32_t end) { |
| std::vector<iree_vm_value_t> result; |
| result.resize(abs(start - end) + 1); |
| int32_t value = start; |
| int32_t delta = start < end ? 1 : -1; |
| for (size_t i = 0; i < result.size(); ++i, value += delta) { |
| result[i] = iree_vm_value_make_i32(value); |
| } |
| return result; |
| } |
| |
| static bool operator==(const iree_vm_ref_t& lhs, |
| const iree_vm_ref_t& rhs) noexcept { |
| return lhs.type == rhs.type && lhs.ptr == rhs.ptr; |
| } |
| |
| static std::ostream& operator<<(std::ostream& os, const iree_vm_ref_t& value) { |
| // Just nulls today. |
| return os << (iree_vm_ref_is_null(&value) ? "(null)" : "??"); |
| } |
| |
| static std::vector<iree_vm_ref_t> MakeNullRefList(size_t count) { |
| std::vector<iree_vm_ref_t> result; |
| result.resize(count); |
| return result; |
| } |
| |
| namespace { |
| |
| using iree::StatusCode; |
| using iree::StatusOr; |
| using iree::testing::status::IsOkAndHolds; |
| using iree::testing::status::StatusIs; |
| using iree::vm::ref; |
| using testing::Eq; |
| |
| class VMBytecodeModuleTest : public ::testing::Test { |
| protected: |
| virtual void SetUp() { |
| IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, |
| iree_allocator_system(), &instance_)); |
| |
| const auto* module_file_toc = iree_vm_bytecode_module_test_module_create(); |
| IREE_CHECK_OK(iree_vm_bytecode_module_create( |
| instance_, |
| iree_const_byte_span_t{ |
| reinterpret_cast<const uint8_t*>(module_file_toc->data), |
| static_cast<iree_host_size_t>(module_file_toc->size)}, |
| iree_allocator_null(), iree_allocator_system(), &bytecode_module_)); |
| |
| std::vector<iree_vm_module_t*> modules = {bytecode_module_}; |
| IREE_CHECK_OK(iree_vm_context_create_with_modules( |
| instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(), |
| iree_allocator_system(), &context_)); |
| } |
| |
| virtual void TearDown() { |
| iree_vm_module_release(bytecode_module_); |
| iree_vm_context_release(context_); |
| iree_vm_instance_release(instance_); |
| } |
| |
| StatusOr<std::vector<iree_vm_value_t>> RunFunction( |
| const char* function_name, std::vector<iree_vm_value_t> inputs) { |
| ref<iree_vm_list_t> input_list; |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_create(iree_vm_make_undefined_type_def(), inputs.size(), |
| iree_allocator_system(), &input_list)); |
| IREE_RETURN_IF_ERROR(iree_vm_list_resize(input_list.get(), inputs.size())); |
| for (iree_host_size_t i = 0; i < inputs.size(); ++i) { |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_set_value(input_list.get(), i, &inputs[i])); |
| } |
| |
| ref<iree_vm_list_t> output_list; |
| IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), |
| 8, iree_allocator_system(), |
| &output_list)); |
| |
| iree_vm_function_t function; |
| IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name( |
| bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, |
| iree_make_cstring_view(function_name), &function)); |
| IREE_RETURN_IF_ERROR( |
| iree_vm_invoke(context_, function, IREE_VM_INVOCATION_FLAG_NONE, |
| /*policy=*/nullptr, input_list.get(), output_list.get(), |
| iree_allocator_system())); |
| |
| std::vector<iree_vm_value_t> outputs; |
| outputs.resize(iree_vm_list_size(output_list.get())); |
| for (iree_host_size_t i = 0; i < outputs.size(); ++i) { |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_get_value(output_list.get(), i, &outputs[i])); |
| } |
| return outputs; |
| } |
| |
| StatusOr<std::vector<iree_vm_ref_t>> RunFunction( |
| const char* function_name, std::vector<iree_vm_ref_t> inputs) { |
| ref<iree_vm_list_t> input_list; |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_create(iree_vm_make_undefined_type_def(), inputs.size(), |
| iree_allocator_system(), &input_list)); |
| IREE_RETURN_IF_ERROR(iree_vm_list_resize(input_list.get(), inputs.size())); |
| for (iree_host_size_t i = 0; i < inputs.size(); ++i) { |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_set_ref_retain(input_list.get(), i, &inputs[i])); |
| } |
| |
| ref<iree_vm_list_t> output_list; |
| IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), |
| 8, iree_allocator_system(), |
| &output_list)); |
| |
| iree_vm_function_t function; |
| IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name( |
| bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, |
| iree_make_cstring_view(function_name), &function)); |
| IREE_RETURN_IF_ERROR( |
| iree_vm_invoke(context_, function, IREE_VM_INVOCATION_FLAG_NONE, |
| /*policy=*/nullptr, input_list.get(), output_list.get(), |
| iree_allocator_system())); |
| |
| std::vector<iree_vm_ref_t> outputs; |
| outputs.resize(iree_vm_list_size(output_list.get())); |
| for (iree_host_size_t i = 0; i < outputs.size(); ++i) { |
| IREE_RETURN_IF_ERROR( |
| iree_vm_list_get_ref_retain(output_list.get(), i, &outputs[i])); |
| } |
| return outputs; |
| } |
| |
| iree_vm_instance_t* instance_ = nullptr; |
| iree_vm_context_t* context_ = nullptr; |
| iree_vm_module_t* bytecode_module_ = nullptr; |
| }; |
| |
| TEST_F(VMBytecodeModuleTest, FuncIOEmpty) { |
| EXPECT_THAT(RunFunction("FuncIOEmpty", std::vector<iree_vm_value_t>()), |
| IsOkAndHolds(Eq(std::vector<iree_vm_value_t>()))); |
| } |
| |
| TEST_F(VMBytecodeModuleTest, FuncIO1) { |
| EXPECT_THAT(RunFunction("FuncIO1", MakeValuesList({1})), |
| IsOkAndHolds(Eq(MakeValuesList({1})))); |
| } |
| |
| TEST_F(VMBytecodeModuleTest, FuncIO8) { |
| EXPECT_THAT(RunFunction("FuncIO8", MakeValueRangeList(0, 7)), |
| IsOkAndHolds(Eq(MakeValueRangeList(7, 0)))); |
| } |
| |
| TEST_F(VMBytecodeModuleTest, FuncIO600) { |
| EXPECT_THAT(RunFunction("FuncIO600", MakeNullRefList(600)), |
| IsOkAndHolds(Eq(MakeNullRefList(600)))); |
| } |
| |
| } // namespace |