| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "iree/base/buffer_string_util.h" |
| |
| #include <functional> |
| #include <sstream> |
| #include <string> |
| #include <type_traits> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/strings/ascii.h" |
| #include "absl/strings/numbers.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/strings/str_split.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/strings/strip.h" |
| #include "absl/types/optional.h" |
| #include "iree/base/memory.h" |
| #include "iree/base/source_location.h" |
| #include "iree/base/status.h" |
| |
| namespace iree { |
| |
| namespace { |
| |
| /* clang-format off */ |
| constexpr char kHexValue[256] = { |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9' |
| 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F' |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f' |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 |
| }; |
| /* clang-format on */ |
| |
| template <typename T> |
| void HexStringToBytes(const char* from, T to, ptrdiff_t num) { |
| for (int i = 0; i < num; i++) { |
| to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) + |
| (kHexValue[from[i * 2 + 1] & 0xFF]); |
| } |
| } |
| |
| constexpr char kHexTable[513] = |
| "000102030405060708090a0b0c0d0e0f" |
| "101112131415161718191a1b1c1d1e1f" |
| "202122232425262728292a2b2c2d2e2f" |
| "303132333435363738393a3b3c3d3e3f" |
| "404142434445464748494a4b4c4d4e4f" |
| "505152535455565758595a5b5c5d5e5f" |
| "606162636465666768696a6b6c6d6e6f" |
| "707172737475767778797a7b7c7d7e7f" |
| "808182838485868788898a8b8c8d8e8f" |
| "909192939495969798999a9b9c9d9e9f" |
| "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf" |
| "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf" |
| "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" |
| "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" |
| "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" |
| "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"; |
| |
| // Like the absl method, but works in-place. |
| template <typename T> |
| void BytesToHexString(const unsigned char* src, T dest, ptrdiff_t num) { |
| auto dest_ptr = &dest[0]; |
| for (auto src_ptr = src; src_ptr != (src + num); ++src_ptr, dest_ptr += 2) { |
| const char* hex_p = &kHexTable[*src_ptr * 2]; |
| std::copy(hex_p, hex_p + 2, dest_ptr); |
| } |
| } |
| |
| // Returns true if the given type is represented as binary hex data. |
| bool IsBinaryType(absl::string_view type_str) { |
| return !type_str.empty() && absl::ascii_isdigit(type_str[0]); |
| } |
| |
| // Parses binary hex data. |
| Status ParseBinaryData(absl::string_view data_str, absl::Span<uint8_t> output) { |
| data_str = absl::StripAsciiWhitespace(data_str); |
| size_t dst_i = 0; |
| size_t src_i = 0; |
| while (src_i < data_str.size() && dst_i < output.size()) { |
| char c = data_str[src_i]; |
| if (absl::ascii_isspace(c) || c == ',') { |
| ++src_i; |
| continue; |
| } |
| if (src_i + 1 >= data_str.size()) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Invalid input hex data (offset=" << src_i << ")"; |
| } |
| HexStringToBytes(data_str.data() + src_i, output.data() + dst_i, 1); |
| src_i += 2; |
| ++dst_i; |
| } |
| if (dst_i < output.size()) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Too few elements to fill type; expected " << output.size() |
| << " but only read " << dst_i; |
| } else if (data_str.size() - src_i > 0) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Input data string contains more elements than the underlying " |
| "buffer (" |
| << output.size() << "): " << data_str; |
| } |
| return OkStatus(); |
| } |
| |
| template <typename ElementType, typename Enabled = void> |
| struct SimpleStrToValue { |
| absl::optional<ElementType> operator()(absl::string_view text) const = delete; |
| }; |
| |
| template <typename IntegerType> |
| struct SimpleStrToValue< |
| IntegerType, |
| typename std::enable_if<(sizeof(IntegerType) < 4), void>::type> { |
| absl::optional<IntegerType> operator()(absl::string_view text) const { |
| int32_t value; |
| return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} |
| : absl::nullopt; |
| } |
| }; |
| |
| template <typename IntegerType> |
| struct SimpleStrToValue< |
| IntegerType, |
| typename std::enable_if<(sizeof(IntegerType) >= 4), void>::type> { |
| absl::optional<IntegerType> operator()(absl::string_view text) const { |
| IntegerType value; |
| return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} |
| : absl::nullopt; |
| } |
| }; |
| |
| template <> |
| struct SimpleStrToValue<float, void> { |
| absl::optional<float> operator()(absl::string_view text) const { |
| float value; |
| return absl::SimpleAtof(text, &value) ? absl::optional<float>{value} |
| : absl::nullopt; |
| } |
| }; |
| |
| template <> |
| struct SimpleStrToValue<double, void> { |
| absl::optional<double> operator()(absl::string_view text) const { |
| double value; |
| return absl::SimpleAtod(text, &value) ? absl::optional<double>{value} |
| : absl::nullopt; |
| } |
| }; |
| |
| template <typename T> |
| Status ParseNumericalDataElement(absl::string_view data_str, size_t token_start, |
| size_t token_end, absl::Span<T> output, |
| int dst_i) { |
| if (dst_i >= output.size()) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Input data string contains more elements than the underlying " |
| "buffer (" |
| << output.size() << "): " << data_str; |
| } |
| auto element_str = data_str.substr(token_start, token_end - token_start + 1); |
| auto element = SimpleStrToValue<T>()(element_str); |
| if (!element.has_value()) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Unable to parse element " << dst_i << " = '" << element_str |
| << "'"; |
| } |
| output[dst_i] = element.value(); |
| return OkStatus(); |
| } |
| |
| template <typename T> |
| Status ParseNumericalDataAsType(absl::string_view data_str, |
| absl::Span<uint8_t> output) { |
| auto cast_output = ReinterpretSpan<T>(output); |
| size_t src_i = 0; |
| size_t dst_i = 0; |
| size_t token_start = std::string::npos; |
| while (src_i < data_str.size()) { |
| char c = data_str[src_i++]; |
| bool is_separator = |
| absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']'; |
| if (token_start == std::string::npos) { |
| if (!is_separator) { |
| token_start = src_i - 1; |
| } |
| continue; |
| } else if (token_start != std::string::npos && !is_separator) { |
| continue; |
| } |
| RETURN_IF_ERROR(ParseNumericalDataElement<T>( |
| data_str, token_start, src_i - 2, cast_output, dst_i++)); |
| token_start = std::string::npos; |
| } |
| if (token_start != std::string::npos) { |
| RETURN_IF_ERROR(ParseNumericalDataElement<T>( |
| data_str, token_start, data_str.size() - 1, cast_output, dst_i++)); |
| } |
| if (dst_i == 1 && cast_output.size() > 1) { |
| // Splat the single value we got to the entire tensor. |
| for (int i = 1; i < cast_output.size(); ++i) { |
| cast_output[i] = cast_output[0]; |
| } |
| } else if (dst_i < cast_output.size()) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Input data string contains fewer elements than the underlying " |
| "buffer (expected " |
| << cast_output.size() << ") " << data_str; |
| } |
| return OkStatus(); |
| } |
| |
| // Parses numerical data (ints, floats, etc) in some typed form. |
| Status ParseNumericalData(absl::string_view type_str, |
| absl::string_view data_str, |
| absl::Span<uint8_t> output) { |
| if (type_str == "i8") { |
| return ParseNumericalDataAsType<int8_t>(data_str, output); |
| } else if (type_str == "u8") { |
| return ParseNumericalDataAsType<uint8_t>(data_str, output); |
| } else if (type_str == "i16") { |
| return ParseNumericalDataAsType<int16_t>(data_str, output); |
| } else if (type_str == "u16") { |
| return ParseNumericalDataAsType<uint16_t>(data_str, output); |
| } else if (type_str == "i32") { |
| return ParseNumericalDataAsType<int32_t>(data_str, output); |
| } else if (type_str == "u32") { |
| return ParseNumericalDataAsType<uint32_t>(data_str, output); |
| } else if (type_str == "i64") { |
| return ParseNumericalDataAsType<int64_t>(data_str, output); |
| } else if (type_str == "u64") { |
| return ParseNumericalDataAsType<uint64_t>(data_str, output); |
| } else if (type_str == "f32") { |
| return ParseNumericalDataAsType<float>(data_str, output); |
| } else if (type_str == "f64") { |
| return ParseNumericalDataAsType<double>(data_str, output); |
| } else { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Unsupported type: " << type_str; |
| } |
| } |
| |
| template <typename T> |
| void PrintElementList(const Shape& shape, absl::Span<const T> data, |
| size_t* max_entries, std::ostream* stream) { |
| if (shape.empty()) { |
| // Scalar value. |
| PrintElementList({1}, data, max_entries, stream); |
| return; |
| } else if (shape.size() == 1) { |
| // Leaf dimension; output data. |
| size_t max_count = std::min(*max_entries, static_cast<size_t>(shape[0])); |
| *stream << absl::StrJoin(data.subspan(0, max_count), " "); |
| if (max_count < shape[0]) { |
| *stream << "..."; |
| } |
| *max_entries -= max_count; |
| } else { |
| // Nested; recurse into next dimension. |
| Shape nested_shape = Shape(shape.subspan(1)); |
| size_t length = nested_shape.element_count(); |
| size_t offset = 0; |
| for (int i = 0; i < shape[0]; ++i) { |
| *stream << "["; |
| PrintElementList<T>(nested_shape, data.subspan(offset, length), |
| max_entries, stream); |
| offset += length; |
| *stream << "]"; |
| } |
| } |
| } |
| |
| template <typename T> |
| Status PrintNumericalDataToStreamAsType(const Shape& shape, |
| absl::Span<const uint8_t> contents, |
| size_t max_entries, |
| std::ostream* stream) { |
| auto cast_contents = ReinterpretSpan<T>(contents); |
| PrintElementList(shape, cast_contents, &max_entries, stream); |
| return OkStatus(); |
| } |
| |
| } // namespace |
| |
| StatusOr<BufferDataPrintMode> ParseBufferDataPrintMode(absl::string_view str) { |
| if (str.empty()) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Cannot get print mode of empty string"; |
| } |
| switch (str[0]) { |
| case 'b': |
| return BufferDataPrintMode::kBinary; |
| case 'i': |
| return BufferDataPrintMode::kSignedInteger; |
| case 'u': |
| return BufferDataPrintMode::kUnsignedInteger; |
| case 'f': |
| return BufferDataPrintMode::kFloatingPoint; |
| default: |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Unsupported output type '" << str << "'"; |
| } |
| } |
| |
| StatusOr<int> ParseBufferTypeElementSize(absl::string_view type_str) { |
| static const auto* const type_sizes = |
| new absl::flat_hash_map<absl::string_view, int>{ |
| {"1", 1}, {"2", 2}, {"4", 4}, {"8", 8}, {"i8", 1}, |
| {"u8", 1}, {"i16", 2}, {"u16", 2}, {"i32", 4}, {"u32", 4}, |
| {"i64", 8}, {"u64", 8}, {"f32", 4}, {"f64", 8}, |
| }; |
| auto type_size = type_sizes->find(type_str); |
| if (type_size != type_sizes->end()) { |
| return type_size->second; |
| } else { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Unsupported type '" << type_str << "'"; |
| } |
| } |
| |
| StatusOr<std::string> MakeBufferTypeString(int element_size, |
| BufferDataPrintMode print_mode) { |
| if (element_size <= 0) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Invalid element size '" << element_size << "'"; |
| } |
| switch (print_mode) { |
| case BufferDataPrintMode::kBinary: |
| return absl::StrCat(element_size); |
| case BufferDataPrintMode::kSignedInteger: |
| return absl::StrCat("i", element_size * 8); |
| case BufferDataPrintMode::kUnsignedInteger: |
| return absl::StrCat("u", element_size * 8); |
| case BufferDataPrintMode::kFloatingPoint: |
| return absl::StrCat("f", element_size * 8); |
| default: |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Unsupported print mode (" << static_cast<int>(print_mode) |
| << ")"; |
| } |
| } |
| |
| StatusOr<Shape> ParseShape(absl::string_view shape_str) { |
| if (shape_str.empty()) { |
| return Shape{}; |
| } |
| std::vector<int> dims; |
| for (auto dim_str : absl::StrSplit(shape_str, 'x')) { |
| int dim_value = 0; |
| if (!absl::SimpleAtoi(dim_str, &dim_value)) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Invalid shape dimension '" << dim_str |
| << "' while parsing shape '" << shape_str << "'"; |
| } |
| if (dim_value < 0) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Unsupported shape dimension '" << dim_str << "'"; |
| } |
| dims.push_back(dim_value); |
| } |
| if (dims.size() > kMaxRank) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Shape rank '" << dims.size() << "' exceeds maximum rank '" |
| << kMaxRank << "'"; |
| } |
| return Shape{dims}; |
| } |
| |
| std::string PrintShapedTypeToString(const Shape& shape, |
| absl::string_view type_str) { |
| std::string result; |
| PrintShapedTypeToString(shape, type_str, &result); |
| return result; |
| } |
| |
| void PrintShapedTypeToString(const Shape& shape, absl::string_view type_str, |
| std::string* out_result) { |
| std::ostringstream stream; |
| PrintShapedTypeToStream(shape, type_str, &stream); |
| *out_result = stream.str(); |
| } |
| |
| void PrintShapedTypeToStream(const Shape& shape, absl::string_view type_str, |
| std::ostream* stream) { |
| *stream << absl::StrJoin(shape.begin(), shape.end(), "x"); |
| if (!shape.empty()) *stream << "x"; |
| *stream << type_str; |
| } |
| |
| // Prints binary hex data. |
| StatusOr<std::string> PrintBinaryDataToString( |
| int element_size, absl::Span<const uint8_t> contents, size_t max_entries) { |
| std::string result; |
| RETURN_IF_ERROR( |
| PrintBinaryDataToString(element_size, contents, max_entries, &result)); |
| return result; |
| } |
| |
| Status PrintBinaryDataToString(int element_size, |
| absl::Span<const uint8_t> contents, |
| size_t max_entries, std::string* out_result) { |
| std::ostringstream stream; |
| RETURN_IF_ERROR( |
| PrintBinaryDataToStream(element_size, contents, max_entries, &stream)); |
| *out_result = stream.str(); |
| return OkStatus(); |
| } |
| |
| Status PrintBinaryDataToStream(int element_size, |
| absl::Span<const uint8_t> contents, |
| size_t max_entries, std::ostream* stream) { |
| // TODO(gcmn) Can we avoid this fiddly byte counting? |
| max_entries *= element_size; // Counting bytes, but treat them as elements. |
| constexpr size_t hex_chars_per_byte = 2; |
| constexpr size_t max_bytes = sizeof(int64_t); |
| if (element_size != 1 && element_size != 2 && element_size != 4 && |
| element_size != 8) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Invalid element size '" << element_size |
| << "'; only '1', '2', '4' and '8' are supported"; |
| } |
| if (contents.size() % element_size != 0) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Element size '" << element_size |
| << "' doesn't divide contents size " << contents.size(); |
| } |
| // Plus one char for the null terminator. |
| char hex_buffer[hex_chars_per_byte * max_bytes + 1] = {0}; |
| for (size_t i = 0; i < std::min(max_entries, contents.size()); |
| i += element_size) { |
| if (i > 0) *stream << " "; |
| BytesToHexString(contents.data() + i, hex_buffer, element_size); |
| *stream << hex_buffer; |
| } |
| if (contents.size() > max_entries) *stream << "..."; |
| return OkStatus(); |
| } |
| |
| StatusOr<std::string> PrintNumericalDataToString( |
| const Shape& shape, absl::string_view type_str, |
| absl::Span<const uint8_t> contents, size_t max_entries) { |
| std::string result; |
| RETURN_IF_ERROR(PrintNumericalDataToString(shape, type_str, contents, |
| max_entries, &result)); |
| return result; |
| } |
| |
| Status PrintNumericalDataToString(const Shape& shape, |
| absl::string_view type_str, |
| absl::Span<const uint8_t> contents, |
| size_t max_entries, std::string* out_result) { |
| std::ostringstream stream; |
| RETURN_IF_ERROR(PrintNumericalDataToStream(shape, type_str, contents, |
| max_entries, &stream)); |
| *out_result = stream.str(); |
| return OkStatus(); |
| } |
| |
| // Prints numerical data (ints, floats, etc) from some typed form. |
| Status PrintNumericalDataToStream(const Shape& shape, |
| absl::string_view type_str, |
| absl::Span<const uint8_t> contents, |
| size_t max_entries, std::ostream* stream) { |
| if (type_str == "i8") { |
| return PrintNumericalDataToStreamAsType<int8_t>(shape, contents, |
| max_entries, stream); |
| } else if (type_str == "u8") { |
| return PrintNumericalDataToStreamAsType<uint8_t>(shape, contents, |
| max_entries, stream); |
| } else if (type_str == "i16") { |
| return PrintNumericalDataToStreamAsType<int16_t>(shape, contents, |
| max_entries, stream); |
| } else if (type_str == "u16") { |
| return PrintNumericalDataToStreamAsType<uint16_t>(shape, contents, |
| max_entries, stream); |
| } else if (type_str == "i32") { |
| return PrintNumericalDataToStreamAsType<int32_t>(shape, contents, |
| max_entries, stream); |
| } else if (type_str == "u32") { |
| return PrintNumericalDataToStreamAsType<uint32_t>(shape, contents, |
| max_entries, stream); |
| } else if (type_str == "i64") { |
| return PrintNumericalDataToStreamAsType<int64_t>(shape, contents, |
| max_entries, stream); |
| } else if (type_str == "u64") { |
| return PrintNumericalDataToStreamAsType<uint64_t>(shape, contents, |
| max_entries, stream); |
| } else if (type_str == "f32") { |
| return PrintNumericalDataToStreamAsType<float>(shape, contents, max_entries, |
| stream); |
| } else if (type_str == "f64") { |
| return PrintNumericalDataToStreamAsType<double>(shape, contents, |
| max_entries, stream); |
| } else { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Unsupported type: " << type_str; |
| } |
| } |
| |
| Status ParseBufferDataAsType(absl::string_view data_str, |
| absl::string_view type_str, |
| absl::Span<uint8_t> output) { |
| // Parse the data from the string right into the buffer. |
| if (IsBinaryType(type_str)) { |
| // Parse as binary hex. |
| return ParseBinaryData(data_str, output); |
| } |
| // Parse as some nicely formatted type. |
| return ParseNumericalData(type_str, data_str, output); |
| } |
| |
| // static |
| BufferStringParts BufferStringParts::ExtractFrom( |
| absl::string_view shaped_buf_str) { |
| BufferStringParts parts; |
| absl::string_view shape_and_type_str; |
| auto equal_index = shaped_buf_str.find('='); |
| if (equal_index == std::string::npos) { |
| // Treat a lack of = as defaulting the data to zeros. |
| shape_and_type_str = shaped_buf_str; |
| } else { |
| shape_and_type_str = shaped_buf_str.substr(0, equal_index); |
| parts.data_str = shaped_buf_str.substr(equal_index + 1); |
| } |
| auto last_x_index = shape_and_type_str.rfind('x'); |
| if (last_x_index == std::string::npos) { |
| // Scalar. |
| parts.type_str = shape_and_type_str; |
| } else { |
| // Has a shape. |
| parts.shape_str = shape_and_type_str.substr(0, last_x_index); |
| parts.type_str = shape_and_type_str.substr(last_x_index + 1); |
| } |
| return parts; |
| } |
| |
| } // namespace iree |