blob: db2a1b043d581eab2c86b9c2056cfe7843e79786 [file]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/tooling/function_io.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/io/vec_stream.h"
#include "iree/modules/hal/module.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
#include "iree/vm/api.h"
namespace iree {
namespace {
struct FunctionIOTest : public ::testing::Test {
virtual void SetUp() {
host_allocator = iree_allocator_system();
IREE_ASSERT_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT,
host_allocator, &instance));
IREE_ASSERT_OK(iree_hal_module_register_all_types(instance));
IREE_ASSERT_OK(iree_hal_allocator_create_heap(
IREE_SV("test"), host_allocator, host_allocator, &device_allocator));
}
virtual void TearDown() {
iree_hal_allocator_release(device_allocator);
iree_vm_instance_release(instance);
}
Status ParseToVariantList(iree_string_view_t cconv,
std::vector<std::string> input_strings,
iree_vm_list_t** out_list) {
std::vector<iree_string_view_t> input_string_views(input_strings.size());
for (size_t i = 0; i < input_strings.size(); ++i) {
input_string_views[i].data = input_strings[i].data();
input_string_views[i].size = input_strings[i].size();
}
return iree_tooling_parse_variants(
cconv,
iree_string_view_list_t{input_string_views.size(),
input_string_views.data()},
/*device=*/NULL, device_allocator, host_allocator, out_list);
}
Status PrintVariantList(iree_vm_list_t* variant_list,
std::string* out_string) {
iree_io_stream_t* stream = NULL;
IREE_RETURN_IF_ERROR(iree_io_vec_stream_create(
IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE |
IREE_IO_STREAM_MODE_SEEKABLE,
/*block_size=*/32 * 1024, host_allocator, &stream));
iree_status_t status = iree_tooling_print_variants(
IREE_SV("result"), variant_list, /*max_element_count=*/1024, stream,
host_allocator);
if (iree_status_is_ok(status)) {
status = iree_io_stream_seek(stream, IREE_IO_STREAM_SEEK_SET, 0);
}
if (iree_status_is_ok(status)) {
out_string->resize(iree_io_stream_length(stream));
status = iree_io_stream_read(stream, out_string->size(),
out_string->data(), NULL);
}
iree_io_stream_release(stream);
return status;
}
iree_allocator_t host_allocator;
iree_vm_instance_t* instance = nullptr;
iree_hal_allocator_t* device_allocator = nullptr;
};
TEST_F(FunctionIOTest, ParsePrintBuffer) {
std::string buf_string = "&2x2xi32=[42 43][44 45]";
vm::ref<iree_vm_list_t> variant_list;
IREE_ASSERT_OK(ParseToVariantList(
IREE_SV("r"), std::vector<std::string>{buf_string}, &variant_list));
std::string result;
IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result));
EXPECT_EQ(result,
std::string("result[0]: hal.buffer\n") + "(no printer)" + "\n");
}
TEST_F(FunctionIOTest, ParsePrintBufferView) {
std::string buf_string = "2x2xi32=[42 43][44 45]";
vm::ref<iree_vm_list_t> variant_list;
IREE_ASSERT_OK(ParseToVariantList(
IREE_SV("r"), std::vector<std::string>{buf_string}, &variant_list));
std::string result;
IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result));
EXPECT_EQ(result,
std::string("result[0]: hal.buffer_view\n") + buf_string + "\n");
}
TEST_F(FunctionIOTest, ParsePrintScalar) {
std::string input_string = "42";
vm::ref<iree_vm_list_t> variant_list;
IREE_ASSERT_OK(ParseToVariantList(
IREE_SV("i"), std::vector<std::string>{input_string}, &variant_list));
std::string result;
IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result));
EXPECT_EQ(result, std::string("result[0]: i32=") + input_string + "\n");
}
TEST_F(FunctionIOTest, ParsePrintRank0BufferView) {
std::string buf_string = "i32=42";
vm::ref<iree_vm_list_t> variant_list;
IREE_ASSERT_OK(ParseToVariantList(
IREE_SV("r"), std::vector<std::string>{buf_string}, &variant_list));
std::string result;
IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result));
EXPECT_EQ(result,
std::string("result[0]: hal.buffer_view\n") + buf_string + "\n");
}
TEST_F(FunctionIOTest, ParsePrintMultipleBufferViews) {
std::string buf_string1 = "2x2xi32=[42 43][44 45]";
std::string buf_string2 = "2x3xf64=[1 2 3][4 5 6]";
vm::ref<iree_vm_list_t> variant_list;
IREE_ASSERT_OK(ParseToVariantList(
IREE_SV("rr"), std::vector<std::string>{buf_string1, buf_string2},
&variant_list));
std::string result;
IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result));
EXPECT_EQ(result, std::string("result[0]: hal.buffer_view\n") + buf_string1 +
"\nresult[1]: hal.buffer_view\n" + buf_string2 + "\n");
}
} // namespace
} // namespace iree