Adding support for parsing/printing bfloat16 values in tools. (#14869)
diff --git a/runtime/src/iree/hal/string_util.c b/runtime/src/iree/hal/string_util.c
index c3420da..11cd2ce 100644
--- a/runtime/src/iree/hal/string_util.c
+++ b/runtime/src/iree/hal/string_util.c
@@ -366,6 +366,14 @@
return iree_string_view_atoi_uint64(data_str, (uint64_t*)out_data)
? iree_ok_status()
: iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
+ case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: {
+ float temp = 0;
+ if (!iree_string_view_atof(data_str, &temp)) {
+ return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
+ }
+ *(uint16_t*)out_data = iree_math_f32_to_bf16(temp);
+ return iree_ok_status();
+ }
case IREE_HAL_ELEMENT_TYPE_FLOAT_16: {
float temp = 0;
if (!iree_string_view_atof(data_str, &temp)) {
@@ -489,6 +497,10 @@
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu64,
*(const uint64_t*)data.data);
break;
+ case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+ n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
+ iree_math_bf16_to_f32(*(const uint16_t*)data.data));
+ break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
n = snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
iree_math_f16_to_f32(*(const uint16_t*)data.data));
diff --git a/runtime/src/iree/hal/string_util_test.cc b/runtime/src/iree/hal/string_util_test.cc
index 94a2fa1..2d134fd 100644
--- a/runtime/src/iree/hal/string_util_test.cc
+++ b/runtime/src/iree/hal/string_util_test.cc
@@ -31,7 +31,7 @@
// Parses a serialized set of shape dimensions using the canonical shape format
// (the same as produced by FormatShape).
-StatusOr<Shape> ParseShape(const std::string& value) {
+static StatusOr<Shape> ParseShape(const std::string& value) {
Shape shape(6);
iree_host_size_t actual_rank = 0;
iree_status_t status = iree_ok_status();
@@ -46,7 +46,8 @@
}
// Converts shape dimensions into a `4x5x6` format.
-StatusOr<std::string> FormatShape(iree::span<const iree_hal_dim_t> value) {
+static StatusOr<std::string> FormatShape(
+ iree::span<const iree_hal_dim_t> value) {
std::string buffer(16, '\0');
iree_host_size_t actual_length = 0;
iree_status_t status = iree_ok_status();
@@ -62,7 +63,8 @@
// Parses a serialized iree_hal_element_type_t. The format is the same as
// produced by FormatElementType.
-StatusOr<iree_hal_element_type_t> ParseElementType(const std::string& value) {
+static StatusOr<iree_hal_element_type_t> ParseElementType(
+ const std::string& value) {
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
iree_status_t status = iree_hal_parse_element_type(
iree_string_view_t{value.data(), (iree_host_size_t)value.size()},
@@ -74,7 +76,7 @@
// Converts an iree_hal_element_type_t enum value to a canonical string
// representation, like `IREE_HAL_ELEMENT_TYPE_FLOAT_16` to `f16`.
-StatusOr<std::string> FormatElementType(iree_hal_element_type_t value) {
+static StatusOr<std::string> FormatElementType(iree_hal_element_type_t value) {
std::string buffer(16, '\0');
iree_host_size_t actual_length = 0;
iree_status_t status = iree_ok_status();
@@ -100,7 +102,8 @@
}
// Parses a serialized set of shape dimensions and an element type.
-StatusOr<ShapeAndType> ParseShapeAndElementType(const std::string& value) {
+static StatusOr<ShapeAndType> ParseShapeAndElementType(
+ const std::string& value) {
Shape shape(6);
iree_host_size_t actual_rank = 0;
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
@@ -120,9 +123,9 @@
// For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4
// byte float value of 1.2 to |buffer|.
template <typename T>
-Status ParseElement(const std::string& value,
- iree_hal_element_type_t element_type,
- iree::span<T> buffer) {
+static Status ParseElement(const std::string& value,
+ iree_hal_element_type_t element_type,
+ iree::span<T> buffer) {
return iree_hal_parse_element(
iree_string_view_t{value.data(), (iree_host_size_t)value.size()},
element_type,
@@ -133,8 +136,8 @@
// Converts a single element of |element_type| to a string.
template <typename T>
-StatusOr<std::string> FormatElement(T value,
- iree_hal_element_type_t element_type) {
+static StatusOr<std::string> FormatElement(
+ T value, iree_hal_element_type_t element_type) {
std::string result(16, '\0');
iree_status_t status;
do {
@@ -155,9 +158,9 @@
// produced by FormatBufferElements. Supports additional inputs of
// empty to denote a 0 fill and a single element to denote a splat.
template <typename T>
-Status ParseBufferElements(const std::string& value,
- iree_hal_element_type_t element_type,
- iree::span<T> buffer) {
+static Status ParseBufferElements(const std::string& value,
+ iree_hal_element_type_t element_type,
+ iree::span<T> buffer) {
IREE_RETURN_IF_ERROR(
iree_hal_parse_buffer_elements(
iree_string_view_t{value.data(), (iree_host_size_t)value.size()},
@@ -177,10 +180,9 @@
// |max_element_count| can be used to limit the total number of elements printed
// when the count may be large. Elided elements will be replaced with `...`.
template <typename T>
-StatusOr<std::string> FormatBufferElements(iree::span<const T> data,
- const Shape& shape,
- iree_hal_element_type_t element_type,
- size_t max_element_count) {
+static StatusOr<std::string> FormatBufferElements(
+ iree::span<const T> data, const Shape& shape,
+ iree_hal_element_type_t element_type, size_t max_element_count) {
std::string result(255, '\0');
iree_status_t status;
do {
@@ -263,12 +265,21 @@
iree::span<T>(&result, 1)));
return result;
}
+inline StatusOr<uint16_t> ParseElementBF16(const std::string& value) {
+ uint16_t result = uint16_t();
+ IREE_RETURN_IF_ERROR(ParseElement(value, IREE_HAL_ELEMENT_TYPE_BFLOAT_16,
+ iree::span<uint16_t>(&result, 1)));
+ return result;
+}
// Converts a single element of to a string value.
template <typename T>
inline StatusOr<std::string> FormatElement(T value) {
return FormatElement(value, ElementTypeFromCType<T>::value);
}
+inline StatusOr<std::string> FormatElementBF16(uint16_t value) {
+ return FormatElement(value, IREE_HAL_ELEMENT_TYPE_BFLOAT_16);
+}
// Parses a serialized set of elements of type T.
// The resulting parsed data is written to |buffer|, which must be at least
@@ -695,6 +706,7 @@
IsOkAndHolds(Eq(INT64_MAX)));
EXPECT_THAT(ParseElement<uint64_t>("18446744073709551615"),
IsOkAndHolds(Eq(UINT64_MAX)));
+ EXPECT_THAT(ParseElementBF16("1.5"), IsOkAndHolds(Eq(0x3FC0u)));
EXPECT_THAT(ParseElement<float>("1.5"), IsOkAndHolds(Eq(1.5f)));
EXPECT_THAT(ParseElement<double>("1.567890123456789"),
IsOkAndHolds(Eq(1.567890123456789)));
@@ -760,6 +772,8 @@
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseElement<uint32_t>("asdfasdf"),
StatusIs(StatusCode::kInvalidArgument));
+ EXPECT_THAT(ParseElementBF16("asdfasdf"),
+ StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseElement<float>("asdfasdf"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseElement<double>("asdfasdf"),
@@ -832,6 +846,7 @@
IsOkAndHolds(Eq("9223372036854775807")));
EXPECT_THAT(FormatElement<uint64_t>(UINT64_MAX),
IsOkAndHolds(Eq("18446744073709551615")));
+ EXPECT_THAT(FormatElementBF16(0x3FC0u), IsOkAndHolds(Eq("1.5")));
EXPECT_THAT(FormatElement<float>(1.5f), IsOkAndHolds(Eq("1.5")));
EXPECT_THAT(FormatElement<double>(1123.56789456789),
IsOkAndHolds(Eq("1123.57")));