blob: 4381de39cbdd342a352e42e5fac2dc4120ac9afa [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Tests for bytecode_module.cc implementations.
// This means mostly just FlatBuffer verification, module interface functions,
// etc. bytecode_dispatch_test.cc covers actual dispatch.
#include "iree/vm/bytecode/module.h"
#include <memory>
#include <vector>
#include "iree/base/api.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode/module_test_module_c.h"
static bool operator==(const iree_vm_value_t& lhs,
const iree_vm_value_t& rhs) noexcept {
if (lhs.type != rhs.type) return false;
switch (lhs.type) {
default:
case IREE_VM_VALUE_TYPE_NONE:
return true; // none == none
case IREE_VM_VALUE_TYPE_I8:
return lhs.i8 == rhs.i8;
case IREE_VM_VALUE_TYPE_I16:
return lhs.i16 == rhs.i16;
case IREE_VM_VALUE_TYPE_I32:
return lhs.i32 == rhs.i32;
case IREE_VM_VALUE_TYPE_I64:
return lhs.i64 == rhs.i64;
case IREE_VM_VALUE_TYPE_F32:
return lhs.f32 == rhs.f32;
case IREE_VM_VALUE_TYPE_F64:
return lhs.f64 == rhs.f64;
}
}
static std::ostream& operator<<(std::ostream& os,
const iree_vm_value_t& value) {
switch (value.type) {
default:
case IREE_VM_VALUE_TYPE_NONE:
return os << "??";
case IREE_VM_VALUE_TYPE_I8:
return os << value.i8;
case IREE_VM_VALUE_TYPE_I16:
return os << value.i16;
case IREE_VM_VALUE_TYPE_I32:
return os << value.i32;
case IREE_VM_VALUE_TYPE_I64:
return os << value.i64;
case IREE_VM_VALUE_TYPE_F32:
return os << value.f32;
case IREE_VM_VALUE_TYPE_F64:
return os << value.f64;
}
}
template <size_t N>
static std::vector<iree_vm_value_t> MakeValuesList(const int32_t (&values)[N]) {
std::vector<iree_vm_value_t> result;
result.resize(N);
for (size_t i = 0; i < N; ++i) result[i] = iree_vm_value_make_i32(values[i]);
return result;
}
static std::vector<iree_vm_value_t> MakeValueRangeList(int32_t start,
int32_t end) {
std::vector<iree_vm_value_t> result;
result.resize(abs(start - end) + 1);
int32_t value = start;
int32_t delta = start < end ? 1 : -1;
for (size_t i = 0; i < result.size(); ++i, value += delta) {
result[i] = iree_vm_value_make_i32(value);
}
return result;
}
static bool operator==(const iree_vm_ref_t& lhs,
const iree_vm_ref_t& rhs) noexcept {
return lhs.type == rhs.type && lhs.ptr == rhs.ptr;
}
static std::ostream& operator<<(std::ostream& os, const iree_vm_ref_t& value) {
// Just nulls today.
return os << (iree_vm_ref_is_null(&value) ? "(null)" : "??");
}
static std::vector<iree_vm_ref_t> MakeNullRefList(size_t count) {
std::vector<iree_vm_ref_t> result;
result.resize(count);
return result;
}
namespace {
using iree::StatusCode;
using iree::StatusOr;
using iree::testing::status::IsOkAndHolds;
using iree::testing::status::StatusIs;
using iree::vm::ref;
using testing::Eq;
class VMBytecodeModuleTest : public ::testing::Test {
protected:
virtual void SetUp() {
IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT,
iree_allocator_system(), &instance_));
const auto* module_file_toc = iree_vm_bytecode_module_test_module_create();
IREE_CHECK_OK(iree_vm_bytecode_module_create(
instance_,
iree_const_byte_span_t{
reinterpret_cast<const uint8_t*>(module_file_toc->data),
static_cast<iree_host_size_t>(module_file_toc->size)},
iree_allocator_null(), iree_allocator_system(), &bytecode_module_));
std::vector<iree_vm_module_t*> modules = {bytecode_module_};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(),
iree_allocator_system(), &context_));
}
virtual void TearDown() {
iree_vm_module_release(bytecode_module_);
iree_vm_context_release(context_);
iree_vm_instance_release(instance_);
}
StatusOr<std::vector<iree_vm_value_t>> RunFunction(
const char* function_name, std::vector<iree_vm_value_t> inputs) {
ref<iree_vm_list_t> input_list;
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(), inputs.size(),
iree_allocator_system(), &input_list));
IREE_RETURN_IF_ERROR(iree_vm_list_resize(input_list.get(), inputs.size()));
for (iree_host_size_t i = 0; i < inputs.size(); ++i) {
IREE_RETURN_IF_ERROR(
iree_vm_list_set_value(input_list.get(), i, &inputs[i]));
}
ref<iree_vm_list_t> output_list;
IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(),
8, iree_allocator_system(),
&output_list));
iree_vm_function_t function;
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name(
bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_make_cstring_view(function_name), &function));
IREE_RETURN_IF_ERROR(
iree_vm_invoke(context_, function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, input_list.get(), output_list.get(),
iree_allocator_system()));
std::vector<iree_vm_value_t> outputs;
outputs.resize(iree_vm_list_size(output_list.get()));
for (iree_host_size_t i = 0; i < outputs.size(); ++i) {
IREE_RETURN_IF_ERROR(
iree_vm_list_get_value(output_list.get(), i, &outputs[i]));
}
return outputs;
}
StatusOr<std::vector<iree_vm_ref_t>> RunFunction(
const char* function_name, std::vector<iree_vm_ref_t> inputs) {
ref<iree_vm_list_t> input_list;
IREE_RETURN_IF_ERROR(
iree_vm_list_create(iree_vm_make_undefined_type_def(), inputs.size(),
iree_allocator_system(), &input_list));
IREE_RETURN_IF_ERROR(iree_vm_list_resize(input_list.get(), inputs.size()));
for (iree_host_size_t i = 0; i < inputs.size(); ++i) {
IREE_RETURN_IF_ERROR(
iree_vm_list_set_ref_retain(input_list.get(), i, &inputs[i]));
}
ref<iree_vm_list_t> output_list;
IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(),
8, iree_allocator_system(),
&output_list));
iree_vm_function_t function;
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name(
bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_make_cstring_view(function_name), &function));
IREE_RETURN_IF_ERROR(
iree_vm_invoke(context_, function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, input_list.get(), output_list.get(),
iree_allocator_system()));
std::vector<iree_vm_ref_t> outputs;
outputs.resize(iree_vm_list_size(output_list.get()));
for (iree_host_size_t i = 0; i < outputs.size(); ++i) {
IREE_RETURN_IF_ERROR(
iree_vm_list_get_ref_retain(output_list.get(), i, &outputs[i]));
}
return outputs;
}
iree_vm_instance_t* instance_ = nullptr;
iree_vm_context_t* context_ = nullptr;
iree_vm_module_t* bytecode_module_ = nullptr;
};
TEST_F(VMBytecodeModuleTest, FuncIOEmpty) {
EXPECT_THAT(RunFunction("FuncIOEmpty", std::vector<iree_vm_value_t>()),
IsOkAndHolds(Eq(std::vector<iree_vm_value_t>())));
}
TEST_F(VMBytecodeModuleTest, FuncIO1) {
EXPECT_THAT(RunFunction("FuncIO1", MakeValuesList({1})),
IsOkAndHolds(Eq(MakeValuesList({1}))));
}
TEST_F(VMBytecodeModuleTest, FuncIO8) {
EXPECT_THAT(RunFunction("FuncIO8", MakeValueRangeList(0, 7)),
IsOkAndHolds(Eq(MakeValueRangeList(7, 0))));
}
TEST_F(VMBytecodeModuleTest, FuncIO600) {
EXPECT_THAT(RunFunction("FuncIO600", MakeNullRefList(600)),
IsOkAndHolds(Eq(MakeNullRefList(600))));
}
} // namespace