Use half lib to support printing/parsing of f16
PiperOrigin-RevId: 343917287
diff --git a/BUILD.bazel b/BUILD.bazel
index 38e40b4..6b3ee5c 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -24,6 +24,7 @@
# "@com_github_pytorch_cpuinfo//"
# "@com_github_google_flatbuffers//"
# "@com_github_dvidelabs_flatcc//"
+# "@half//"
# "@com_google_googletest//"
# "@llvm-project//"
# "@pffft//"
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 468ab6e..c8f1dbf 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -73,7 +73,7 @@
"@pffft": ["pffft"],
"@sdl2//:SDL2": ["SDL2-static"],
"@com_github_pytorch_cpuinfo//:cpuinfo": ["cpuinfo"],
- "@half//:half": ["half"],
+ "@half": ["half"],
}
diff --git a/build_tools/third_party/half/BUILD.overlay b/build_tools/third_party/half/BUILD.overlay
index 33671b7..b27851b 100644
--- a/build_tools/third_party/half/BUILD.overlay
+++ b/build_tools/third_party/half/BUILD.overlay
@@ -16,7 +16,7 @@
package(default_visibility = ["//visibility:public"])
cc_library(
- name = "half",
+ name = "includes",
hdrs = ["half.hpp"],
include_prefix = "third_party/half",
)
diff --git a/build_tools/third_party/half/CMakeLists.txt b/build_tools/third_party/half/CMakeLists.txt
index 84f9ade..4ea1374 100644
--- a/build_tools/third_party/half/CMakeLists.txt
+++ b/build_tools/third_party/half/CMakeLists.txt
@@ -18,7 +18,7 @@
PACKAGE
half
NAME
- half
+ includes
ROOT
${HALF_API_ROOT}
HDRS
diff --git a/iree/hal/BUILD b/iree/hal/BUILD
index 154a8df..39c8c8f 100644
--- a/iree/hal/BUILD
+++ b/iree/hal/BUILD
@@ -60,6 +60,7 @@
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
+ "@half//:includes",
],
)
diff --git a/iree/hal/api.cc b/iree/hal/api.cc
index d87362b..7ec77af 100644
--- a/iree/hal/api.cc
+++ b/iree/hal/api.cc
@@ -39,6 +39,7 @@
#include "iree/hal/heap_buffer.h"
#include "iree/hal/host/host_local_allocator.h"
#include "iree/hal/semaphore.h"
+#include "third_party/half/half.hpp"
namespace iree {
namespace hal {
@@ -294,9 +295,16 @@
reinterpret_cast<uint64_t*>(out_data))
? iree_ok_status()
: iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
- case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "float16 parsing not implemented");
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_16: {
+ float temp = 0;
+ if (!absl::SimpleAtof(absl::string_view(data_str.data, data_str.size),
+ &temp)) {
+ return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
+ }
+ *reinterpret_cast<uint16_t*>(out_data) =
+ half_float::detail::float2half<std::round_to_nearest>(temp);
+ return iree_ok_status();
+ }
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
return absl::SimpleAtof(absl::string_view(data_str.data, data_str.size),
reinterpret_cast<float*>(out_data))
@@ -407,8 +415,10 @@
*reinterpret_cast<const uint64_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "parser for float16 not yet implemented");
+ n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
+ half_float::detail::half2float<float>(
+ *reinterpret_cast<const uint16_t*>(data.data)));
+ break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
*reinterpret_cast<const float*>(data.data));
diff --git a/iree/hal/api_string_util_test.cc b/iree/hal/api_string_util_test.cc
index 44a5711..a148af1 100644
--- a/iree/hal/api_string_util_test.cc
+++ b/iree/hal/api_string_util_test.cc
@@ -554,6 +554,8 @@
IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_UINT_16)));
EXPECT_THAT(ParseElementType("f32"),
IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_32)));
+ EXPECT_THAT(ParseElementType("f16"),
+ IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_16)));
EXPECT_THAT(ParseElementType("x64"),
IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_OPAQUE_64)));
EXPECT_THAT(ParseElementType("*64"),
@@ -1000,6 +1002,7 @@
expect_round_trip("4xi16=0 -1 2 3");
expect_round_trip("4xu16=0 1 2 3");
expect_round_trip("2x2xi32=[0 1][2 3]");
+ expect_round_trip("4xf16=0 0.5 2 3");
expect_round_trip("4xf32=0 1.1 2 3");
expect_round_trip("4xf64=0 1.1 2 3");
expect_round_trip("1x2x3xi8=[[0 1 2][3 4 5]]");