blob: a4857b78742322f77964029f3302d938c5695726 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/tooling/buffer_view_matchers.h"
#include "iree/base/api.h"
#include "iree/base/internal/math.h"
#include "iree/base/internal/span.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
namespace iree {
namespace {
using iree::testing::status::IsOk;
using iree::testing::status::StatusIs;
using ::testing::HasSubstr;
// TODO(benvanik): move this handle type to a base cc helper.
struct StringBuilder {
static StringBuilder MakeSystem() {
iree_string_builder_t builder;
iree_string_builder_initialize(iree_allocator_system(), &builder);
return StringBuilder(builder);
}
static StringBuilder MakeEmpty() {
iree_string_builder_t builder;
iree_string_builder_initialize(iree_allocator_null(), &builder);
return StringBuilder(builder);
}
explicit StringBuilder(iree_string_builder_t builder)
: builder(std::move(builder)) {}
~StringBuilder() { iree_string_builder_deinitialize(&builder); }
operator iree_string_builder_t*() { return &builder; }
std::string ToString() const {
return std::string(builder.buffer, builder.size);
}
iree_string_builder_t builder;
};
// TODO(benvanik): move this handle type to a hal cc helper.
// C API iree_*_retain/iree_*_release function pointer.
template <typename T>
using HandleRefFn = void(IREE_API_PTR*)(T*);
// C++ RAII wrapper for an IREE C reference object.
// Behaves the same as a thread-safe intrusive pointer.
template <typename T, HandleRefFn<T> retain_fn, HandleRefFn<T> release_fn>
class Handle {
public:
using handle_type = Handle<T, retain_fn, release_fn>;
static Handle Wrap(T* value) noexcept { return Handle(value, false); }
Handle() noexcept = default;
Handle(std::nullptr_t) noexcept {}
Handle(T* value) noexcept : value_(value) { retain_fn(value_); }
~Handle() noexcept {
if (value_) release_fn(value_);
}
Handle(const Handle& rhs) noexcept : value_(rhs.value_) {
if (value_) retain_fn(value_);
}
Handle& operator=(const Handle& rhs) noexcept {
if (value_ != rhs.value_) {
if (value_) release_fn(value_);
value_ = rhs.get();
if (value_) retain_fn(value_);
}
return *this;
}
Handle(Handle&& rhs) noexcept : value_(rhs.release()) {}
Handle& operator=(Handle&& rhs) noexcept {
if (value_ != rhs.value_) {
if (value_) release_fn(value_);
value_ = rhs.release();
}
return *this;
}
// Gets the pointer referenced by this instance.
constexpr T* get() const noexcept { return value_; }
constexpr operator T*() const noexcept { return value_; }
// Resets the object to nullptr and decrements the reference count, possibly
// deleting it.
void reset() noexcept {
if (value_) {
release_fn(value_);
value_ = nullptr;
}
}
// Returns the current pointer held by this object without having its
// reference count decremented and resets the handle to empty. Returns
// nullptr if the handle holds no value. To re-wrap in a handle use either
// ctor(value) or assign().
T* release() noexcept {
auto* p = value_;
value_ = nullptr;
return p;
}
// Assigns a pointer.
// The pointer will be accepted by the handle and its reference count will
// not be incremented.
void assign(T* value) noexcept {
reset();
value_ = value;
}
// Returns a pointer to the inner pointer storage.
// This allows passing a pointer to the handle as an output argument to
// C-style creation functions.
constexpr T** operator&() noexcept { return &value_; }
// Support boolean expression evaluation ala unique_ptr/shared_ptr:
// https://en.cppreference.com/w/cpp/memory/shared_ptr/operator_bool
typedef T* Handle::*unspecified_bool_type;
constexpr operator unspecified_bool_type() const noexcept {
return value_ ? &Handle::value_ : nullptr;
}
// Supports unary expression evaluation.
constexpr bool operator!() const noexcept { return !value_; }
// Swap support.
void swap(Handle& rhs) noexcept { std::swap(value_, rhs.value_); }
protected:
Handle(T* value, bool) noexcept : value_(value) {}
private:
T* value_ = nullptr;
};
// C++ wrapper for iree_hal_buffer_view_t.
struct BufferView final
: public Handle<iree_hal_buffer_view_t, iree_hal_buffer_view_retain,
iree_hal_buffer_view_release> {
using handle_type::handle_type;
};
static const iree_hal_buffer_equality_t kExactEquality = ([]() {
iree_hal_buffer_equality_t equality;
equality.mode = IREE_HAL_BUFFER_EQUALITY_EXACT;
return equality;
})();
static const iree_hal_buffer_equality_t kApproximateEquality = ([]() {
iree_hal_buffer_equality_t equality;
equality.mode = IREE_HAL_BUFFER_EQUALITY_APPROXIMATE_ABSOLUTE;
equality.f16_threshold = 0.001f;
equality.f32_threshold = 0.0001f;
equality.f64_threshold = 0.0001;
return equality;
})();
class BufferViewMatchersTest : public ::testing::Test {
protected:
iree_hal_allocator_t* device_allocator_ = nullptr;
virtual void SetUp() {
IREE_CHECK_OK(iree_hal_allocator_create_heap(
IREE_SV("heap"), iree_allocator_system(), iree_allocator_system(),
&device_allocator_));
}
virtual void TearDown() { iree_hal_allocator_release(device_allocator_); }
template <typename T>
StatusOr<BufferView> CreateBufferView(iree::span<const iree_hal_dim_t> shape,
iree_hal_element_type_t element_type,
const T* contents) {
iree_hal_dim_t num_elements = 1;
for (iree_hal_dim_t dim : shape) num_elements *= dim;
iree_hal_buffer_params_t params = {0};
params.type =
IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
params.usage =
IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING;
BufferView buffer_view;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
device_allocator_, shape.size(), shape.data(), element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, params,
iree_make_const_byte_span(contents, num_elements * sizeof(T)),
&buffer_view));
return std::move(buffer_view);
}
};
//===----------------------------------------------------------------------===//
// iree_hal_buffer_equality_t
//===----------------------------------------------------------------------===//
TEST_F(BufferViewMatchersTest, CompareBroadcastI8EQ) {
const int8_t lhs = 1;
const int8_t rhs[] = {1, 1, 1};
iree_host_size_t index = 0;
EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_i8(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
}
TEST_F(BufferViewMatchersTest, CompareBroadcastI8NE) {
const int8_t lhs = 1;
const int8_t rhs[] = {1, 2, 3};
iree_host_size_t index = 0;
EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_i8(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
EXPECT_EQ(index, 1);
}
TEST_F(BufferViewMatchersTest, CompareBroadcastI64EQ) {
const int64_t lhs = 1;
const int64_t rhs[] = {1, 1, 1};
iree_host_size_t index = 0;
EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_i64(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
}
TEST_F(BufferViewMatchersTest, CompareBroadcastI64NE) {
const int64_t lhs = 1;
const int64_t rhs[] = {1, 2, 3};
iree_host_size_t index = 0;
EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_i64(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
EXPECT_EQ(index, 1);
}
TEST_F(BufferViewMatchersTest, CompareBroadcastF16EQ) {
const float lhs = 1.0f;
const uint16_t rhs[] = {
iree_math_f32_to_f16(1.0f),
iree_math_f32_to_f16(1.0f),
iree_math_f32_to_f16(1.0f),
};
iree_host_size_t index = 0;
EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_f16(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
}
TEST_F(BufferViewMatchersTest, CompareBroadcastF16NE) {
const float lhs = 1.0f;
const uint16_t rhs[] = {
iree_math_f32_to_f16(1.0f),
iree_math_f32_to_f16(3.0f),
iree_math_f32_to_f16(4.0f),
};
iree_host_size_t index = 0;
EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_f16(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
EXPECT_EQ(index, 1);
}
TEST_F(BufferViewMatchersTest, CompareBroadcastF32EQ) {
const float lhs = 1.0f;
const float rhs[] = {1.0f, 1.0f, 1.0f};
iree_host_size_t index = 0;
EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_f32(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
}
TEST_F(BufferViewMatchersTest, CompareBroadcastF32NE) {
const float lhs = 1.0f;
const float rhs[] = {1.0f, 3.0f, 4.0f};
iree_host_size_t index = 0;
EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_f32(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
EXPECT_EQ(index, 1);
}
TEST_F(BufferViewMatchersTest, CompareBroadcastF64EQ) {
const double lhs = 1.0;
const double rhs[] = {1.0, 1.0, 1.0};
iree_host_size_t index = 0;
EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_f64(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
}
TEST_F(BufferViewMatchersTest, CompareBroadcastF64NE) {
const double lhs = 1.0;
const double rhs[] = {1.0, 3.0, 4.0};
iree_host_size_t index = 0;
EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast(
kApproximateEquality, iree_hal_make_buffer_element_f64(lhs),
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
&index));
EXPECT_EQ(index, 1);
}
TEST_F(BufferViewMatchersTest, CompareElementwiseF16EQ) {
const uint16_t lhs[] = {
iree_math_f32_to_f16(1.0f),
iree_math_f32_to_f16(2.0f),
iree_math_f32_to_f16(3.0f),
};
const uint16_t rhs[] = {
iree_math_f32_to_f16(1.0f),
iree_math_f32_to_f16(2.0f),
iree_math_f32_to_f16(3.0f),
};
iree_host_size_t index = 0;
EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise(
kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_16, IREE_ARRAYSIZE(lhs),
iree_make_const_byte_span(lhs, sizeof(lhs)),
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
}
TEST_F(BufferViewMatchersTest, CompareElementwiseF16NearEQ) {
const uint16_t lhs[] = {
iree_math_f32_to_f16(1.0f),
iree_math_f32_to_f16(1.99999f),
iree_math_f32_to_f16(0.00001f),
iree_math_f32_to_f16(4.0f),
};
const uint16_t rhs[] = {
iree_math_f32_to_f16(1.00001f),
iree_math_f32_to_f16(2.0f),
iree_math_f32_to_f16(0.0f),
iree_math_f32_to_f16(4.0f),
};
iree_host_size_t index = 0;
EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise(
kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_16, IREE_ARRAYSIZE(lhs),
iree_make_const_byte_span(lhs, sizeof(lhs)),
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
}
TEST_F(BufferViewMatchersTest, CompareElementwiseF16NE) {
const uint16_t lhs[] = {
iree_math_f32_to_f16(1.0f),
iree_math_f32_to_f16(2.0f),
iree_math_f32_to_f16(4.0f),
};
const uint16_t rhs[] = {
iree_math_f32_to_f16(1.0f),
iree_math_f32_to_f16(3.0f),
iree_math_f32_to_f16(4.0f),
};
iree_host_size_t index = 0;
EXPECT_FALSE(iree_hal_compare_buffer_elements_elementwise(
kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_16, IREE_ARRAYSIZE(lhs),
iree_make_const_byte_span(lhs, sizeof(lhs)),
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
EXPECT_EQ(index, 1);
}
TEST_F(BufferViewMatchersTest, CompareElementwiseF32EQ) {
const float lhs[] = {1.0f, 2.0f, 3.0f};
const float rhs[] = {1.0f, 2.0f, 3.0f};
iree_host_size_t index = 0;
EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise(
kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_ARRAYSIZE(lhs),
iree_make_const_byte_span(lhs, sizeof(lhs)),
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
}
TEST_F(BufferViewMatchersTest, CompareElementwiseF32NE) {
const float lhs[] = {1.0f, 2.0f, 4.0f};
const float rhs[] = {1.0f, 3.0f, 4.0f};
iree_host_size_t index = 0;
EXPECT_FALSE(iree_hal_compare_buffer_elements_elementwise(
kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_ARRAYSIZE(lhs),
iree_make_const_byte_span(lhs, sizeof(lhs)),
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
EXPECT_EQ(index, 1);
}
TEST_F(BufferViewMatchersTest, CompareElementwiseF64EQ) {
const double lhs[] = {1.0, 2.0, 3.0};
const double rhs[] = {1.0, 2.0, 3.0};
iree_host_size_t index = 0;
EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise(
kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_64, IREE_ARRAYSIZE(lhs),
iree_make_const_byte_span(lhs, sizeof(lhs)),
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
}
TEST_F(BufferViewMatchersTest, CompareElementwiseF64NE) {
const double lhs[] = {1.0, 2.0, 4.0};
const double rhs[] = {1.0, 3.0, 4.0};
iree_host_size_t index = 0;
EXPECT_FALSE(iree_hal_compare_buffer_elements_elementwise(
kApproximateEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_64, IREE_ARRAYSIZE(lhs),
iree_make_const_byte_span(lhs, sizeof(lhs)),
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
EXPECT_EQ(index, 1);
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_view_metadata_matcher_t
//===----------------------------------------------------------------------===//
TEST_F(BufferViewMatchersTest, MetadataEmpty) {
const float contents[1] = {0};
iree_hal_dim_t shape[] = {0};
IREE_ASSERT_OK_AND_ASSIGN(
auto lhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents));
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(
iree_hal_buffer_view_match_metadata_like(lhs, rhs, sb, &match));
EXPECT_TRUE(match);
}
TEST_F(BufferViewMatchersTest, MetadataShapesDiffer) {
const float lhs_contents[] = {1.0f};
const float rhs_contents[] = {1.0f, 2.0f};
iree_hal_dim_t lhs_shape[] = {1};
iree_hal_dim_t rhs_shape[] = {1, 2};
IREE_ASSERT_OK_AND_ASSIGN(
auto lhs, CreateBufferView(lhs_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
lhs_contents));
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs, CreateBufferView(rhs_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(
iree_hal_buffer_view_match_metadata_like(lhs, rhs, sb, &match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("is 1x2xf32"));
EXPECT_THAT(sb.ToString(), HasSubstr("matches 1xf32"));
}
TEST_F(BufferViewMatchersTest, MetadataElementTypesDiffer) {
const float contents[] = {1.0f};
iree_hal_dim_t shape[] = {1};
IREE_ASSERT_OK_AND_ASSIGN(
auto lhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, contents));
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(
iree_hal_buffer_view_match_metadata_like(lhs, rhs, sb, &match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("is 1xf32"));
EXPECT_THAT(sb.ToString(), HasSubstr("matches 1xi32"));
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_view_element_matcher_t
//===----------------------------------------------------------------------===//
TEST_F(BufferViewMatchersTest, ElementTypesDiffer) {
const float lhs_value = 1;
const int32_t rhs_contents[] = {1, 1, 1};
const iree_hal_dim_t shape[] = {1, 3};
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(iree_hal_buffer_view_match_elements(
kExactEquality, iree_hal_make_buffer_element_f32(lhs_value), rhs, sb,
&match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("type (i32)"));
EXPECT_THAT(sb.ToString(), HasSubstr("expected (f32)"));
}
TEST_F(BufferViewMatchersTest, MatchElementContentsI32) {
const int32_t lhs_value = 1;
const int32_t rhs_contents[] = {1, 1, 1};
const iree_hal_dim_t shape[] = {1, 3};
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(iree_hal_buffer_view_match_elements(
kExactEquality, iree_hal_make_buffer_element_i32(lhs_value), rhs, sb,
&match));
EXPECT_TRUE(match);
EXPECT_TRUE(sb.ToString().empty());
}
TEST_F(BufferViewMatchersTest, MismatchElementContentsI32) {
const int32_t lhs_value = 1;
const int32_t rhs_contents[] = {1, 2, 3};
const iree_hal_dim_t shape[] = {1, 3};
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(iree_hal_buffer_view_match_elements(
kExactEquality, iree_hal_make_buffer_element_i32(lhs_value), rhs, sb,
&match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("element at index 1"));
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_view_array_matcher_t
//===----------------------------------------------------------------------===//
TEST_F(BufferViewMatchersTest, MatchArrayTypesDiffer) {
const float lhs_contents[] = {1, 1, 1};
const int32_t rhs_contents[] = {1, 1, 1};
const iree_hal_dim_t shape[] = {1, 3};
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(iree_hal_buffer_view_match_array(
kExactEquality, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_ARRAYSIZE(lhs_contents),
iree_make_const_byte_span(lhs_contents, sizeof(lhs_contents)), rhs, sb,
&match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("type (i32)"));
EXPECT_THAT(sb.ToString(), HasSubstr("expected (f32)"));
}
TEST_F(BufferViewMatchersTest, MatchArrayCountsDiffer) {
const int32_t lhs_contents[] = {1, 1};
const int32_t rhs_contents[] = {1, 1, 1};
const iree_hal_dim_t shape[] = {1, 3};
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(iree_hal_buffer_view_match_array(
kExactEquality, IREE_HAL_ELEMENT_TYPE_INT_32,
IREE_ARRAYSIZE(lhs_contents),
iree_make_const_byte_span(lhs_contents, sizeof(lhs_contents)), rhs, sb,
&match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("count (3)"));
EXPECT_THAT(sb.ToString(), HasSubstr("expected (2)"));
}
TEST_F(BufferViewMatchersTest, MatchArrayContentsI32) {
const int32_t lhs_contents[] = {1, 1, 1};
const int32_t rhs_contents[] = {1, 1, 1};
const iree_hal_dim_t shape[] = {1, 3};
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(iree_hal_buffer_view_match_array(
kExactEquality, IREE_HAL_ELEMENT_TYPE_INT_32,
IREE_ARRAYSIZE(lhs_contents),
iree_make_const_byte_span(lhs_contents, sizeof(lhs_contents)), rhs, sb,
&match));
EXPECT_TRUE(match);
EXPECT_TRUE(sb.ToString().empty());
}
TEST_F(BufferViewMatchersTest, MismatchArrayContentsI32) {
const int32_t lhs_contents[] = {1, 1, 1};
const int32_t rhs_contents[] = {1, 2, 3};
const iree_hal_dim_t shape[] = {1, 3};
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(iree_hal_buffer_view_match_array(
kExactEquality, IREE_HAL_ELEMENT_TYPE_INT_32,
IREE_ARRAYSIZE(lhs_contents),
iree_make_const_byte_span(lhs_contents, sizeof(lhs_contents)), rhs, sb,
&match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("element at index 1"));
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_view_matcher_t
//===----------------------------------------------------------------------===//
TEST_F(BufferViewMatchersTest, MatchEmpty) {
const float contents[1] = {0};
iree_hal_dim_t shape[] = {0};
IREE_ASSERT_OK_AND_ASSIGN(
auto lhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents));
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(
iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match));
EXPECT_TRUE(match);
EXPECT_TRUE(sb.ToString().empty());
}
TEST_F(BufferViewMatchersTest, MatchShapesDiffer) {
const float lhs_contents[] = {1.0f};
const float rhs_contents[] = {1.0f, 2.0f};
iree_hal_dim_t lhs_shape[] = {1};
iree_hal_dim_t rhs_shape[] = {1, 2};
IREE_ASSERT_OK_AND_ASSIGN(
auto lhs, CreateBufferView(lhs_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
lhs_contents));
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs, CreateBufferView(rhs_shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(
iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("is 1x2xf32"));
EXPECT_THAT(sb.ToString(), HasSubstr("matches 1xf32"));
}
TEST_F(BufferViewMatchersTest, MatchElementTypesDiffer) {
const float contents[] = {1.0f};
iree_hal_dim_t shape[] = {1};
IREE_ASSERT_OK_AND_ASSIGN(
auto lhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_INT_32, contents));
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_32, contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(
iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("is 1xf32"));
EXPECT_THAT(sb.ToString(), HasSubstr("matches 1xi32"));
}
TEST_F(BufferViewMatchersTest, MatchContentsF16) {
const uint16_t lhs_contents[] = {iree_math_f32_to_f16(2.0f)};
const uint16_t rhs_contents[] = {iree_math_f32_to_f16(2.0f)};
iree_hal_dim_t shape[] = {1};
IREE_ASSERT_OK_AND_ASSIGN(
auto lhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_16, lhs_contents));
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_16, rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(
iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match));
EXPECT_TRUE(match);
EXPECT_TRUE(sb.ToString().empty());
}
TEST_F(BufferViewMatchersTest, MismatchContentsF16) {
const uint16_t lhs_contents[] = {iree_math_f32_to_f16(1.0f)};
const uint16_t rhs_contents[] = {iree_math_f32_to_f16(2.0f)};
const iree_hal_dim_t shape[] = {1};
IREE_ASSERT_OK_AND_ASSIGN(
auto lhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_16, lhs_contents));
IREE_ASSERT_OK_AND_ASSIGN(
auto rhs,
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_FLOAT_16, rhs_contents));
auto sb = StringBuilder::MakeSystem();
bool match = false;
IREE_ASSERT_OK(
iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match));
EXPECT_FALSE(match);
EXPECT_THAT(sb.ToString(), HasSubstr("element at index 0"));
}
} // namespace
} // namespace iree