blob: b60e96e406a9dd68b6b44304653088f87bb639da [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/tooling/comparison.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
#include "iree/tooling/vm_util.h"
#include "iree/vm/api.h"
#include "iree/vm/ref_cc.h"
namespace iree {
namespace {
using ::testing::HasSubstr;
class ComparisonTest : public ::testing::Test {
protected:
virtual void SetUp() {
IREE_ASSERT_OK(iree_vm_instance_create(host_allocator_, &instance_));
IREE_ASSERT_OK(iree_hal_module_register_all_types(instance_));
IREE_ASSERT_OK(iree_hal_allocator_create_heap(
IREE_SV("heap"), host_allocator_, host_allocator_, &device_allocator_));
}
virtual void TearDown() {
iree_hal_allocator_release(device_allocator_);
iree_vm_instance_release(instance_);
}
bool ParseAndCompareVariantLists(
iree::span<const std::string> expected_strings,
iree::span<const std::string> actual_strings, std::ostream* os) {
vm::ref<iree_vm_list_t> expected_list;
IREE_CHECK_OK(ParseToVariantList(device_allocator_, expected_strings,
host_allocator_, &expected_list));
vm::ref<iree_vm_list_t> actual_list;
IREE_CHECK_OK(ParseToVariantList(device_allocator_, actual_strings,
host_allocator_, &actual_list));
return iree_tooling_compare_variant_lists(
expected_list.get(), actual_list.get(), host_allocator_, os);
}
iree_vm_instance_t* instance_ = nullptr;
iree_allocator_t host_allocator_ = iree_allocator_system();
iree_hal_allocator_t* device_allocator_ = nullptr;
};
TEST_F(ComparisonTest, CompareEqualLists) {
std::string buf_string1 = "2x2xi32=[42 43][44 45]";
std::string buf_string2 = "2x3xf64=[1 2 3][4 5 6]";
auto buf_strings = std::vector<std::string>{buf_string1, buf_string2};
std::stringstream os;
EXPECT_TRUE(ParseAndCompareVariantLists(buf_strings, buf_strings, &os));
EXPECT_EQ(os.str(), "");
}
TEST_F(ComparisonTest, CompareListsWithIgnored) {
std::string buf_string1 = "2x2xi32=[42 43][44 45]";
std::string buf_string2 = "2x3xf64=[1 2 999][4 5 6]";
std::string buf_string2_ignored = "(ignored)";
auto actual_strings = std::vector<std::string>{buf_string1, buf_string2};
auto expected_strings =
std::vector<std::string>{buf_string1, buf_string2_ignored};
std::stringstream os;
EXPECT_TRUE(
ParseAndCompareVariantLists(expected_strings, actual_strings, &os));
EXPECT_EQ(os.str(), "");
}
TEST_F(ComparisonTest, CompareTruncatedLists) {
std::string buf_string1 = "2x2xi32=[42 43][44 45]";
std::string buf_string2 = "2x3xf64=[1 2 3][4 5 6]";
auto actual_strings = std::vector<std::string>{buf_string1, buf_string2};
auto expected_strings = std::vector<std::string>{buf_string1};
std::stringstream os;
EXPECT_FALSE(
ParseAndCompareVariantLists(expected_strings, actual_strings, &os));
EXPECT_THAT(os.str(), HasSubstr("expected 1 list elements but 2 provided"));
}
TEST_F(ComparisonTest, CompareDifferingLists) {
std::string buf_string1 = "2x2xi32=[42 43][44 45]";
std::string buf_string2 = "2x3xf64=[1 2 999][4 5 6]";
std::string buf_string2_good = "2x3xf64=[1 2 3][4 5 6]";
auto actual_strings = std::vector<std::string>{buf_string1, buf_string2};
auto expected_strings =
std::vector<std::string>{buf_string1, buf_string2_good};
std::stringstream os;
EXPECT_FALSE(
ParseAndCompareVariantLists(expected_strings, actual_strings, &os));
EXPECT_THAT(
os.str(),
HasSubstr("element at index 2 (999) does not match the expected (3)"));
}
TEST_F(ComparisonTest, CompareListsWithDifferingTypes) {
std::string buf_string1 = "2x2xi32=[42 43][44 45]";
std::string buf_string2 = "123";
std::string buf_string2_good = "2x3xf64=[1 2 3][4 5 6]";
auto actual_strings = std::vector<std::string>{buf_string1, buf_string2};
auto expected_strings =
std::vector<std::string>{buf_string1, buf_string2_good};
std::stringstream os;
EXPECT_FALSE(
ParseAndCompareVariantLists(expected_strings, actual_strings, &os));
EXPECT_THAT(os.str(), HasSubstr("variant types mismatch"));
}
} // namespace
} // namespace iree