blob: cd8275b95b90e9b77365d9da103d6ab6e9374c6a [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/hal/string_util.h"
#include <ctype.h>
#include <inttypes.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "iree/base/api.h"
#include "iree/base/internal/math.h"
#include "iree/hal/buffer_view.h"
IREE_API_EXPORT iree_status_t iree_hal_parse_shape(
iree_string_view_t value, iree_host_size_t shape_capacity,
iree_host_size_t* out_shape_rank, iree_hal_dim_t* out_shape) {
IREE_ASSERT_ARGUMENT(out_shape_rank);
*out_shape_rank = 0;
if (iree_string_view_is_empty(value)) {
return iree_ok_status(); // empty shape
}
// Count the number of dimensions to see if we have capacity.
iree_host_size_t shape_rank = 1; // always at least one if we are not empty
for (iree_host_size_t i = 0; i < value.size; ++i) {
if (value.data[i] == 'x') ++shape_rank;
}
if (out_shape_rank) {
*out_shape_rank = shape_rank;
}
if (shape_rank > shape_capacity) {
// NOTE: fast return for capacity queries.
return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
}
iree_host_size_t dim_index = 0;
iree_string_view_t lhs;
iree_string_view_t rhs = value;
while (iree_string_view_split(rhs, 'x', &lhs, &rhs) &&
!iree_string_view_is_empty(lhs)) {
iree_hal_dim_t dim_value = 0;
if (sizeof(iree_hal_dim_t) == 32) {
int32_t parsed_value = 0;
if (!iree_string_view_atoi_int32(lhs, &parsed_value) ||
parsed_value < 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"shape[%zu] invalid value '%.*s' of '%.*s'",
dim_index, (int)lhs.size, lhs.data,
(int)value.size, value.data);
}
dim_value = parsed_value;
} else {
int64_t parsed_value = 0;
if (!iree_string_view_atoi_int64(lhs, &parsed_value) ||
parsed_value < 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"shape[%zu] invalid value '%.*s' of '%.*s'",
dim_index, (int)lhs.size, lhs.data,
(int)value.size, value.data);
}
dim_value = parsed_value;
}
out_shape[dim_index++] = dim_value;
}
if (dim_index != shape_rank) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"invalid shape specification: '%.*s'",
(int)value.size, value.data);
}
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t
iree_hal_format_shape(iree_host_size_t shape_rank, const iree_hal_dim_t* shape,
iree_host_size_t buffer_capacity, char* buffer,
iree_host_size_t* out_buffer_length) {
if (out_buffer_length) {
*out_buffer_length = 0;
}
iree_host_size_t buffer_length = 0;
for (iree_host_size_t i = 0; i < shape_rank; ++i) {
int n =
snprintf(buffer ? buffer + buffer_length : NULL,
buffer ? buffer_capacity - buffer_length : 0,
(i < shape_rank - 1) ? "%" PRIdim "x" : "%" PRIdim, shape[i]);
if (IREE_UNLIKELY(n < 0)) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
"snprintf failed to write dimension %zu", i);
} else if (buffer && n >= buffer_capacity - buffer_length) {
buffer = NULL;
}
buffer_length += n;
}
if (out_buffer_length) {
*out_buffer_length = buffer_length;
}
return buffer ? iree_ok_status()
: iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
}
IREE_API_EXPORT iree_status_t iree_hal_parse_element_type(
iree_string_view_t value, iree_hal_element_type_t* out_element_type) {
IREE_ASSERT_ARGUMENT(out_element_type);
*out_element_type = IREE_HAL_ELEMENT_TYPE_NONE;
iree_string_view_t str_value = value;
iree_hal_numerical_type_t numerical_type = IREE_HAL_NUMERICAL_TYPE_UNKNOWN;
if (iree_string_view_consume_prefix(&str_value, IREE_SV("i"))) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER;
} else if (iree_string_view_consume_prefix(&str_value, IREE_SV("si"))) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED;
} else if (iree_string_view_consume_prefix(&str_value, IREE_SV("ui"))) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED;
} else if (iree_string_view_consume_prefix(&str_value, IREE_SV("f"))) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE;
} else if (iree_string_view_consume_prefix(&str_value, IREE_SV("bf"))) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_FLOAT_BRAIN;
} else if (iree_string_view_consume_prefix(&str_value, IREE_SV("x")) ||
iree_string_view_consume_prefix(&str_value, IREE_SV("*"))) {
numerical_type = IREE_HAL_NUMERICAL_TYPE_UNKNOWN;
} else {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled element type prefix in '%.*s'",
(int)value.size, value.data);
}
uint32_t bit_count = 0;
if (!iree_string_view_atoi_uint32(str_value, &bit_count) ||
bit_count > 0xFFu) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"out of range bit count in '%.*s'", (int)value.size,
value.data);
}
*out_element_type = iree_hal_make_element_type(numerical_type, bit_count);
return iree_ok_status();
}
IREE_API_EXPORT iree_status_t iree_hal_format_element_type(
iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity,
char* buffer, iree_host_size_t* out_buffer_length) {
if (out_buffer_length) {
*out_buffer_length = 0;
}
const char* prefix;
switch (iree_hal_element_numerical_type(element_type)) {
case IREE_HAL_NUMERICAL_TYPE_INTEGER:
prefix = "i";
break;
case IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED:
prefix = "si";
break;
case IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED:
prefix = "ui";
break;
case IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE:
prefix = "f";
break;
case IREE_HAL_NUMERICAL_TYPE_FLOAT_BRAIN:
prefix = "bf";
break;
default:
prefix = "*";
break;
}
int n = snprintf(buffer, buffer_capacity, "%s%d", prefix,
(int32_t)iree_hal_element_bit_count(element_type));
if (n < 0) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "snprintf failed");
}
if (out_buffer_length) {
*out_buffer_length = n;
}
return n >= buffer_capacity ? iree_status_from_code(IREE_STATUS_OUT_OF_RANGE)
: iree_ok_status();
}
IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type(
iree_string_view_t value, iree_host_size_t shape_capacity,
iree_host_size_t* out_shape_rank, iree_hal_dim_t* out_shape,
iree_hal_element_type_t* out_element_type) {
*out_shape_rank = 0;
*out_element_type = IREE_HAL_ELEMENT_TYPE_NONE;
// Strip whitespace that may come along (linefeeds/etc).
value = iree_string_view_trim(value);
value = iree_string_view_strip_prefix(value, IREE_SV("\""));
value = iree_string_view_strip_suffix(value, IREE_SV("\""));
if (iree_string_view_is_empty(value)) {
// Empty lines are invalid; need at least the shape/type information.
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "empty string input");
}
// The part of the string corresponding to the shape, e.g. 1x2x3.
iree_string_view_t shape_str = iree_string_view_empty();
// The part of the string corresponding to the type, e.g. f32
iree_string_view_t type_str = iree_string_view_empty();
// The part of the string corresponding to the buffer data, e.g. 1 2 3 4 5 6
// We ignore this.
iree_string_view_t data_str = iree_string_view_empty();
iree_string_view_t shape_and_type_str = value;
iree_string_view_split(value, '=', &shape_and_type_str, &data_str);
iree_host_size_t last_x_index = iree_string_view_find_last_of(
shape_and_type_str, IREE_SV("x"), IREE_STRING_VIEW_NPOS);
if (last_x_index == IREE_STRING_VIEW_NPOS) {
// Scalar.
type_str = shape_and_type_str;
} else {
// Has a shape.
shape_str = iree_string_view_substr(shape_and_type_str, 0, last_x_index);
type_str = iree_string_view_substr(shape_and_type_str, last_x_index + 1,
IREE_STRING_VIEW_NPOS);
}
// AxBxC...
IREE_RETURN_IF_ERROR(iree_hal_parse_shape(shape_str, shape_capacity,
out_shape_rank, out_shape));
// f32, i32, etc
IREE_RETURN_IF_ERROR(iree_hal_parse_element_type(type_str, out_element_type));
return iree_ok_status();
}
// Parses a string of two character pairs representing hex numbers into bytes.
static void iree_hal_hex_string_to_bytes(const char* from, uint8_t* to,
ptrdiff_t num) {
/* clang-format off */
static const char kHexValue[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9'
0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F'
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f'
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
};
/* clang-format on */
for (int i = 0; i < num; i++) {
to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) +
(kHexValue[from[i * 2 + 1] & 0xFF]);
}
}
// Parses a signal element string, assuming that the caller has validated that
// |out_data| has enough storage space for the parsed element data.
static iree_status_t iree_hal_parse_element_unsafe(
iree_string_view_t data_str, iree_hal_element_type_t element_type,
uint8_t* out_data) {
switch (element_type) {
case IREE_HAL_ELEMENT_TYPE_INT_8:
case IREE_HAL_ELEMENT_TYPE_SINT_8: {
int32_t temp = 0;
if (!iree_string_view_atoi_int32(data_str, &temp) || temp > INT8_MAX) {
return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
}
*(int8_t*)out_data = (int8_t)temp;
return iree_ok_status();
}
case IREE_HAL_ELEMENT_TYPE_UINT_8: {
uint32_t temp = 0;
if (!iree_string_view_atoi_uint32(data_str, &temp) || temp > UINT8_MAX) {
return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
}
*(uint8_t*)out_data = (uint8_t)temp;
return iree_ok_status();
}
case IREE_HAL_ELEMENT_TYPE_INT_16:
case IREE_HAL_ELEMENT_TYPE_SINT_16: {
int32_t temp = 0;
if (!iree_string_view_atoi_int32(data_str, &temp) || temp > INT16_MAX) {
return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
}
*(int16_t*)out_data = (int16_t)temp;
return iree_ok_status();
}
case IREE_HAL_ELEMENT_TYPE_UINT_16: {
uint32_t temp = 0;
if (!iree_string_view_atoi_uint32(data_str, &temp) || temp > UINT16_MAX) {
return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
}
*(uint16_t*)out_data = (uint16_t)temp;
return iree_ok_status();
}
case IREE_HAL_ELEMENT_TYPE_INT_32:
case IREE_HAL_ELEMENT_TYPE_SINT_32:
return iree_string_view_atoi_int32(data_str, (int32_t*)out_data)
? iree_ok_status()
: iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
case IREE_HAL_ELEMENT_TYPE_UINT_32:
return iree_string_view_atoi_uint32(data_str, (uint32_t*)out_data)
? iree_ok_status()
: iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
case IREE_HAL_ELEMENT_TYPE_INT_64:
case IREE_HAL_ELEMENT_TYPE_SINT_64:
return iree_string_view_atoi_int64(data_str, (int64_t*)out_data)
? iree_ok_status()
: iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
case IREE_HAL_ELEMENT_TYPE_UINT_64:
return iree_string_view_atoi_uint64(data_str, (uint64_t*)out_data)
? iree_ok_status()
: iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
case IREE_HAL_ELEMENT_TYPE_FLOAT_16: {
float temp = 0;
if (!iree_string_view_atof(data_str, &temp)) {
return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
}
*(uint16_t*)out_data = iree_math_f32_to_f16(temp);
return iree_ok_status();
}
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
return iree_string_view_atof(data_str, (float*)out_data)
? iree_ok_status()
: iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
return iree_string_view_atod(data_str, (double*)out_data)
? iree_ok_status()
: iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
default: {
// Treat any unknown format as binary.
iree_host_size_t element_size =
iree_hal_element_dense_byte_count(element_type);
if (data_str.size != element_size * 2) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"binary hex element count mismatch: buffer "
"length=%zu < expected=%zu",
data_str.size, element_size * 2);
}
iree_hal_hex_string_to_bytes(data_str.data, out_data, element_size);
return iree_ok_status();
}
}
}
IREE_API_EXPORT iree_status_t iree_hal_parse_element(
iree_string_view_t data_str, iree_hal_element_type_t element_type,
iree_byte_span_t data_ptr) {
iree_host_size_t element_size =
iree_hal_element_dense_byte_count(element_type);
if (data_ptr.data_length < element_size) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"output data buffer overflow: data_length=%zu < element_size=%zu",
data_ptr.data_length, element_size);
}
return iree_hal_parse_element_unsafe(data_str, element_type, data_ptr.data);
}
// Converts a sequence of bytes into hex number strings.
static void iree_hal_bytes_to_hex_string(const uint8_t* src, char* dest,
ptrdiff_t num) {
static const char kHexTable[513] =
"000102030405060708090A0B0C0D0E0F"
"101112131415161718191A1B1C1D1E1F"
"202122232425262728292A2B2C2D2E2F"
"303132333435363738393A3B3C3D3E3F"
"404142434445464748494A4B4C4D4E4F"
"505152535455565758595A5B5C5D5E5F"
"606162636465666768696A6B6C6D6E6F"
"707172737475767778797A7B7C7D7E7F"
"808182838485868788898A8B8C8D8E8F"
"909192939495969798999A9B9C9D9E9F"
"A0A1A2A3A4A5A6A7A8A9AAABACADAEAF"
"B0B1B2B3B4B5B6B7B8B9BABBBCBDBEBF"
"C0C1C2C3C4C5C6C7C8C9CACBCCCDCECF"
"D0D1D2D3D4D5D6D7D8D9DADBDCDDDEDF"
"E0E1E2E3E4E5E6E7E8E9EAEBECEDEEEF"
"F0F1F2F3F4F5F6F7F8F9FAFBFCFDFEFF";
for (const uint8_t* src_ptr = src; src_ptr != (src + num);
++src_ptr, dest += 2) {
const char* hex_p = &kHexTable[*src_ptr * 2];
memcpy(dest, hex_p, 2);
}
}
IREE_API_EXPORT iree_status_t iree_hal_format_element(
iree_const_byte_span_t data, iree_hal_element_type_t element_type,
iree_host_size_t buffer_capacity, char* buffer,
iree_host_size_t* out_buffer_length) {
iree_host_size_t element_size =
iree_hal_element_dense_byte_count(element_type);
if (data.data_length < element_size) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"data buffer underflow: data_length=%zu < element_size=%zu",
data.data_length, element_size);
}
int n = 0;
switch (element_type) {
case IREE_HAL_ELEMENT_TYPE_INT_8:
case IREE_HAL_ELEMENT_TYPE_SINT_8:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi8,
*(const int8_t*)data.data);
break;
case IREE_HAL_ELEMENT_TYPE_UINT_8:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu8,
*(const uint8_t*)data.data);
break;
case IREE_HAL_ELEMENT_TYPE_INT_16:
case IREE_HAL_ELEMENT_TYPE_SINT_16:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi16,
*(const int16_t*)data.data);
break;
case IREE_HAL_ELEMENT_TYPE_UINT_16:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu16,
*(const uint16_t*)data.data);
break;
case IREE_HAL_ELEMENT_TYPE_INT_32:
case IREE_HAL_ELEMENT_TYPE_SINT_32:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi32,
*(const int32_t*)data.data);
break;
case IREE_HAL_ELEMENT_TYPE_UINT_32:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu32,
*(const uint32_t*)data.data);
break;
case IREE_HAL_ELEMENT_TYPE_INT_64:
case IREE_HAL_ELEMENT_TYPE_SINT_64:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi64,
*(const int64_t*)data.data);
break;
case IREE_HAL_ELEMENT_TYPE_UINT_64:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu64,
*(const uint64_t*)data.data);
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
iree_math_f16_to_f32(*(const uint16_t*)data.data));
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
*(const float*)data.data);
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
*(const double*)data.data);
break;
default: {
// Treat any unknown format as binary.
n = 2 * (int)element_size;
if (buffer && buffer_capacity > n) {
iree_hal_bytes_to_hex_string(data.data, buffer, element_size);
buffer[n] = 0;
}
}
}
if (n < 0) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "snprintf failed");
} else if (buffer && n >= buffer_capacity) {
buffer = NULL;
}
if (out_buffer_length) {
*out_buffer_length = n;
}
return buffer ? iree_ok_status()
: iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
}
IREE_API_EXPORT iree_status_t iree_hal_parse_buffer_elements(
iree_string_view_t data_str, iree_hal_element_type_t element_type,
iree_byte_span_t data_ptr) {
iree_host_size_t element_size =
iree_hal_element_dense_byte_count(element_type);
iree_host_size_t element_capacity = data_ptr.data_length / element_size;
if (iree_string_view_is_empty(data_str)) {
memset(data_ptr.data, 0, data_ptr.data_length);
return iree_ok_status();
}
size_t src_i = 0;
size_t dst_i = 0;
size_t token_start = IREE_STRING_VIEW_NPOS;
while (src_i < data_str.size) {
char c = data_str.data[src_i++];
bool is_separator = isspace(c) || c == ',' || c == '[' || c == ']';
if (token_start == IREE_STRING_VIEW_NPOS) {
if (!is_separator) {
token_start = src_i - 1;
}
continue;
} else if (token_start != IREE_STRING_VIEW_NPOS && !is_separator) {
continue;
}
if (dst_i >= element_capacity) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"output data buffer overflow: element_capacity=%zu < dst_i=%zu+",
element_capacity, dst_i);
}
IREE_RETURN_IF_ERROR(iree_hal_parse_element_unsafe(
iree_make_string_view(data_str.data + token_start,
src_i - 2 - token_start + 1),
element_type, data_ptr.data + dst_i * element_size));
++dst_i;
token_start = IREE_STRING_VIEW_NPOS;
}
if (token_start != IREE_STRING_VIEW_NPOS) {
if (dst_i >= element_capacity) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"output data overflow: element_capacity=%zu < dst_i=%zu",
element_capacity, dst_i);
}
IREE_RETURN_IF_ERROR(iree_hal_parse_element_unsafe(
iree_make_string_view(data_str.data + token_start,
data_str.size - token_start),
element_type, data_ptr.data + dst_i * element_size));
++dst_i;
}
if (dst_i == 1 && element_capacity > 1) {
// Splat the single value we got to the entire buffer.
uint8_t* p = data_ptr.data + element_size;
for (int i = 1; i < element_capacity; ++i, p += element_size) {
memcpy(p, data_ptr.data, element_size);
}
} else if (dst_i < element_capacity) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"input data string underflow: dst_i=%zu < element_capacity=%zu", dst_i,
element_capacity);
}
return iree_ok_status();
}
#define APPEND_CHAR(c) \
{ \
if (buffer) { \
if (buffer_length < buffer_capacity - 1) { \
buffer[buffer_length] = c; \
buffer[buffer_length + 1] = '\0'; \
} else { \
buffer = NULL; \
} \
} \
++buffer_length; \
}
static iree_status_t iree_hal_format_buffer_elements_recursive(
iree_const_byte_span_t data, iree_host_size_t shape_rank,
const iree_hal_dim_t* shape, iree_hal_element_type_t element_type,
iree_host_size_t* max_element_count, iree_host_size_t buffer_capacity,
char* buffer, iree_host_size_t* out_buffer_length) {
iree_host_size_t buffer_length = 0;
if (shape_rank == 0) {
// Scalar value; recurse to get on to the leaf dimension path.
const iree_hal_dim_t one = 1;
return iree_hal_format_buffer_elements_recursive(
data, 1, &one, element_type, max_element_count, buffer_capacity, buffer,
out_buffer_length);
} else if (shape_rank > 1) {
// Nested dimension; recurse into the next innermost dimension.
iree_hal_dim_t dim_length = 1;
for (iree_host_size_t i = 1; i < shape_rank; ++i) {
dim_length *= shape[i];
}
iree_device_size_t dim_stride =
dim_length * iree_hal_element_dense_byte_count(element_type);
if (data.data_length < dim_stride * shape[0]) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"input data underflow: data_length=%zu < expected=%zu",
data.data_length, (iree_host_size_t)(dim_stride * shape[0]));
}
iree_const_byte_span_t subdata;
subdata.data = data.data;
subdata.data_length = dim_stride;
for (iree_hal_dim_t i = 0; i < shape[0]; ++i) {
APPEND_CHAR('[');
iree_host_size_t actual_length = 0;
iree_status_t status = iree_hal_format_buffer_elements_recursive(
subdata, shape_rank - 1, shape + 1, element_type, max_element_count,
buffer ? buffer_capacity - buffer_length : 0,
buffer ? buffer + buffer_length : NULL, &actual_length);
buffer_length += actual_length;
if (iree_status_is_out_of_range(status)) {
buffer = NULL;
} else if (!iree_status_is_ok(status)) {
return status;
}
subdata.data += dim_stride;
APPEND_CHAR(']');
}
} else {
// Leaf dimension; output data.
iree_host_size_t max_count =
iree_min(*max_element_count, (iree_host_size_t)shape[0]);
iree_device_size_t element_stride =
iree_hal_element_dense_byte_count(element_type);
if (data.data_length < max_count * element_stride) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"input data underflow; data_length=%zu < expected=%zu",
data.data_length, (iree_host_size_t)(max_count * element_stride));
}
*max_element_count -= max_count;
iree_const_byte_span_t subdata;
subdata.data = data.data;
subdata.data_length = element_stride;
for (iree_hal_dim_t i = 0; i < max_count; ++i) {
if (i > 0) APPEND_CHAR(' ');
iree_host_size_t actual_length = 0;
iree_status_t status = iree_hal_format_element(
subdata, element_type, buffer ? buffer_capacity - buffer_length : 0,
buffer ? buffer + buffer_length : NULL, &actual_length);
subdata.data += element_stride;
buffer_length += actual_length;
if (iree_status_is_out_of_range(status)) {
buffer = NULL;
} else if (!iree_status_is_ok(status)) {
return status;
}
}
if (max_count < shape[0]) {
APPEND_CHAR('.');
APPEND_CHAR('.');
APPEND_CHAR('.');
}
}
if (out_buffer_length) {
*out_buffer_length = buffer_length;
}
return buffer ? iree_ok_status()
: iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
}
IREE_API_EXPORT iree_status_t iree_hal_format_buffer_elements(
iree_const_byte_span_t data, iree_host_size_t shape_rank,
const iree_hal_dim_t* shape, iree_hal_element_type_t element_type,
iree_host_size_t max_element_count, iree_host_size_t buffer_capacity,
char* buffer, iree_host_size_t* out_buffer_length) {
if (out_buffer_length) {
*out_buffer_length = 0;
}
if (buffer && buffer_capacity) {
buffer[0] = '\0';
}
return iree_hal_format_buffer_elements_recursive(
data, shape_rank, shape, element_type, &max_element_count,
buffer_capacity, buffer, out_buffer_length);
}