blob: b42e765eb6aa0eb675ab60d70d7991be389927bb [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/tooling/buffer_view_matchers.h"
#include <math.h>
#include "iree/base/internal/math.h"
//===----------------------------------------------------------------------===//
// iree_hal_buffer_equality_t
//===----------------------------------------------------------------------===//
static iree_hal_buffer_element_t iree_hal_buffer_element_at(
iree_hal_element_type_t element_type, iree_const_byte_span_t elements,
iree_host_size_t index) {
iree_host_size_t element_size =
iree_hal_element_dense_byte_count(element_type);
iree_const_byte_span_t element_data = iree_make_const_byte_span(
elements.data + index * element_size, element_size);
iree_hal_buffer_element_t element = {
.type = element_type,
};
memcpy(element.storage, element_data.data, element_size);
return element;
}
static iree_status_t iree_hal_append_element_string(
iree_hal_buffer_element_t value, iree_string_builder_t* builder) {
char temp[64];
iree_host_size_t temp_length = 0;
IREE_RETURN_IF_ERROR(iree_hal_format_element(
iree_make_const_byte_span(value.storage,
iree_hal_element_dense_byte_count(value.type)),
value.type, sizeof(temp), temp, &temp_length));
return iree_string_builder_append_string(
builder, iree_make_string_view(temp, temp_length));
}
static iree_status_t iree_hal_append_element_mismatch_string(
iree_host_size_t index, iree_hal_buffer_element_t expected_element,
iree_hal_buffer_element_t actual_element, iree_string_builder_t* builder) {
IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
builder, "element at index %" PRIhsz " (", index));
IREE_RETURN_IF_ERROR(iree_hal_append_element_string(actual_element, builder));
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV(") does not match the expected (")));
IREE_RETURN_IF_ERROR(
iree_hal_append_element_string(expected_element, builder));
IREE_RETURN_IF_ERROR(
iree_string_builder_append_string(builder, IREE_SV(")")));
return iree_ok_status();
}
static bool iree_hal_compare_strided_elements_exact(
iree_hal_element_type_t element_type, iree_host_size_t element_count,
iree_const_byte_span_t expected_elements, iree_host_size_t expected_stride,
iree_const_byte_span_t actual_elements, iree_host_size_t actual_stride,
iree_host_size_t* out_index) {
const iree_host_size_t element_size =
iree_hal_element_dense_byte_count(element_type);
const uint8_t* expected_ptr = expected_elements.data;
const uint8_t* actual_ptr = actual_elements.data;
for (iree_host_size_t i = 0; i < element_count; ++i) {
int cmp = memcmp(expected_ptr, actual_ptr, element_size);
if (cmp != 0) {
*out_index = i;
return false;
}
expected_ptr += expected_stride * element_size;
actual_ptr += actual_stride * element_size;
}
return true;
}
static bool iree_hal_compare_strided_elements_approximate_absolute_f16(
iree_hal_buffer_equality_t equality, iree_host_size_t element_count,
const uint16_t* expected_ptr, iree_host_size_t expected_stride,
const uint16_t* actual_ptr, iree_host_size_t actual_stride,
iree_host_size_t* out_index) {
for (iree_host_size_t i = 0; i < element_count; ++i) {
if (fabsf(iree_math_f16_to_f32(*expected_ptr) -
iree_math_f16_to_f32(*actual_ptr)) > equality.f16_threshold) {
*out_index = i;
return false;
}
expected_ptr += expected_stride;
actual_ptr += actual_stride;
}
return true;
}
static bool iree_hal_compare_strided_elements_approximate_absolute_f32(
iree_hal_buffer_equality_t equality, iree_host_size_t element_count,
const float* expected_ptr, iree_host_size_t expected_stride,
const float* actual_ptr, iree_host_size_t actual_stride,
iree_host_size_t* out_index) {
for (iree_host_size_t i = 0; i < element_count; ++i) {
if (fabsf(*expected_ptr - *actual_ptr) > equality.f32_threshold) {
*out_index = i;
return false;
}
expected_ptr += expected_stride;
actual_ptr += actual_stride;
}
return true;
}
static bool iree_hal_compare_strided_elements_approximate_absolute_f64(
iree_hal_buffer_equality_t equality, iree_host_size_t element_count,
const double* expected_ptr, iree_host_size_t expected_stride,
const double* actual_ptr, iree_host_size_t actual_stride,
iree_host_size_t* out_index) {
for (iree_host_size_t i = 0; i < element_count; ++i) {
if (fabs(*expected_ptr - *actual_ptr) > equality.f64_threshold) {
*out_index = i;
return false;
}
expected_ptr += expected_stride;
actual_ptr += actual_stride;
}
return true;
}
static bool iree_hal_compare_strided_elements_approximate_absolute(
iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type,
iree_host_size_t element_count, iree_const_byte_span_t expected_elements,
iree_host_size_t expected_stride, iree_const_byte_span_t actual_elements,
iree_host_size_t actual_stride, iree_host_size_t* out_index) {
switch (element_type) {
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
return iree_hal_compare_strided_elements_approximate_absolute_f16(
equality, element_count, (const uint16_t*)expected_elements.data,
expected_stride, (const uint16_t*)actual_elements.data, actual_stride,
out_index);
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
return iree_hal_compare_strided_elements_approximate_absolute_f32(
equality, element_count, (const float*)expected_elements.data,
expected_stride, (const float*)actual_elements.data, actual_stride,
out_index);
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
return iree_hal_compare_strided_elements_approximate_absolute_f64(
equality, element_count, (const double*)expected_elements.data,
expected_stride, (const double*)actual_elements.data, actual_stride,
out_index);
default:
return iree_hal_compare_strided_elements_exact(
element_type, element_count, expected_elements, expected_stride,
actual_elements, actual_stride, out_index);
}
}
// Compares two buffers element by element.
// The provided strides (in elements) are applied up to |element_count|.
static bool iree_hal_compare_strided_elements(
iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type,
iree_host_size_t element_count, iree_const_byte_span_t expected_elements,
iree_host_size_t expected_stride, iree_const_byte_span_t actual_elements,
iree_host_size_t actual_stride, iree_host_size_t* out_index) {
switch (equality.mode) {
case IREE_HAL_BUFFER_EQUALITY_EXACT:
return iree_hal_compare_strided_elements_exact(
element_type, element_count, expected_elements, expected_stride,
actual_elements, actual_stride, out_index);
case IREE_HAL_BUFFER_EQUALITY_APPROXIMATE_ABSOLUTE:
return iree_hal_compare_strided_elements_approximate_absolute(
equality, element_type, element_count, expected_elements,
expected_stride, actual_elements, actual_stride, out_index);
default:
IREE_ASSERT(false && "unhandled equality mode");
return false;
}
}
bool iree_hal_compare_buffer_elements_broadcast(
iree_hal_buffer_equality_t equality,
iree_hal_buffer_element_t expected_element, iree_host_size_t element_count,
iree_const_byte_span_t actual_elements, iree_host_size_t* out_index) {
return iree_hal_compare_strided_elements(
equality, expected_element.type, element_count,
iree_make_const_byte_span(
expected_element.storage,
iree_hal_element_dense_byte_count(expected_element.type)),
/*expected_stride=*/0, actual_elements, /*actual_stride=*/1, out_index);
}
bool iree_hal_compare_buffer_elements_elementwise(
iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type,
iree_host_size_t element_count, iree_const_byte_span_t expected_elements,
iree_const_byte_span_t actual_elements, iree_host_size_t* out_index) {
return iree_hal_compare_strided_elements(
equality, element_type, element_count, expected_elements,
/*expected_stride=*/1, actual_elements, /*actual_stride=*/1, out_index);
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_view_metadata_matcher_t
//===----------------------------------------------------------------------===//
iree_status_t iree_hal_buffer_view_metadata_matcher_initialize(
iree_host_size_t shape_rank, const iree_hal_dim_t* shape,
iree_hal_element_type_t element_type,
iree_hal_encoding_type_t encoding_type,
iree_hal_buffer_view_metadata_matcher_t* out_matcher) {
IREE_ASSERT_ARGUMENT(!shape_rank || shape);
memset(out_matcher, 0, sizeof(*out_matcher));
if (shape_rank > IREE_ARRAYSIZE(out_matcher->shape)) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"maximum shape rank exceeded");
}
out_matcher->shape_rank = shape_rank;
memcpy(out_matcher->shape, shape, shape_rank * sizeof(*shape));
out_matcher->element_type = element_type;
out_matcher->encoding_type = encoding_type;
return iree_ok_status();
}
void iree_hal_buffer_view_metadata_matcher_deinitialize(
iree_hal_buffer_view_metadata_matcher_t* matcher) {
IREE_ASSERT_ARGUMENT(matcher);
memset(matcher, 0, sizeof(*matcher));
}
iree_status_t iree_hal_buffer_view_metadata_matcher_describe(
iree_hal_buffer_view_metadata_matcher_t* matcher,
iree_string_builder_t* builder) {
IREE_ASSERT_ARGUMENT(matcher);
IREE_ASSERT_ARGUMENT(builder);
IREE_RETURN_IF_ERROR(
iree_string_builder_append_string(builder, IREE_SV("matches ")));
IREE_RETURN_IF_ERROR(iree_hal_append_shape_and_element_type_string(
matcher->shape_rank, matcher->shape, matcher->element_type, builder));
return iree_ok_status();
}
static bool iree_hal_buffer_view_shape_matches(
iree_host_size_t shape_rank, const iree_hal_dim_t* shape,
iree_hal_buffer_view_t* matchee) {
if (shape_rank != iree_hal_buffer_view_shape_rank(matchee)) return false;
for (iree_host_size_t i = 0; i < shape_rank; ++i) {
if (shape[i] != iree_hal_buffer_view_shape_dim(matchee, i)) return false;
}
return true;
}
iree_status_t iree_hal_buffer_view_metadata_matcher_match(
iree_hal_buffer_view_metadata_matcher_t* matcher,
iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder,
bool* out_matched) {
IREE_ASSERT_ARGUMENT(matcher);
IREE_ASSERT_ARGUMENT(matchee);
IREE_ASSERT_ARGUMENT(builder);
IREE_ASSERT_ARGUMENT(out_matched);
*out_matched = false;
const bool shape_match = iree_hal_buffer_view_shape_matches(
matcher->shape_rank, matcher->shape, matchee);
const bool element_type_match =
matcher->element_type == IREE_HAL_ELEMENT_TYPE_NONE ||
matcher->element_type == iree_hal_buffer_view_element_type(matchee);
const bool encoding_type_match =
matcher->encoding_type == IREE_HAL_ENCODING_TYPE_OPAQUE ||
matcher->encoding_type == iree_hal_buffer_view_encoding_type(matchee);
if (shape_match && element_type_match && encoding_type_match) {
*out_matched = true;
return iree_ok_status();
}
IREE_RETURN_IF_ERROR(
iree_string_builder_append_string(builder, IREE_SV("metadata is ")));
IREE_RETURN_IF_ERROR(iree_hal_append_shape_and_element_type_string(
iree_hal_buffer_view_shape_rank(matchee),
iree_hal_buffer_view_shape_dims(matchee),
iree_hal_buffer_view_element_type(matchee), builder));
*out_matched = false;
return iree_ok_status();
}
iree_status_t iree_hal_buffer_view_match_metadata(
iree_host_size_t shape_rank, const iree_hal_dim_t* shape,
iree_hal_element_type_t element_type,
iree_hal_encoding_type_t encoding_type, iree_hal_buffer_view_t* matchee,
iree_string_builder_t* builder, bool* out_matched) {
iree_hal_buffer_view_metadata_matcher_t matcher;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_metadata_matcher_initialize(
shape_rank, shape, element_type, encoding_type, &matcher));
iree_status_t status = iree_hal_buffer_view_metadata_matcher_match(
&matcher, matchee, builder, out_matched);
if (iree_status_is_ok(status) && !*out_matched) {
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV("; expected that the view ")));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_metadata_matcher_describe(&matcher, builder));
}
iree_hal_buffer_view_metadata_matcher_deinitialize(&matcher);
return status;
}
iree_status_t iree_hal_buffer_view_match_metadata_like(
iree_hal_buffer_view_t* expected, iree_hal_buffer_view_t* matchee,
iree_string_builder_t* builder, bool* out_matched) {
IREE_ASSERT_ARGUMENT(expected);
return iree_hal_buffer_view_match_metadata(
iree_hal_buffer_view_shape_rank(expected),
iree_hal_buffer_view_shape_dims(expected),
iree_hal_buffer_view_element_type(expected),
iree_hal_buffer_view_encoding_type(expected), matchee, builder,
out_matched);
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_view_element_matcher_t
//===----------------------------------------------------------------------===//
iree_status_t iree_hal_buffer_view_element_matcher_initialize(
iree_hal_buffer_equality_t equality, iree_hal_buffer_element_t value,
iree_hal_buffer_view_element_matcher_t* out_matcher) {
memset(out_matcher, 0, sizeof(*out_matcher));
out_matcher->equality = equality;
out_matcher->value = value;
return iree_ok_status();
}
void iree_hal_buffer_view_element_matcher_deinitialize(
iree_hal_buffer_view_element_matcher_t* matcher) {
IREE_ASSERT_ARGUMENT(matcher);
memset(matcher, 0, sizeof(*matcher));
}
iree_status_t iree_hal_buffer_view_element_matcher_describe(
iree_hal_buffer_view_element_matcher_t* matcher,
iree_string_builder_t* builder) {
IREE_ASSERT_ARGUMENT(matcher);
IREE_ASSERT_ARGUMENT(builder);
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV("has all elements match ")));
IREE_RETURN_IF_ERROR(
iree_hal_append_element_type_string(matcher->value.type, builder));
IREE_RETURN_IF_ERROR(
iree_string_builder_append_string(builder, IREE_SV("=")));
IREE_RETURN_IF_ERROR(iree_hal_append_element_string(matcher->value, builder));
return iree_ok_status();
}
iree_status_t iree_hal_buffer_view_element_matcher_match(
iree_hal_buffer_view_element_matcher_t* matcher,
iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder,
bool* out_matched) {
IREE_ASSERT_ARGUMENT(matcher);
IREE_ASSERT_ARGUMENT(matchee);
IREE_ASSERT_ARGUMENT(builder);
IREE_ASSERT_ARGUMENT(out_matched);
*out_matched = false;
if (iree_hal_buffer_view_encoding_type(matchee) !=
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"non-dense encodings not supported for matching");
} else if (iree_hal_buffer_view_element_type(matchee) !=
matcher->value.type) {
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV("whose element type (")));
IREE_RETURN_IF_ERROR(iree_hal_append_element_type_string(
iree_hal_buffer_view_element_type(matchee), builder));
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV(") does not match expected (")));
IREE_RETURN_IF_ERROR(
iree_hal_append_element_type_string(matcher->value.type, builder));
IREE_RETURN_IF_ERROR(
iree_string_builder_append_string(builder, IREE_SV(")")));
*out_matched = false;
return iree_ok_status();
}
iree_hal_buffer_mapping_t actual_mapping;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
iree_hal_buffer_view_buffer(matchee), IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_HAL_WHOLE_BUFFER, &actual_mapping));
iree_const_byte_span_t actual_contents = iree_make_const_byte_span(
actual_mapping.contents.data, actual_mapping.contents.data_length);
iree_host_size_t i = 0;
const bool all_match = iree_hal_compare_buffer_elements_broadcast(
matcher->equality, matcher->value,
iree_hal_buffer_view_element_count(matchee), actual_contents, &i);
iree_hal_buffer_element_t actual_element = iree_hal_buffer_element_at(
iree_hal_buffer_view_element_type(matchee), actual_contents, i);
IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&actual_mapping));
if (!all_match) {
IREE_RETURN_IF_ERROR(iree_hal_append_element_mismatch_string(
i, matcher->value, actual_element, builder));
}
*out_matched = all_match;
return iree_ok_status();
}
iree_status_t iree_hal_buffer_view_match_elements(
iree_hal_buffer_equality_t equality, iree_hal_buffer_element_t value,
iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder,
bool* out_matched) {
iree_hal_buffer_view_element_matcher_t matcher;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_element_matcher_initialize(
equality, value, &matcher));
iree_status_t status = iree_hal_buffer_view_element_matcher_match(
&matcher, matchee, builder, out_matched);
if (iree_status_is_ok(status) && !*out_matched) {
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV("; expected that the view ")));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_element_matcher_describe(&matcher, builder));
}
iree_hal_buffer_view_element_matcher_deinitialize(&matcher);
return status;
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_view_array_matcher_t
//===----------------------------------------------------------------------===//
iree_status_t iree_hal_buffer_view_array_matcher_initialize(
iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type,
iree_host_size_t element_count, iree_const_byte_span_t elements,
iree_hal_buffer_view_array_matcher_t* out_matcher) {
IREE_ASSERT_ARGUMENT(!element_count ||
!iree_const_byte_span_is_empty(elements));
memset(out_matcher, 0, sizeof(*out_matcher));
out_matcher->equality = equality;
out_matcher->element_type = element_type;
out_matcher->element_count = element_count;
out_matcher->elements = elements;
return iree_ok_status();
}
void iree_hal_buffer_view_array_matcher_deinitialize(
iree_hal_buffer_view_array_matcher_t* matcher) {
IREE_ASSERT_ARGUMENT(matcher);
memset(matcher, 0, sizeof(*matcher));
}
iree_status_t iree_hal_buffer_view_array_matcher_describe(
iree_hal_buffer_view_array_matcher_t* matcher,
iree_string_builder_t* builder) {
IREE_ASSERT_ARGUMENT(matcher);
IREE_ASSERT_ARGUMENT(builder);
IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
builder,
"has all elements match those in %" PRIhsz " element <array> of ",
matcher->element_count));
IREE_RETURN_IF_ERROR(
iree_hal_append_element_type_string(matcher->element_type, builder));
// TODO(benvanik): format array contents (elided)? make caller do?
return iree_ok_status();
}
iree_status_t iree_hal_buffer_view_array_matcher_match(
iree_hal_buffer_view_array_matcher_t* matcher,
iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder,
bool* out_matched) {
IREE_ASSERT_ARGUMENT(matcher);
IREE_ASSERT_ARGUMENT(matchee);
IREE_ASSERT_ARGUMENT(builder);
IREE_ASSERT_ARGUMENT(out_matched);
*out_matched = false;
if (iree_hal_buffer_view_encoding_type(matchee) !=
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"non-dense encodings not supported for matching");
} else if (iree_hal_buffer_view_element_type(matchee) !=
matcher->element_type) {
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV("whose element type (")));
IREE_RETURN_IF_ERROR(iree_hal_append_element_type_string(
iree_hal_buffer_view_element_type(matchee), builder));
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV(") does not match expected (")));
IREE_RETURN_IF_ERROR(
iree_hal_append_element_type_string(matcher->element_type, builder));
IREE_RETURN_IF_ERROR(
iree_string_builder_append_string(builder, IREE_SV(")")));
*out_matched = false;
return iree_ok_status();
} else if (iree_hal_buffer_view_element_count(matchee) !=
matcher->element_count) {
IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
builder,
"whose element count (%" PRIhsz ") does not match expected (%" PRIhsz
")",
iree_hal_buffer_view_element_count(matchee), matcher->element_count));
*out_matched = false;
return iree_ok_status();
}
iree_hal_buffer_mapping_t actual_mapping;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
iree_hal_buffer_view_buffer(matchee), IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_HAL_WHOLE_BUFFER, &actual_mapping));
iree_const_byte_span_t actual_contents = iree_make_const_byte_span(
actual_mapping.contents.data, actual_mapping.contents.data_length);
iree_host_size_t i = 0;
const bool all_match = iree_hal_compare_buffer_elements_elementwise(
matcher->equality, iree_hal_buffer_view_element_type(matchee),
iree_hal_buffer_view_element_count(matchee), matcher->elements,
actual_contents, &i);
iree_hal_buffer_element_t actual_element = iree_hal_buffer_element_at(
iree_hal_buffer_view_element_type(matchee), actual_contents, i);
iree_hal_buffer_element_t expected_element = iree_hal_buffer_element_at(
iree_hal_buffer_view_element_type(matchee), matcher->elements, i);
IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&actual_mapping));
if (!all_match) {
IREE_RETURN_IF_ERROR(iree_hal_append_element_mismatch_string(
i, expected_element, actual_element, builder));
}
*out_matched = all_match;
return iree_ok_status();
}
iree_status_t iree_hal_buffer_view_match_array(
iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type,
iree_host_size_t element_count, iree_const_byte_span_t elements,
iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder,
bool* out_matched) {
iree_hal_buffer_view_array_matcher_t matcher;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_array_matcher_initialize(
equality, element_type, element_count, elements, &matcher));
iree_status_t status = iree_hal_buffer_view_array_matcher_match(
&matcher, matchee, builder, out_matched);
if (iree_status_is_ok(status) && !*out_matched) {
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV("; expected that the view ")));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_array_matcher_describe(&matcher, builder));
}
iree_hal_buffer_view_array_matcher_deinitialize(&matcher);
return status;
}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_view_matcher_t
//===----------------------------------------------------------------------===//
iree_status_t iree_hal_buffer_view_matcher_initialize(
iree_hal_buffer_equality_t equality, iree_hal_buffer_view_t* expected,
iree_hal_buffer_view_matcher_t* out_matcher) {
IREE_ASSERT_ARGUMENT(expected);
memset(out_matcher, 0, sizeof(*out_matcher));
out_matcher->equality = equality;
out_matcher->expected = expected;
iree_hal_buffer_view_retain(expected);
return iree_ok_status();
}
void iree_hal_buffer_view_matcher_deinitialize(
iree_hal_buffer_view_matcher_t* matcher) {
IREE_ASSERT_ARGUMENT(matcher);
iree_hal_buffer_view_release(matcher->expected);
memset(matcher, 0, sizeof(*matcher));
}
iree_status_t iree_hal_buffer_view_matcher_describe(
iree_hal_buffer_view_matcher_t* matcher, iree_string_builder_t* builder) {
IREE_ASSERT_ARGUMENT(matcher);
IREE_ASSERT_ARGUMENT(builder);
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV("is equal to contents of a view of ")));
IREE_RETURN_IF_ERROR(iree_hal_append_shape_and_element_type_string(
iree_hal_buffer_view_shape_rank(matcher->expected),
iree_hal_buffer_view_shape_dims(matcher->expected),
iree_hal_buffer_view_element_type(matcher->expected), builder));
// TODO(benvanik): format buffer view contents (elided)? make caller do?
return iree_ok_status();
}
iree_status_t iree_hal_buffer_view_matcher_match(
iree_hal_buffer_view_matcher_t* matcher, iree_hal_buffer_view_t* matchee,
iree_string_builder_t* builder, bool* out_matched) {
IREE_ASSERT_ARGUMENT(matcher);
IREE_ASSERT_ARGUMENT(matchee);
IREE_ASSERT_ARGUMENT(builder);
IREE_ASSERT_ARGUMENT(out_matched);
*out_matched = false;
if (iree_hal_buffer_view_encoding_type(matchee) !=
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"non-dense encodings not supported for matching");
}
// Reuse metadata matching to ensure the buffer views are the same shape/type.
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_match_metadata_like(
matcher->expected, matchee, builder, out_matched));
if (!*out_matched) return iree_ok_status();
iree_hal_buffer_mapping_t actual_mapping;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
iree_hal_buffer_view_buffer(matchee), IREE_HAL_MAPPING_MODE_SCOPED,
IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_HAL_WHOLE_BUFFER, &actual_mapping));
iree_hal_buffer_mapping_t expected_mapping;
iree_status_t status = iree_hal_buffer_map_range(
iree_hal_buffer_view_buffer(matcher->expected),
IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_READ, 0,
IREE_HAL_WHOLE_BUFFER, &expected_mapping);
if (!iree_status_is_ok(status)) {
iree_hal_buffer_unmap_range(&actual_mapping);
return status;
}
iree_const_byte_span_t actual_contents = iree_make_const_byte_span(
actual_mapping.contents.data, actual_mapping.contents.data_length);
iree_const_byte_span_t expected_contents = iree_make_const_byte_span(
expected_mapping.contents.data, expected_mapping.contents.data_length);
iree_host_size_t i = 0;
const bool all_match = iree_hal_compare_buffer_elements_elementwise(
matcher->equality, iree_hal_buffer_view_element_type(matchee),
iree_hal_buffer_view_element_count(matchee), expected_contents,
actual_contents, &i);
iree_hal_buffer_element_t actual_element = iree_hal_buffer_element_at(
iree_hal_buffer_view_element_type(matchee), actual_contents, i);
iree_hal_buffer_element_t expected_element = iree_hal_buffer_element_at(
iree_hal_buffer_view_element_type(matchee), expected_contents, i);
IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&actual_mapping));
IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&expected_mapping));
if (!all_match) {
IREE_RETURN_IF_ERROR(iree_hal_append_element_mismatch_string(
i, expected_element, actual_element, builder));
}
*out_matched = all_match;
return iree_ok_status();
}
iree_status_t iree_hal_buffer_view_match_equal(
iree_hal_buffer_equality_t equality, iree_hal_buffer_view_t* expected,
iree_hal_buffer_view_t* matchee, iree_string_builder_t* builder,
bool* out_matched) {
iree_hal_buffer_view_matcher_t matcher;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_matcher_initialize(equality, expected, &matcher));
iree_status_t status = iree_hal_buffer_view_matcher_match(
&matcher, matchee, builder, out_matched);
if (iree_status_is_ok(status) && !*out_matched) {
IREE_RETURN_IF_ERROR(iree_string_builder_append_string(
builder, IREE_SV("; expected that the view ")));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_matcher_describe(&matcher, builder));
}
iree_hal_buffer_view_matcher_deinitialize(&matcher);
return status;
}