| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "third_party/mlir_edge/iree/hal/buffer_view_string_util.h" |
| |
| #include <functional> |
| #include <sstream> |
| #include <type_traits> |
| |
| #include "third_party/absl/strings/ascii.h" |
| #include "third_party/absl/strings/escaping.h" |
| #include "third_party/absl/strings/numbers.h" |
| #include "third_party/absl/strings/str_join.h" |
| #include "third_party/absl/strings/str_split.h" |
| #include "third_party/absl/strings/strip.h" |
| #include "third_party/absl/types/optional.h" |
| #include "third_party/absl/types/source_location.h" |
| #include "third_party/mlir_edge/iree/base/status.h" |
| #include "third_party/mlir_edge/iree/hal/heap_buffer.h" |
| |
| namespace iree { |
| namespace hal { |
| |
| namespace { |
| |
| // Returns true if the given type is represented as binary hex data. |
| bool IsBinaryType(absl::string_view type_str) { |
| return !type_str.empty() && absl::ascii_isdigit(type_str[0]); |
| } |
| |
| // Parses binary hex data. |
| Status ParseBinaryData(absl::string_view data_str, Buffer* buffer) { |
| data_str = absl::StripAsciiWhitespace(data_str); |
| ASSIGN_OR_RETURN(auto mapping, |
| buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite)); |
| auto contents = mapping.mutable_contents(); |
| size_t dst_i = 0; |
| size_t src_i = 0; |
| while (src_i < data_str.size() && dst_i < contents.size()) { |
| char c = data_str[src_i]; |
| if (absl::ascii_isspace(c) || c == ',') { |
| ++src_i; |
| continue; |
| } |
| if (src_i + 1 >= data_str.size()) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Invalid input hex data (offset=" << src_i << ")"; |
| } |
| strings::a2b_hex(data_str.data() + src_i, contents.data() + dst_i, 1); |
| src_i += 2; |
| ++dst_i; |
| } |
| if (dst_i < contents.size()) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Too few elements to fill type; expected " << contents.size() |
| << " but only read " << dst_i; |
| } else if (data_str.size() - src_i > 0) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Input data string contains more elements than the underlying " |
| "buffer (" |
| << contents.size() << ")"; |
| } |
| return OkStatus(); |
| } |
| |
| // Prints binary hex data. |
| Status PrintBinaryData(int element_size, Buffer* buffer, size_t max_entries, |
| std::ostream* stream) { |
| max_entries *= element_size; // Counting bytes, but treat them as elements. |
| ASSIGN_OR_RETURN(auto mapping, |
| buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); |
| auto contents = mapping.contents(); |
| char hex_buffer[8 * 2]; |
| for (size_t i = 0; i < std::min(max_entries, mapping.size()); |
| i += element_size) { |
| if (i > 0) *stream << " "; |
| strings::b2a_hex(contents.data() + i, hex_buffer, element_size); |
| *stream << hex_buffer; |
| } |
| if (mapping.size() > max_entries) *stream << "..."; |
| return OkStatus(); |
| } |
| |
| template <typename ElementType, typename Enabled = void> |
| struct SimpleStrToValue { |
| absl::optional<ElementType> operator()(absl::string_view text) const = delete; |
| }; |
| |
| template <typename IntegerType> |
| struct SimpleStrToValue< |
| IntegerType, |
| typename std::enable_if<(sizeof(IntegerType) < 4), void>::type> { |
| absl::optional<IntegerType> operator()(absl::string_view text) const { |
| int32_t value; |
| return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} |
| : absl::nullopt; |
| } |
| }; |
| |
| template <typename IntegerType> |
| struct SimpleStrToValue< |
| IntegerType, |
| typename std::enable_if<(sizeof(IntegerType) >= 4), void>::type> { |
| absl::optional<IntegerType> operator()(absl::string_view text) const { |
| IntegerType value; |
| return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} |
| : absl::nullopt; |
| } |
| }; |
| |
| template <> |
| struct SimpleStrToValue<float, void> { |
| absl::optional<float> operator()(absl::string_view text) const { |
| float value; |
| return absl::SimpleAtof(text, &value) ? absl::optional<float>{value} |
| : absl::nullopt; |
| } |
| }; |
| |
| template <> |
| struct SimpleStrToValue<double, void> { |
| absl::optional<double> operator()(absl::string_view text) const { |
| double value; |
| return absl::SimpleAtod(text, &value) ? absl::optional<double>{value} |
| : absl::nullopt; |
| } |
| }; |
| |
| template <typename T> |
| Status ParseNumericalDataElement(absl::string_view data_str, size_t token_start, |
| size_t token_end, absl::Span<T> contents, |
| int dst_i) { |
| if (dst_i >= contents.size()) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Input data string contains more elements than the underlying " |
| "buffer (" |
| << contents.size() << ")"; |
| } |
| auto element_str = data_str.substr(token_start, token_end - token_start + 1); |
| auto element = SimpleStrToValue<T>()(element_str); |
| if (!element.has_value()) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Unable to parse element " << dst_i << " = '" << element_str |
| << "'"; |
| } |
| contents[dst_i] = element.value(); |
| return OkStatus(); |
| } |
| |
| template <typename T> |
| Status ParseNumericalDataAsType(absl::string_view data_str, Buffer* buffer) { |
| ASSIGN_OR_RETURN(auto mapping, |
| buffer->MapMemory<T>(MemoryAccess::kDiscardWrite)); |
| auto contents = mapping.mutable_contents(); |
| size_t src_i = 0; |
| size_t dst_i = 0; |
| size_t token_start = std::string::npos; |
| while (src_i < data_str.size()) { |
| char c = data_str[src_i++]; |
| bool is_separator = |
| absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']'; |
| if (token_start == std::string::npos) { |
| if (!is_separator) { |
| token_start = src_i - 1; |
| } |
| continue; |
| } else if (token_start != std::string::npos && !is_separator) { |
| continue; |
| } |
| RETURN_IF_ERROR(ParseNumericalDataElement<T>(data_str, token_start, |
| src_i - 2, contents, dst_i++)); |
| token_start = std::string::npos; |
| } |
| if (token_start != std::string::npos) { |
| RETURN_IF_ERROR(ParseNumericalDataElement<T>( |
| data_str, token_start, data_str.size() - 1, contents, dst_i++)); |
| } |
| if (dst_i < contents.size()) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Input data string contains fewer elements than the underlying " |
| "buffer (expected " |
| << contents.size() << ")"; |
| } |
| return OkStatus(); |
| } |
| |
| // Parses numerical data (ints, floats, etc) in some typed form. |
| Status ParseNumericalData(absl::string_view type_str, |
| absl::string_view data_str, Buffer* buffer) { |
| if (type_str == "i8") { |
| return ParseNumericalDataAsType<int8_t>(data_str, buffer); |
| } else if (type_str == "u8") { |
| return ParseNumericalDataAsType<uint8_t>(data_str, buffer); |
| } else if (type_str == "i16") { |
| return ParseNumericalDataAsType<int16_t>(data_str, buffer); |
| } else if (type_str == "u16") { |
| return ParseNumericalDataAsType<uint16_t>(data_str, buffer); |
| } else if (type_str == "i32") { |
| return ParseNumericalDataAsType<int32_t>(data_str, buffer); |
| } else if (type_str == "u32") { |
| return ParseNumericalDataAsType<uint32_t>(data_str, buffer); |
| } else if (type_str == "i64") { |
| return ParseNumericalDataAsType<int64_t>(data_str, buffer); |
| } else if (type_str == "u64") { |
| return ParseNumericalDataAsType<uint64_t>(data_str, buffer); |
| } else if (type_str == "f32") { |
| return ParseNumericalDataAsType<float>(data_str, buffer); |
| } else if (type_str == "f64") { |
| return ParseNumericalDataAsType<double>(data_str, buffer); |
| } else { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Unsupported type: " << type_str; |
| } |
| } |
| |
| template <typename T> |
| void PrintElementList(const Shape& shape, absl::Span<const T> data, |
| size_t* max_entries, std::ostream* stream) { |
| if (shape.empty()) { |
| // Scalar value. |
| PrintElementList({1}, data, max_entries, stream); |
| return; |
| } else if (shape.size() == 1) { |
| // Leaf dimension; output data. |
| size_t max_count = std::min(*max_entries, static_cast<size_t>(shape[0])); |
| *stream << absl::StrJoin(data.subspan(0, max_count), " "); |
| if (max_count < shape[0]) { |
| *stream << "..."; |
| } |
| *max_entries -= max_count; |
| } else { |
| // Nested; recurse into next dimension. |
| Shape nested_shape = Shape(shape.subspan(1)); |
| size_t length = nested_shape.element_count(); |
| size_t offset = 0; |
| for (int i = 0; i < shape[0]; ++i) { |
| *stream << "["; |
| PrintElementList<T>(nested_shape, data.subspan(offset, length), |
| max_entries, stream); |
| offset += length; |
| *stream << "]"; |
| } |
| } |
| } |
| |
| template <typename T> |
| Status PrintNumericalDataAsType(const Shape& shape, Buffer* buffer, |
| size_t max_entries, std::ostream* stream) { |
| ASSIGN_OR_RETURN(auto mapping, buffer->MapMemory<T>(MemoryAccess::kRead)); |
| PrintElementList(shape, mapping.contents(), &max_entries, stream); |
| return OkStatus(); |
| } |
| |
| // Prints numerical data (ints, floats, etc) from some typed form. |
| Status PrintNumericalData(const Shape& shape, absl::string_view type_str, |
| Buffer* buffer, size_t max_entries, |
| std::ostream* stream) { |
| if (type_str == "i8") { |
| return PrintNumericalDataAsType<int8_t>(shape, buffer, max_entries, stream); |
| } else if (type_str == "u8") { |
| return PrintNumericalDataAsType<uint8_t>(shape, buffer, max_entries, |
| stream); |
| } else if (type_str == "i16") { |
| return PrintNumericalDataAsType<int16_t>(shape, buffer, max_entries, |
| stream); |
| } else if (type_str == "u16") { |
| return PrintNumericalDataAsType<uint16_t>(shape, buffer, max_entries, |
| stream); |
| } else if (type_str == "i32") { |
| return PrintNumericalDataAsType<int32_t>(shape, buffer, max_entries, |
| stream); |
| } else if (type_str == "u32") { |
| return PrintNumericalDataAsType<uint32_t>(shape, buffer, max_entries, |
| stream); |
| } else if (type_str == "i64") { |
| return PrintNumericalDataAsType<int64_t>(shape, buffer, max_entries, |
| stream); |
| } else if (type_str == "u64") { |
| return PrintNumericalDataAsType<uint64_t>(shape, buffer, max_entries, |
| stream); |
| } else if (type_str == "f32") { |
| return PrintNumericalDataAsType<float>(shape, buffer, max_entries, stream); |
| } else if (type_str == "f64") { |
| return PrintNumericalDataAsType<double>(shape, buffer, max_entries, stream); |
| } else { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Unsupported type: " << type_str; |
| } |
| } |
| |
| } // namespace |
| |
| StatusOr<int> GetTypeElementSize(absl::string_view type_str) { |
| if (type_str.empty()) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) << "Type is empty"; |
| } else if (IsBinaryType(type_str)) { |
| // If the first character is a digit then we are dealign with binary data. |
| // The type is just the number of bytes per element. |
| int element_size = 0; |
| if (!absl::SimpleAtoi(type_str, &element_size)) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Unable to parse element size type '" << type_str << "'"; |
| } |
| return element_size; |
| } |
| // We know that our types are single characters followed by bit counts. |
| // If we start to support other types we may need to do something more clever. |
| int bit_count = 0; |
| if (!absl::SimpleAtoi(type_str.substr(1), &bit_count)) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Unable to parse type bit count from '" << type_str |
| << "'; expecting something like 'i32'"; |
| } |
| return bit_count / 8; |
| } |
| |
| StatusOr<Shape> ParseShape(absl::string_view shape_str) { |
| std::vector<int> dims; |
| for (auto dim_str : absl::StrSplit(shape_str, 'x', absl::SkipWhitespace())) { |
| int dim_value = 0; |
| if (!absl::SimpleAtoi(dim_str, &dim_value)) { |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Invalid shape dimension '" << dim_str |
| << "' while parsing shape '" << shape_str << "'"; |
| } |
| dims.push_back(dim_value); |
| } |
| return Shape{dims}; |
| } |
| |
| StatusOr<BufferView> ParseBufferViewFromString( |
| absl::string_view buffer_view_str, hal::Allocator* allocator) { |
| // Strip whitespace that may come along (linefeeds/etc). |
| buffer_view_str = absl::StripAsciiWhitespace(buffer_view_str); |
| if (buffer_view_str.empty()) { |
| // Empty lines denote empty buffer_views. |
| return BufferView{}; |
| } |
| |
| // Split into the components we can work with: shape, type, and data. |
| absl::string_view shape_and_type_str; |
| absl::string_view data_str; |
| auto equal_index = buffer_view_str.find('='); |
| if (equal_index == std::string::npos) { |
| // Treat a lack of = as defaulting the data to zeros. |
| shape_and_type_str = buffer_view_str; |
| } else { |
| shape_and_type_str = buffer_view_str.substr(0, equal_index); |
| data_str = buffer_view_str.substr(equal_index + 1); |
| } |
| absl::string_view shape_str; |
| absl::string_view type_str; |
| auto last_x_index = shape_and_type_str.rfind('x'); |
| if (last_x_index == std::string::npos) { |
| // Scalar. |
| type_str = shape_and_type_str; |
| } else { |
| // Has a shape. |
| shape_str = shape_and_type_str.substr(0, last_x_index); |
| type_str = shape_and_type_str.substr(last_x_index + 1); |
| } |
| |
| // Populate BufferView metadata required for allocation. |
| BufferView result; |
| ASSIGN_OR_RETURN(result.element_size, GetTypeElementSize(type_str)); |
| ASSIGN_OR_RETURN(result.shape, ParseShape(shape_str)); |
| |
| // Allocate the host buffer. |
| size_t allocation_size = result.shape.element_count() * result.element_size; |
| if (allocator) { |
| ASSIGN_OR_RETURN( |
| result.buffer, |
| allocator->Allocate(MemoryType::kHostLocal | MemoryType::kDeviceVisible, |
| BufferUsage::kAll | BufferUsage::kConstant, |
| allocation_size)); |
| } else { |
| result.buffer = HeapBuffer::Allocate( |
| MemoryType::kHostLocal, BufferUsage::kAll | BufferUsage::kConstant, |
| allocation_size); |
| } |
| |
| if (!data_str.empty()) { |
| // Parse the data from the string right into the buffer. |
| if (IsBinaryType(type_str)) { |
| // Parse as binary hex. |
| RETURN_IF_ERROR(ParseBinaryData(data_str, result.buffer.get())); |
| } else { |
| // Parse as some nicely formatted type. |
| RETURN_IF_ERROR( |
| ParseNumericalData(type_str, data_str, result.buffer.get())); |
| } |
| } |
| |
| return result; |
| } |
| |
| StatusOr<BufferViewPrintMode> ParseBufferViewPrintMode(absl::string_view str) { |
| char str_char = str.empty() ? '?' : str[0]; |
| switch (str_char) { |
| case 'b': |
| return BufferViewPrintMode::kBinary; |
| case 'i': |
| return BufferViewPrintMode::kSignedInteger; |
| case 'u': |
| return BufferViewPrintMode::kUnsignedInteger; |
| case 'f': |
| return BufferViewPrintMode::kFloatingPoint; |
| default: |
| return InvalidArgumentErrorBuilder(ABSL_LOC) |
| << "Unsupported output type '" << str << "'"; |
| } |
| } |
| |
| StatusOr<std::string> PrintBufferViewToString(const BufferView& buffer_view, |
| BufferViewPrintMode print_mode, |
| size_t max_entries) { |
| std::string result; |
| RETURN_IF_ERROR( |
| PrintBufferViewToString(buffer_view, print_mode, max_entries, &result)); |
| return result; |
| } |
| |
| Status PrintBufferViewToString(const BufferView& buffer_view, |
| BufferViewPrintMode print_mode, |
| size_t max_entries, std::string* out_result) { |
| std::ostringstream stream; |
| RETURN_IF_ERROR( |
| PrintBufferViewToStream(buffer_view, print_mode, max_entries, &stream)); |
| *out_result = stream.str(); |
| return OkStatus(); |
| } |
| |
| Status PrintBufferViewToStream(const BufferView& buffer_view, |
| BufferViewPrintMode print_mode, |
| size_t max_entries, std::ostream* stream) { |
| if (!buffer_view.buffer) { |
| // No buffer means the buffer_view is empty. We use the empty string to |
| // denote this (as we have no useful information). |
| return OkStatus(); |
| } |
| |
| // Pick a type based on the element size and the printing mode. |
| std::string type_str; |
| switch (print_mode) { |
| case BufferViewPrintMode::kBinary: |
| type_str = std::to_string(buffer_view.element_size); |
| break; |
| case BufferViewPrintMode::kSignedInteger: |
| absl::StrAppend(&type_str, "i", buffer_view.element_size * 8); |
| break; |
| case BufferViewPrintMode::kUnsignedInteger: |
| absl::StrAppend(&type_str, "u", buffer_view.element_size * 8); |
| break; |
| case BufferViewPrintMode::kFloatingPoint: |
| absl::StrAppend(&type_str, "f", buffer_view.element_size * 8); |
| break; |
| } |
| |
| // [shape]x[type]= prefix (taking into account scalar values). |
| *stream << absl::StrJoin(buffer_view.shape.begin(), buffer_view.shape.end(), |
| "x"); |
| if (!buffer_view.shape.empty()) *stream << "x"; |
| *stream << type_str; |
| *stream << "="; |
| |
| if (IsBinaryType(type_str)) { |
| return PrintBinaryData(buffer_view.element_size, buffer_view.buffer.get(), |
| max_entries, stream); |
| } else { |
| return PrintNumericalData(buffer_view.shape, type_str, |
| buffer_view.buffer.get(), max_entries, stream); |
| } |
| } |
| |
| } // namespace hal |
| } // namespace iree |