Cleaning up unused dawn HAL backend and legacy shape/shaped buffer print/parsing
diff --git a/.bazelignore b/.bazelignore index 0cdaee2..743cde4 100644 --- a/.bazelignore +++ b/.bazelignore
@@ -22,10 +22,6 @@ third_party build_tools/bazel/third_party_import/llvm-project/overlay -# TODO(scotttodd): enable when Dawn HAL implementation is functional -iree/hal/dawn -iree/tools/web - # TODO: enable this when Java bindings are wired up. bindings/java bindings/javatests
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py index 112bacc..34e09a8 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
@@ -43,8 +43,6 @@ self.FLATBUFFER_SUPPORTS_REFLECTIONS = False self.PLATFORM_VULKAN_LOADER_COPTS = [] self.IREE_DRIVER_MODULES = [ - # TODO(b/142004903): enable when Dawn HAL implementation is functional - # "//iree/hal/dawn:dawn_driver_module", "//iree/hal/vmla:vmla_driver_module", "//iree/hal/vulkan:vulkan_driver_module", "//iree/hal/llvmjit:llvmjit_driver_module",
diff --git a/iree/base/BUILD b/iree/base/BUILD index 375fab2..b8c81ed 100644 --- a/iree/base/BUILD +++ b/iree/base/BUILD
@@ -112,35 +112,6 @@ ) cc_library( - name = "buffer_string_util", - srcs = ["buffer_string_util.cc"], - hdrs = ["buffer_string_util.h"], - deps = [ - ":memory", - ":shape", - ":source_location", - ":status", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "buffer_string_util_test", - srcs = ["buffer_string_util_test.cc"], - deps = [ - ":buffer_string_util", - ":memory", - ":status", - ":status_matchers", - "//iree/testing:gtest_main", - "@com_google_absl//absl/strings", - ], -) - -cc_library( name = "dynamic_library", srcs = [ "dynamic_library_posix.cc", @@ -398,76 +369,6 @@ ) cc_library( - name = "shape", - srcs = ["shape.cc"], - hdrs = ["shape.h"], - deps = [ - ":logging", - ":source_location", - ":status", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "shape_test", - srcs = ["shape_test.cc"], - deps = [ - ":shape", - ":status", - ":status_matchers", - "//iree/testing:gtest_main", - ], -) - -cc_library( - name = "shaped_buffer", - hdrs = ["shaped_buffer.h"], - deps = [ - ":logging", - ":memory", - ":shape", - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "shaped_buffer_string_util", - srcs = ["shaped_buffer_string_util.cc"], - hdrs = ["shaped_buffer_string_util.h"], - deps = [ - ":buffer_string_util", - ":memory", - ":shape", - ":shaped_buffer", - ":source_location", - ":status", - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "shaped_buffer_string_util_test", - srcs = ["shaped_buffer_string_util_test.cc"], - deps = [ - ":buffer_string_util", - ":memory", - ":shaped_buffer", - ":shaped_buffer_string_util", - ":status", - ":status_matchers", - "//iree/testing:gtest_main", - "@com_google_absl//absl/strings", - ], -) - -cc_library( name = "signature_mangle", srcs = ["signature_mangle.cc"], hdrs = ["signature_mangle.h"],
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt index 64e0b5e..3a6ca54 100644 --- a/iree/base/CMakeLists.txt +++ b/iree/base/CMakeLists.txt
@@ -120,39 +120,6 @@ iree::testing::gtest_main ) -iree_cc_library( - NAME - buffer_string_util - HDRS - "buffer_string_util.h" - SRCS - "buffer_string_util.cc" - DEPS - ::memory - ::shape - ::source_location - ::status - absl::flat_hash_map - absl::optional - absl::span - absl::strings - PUBLIC -) - -iree_cc_test( - NAME - buffer_string_util_test - SRCS - "buffer_string_util_test.cc" - DEPS - ::buffer_string_util - ::memory - ::status - ::status_matchers - absl::strings - iree::testing::gtest_main -) - iree_select_compiler_opts(_DYNAMIC_LIBRARY_LINKOPTS CLANG_OR_GCC "-ldl" @@ -470,86 +437,6 @@ iree_cc_library( NAME - shape - HDRS - "shape.h" - SRCS - "shape.cc" - DEPS - ::logging - ::source_location - ::status - absl::span - absl::strings - absl::type_traits - PUBLIC -) - -iree_cc_test( - NAME - shape_test - SRCS - "shape_test.cc" - DEPS - ::shape - ::status - ::status_matchers - iree::testing::gtest_main -) - -iree_cc_library( - NAME - shaped_buffer - HDRS - "shaped_buffer.h" - DEPS - ::logging - ::memory - ::shape - absl::fixed_array - absl::span - PUBLIC -) - -iree_cc_library( - NAME - shaped_buffer_string_util - HDRS - "shaped_buffer_string_util.h" - SRCS - "shaped_buffer_string_util.cc" - DEPS - ::buffer_string_util - ::memory - ::shape - ::shaped_buffer - ::source_location - ::status - absl::fixed_array - absl::optional - absl::span - absl::strings - PUBLIC -) - -iree_cc_test( - NAME - shaped_buffer_string_util_test - SRCS - "shaped_buffer_string_util_test.cc" - DEPS - ::buffer_string_util - ::memory - ::shaped_buffer - ::shaped_buffer_string_util - ::status - ::status_matchers - absl::strings - iree::testing::gtest_main -) - -iree_cc_library( - NAME signature_mangle HDRS "signature_mangle.h"
diff --git a/iree/base/buffer_string_util.cc b/iree/base/buffer_string_util.cc deleted file mode 100644 index 5317349..0000000 --- a/iree/base/buffer_string_util.cc +++ /dev/null
@@ -1,563 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/buffer_string_util.h" - -#include <functional> -#include <sstream> -#include <string> -#include <type_traits> - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/ascii.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include "absl/types/optional.h" -#include "iree/base/memory.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" - -namespace iree { - -namespace { - -/* clang-format off */ -constexpr char kHexValue[256] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9' - 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F' - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f' - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 -}; -/* clang-format on */ - -template <typename T> -void HexStringToBytes(const char* from, T to, ptrdiff_t num) { - for (int i = 0; i < num; i++) { - to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) + - (kHexValue[from[i * 2 + 1] & 0xFF]); - } -} - -constexpr char kHexTable[513] = - "000102030405060708090a0b0c0d0e0f" - "101112131415161718191a1b1c1d1e1f" - "202122232425262728292a2b2c2d2e2f" - "303132333435363738393a3b3c3d3e3f" - "404142434445464748494a4b4c4d4e4f" - "505152535455565758595a5b5c5d5e5f" - "606162636465666768696a6b6c6d6e6f" - "707172737475767778797a7b7c7d7e7f" - "808182838485868788898a8b8c8d8e8f" - "909192939495969798999a9b9c9d9e9f" - "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf" - "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf" - "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" - "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" - "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" - "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"; - -// Like the absl method, but works in-place. -template <typename T> -void BytesToHexString(const unsigned char* src, T dest, ptrdiff_t num) { - auto dest_ptr = &dest[0]; - for (auto src_ptr = src; src_ptr != (src + num); ++src_ptr, dest_ptr += 2) { - const char* hex_p = &kHexTable[*src_ptr * 2]; - std::copy(hex_p, hex_p + 2, dest_ptr); - } -} - -// Returns true if the given type is represented as binary hex data. -bool IsBinaryType(absl::string_view type_str) { - return !type_str.empty() && absl::ascii_isdigit(type_str[0]); -} - -// Parses binary hex data. -Status ParseBinaryData(absl::string_view data_str, absl::Span<uint8_t> output) { - data_str = absl::StripAsciiWhitespace(data_str); - size_t dst_i = 0; - size_t src_i = 0; - while (src_i < data_str.size() && dst_i < output.size()) { - char c = data_str[src_i]; - if (absl::ascii_isspace(c) || c == ',') { - ++src_i; - continue; - } - if (src_i + 1 >= data_str.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid input hex data (offset=" << src_i << ")"; - } - HexStringToBytes(data_str.data() + src_i, output.data() + dst_i, 1); - src_i += 2; - ++dst_i; - } - if (dst_i < output.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Too few elements to fill type; expected " << output.size() - << " but only read " << dst_i; - } else if (data_str.size() - src_i > 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Input data string contains more elements than the underlying " - "buffer (" - << output.size() << "): " << data_str; - } - return OkStatus(); -} - -template <typename ElementType, typename Enabled = void> -struct SimpleStrToValue { - absl::optional<ElementType> operator()(absl::string_view text) const = delete; -}; - -template <typename IntegerType> -struct SimpleStrToValue< - IntegerType, - typename std::enable_if<(sizeof(IntegerType) < 4), void>::type> { - absl::optional<IntegerType> operator()(absl::string_view text) const { - int32_t value; - return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} - : absl::nullopt; - } -}; - -template <typename IntegerType> -struct SimpleStrToValue< - IntegerType, - typename std::enable_if<(sizeof(IntegerType) >= 4), void>::type> { - absl::optional<IntegerType> operator()(absl::string_view text) const { - IntegerType value; - return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} - : absl::nullopt; - } -}; - -template <> -struct SimpleStrToValue<float, void> { - absl::optional<float> operator()(absl::string_view text) const { - float value; - return absl::SimpleAtof(text, &value) ? absl::optional<float>{value} - : absl::nullopt; - } -}; - -template <> -struct SimpleStrToValue<double, void> { - absl::optional<double> operator()(absl::string_view text) const { - double value; - return absl::SimpleAtod(text, &value) ? absl::optional<double>{value} - : absl::nullopt; - } -}; - -template <typename T> -Status ParseNumericalDataElement(absl::string_view data_str, size_t token_start, - size_t token_end, absl::Span<T> output, - int dst_i) { - if (dst_i >= output.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Input data string contains more elements than the underlying " - "buffer (" - << output.size() << "): " << data_str; - } - auto element_str = data_str.substr(token_start, token_end - token_start + 1); - auto element = SimpleStrToValue<T>()(element_str); - if (!element.has_value()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unable to parse element " << dst_i << " = '" << element_str - << "'"; - } - output[dst_i] = element.value(); - return OkStatus(); -} - -template <typename T> -Status ParseNumericalDataAsType(absl::string_view data_str, - absl::Span<uint8_t> output) { - auto cast_output = ReinterpretSpan<T>(output); - size_t src_i = 0; - size_t dst_i = 0; - size_t token_start = std::string::npos; - while (src_i < data_str.size()) { - char c = data_str[src_i++]; - bool is_separator = - absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']'; - if (token_start == std::string::npos) { - if (!is_separator) { - token_start = src_i - 1; - } - continue; - } else if (token_start != std::string::npos && !is_separator) { - continue; - } - RETURN_IF_ERROR(ParseNumericalDataElement<T>( - data_str, token_start, src_i - 2, cast_output, dst_i++)); - token_start = std::string::npos; - } - if (token_start != std::string::npos) { - RETURN_IF_ERROR(ParseNumericalDataElement<T>( - data_str, token_start, data_str.size() - 1, cast_output, dst_i++)); - } - if (dst_i == 1 && cast_output.size() > 1) { - // Splat the single value we got to the entire tensor. - for (int i = 1; i < cast_output.size(); ++i) { - cast_output[i] = cast_output[0]; - } - } else if (dst_i < cast_output.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Input data string contains fewer elements than the underlying " - "buffer (expected " - << cast_output.size() << ") " << data_str; - } - return OkStatus(); -} - -// Parses numerical data (ints, floats, etc) in some typed form. -Status ParseNumericalData(absl::string_view type_str, - absl::string_view data_str, - absl::Span<uint8_t> output) { - if (type_str == "i8") { - return ParseNumericalDataAsType<int8_t>(data_str, output); - } else if (type_str == "u8") { - return ParseNumericalDataAsType<uint8_t>(data_str, output); - } else if (type_str == "i16") { - return ParseNumericalDataAsType<int16_t>(data_str, output); - } else if (type_str == "u16") { - return ParseNumericalDataAsType<uint16_t>(data_str, output); - } else if (type_str == "i32") { - return ParseNumericalDataAsType<int32_t>(data_str, output); - } else if (type_str == "u32") { - return ParseNumericalDataAsType<uint32_t>(data_str, output); - } else if (type_str == "i64") { - return ParseNumericalDataAsType<int64_t>(data_str, output); - } else if (type_str == "u64") { - return ParseNumericalDataAsType<uint64_t>(data_str, output); - } else if (type_str == "f32") { - return ParseNumericalDataAsType<float>(data_str, output); - } else if (type_str == "f64") { - return ParseNumericalDataAsType<double>(data_str, output); - } else { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported type: " << type_str; - } -} - -template <typename T> -void PrintElementList(const Shape& shape, absl::Span<const T> data, - size_t* max_entries, std::ostream* stream) { - if (shape.empty()) { - // Scalar value. - PrintElementList({1}, data, max_entries, stream); - return; - } else if (shape.size() == 1) { - // Leaf dimension; output data. - size_t max_count = std::min(*max_entries, static_cast<size_t>(shape[0])); - *stream << absl::StrJoin(data.subspan(0, max_count), " "); - if (max_count < shape[0]) { - *stream << "..."; - } - *max_entries -= max_count; - } else { - // Nested; recurse into next dimension. - Shape nested_shape = Shape(shape.subspan(1)); - size_t length = nested_shape.element_count(); - size_t offset = 0; - for (int i = 0; i < shape[0]; ++i) { - *stream << "["; - PrintElementList<T>(nested_shape, data.subspan(offset, length), - max_entries, stream); - offset += length; - *stream << "]"; - } - } -} - -template <typename T> -Status PrintNumericalDataToStreamAsType(const Shape& shape, - absl::Span<const uint8_t> contents, - size_t max_entries, - std::ostream* stream) { - auto cast_contents = ReinterpretSpan<T>(contents); - PrintElementList(shape, cast_contents, &max_entries, stream); - return OkStatus(); -} - -} // namespace - -StatusOr<BufferDataPrintMode> ParseBufferDataPrintMode(absl::string_view str) { - if (str.empty()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Cannot get print mode of empty string"; - } - switch (str[0]) { - case 'b': - return BufferDataPrintMode::kBinary; - case 'i': - return BufferDataPrintMode::kSignedInteger; - case 'u': - return BufferDataPrintMode::kUnsignedInteger; - case 'f': - return BufferDataPrintMode::kFloatingPoint; - default: - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported output type '" << str << "'"; - } -} - -StatusOr<int> ParseBufferTypeElementSize(absl::string_view type_str) { - static const auto* const type_sizes = - new absl::flat_hash_map<absl::string_view, int>{ - {"1", 1}, {"2", 2}, {"4", 4}, {"8", 8}, {"i8", 1}, - {"u8", 1}, {"i16", 2}, {"u16", 2}, {"i32", 4}, {"u32", 4}, - {"i64", 8}, {"u64", 8}, {"f32", 4}, {"f64", 8}, - }; - auto type_size = type_sizes->find(type_str); - if (type_size != type_sizes->end()) { - return type_size->second; - } else { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported type '" << type_str << "'"; - } -} - -StatusOr<std::string> MakeBufferTypeString(int element_size, - BufferDataPrintMode print_mode) { - if (element_size <= 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid element size '" << element_size << "'"; - } - switch (print_mode) { - case BufferDataPrintMode::kBinary: - return absl::StrCat(element_size); - case BufferDataPrintMode::kSignedInteger: - return absl::StrCat("i", element_size * 8); - case BufferDataPrintMode::kUnsignedInteger: - return absl::StrCat("u", element_size * 8); - case BufferDataPrintMode::kFloatingPoint: - return absl::StrCat("f", element_size * 8); - default: - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported print mode (" << static_cast<int>(print_mode) - << ")"; - } -} - -StatusOr<Shape> ParseShape(absl::string_view shape_str) { - if (shape_str.empty()) { - return Shape{}; - } - std::vector<int> dims; - for (auto dim_str : absl::StrSplit(shape_str, 'x')) { - int dim_value = 0; - if (!absl::SimpleAtoi(dim_str, &dim_value)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid shape dimension '" << dim_str - << "' while parsing shape '" << shape_str << "'"; - } - if (dim_value < 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported shape dimension '" << dim_str << "'"; - } - dims.push_back(dim_value); - } - if (dims.size() > kMaxRank) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Shape rank '" << dims.size() << "' exceeds maximum rank '" - << kMaxRank << "'"; - } - return Shape{dims}; -} - -std::string PrintShapedTypeToString(const Shape& shape, - absl::string_view type_str) { - std::string result; - PrintShapedTypeToString(shape, type_str, &result); - return result; -} - -void PrintShapedTypeToString(const Shape& shape, absl::string_view type_str, - std::string* out_result) { - std::ostringstream stream; - PrintShapedTypeToStream(shape, type_str, &stream); - *out_result = stream.str(); -} - -void PrintShapedTypeToStream(const Shape& shape, absl::string_view type_str, - std::ostream* stream) { - *stream << absl::StrJoin(shape.begin(), shape.end(), "x"); - if (!shape.empty()) *stream << "x"; - *stream << type_str; -} - -// Prints binary hex data. -StatusOr<std::string> PrintBinaryDataToString( - int element_size, absl::Span<const uint8_t> contents, size_t max_entries) { - std::string result; - RETURN_IF_ERROR( - PrintBinaryDataToString(element_size, contents, max_entries, &result)); - return result; -} - -Status PrintBinaryDataToString(int element_size, - absl::Span<const uint8_t> contents, - size_t max_entries, std::string* out_result) { - std::ostringstream stream; - RETURN_IF_ERROR( - PrintBinaryDataToStream(element_size, contents, max_entries, &stream)); - *out_result = stream.str(); - return OkStatus(); -} - -Status PrintBinaryDataToStream(int element_size, - absl::Span<const uint8_t> contents, - size_t max_entries, std::ostream* stream) { - // TODO(gcmn) Can we avoid this fiddly byte counting? - max_entries *= element_size; // Counting bytes, but treat them as elements. - constexpr size_t hex_chars_per_byte = 2; - constexpr size_t max_bytes = sizeof(int64_t); - if (element_size != 1 && element_size != 2 && element_size != 4 && - element_size != 8) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid element size '" << element_size - << "'; only '1', '2', '4' and '8' are supported"; - } - if (contents.size() % element_size != 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Element size '" << element_size - << "' doesn't divide contents size " << contents.size(); - } - // Plus one char for the null terminator. - char hex_buffer[hex_chars_per_byte * max_bytes + 1] = {0}; - for (size_t i = 0; i < std::min(max_entries, contents.size()); - i += element_size) { - if (i > 0) *stream << " "; - BytesToHexString(contents.data() + i, hex_buffer, element_size); - *stream << hex_buffer; - } - if (contents.size() > max_entries) *stream << "..."; - return OkStatus(); -} - -StatusOr<std::string> PrintNumericalDataToString( - const Shape& shape, absl::string_view type_str, - absl::Span<const uint8_t> contents, size_t max_entries) { - std::string result; - RETURN_IF_ERROR(PrintNumericalDataToString(shape, type_str, contents, - max_entries, &result)); - return result; -} - -Status PrintNumericalDataToString(const Shape& shape, - absl::string_view type_str, - absl::Span<const uint8_t> contents, - size_t max_entries, std::string* out_result) { - std::ostringstream stream; - RETURN_IF_ERROR(PrintNumericalDataToStream(shape, type_str, contents, - max_entries, &stream)); - *out_result = stream.str(); - return OkStatus(); -} - -// Prints numerical data (ints, floats, etc) from some typed form. -Status PrintNumericalDataToStream(const Shape& shape, - absl::string_view type_str, - absl::Span<const uint8_t> contents, - size_t max_entries, std::ostream* stream) { - if (type_str == "i8") { - return PrintNumericalDataToStreamAsType<int8_t>(shape, contents, - max_entries, stream); - } else if (type_str == "u8") { - return PrintNumericalDataToStreamAsType<uint8_t>(shape, contents, - max_entries, stream); - } else if (type_str == "i16") { - return PrintNumericalDataToStreamAsType<int16_t>(shape, contents, - max_entries, stream); - } else if (type_str == "u16") { - return PrintNumericalDataToStreamAsType<uint16_t>(shape, contents, - max_entries, stream); - } else if (type_str == "i32") { - return PrintNumericalDataToStreamAsType<int32_t>(shape, contents, - max_entries, stream); - } else if (type_str == "u32") { - return PrintNumericalDataToStreamAsType<uint32_t>(shape, contents, - max_entries, stream); - } else if (type_str == "i64") { - return PrintNumericalDataToStreamAsType<int64_t>(shape, contents, - max_entries, stream); - } else if (type_str == "u64") { - return PrintNumericalDataToStreamAsType<uint64_t>(shape, contents, - max_entries, stream); - } else if (type_str == "f32") { - return PrintNumericalDataToStreamAsType<float>(shape, contents, max_entries, - stream); - } else if (type_str == "f64") { - return PrintNumericalDataToStreamAsType<double>(shape, contents, - max_entries, stream); - } else { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported type: " << type_str; - } -} - -Status ParseBufferDataAsType(absl::string_view data_str, - absl::string_view type_str, - absl::Span<uint8_t> output) { - // Parse the data from the string right into the buffer. - if (IsBinaryType(type_str)) { - // Parse as binary hex. - return ParseBinaryData(data_str, output); - } - // Parse as some nicely formatted type. - return ParseNumericalData(type_str, data_str, output); -} - -// static -BufferStringParts BufferStringParts::ExtractFrom( - absl::string_view shaped_buf_str) { - BufferStringParts parts; - absl::string_view shape_and_type_str; - auto equal_index = shaped_buf_str.find('='); - if (equal_index == std::string::npos) { - // Treat a lack of = as defaulting the data to zeros. - shape_and_type_str = shaped_buf_str; - } else { - shape_and_type_str = shaped_buf_str.substr(0, equal_index); - parts.data_str = shaped_buf_str.substr(equal_index + 1); - } - auto last_x_index = shape_and_type_str.rfind('x'); - if (last_x_index == std::string::npos) { - // Scalar. - parts.type_str = shape_and_type_str; - } else { - // Has a shape. - parts.shape_str = shape_and_type_str.substr(0, last_x_index); - parts.type_str = shape_and_type_str.substr(last_x_index + 1); - } - return parts; -} - -} // namespace iree
diff --git a/iree/base/buffer_string_util.h b/iree/base/buffer_string_util.h deleted file mode 100644 index 62add6e..0000000 --- a/iree/base/buffer_string_util.h +++ /dev/null
@@ -1,141 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utilities for working with strings defining buffers and associated shapes, -// mostly useful for testing. -// -// The canonical shaped buffer string format is: -// [shape]x[type]=value,value,... -// For example: -// 2x2xi32=0,1,2,3 -// Characters like [] are optional and will be ignored during parsing: -// 2x2xi32=[[0 1][2 3]] -// -// The type may be one of the following: -// * 1/2/4/8 = 1/2/4/8 byte elements in binary hex format. -// * i8/u8 = signed/unsigned 8-bit integers. -// * i16/u16 = signed/unsigned 16-bit integers. -// * i32/u32 = signed/unsigned 32-bit integers. -// * i64/u64 = signed/unsigned 64-bit integers. -// * f32 = 32-bit floating-point number. -// * f64 = 64-bit floating-point number. - -#ifndef IREE_BASE_BUFFER_STRING_UTIL_H_ -#define IREE_BASE_BUFFER_STRING_UTIL_H_ - -#include <stddef.h> -#include <stdint.h> - -#include <ostream> -#include <string> - -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "iree/base/shape.h" -#include "iree/base/status.h" - -namespace iree { - -// Defines how the elements within a buffer representation are interpreted -// during printing. -enum class BufferDataPrintMode { - // Interpret the data as if it were serialized bytes. - // In this mode no conversion is performed and the bytes in memory are printed - // as hex in groupings based on the element size. Shortened to 'b'. - kBinary, - // Interpret elements as signed integers; shortened to 'i'. - kSignedInteger, - // Interpret elements as unsigned integers; shortened to 'u'. - kUnsignedInteger, - // Interpret elements as floating-point values; shortened to 'f'. - kFloatingPoint, -}; - -// Returns the BufferDataPrintMode based on the shortened char in |str|. -StatusOr<BufferDataPrintMode> ParseBufferDataPrintMode(absl::string_view str); - -// Returns the size, in bytes, of the given type. -StatusOr<int> ParseBufferTypeElementSize(absl::string_view type_str); - -// Returns the canonical representation of a type based on its size in bytes and -// the specified printing mode. For example, with a size of 4 and a printing -// mode of kFloatingPoint it returns "f32". -StatusOr<std::string> MakeBufferTypeString(int element_size, - BufferDataPrintMode print_mode); - -// Returns a Shape parsed from the given NxMx... string. -StatusOr<Shape> ParseShape(absl::string_view shape_str); - -// Prints a shape and element type, e.g. 2x3xf32 -std::string PrintShapedTypeToString(const Shape& shape, - absl::string_view type_str); -void PrintShapedTypeToString(const Shape& shape, absl::string_view type_str, - std::string* out_result); -void PrintShapedTypeToStream(const Shape& shape, absl::string_view type_str, - std::ostream* stream); - -// Prints the given bytes as binary hex data. -// Bytes are grouped as elements according to the |element_size| in bytes. If -// the size of |contents| exceeds |max_entries|, the output will be truncated to -// that many entries followed by an ellipses. -StatusOr<std::string> PrintBinaryDataToString( - int element_size, absl::Span<const uint8_t> contents, size_t max_entries); -Status PrintBinaryDataToString(int element_size, - absl::Span<const uint8_t> contents, - size_t max_entries, std::string* out_result); -Status PrintBinaryDataToStream(int element_size, - absl::Span<const uint8_t> contents, - size_t max_entries, std::ostream* stream); - -// Prints a list of elements in a format indicated by the given shape. -// For example: [1 2 3][4 5 6] for a shape of 2x3. -// The bytes in contents will be interpreted as the type specified by -// |type_str|. If the size of |contents| exceeds |max_entries|, the output will -// be truncated to that many entries followed by an ellipses. -StatusOr<std::string> PrintNumericalDataToString( - const Shape& shape, absl::string_view type_str, - absl::Span<const uint8_t> contents, size_t max_entries); -Status PrintNumericalDataToString(const Shape& shape, - absl::string_view type_str, - absl::Span<const uint8_t> contents, - size_t max_entries, std::string* out_result); -Status PrintNumericalDataToStream(const Shape& shape, - absl::string_view type_str, - absl::Span<const uint8_t> contents, - size_t max_entries, std::ostream* stream); - -// Parses |data_str| as elements of the type specified by |type_str| and writes -// them into |output|. -Status ParseBufferDataAsType(absl::string_view data_str, - absl::string_view type_str, - absl::Span<uint8_t> output); - -// A non-owning struct for referencing parts of a string that describes a shaped -// buffer type, e.g. 1x2x3xf32=1 2 3 4 5 6 -struct BufferStringParts { - // The part of the string corresponding to the shape, e.g. 1x2x3. - absl::string_view shape_str; - // The part of the string corresponding to the type, e.g. f32 - absl::string_view type_str; - // The part of the string corresponding to the buffer data, e.g. 1 2 3 4 5 6 - absl::string_view data_str; - - // Extract the corresponding string parts from a string describing the entire - // buffer. - static BufferStringParts ExtractFrom(absl::string_view shaped_buf_str); -}; - -} // namespace iree - -#endif // IREE_BASE_BUFFER_STRING_UTIL_H_
diff --git a/iree/base/buffer_string_util_test.cc b/iree/base/buffer_string_util_test.cc deleted file mode 100644 index 77d1fac..0000000 --- a/iree/base/buffer_string_util_test.cc +++ /dev/null
@@ -1,221 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/buffer_string_util.h" - -#include "absl/strings/string_view.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" -#include "iree/testing/gtest.h" - -namespace iree { -namespace { - -using ::iree::testing::status::IsOkAndHolds; -using ::iree::testing::status::StatusIs; -using ::testing::ElementsAre; -using ::testing::Eq; - -TEST(BufferStringUtilTest, ParseBufferDataPrintMode) { - EXPECT_THAT(ParseBufferDataPrintMode("b"), - IsOkAndHolds(Eq(BufferDataPrintMode::kBinary))); - EXPECT_THAT(ParseBufferDataPrintMode("i"), - IsOkAndHolds(Eq(BufferDataPrintMode::kSignedInteger))); - EXPECT_THAT(ParseBufferDataPrintMode("u"), - IsOkAndHolds(Eq(BufferDataPrintMode::kUnsignedInteger))); - EXPECT_THAT(ParseBufferDataPrintMode("f"), - IsOkAndHolds(Eq(BufferDataPrintMode::kFloatingPoint))); - - EXPECT_THAT(ParseBufferDataPrintMode("bb"), - IsOkAndHolds(Eq(BufferDataPrintMode::kBinary))); - EXPECT_THAT(ParseBufferDataPrintMode("ii"), - IsOkAndHolds(Eq(BufferDataPrintMode::kSignedInteger))); - EXPECT_THAT(ParseBufferDataPrintMode("uu"), - IsOkAndHolds(Eq(BufferDataPrintMode::kUnsignedInteger))); - EXPECT_THAT(ParseBufferDataPrintMode("ff"), - IsOkAndHolds(Eq(BufferDataPrintMode::kFloatingPoint))); - - EXPECT_THAT(ParseBufferDataPrintMode(""), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferDataPrintMode("s"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferDataPrintMode("asdfasdf"), - StatusIs(StatusCode::kInvalidArgument)); -} - -TEST(BufferStringUtilTest, ParseBufferTypeElementSize) { - EXPECT_THAT(ParseBufferTypeElementSize("1"), IsOkAndHolds(Eq(1))); - EXPECT_THAT(ParseBufferTypeElementSize("2"), IsOkAndHolds(Eq(2))); - EXPECT_THAT(ParseBufferTypeElementSize("4"), IsOkAndHolds(Eq(4))); - EXPECT_THAT(ParseBufferTypeElementSize("8"), IsOkAndHolds(Eq(8))); - EXPECT_THAT(ParseBufferTypeElementSize("i8"), IsOkAndHolds(Eq(1))); - EXPECT_THAT(ParseBufferTypeElementSize("u8"), IsOkAndHolds(Eq(1))); - EXPECT_THAT(ParseBufferTypeElementSize("i16"), IsOkAndHolds(Eq(2))); - EXPECT_THAT(ParseBufferTypeElementSize("u16"), IsOkAndHolds(Eq(2))); - EXPECT_THAT(ParseBufferTypeElementSize("i32"), IsOkAndHolds(Eq(4))); - EXPECT_THAT(ParseBufferTypeElementSize("u32"), IsOkAndHolds(Eq(4))); - EXPECT_THAT(ParseBufferTypeElementSize("i64"), IsOkAndHolds(Eq(8))); - EXPECT_THAT(ParseBufferTypeElementSize("u64"), IsOkAndHolds(Eq(8))); - EXPECT_THAT(ParseBufferTypeElementSize("f32"), IsOkAndHolds(Eq(4))); - EXPECT_THAT(ParseBufferTypeElementSize("f64"), IsOkAndHolds(Eq(8))); - - EXPECT_THAT(ParseBufferTypeElementSize(""), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize(" "), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("a"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("ib"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("i"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("i543ff"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("i33"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("x32"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("f16"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("i1"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("i24"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseBufferTypeElementSize("i128"), - StatusIs(StatusCode::kInvalidArgument)); -} - -TEST(BufferStringUtilTest, MakeBufferTypeString) { - EXPECT_THAT(MakeBufferTypeString(1, BufferDataPrintMode::kBinary), - IsOkAndHolds(Eq("1"))); - EXPECT_THAT(MakeBufferTypeString(1, BufferDataPrintMode::kSignedInteger), - IsOkAndHolds(Eq("i8"))); - EXPECT_THAT(MakeBufferTypeString(2, BufferDataPrintMode::kUnsignedInteger), - IsOkAndHolds(Eq("u16"))); - EXPECT_THAT(MakeBufferTypeString(4, BufferDataPrintMode::kFloatingPoint), - IsOkAndHolds(Eq("f32"))); - - EXPECT_THAT(MakeBufferTypeString(0, BufferDataPrintMode::kBinary), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(MakeBufferTypeString(-1, BufferDataPrintMode::kBinary), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(MakeBufferTypeString(-1, BufferDataPrintMode::kSignedInteger), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(MakeBufferTypeString(-2, BufferDataPrintMode::kUnsignedInteger), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(MakeBufferTypeString(-4, BufferDataPrintMode::kFloatingPoint), - StatusIs(StatusCode::kInvalidArgument)); -} - -TEST(BufferStringUtilTest, ParseShape) { - EXPECT_THAT(ParseShape(""), IsOkAndHolds(Eq(Shape{}))); - EXPECT_THAT(ParseShape("0"), IsOkAndHolds(Eq(Shape{0}))); - EXPECT_THAT(ParseShape("1"), IsOkAndHolds(Eq(Shape{1}))); - EXPECT_THAT(ParseShape("1x2"), IsOkAndHolds(Eq(Shape{1, 2}))); - EXPECT_THAT(ParseShape(" 1 x 2 "), IsOkAndHolds(Eq(Shape{1, 2}))); - EXPECT_THAT(ParseShape("1x2x3x4x5"), IsOkAndHolds(Eq(Shape{1, 2, 3, 4, 5}))); - - EXPECT_THAT(ParseShape("abc"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("1xf"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("1xff23"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("1xf32"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("x"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("x1"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("1x"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("x1x2"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("1xx2"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("1x2x"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("0x-1"), StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShape("1x2x3x4x5x6"), - StatusIs(StatusCode::kInvalidArgument)); -} - -TEST(BufferStringUtilTest, PrintShapedTypeToString) { - EXPECT_EQ("f32", PrintShapedTypeToString(Shape{}, "f32")); - EXPECT_EQ("0xi32", PrintShapedTypeToString(Shape{0}, "i32")); - EXPECT_EQ("1xi32", PrintShapedTypeToString(Shape{1}, "i32")); - EXPECT_EQ("1x2xi8", PrintShapedTypeToString(Shape{1, 2}, "i8")); -} - -TEST(BufferStringUtilTest, PrintBinaryDataToString) { - EXPECT_THAT(PrintBinaryDataToString(1, {0, 1, 2, 3}, 10), - IsOkAndHolds(Eq("00 01 02 03"))); - EXPECT_THAT(PrintBinaryDataToString( - 2, {0x01, 0x02, 0x03, 0x04, 0xcc, 0xdd, 0xee, 0xff}, 10), - IsOkAndHolds(Eq("0102 0304 ccdd eeff"))); - EXPECT_THAT(PrintBinaryDataToString(4, {0xfa, 0xbc, 0xfa, 0xbc}, 10), - IsOkAndHolds(Eq("fabcfabc"))); - EXPECT_THAT(PrintBinaryDataToString( - 8, {0xfa, 0xbc, 0xfa, 0xbc, 0xfa, 0xbc, 0xfa, 0xbc}, 10), - IsOkAndHolds(Eq("fabcfabcfabcfabc"))); - - EXPECT_THAT(PrintBinaryDataToString(1, {0, 1, 2, 3}, 0), - IsOkAndHolds(Eq("..."))); - EXPECT_THAT(PrintBinaryDataToString(1, {0, 1, 2, 3}, 1), - IsOkAndHolds(Eq("00..."))); - EXPECT_THAT(PrintBinaryDataToString(1, {0, 1, 2, 3}, 2), - IsOkAndHolds(Eq("00 01..."))); - - EXPECT_THAT(PrintBinaryDataToString(-1, {0, 1, 2, 3}, 10), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT( - PrintBinaryDataToString( - 16, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, 10), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(PrintBinaryDataToString(3, {0, 1, 2, 3}, 10), - StatusIs(StatusCode::kInvalidArgument)); -} - -TEST(BufferStringUtilTest, PrintNumericalDataToString) { - EXPECT_EQ("0 1 2 3", - PrintNumericalDataToString({4}, "u8", {0, 1, 2, 3}, 10).value()); - EXPECT_EQ("[0 1][2 3]", - PrintNumericalDataToString({2, 2}, "u8", {0, 1, 2, 3}, 10).value()); - std::vector<int32_t> data = {0, -1, 2, 3}; - auto bytes = ReinterpretSpan<uint8_t>(absl::MakeSpan(data)); - EXPECT_EQ("0 -1 2 3", - PrintNumericalDataToString({4}, "i32", bytes, 10).value()); -} - -TEST(BufferStringUtilTest, ParseBufferDatai8) { - std::vector<uint8_t> data(4); - auto data_span = absl::MakeSpan(data); - ASSERT_OK(ParseBufferDataAsType("0 1 2 3", "i8", data_span)); - EXPECT_THAT(ReinterpretSpan<int8_t>(data_span), ElementsAre(0, 1, 2, 3)); -} - -TEST(BufferStringUtilTest, ParseBufferDatai32) { - std::vector<uint8_t> data(4 * sizeof(int32_t)); - auto data_span = absl::MakeSpan(data); - ASSERT_OK(ParseBufferDataAsType("0 1 2 3", "i32", data_span)); - EXPECT_THAT(ReinterpretSpan<int32_t>(data_span), ElementsAre(0, 1, 2, 3)); -} - -TEST(BufferStringUtilTest, ParseBufferDataf32) { - std::vector<uint8_t> data(4 * sizeof(float)); - auto data_span = absl::MakeSpan(data); - ASSERT_OK(ParseBufferDataAsType("0 1.1 2 3", "f32", data_span)); - EXPECT_THAT(ReinterpretSpan<float>(data_span), ElementsAre(0, 1.1, 2, 3)); -} - -TEST(BufferStringUtilTest, ParseBufferDataBinary) { - std::vector<uint8_t> data(4); - auto data_span = absl::MakeSpan(data); - ASSERT_OK(ParseBufferDataAsType("00 01 02 03", "8", data_span)); - EXPECT_THAT(data_span, ElementsAre(0, 1, 2, 3)); -} - -} // namespace -} // namespace iree
diff --git a/iree/base/shape.cc b/iree/base/shape.cc deleted file mode 100644 index 875d119..0000000 --- a/iree/base/shape.cc +++ /dev/null
@@ -1,100 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/shape.h" - -#include <cstddef> - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" - -namespace iree { - -Shape::Shape(const int* values, int size) : rank_(size) { - QCHECK_LE(size, kMaxRank) - << "Max rank of " << kMaxRank << ", shape has " << size; - std::memcpy(value_, values, size * sizeof(int)); -} - -std::string Shape::DebugString() const { - return absl::StrCat("[", absl::StrJoin(subspan(), ","), "]"); -} - -absl::Span<const int> Shape::subspan(size_type pos, size_type len) const { - if (len == npos) { - len = rank_ - pos; - } - return absl::MakeConstSpan(&value_[pos], len); -} - -void Shape::push_back(int dim) { - DCHECK_LE(rank_ + 1, kMaxRank); - value_[rank_++] = dim; -} - -void Shape::insert(iterator pos, int dim) { - int axis = static_cast<int>(pos - value_); - DCHECK_GE(axis, 0); - DCHECK_LE(axis, rank_); - DCHECK_LE(rank_ + 1, kMaxRank); - ++rank_; - for (int i = rank_ - 1; i > axis; --i) { - value_[i] = value_[i - 1]; - } - value_[axis] = dim; -} - -void Shape::erase(iterator pos) { - int axis = static_cast<int>(pos - value_); - DCHECK_GE(axis, 0); - DCHECK_LE(axis, rank_); - for (int i = axis; i < rank_ - 1; ++i) { - value_[i] = value_[i + 1]; - } - --rank_; -} - -int Shape::element_count() const { - size_t element_count = 1; - for (int i = 0; i < rank_; ++i) { - int dim = value_[i]; - if (dim == -1) { - return 0; - } - element_count *= dim; - } - return element_count; -} - -StatusOr<int> Shape::ResolveAxis(int axis) const { - if (rank_ == 0 && (axis == -1 || axis == 0)) { - // Scalar axes resolves to 0. - return 0; - } - - int new_axis = axis; - if (new_axis < 0) { - new_axis += rank_; - } - if (new_axis < 0 || new_axis >= rank_) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Axis " << new_axis << " (orig " << axis - << ") out of bounds of rank " << rank_; - } - return new_axis; -} - -} // namespace iree
diff --git a/iree/base/shape.h b/iree/base/shape.h deleted file mode 100644 index 758d66f..0000000 --- a/iree/base/shape.h +++ /dev/null
@@ -1,156 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_SHAPE_H_ -#define IREE_BASE_SHAPE_H_ - -#include <array> -#include <cstring> -#include <initializer_list> -#include <iterator> -#include <string> -#include <type_traits> -#include <vector> - -#include "absl/meta/type_traits.h" -#include "absl/types/span.h" -#include "iree/base/logging.h" -#include "iree/base/status.h" - -namespace iree { - -// For simplicity we limit our shapes to a max of rank-N (shape.size() == N) as -// this prevents dynamic allocations and rarely are there greater ranks. -constexpr int kMaxRank = 5; - -// Represent indices and lengths of tensors. -using Index = std::array<int, kMaxRank>; -using Length = std::array<int, kMaxRank>; - -// Represents the number of elements in multiple dimensions. -// Can be rank-0 (scalar) to rank-kMaxRank. Tries to match the API of -// std::vector and can be converted to a Span via subspan(). -// -// https://www.tensorflow.org/guide/tensors#shape -class Shape { - public: - using size_type = int; - static constexpr size_type npos = ~(size_type(0)); // NOLINT - using iterator = int*; - using const_iterator = const int*; - - Shape() = default; - Shape(const int* values, int size); - Shape(std::initializer_list<int> values) - : Shape(values.begin(), values.size()) {} - explicit Shape(absl::Span<const int> values) - : Shape(values.data(), values.size()) {} - - template <typename Iterator> - using EnableIfForwardIterator = absl::enable_if_t<std::is_convertible< - typename std::iterator_traits<Iterator>::iterator_category, - std::forward_iterator_tag>::value>; - template <typename Iterator, EnableIfForwardIterator<Iterator>* = nullptr> - Shape(Iterator first, Iterator last) { - rank_ = std::distance(first, last); - QCHECK_LE(rank_, kMaxRank); - for (int i = 0; first != last; ++i, static_cast<void>(++first)) { - value_[i] = *first; - } - } - - // Returns a string representation of the given shape. - std::string DebugString() const; - - // Size (aka 'rank') of the shape, counting the number of dimensions. - constexpr size_type size() const noexcept { return rank_; } - - // Whether the shape is rank-0 (scalar). - constexpr bool empty() const noexcept { return rank_ == 0; } - - // Returns the total elements in the tensor shape. - // Returns 0 if the tensor shape is not complete and 1 if the shape is a - // scalar value. - int element_count() const; - - // Resolves an axis in [-R,R) to the real axis value and verifies the range. - StatusOr<int> ResolveAxis(int axis) const; - - // Compares two shapes for equality. - inline static bool Equal(const Shape& a, const Shape& b) { - return a.rank_ == b.rank_ && - std::memcmp(a.value_, b.value_, a.rank_ * sizeof(value_[0])) == 0; - } - - int& operator[](size_type i) noexcept { - DCHECK_GE(i, 0); - DCHECK_LT(i, rank_); - return value_[i]; - } - - const int& operator[](size_type i) const noexcept { - DCHECK_GE(i, 0); - DCHECK_LT(i, rank_); - return value_[i]; - } - - int front() const noexcept { - DCHECK_GE(rank_, 1); - return value_[0]; - } - - int back() const noexcept { - DCHECK_GE(rank_, 1); - return value_[rank_ - 1]; - } - - constexpr iterator begin() const noexcept { - return const_cast<iterator>(&value_[0]); - } - constexpr iterator end() const noexcept { - return const_cast<iterator>(&value_[rank_]); - } - constexpr const_iterator cbegin() const noexcept { return &value_[0]; } - constexpr const_iterator cend() const noexcept { return &value_[rank_]; } - - absl::Span<const int> subspan(size_type pos = 0, size_type len = npos) const; - absl::Span<const int> data() const { return subspan(); } - - void push_back(int dim); - - void insert(iterator pos, int dim); - - void erase(iterator pos); - - void clear() { rank_ = 0; } - - private: - size_type rank_ = 0; - int value_[kMaxRank]; -}; - -inline bool operator==(const Shape& a, const Shape& b) { - return Shape::Equal(a, b); -} - -inline bool operator!=(const Shape& a, const Shape& b) { return !(a == b); } - -inline std::ostream& operator<<(std::ostream& stream, const Shape& shape) { - stream << shape.DebugString(); - return stream; -} - -} // namespace iree - -#endif // IREE_BASE_SHAPE_H_
diff --git a/iree/base/shape_test.cc b/iree/base/shape_test.cc deleted file mode 100644 index 0bd88d3..0000000 --- a/iree/base/shape_test.cc +++ /dev/null
@@ -1,221 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/shape.h" - -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" -#include "iree/testing/gtest.h" - -namespace iree { -namespace { - -using ::testing::ElementsAre; - -// Tests shapes that represent 0-D scalar values. -TEST(ShapeTest, Scalar) { - Shape shape; - EXPECT_EQ(0, shape.size()); - EXPECT_TRUE(shape.empty()); - EXPECT_EQ(1, shape.element_count()); - EXPECT_EQ(shape, shape); - EXPECT_EQ(0, shape.subspan().size()); - for (const int dim : shape) { - FAIL() << "Should have no dimensions, have: " << dim; - } - EXPECT_EQ(shape.begin(), shape.end()); - EXPECT_EQ(shape.cbegin(), shape.cend()); - shape.clear(); - EXPECT_EQ(0, shape.size()); -} - -// Tests the various ways of constructing a 1+D shape. -TEST(ShapeTest, NonScalarConstruction) { - EXPECT_EQ(0, Shape().size()); - EXPECT_EQ(0, Shape({}).size()); - EXPECT_EQ(1, Shape({10}).size()); - EXPECT_EQ(4, Shape({10, 20, 30, 40}).size()); - - std::vector<int> empty_data = {}; - EXPECT_EQ(0, Shape(empty_data.data(), empty_data.size()).size()); - EXPECT_EQ(0, Shape(empty_data.begin(), empty_data.end()).size()); - EXPECT_EQ(0, Shape(absl::MakeConstSpan(empty_data)).size()); - - EXPECT_THAT(Shape({}).subspan(), ElementsAre()); - EXPECT_THAT(Shape({10}).subspan(), ElementsAre(10)); - EXPECT_THAT(Shape({10, 20, 30, 40}).subspan(), ElementsAre(10, 20, 30, 40)); - - std::vector<int> valid_data = {10, 20, 30, 40}; - EXPECT_THAT(Shape(valid_data.begin(), valid_data.end()).subspan(), - ElementsAre(10, 20, 30, 40)); - EXPECT_THAT(Shape(absl::MakeConstSpan(valid_data)).subspan(), - ElementsAre(10, 20, 30, 40)); -} - -// Tests shapes that represent 1+D multidimensional values. -TEST(ShapeTest, NonScalarAccess) { - Shape shape = {1, 2, 3, 4}; - EXPECT_EQ(4, shape.size()); - EXPECT_FALSE(shape.empty()); - EXPECT_EQ(1 * 2 * 3 * 4, shape.element_count()); - EXPECT_EQ(shape, shape); - EXPECT_NE(shape, Shape({4, 3, 2, 1})); - EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4)); - std::vector<int> readout; - for (const int dim : shape) { - readout.push_back(dim); - } - EXPECT_THAT(readout, ElementsAre(1, 2, 3, 4)); - EXPECT_EQ(1, shape[0]); - EXPECT_EQ(2, shape[1]); - EXPECT_EQ(3, shape[2]); - EXPECT_EQ(4, shape[3]); - EXPECT_EQ(1, shape.front()); - EXPECT_EQ(4, shape.back()); -} - -TEST(ShapeTest, PushBack) { - Shape shape; - EXPECT_EQ(0, shape.size()); - - shape.push_back(10); - EXPECT_EQ(1, shape.size()); - EXPECT_EQ(10, shape.front()); - EXPECT_EQ(10, shape.back()); - EXPECT_EQ(10, shape[0]); - EXPECT_THAT(shape.subspan(), ElementsAre(10)); - - shape.push_back(20); - EXPECT_EQ(2, shape.size()); - EXPECT_EQ(10, shape.front()); - EXPECT_EQ(20, shape.back()); - EXPECT_EQ(10, shape[0]); - EXPECT_EQ(20, shape[1]); - EXPECT_THAT(shape.subspan(), ElementsAre(10, 20)); -} - -TEST(ShapeTest, Insert) { - Shape shape; - EXPECT_EQ(0, shape.size()); - - shape.insert(shape.begin(), 20); - EXPECT_THAT(shape.subspan(), ElementsAre(20)); - shape.insert(shape.begin(), 10); - EXPECT_THAT(shape.subspan(), ElementsAre(10, 20)); - shape.insert(shape.end(), 40); - EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 40)); - shape.insert(shape.begin() + 2, 30); - EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 30, 40)); - - Shape ex_shape{72, 4}; - ex_shape.insert(ex_shape.begin(), 144); - EXPECT_THAT(ex_shape.subspan(), ElementsAre(144, 72, 4)); -} - -TEST(ShapeTest, Erase) { - Shape shape = {1, 2, 3, 4}; - EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4)); - shape.erase(shape.begin()); - EXPECT_THAT(shape.subspan(), ElementsAre(2, 3, 4)); - shape.erase(shape.end()); - EXPECT_THAT(shape.subspan(), ElementsAre(2, 3)); - shape.erase(shape.begin() + 1); - EXPECT_THAT(shape.subspan(), ElementsAre(2)); - shape.erase(shape.end()); - EXPECT_THAT(shape.subspan(), ElementsAre()); -} - -TEST(ShapeTest, Clear) { - Shape shape; - EXPECT_EQ(0, shape.size()); - shape.clear(); - EXPECT_EQ(0, shape.size()); - - shape = Shape({1}); - shape.clear(); - EXPECT_EQ(0, shape.size()); - - shape = Shape({1, 2, 3, 4}); - shape.clear(); - EXPECT_EQ(0, shape.size()); -} - -TEST(ShapeTest, DebugString) { - EXPECT_EQ("[]", Shape({}).DebugString()); - EXPECT_EQ("[1]", Shape({1}).DebugString()); - EXPECT_EQ("[1,2]", Shape({1, 2}).DebugString()); -} - -TEST(ShapeTest, ElementCount) { - EXPECT_EQ(1, Shape({}).element_count()); - EXPECT_EQ(0, Shape({0}).element_count()); - EXPECT_EQ(1, Shape({1}).element_count()); - EXPECT_EQ(2, Shape({2, 1}).element_count()); - EXPECT_EQ(10, Shape({2, 5}).element_count()); - EXPECT_EQ(9216, Shape({72, 1, 128}).element_count()); - EXPECT_EQ(9216, Shape({1, 72, 128}).element_count()); - - // Partial shaping should yield no elements. - EXPECT_EQ(0, Shape({1, -1, 2, 3}).element_count()); -} - -TEST(ShapeTest, ResolveAxis) { - int axis; - ASSERT_OK_AND_ASSIGN(axis, Shape({0}).ResolveAxis(0)); - EXPECT_EQ(0, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(1)); - EXPECT_EQ(1, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(2)); - EXPECT_EQ(2, axis); - - EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(3).status())); -} - -TEST(ShapeTest, ResolveAxisNegative) { - int axis; - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-3)); - EXPECT_EQ(0, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-2)); - EXPECT_EQ(1, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-1)); - EXPECT_EQ(2, axis); - - EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(-4).status())); -} - -TEST(ShapeTest, ResolveAxisScalar) { - int axis; - ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(0)); - EXPECT_EQ(0, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(-1)); - EXPECT_EQ(0, axis); - - EXPECT_TRUE(IsInvalidArgument(Shape({}).ResolveAxis(1).status())); -} - -TEST(ShapeTest, Equality) { - EXPECT_EQ(Shape({}), Shape({})); - EXPECT_EQ(Shape({0}), Shape({0})); - EXPECT_EQ(Shape({1}), Shape({1})); - EXPECT_EQ(Shape({1, 2}), Shape({1, 2})); - - EXPECT_NE(Shape({}), Shape({1})); - EXPECT_NE(Shape({-1}), Shape({1})); - EXPECT_NE(Shape({1}), Shape({})); - EXPECT_NE(Shape({1}), Shape({2})); - EXPECT_NE(Shape({1, 2}), Shape({3, 4})); -} - -} // namespace -} // namespace iree
diff --git a/iree/base/shaped_buffer.h b/iree/base/shaped_buffer.h deleted file mode 100644 index 14f2df8..0000000 --- a/iree/base/shaped_buffer.h +++ /dev/null
@@ -1,77 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_IREE_BASE_SHAPED_BUFFER_H_ -#define IREE_IREE_BASE_SHAPED_BUFFER_H_ - -#include <stdint.h> - -#include <vector> - -#include "absl/types/span.h" -#include "iree/base/logging.h" -#include "iree/base/memory.h" -#include "iree/base/shape.h" - -namespace iree { - -// A struct representing a buffer of bytes that should be interpreted as being -// of the specified shape with elements of the specified size. -class ShapedBuffer { - public: - ShapedBuffer() = default; - // move only - ShapedBuffer(ShapedBuffer&& other) = default; - ShapedBuffer& operator=(ShapedBuffer&& other) = default; - ShapedBuffer(int8_t element_size, Shape shape, std::vector<uint8_t> contents) - : element_size_(element_size), - shape_(shape), - contents_(std::move(contents)) {} - - template <typename T> - static ShapedBuffer Create(Shape shape, absl::Span<const T> contents) { - CHECK_EQ(contents.size(), shape.element_count()); - auto byte_span = ReinterpretSpan<uint8_t>(contents); - return ShapedBuffer( - sizeof(T), shape, - std::vector<uint8_t>(byte_span.begin(), byte_span.end())); - } - - static inline bool Equal(const ShapedBuffer& a, const ShapedBuffer& b) { - return a.element_size_ == b.element_size_ && a.shape_ == b.shape_ && - a.contents_ == b.contents_; - } - - int8_t element_size() const { return element_size_; } - Shape shape() const { return shape_; } - absl::Span<const uint8_t> contents() const { return contents_; } - - private: - // Size of the buffer elements, in bytes. - int8_t element_size_; - Shape shape_; - std::vector<uint8_t> contents_; -}; - -inline bool operator==(const ShapedBuffer& a, const ShapedBuffer& b) { - return ShapedBuffer::Equal(a, b); -} - -inline bool operator!=(const ShapedBuffer& a, const ShapedBuffer& b) { - return !(a == b); -} - -} // namespace iree - -#endif // IREE_IREE_BASE_SHAPED_BUFFER_H_
diff --git a/iree/base/shaped_buffer_string_util.cc b/iree/base/shaped_buffer_string_util.cc deleted file mode 100644 index 3464e56..0000000 --- a/iree/base/shaped_buffer_string_util.cc +++ /dev/null
@@ -1,112 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/shaped_buffer_string_util.h" - -#include <stddef.h> -#include <stdint.h> - -#include <sstream> -#include <string> -#include <vector> - -#include "absl/container/fixed_array.h" -#include "absl/strings/ascii.h" -#include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "iree/base/buffer_string_util.h" -#include "iree/base/memory.h" -#include "iree/base/shape.h" -#include "iree/base/shaped_buffer.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" - -namespace iree { - -StatusOr<ShapedBuffer> ParseShapedBufferFromString( - absl::string_view shaped_buf_str) { - // Strip whitespace that may come along (linefeeds/etc). - shaped_buf_str = absl::StripAsciiWhitespace(shaped_buf_str); - shaped_buf_str = absl::StripPrefix(shaped_buf_str, "\""); - shaped_buf_str = absl::StripSuffix(shaped_buf_str, "\""); - if (shaped_buf_str.empty()) { - // Empty lines denote empty shaped_buffers. - return ShapedBuffer{}; - } - - // Split into the components we can work with: shape, type, and data. - auto str_parts = BufferStringParts::ExtractFrom(shaped_buf_str); - - ASSIGN_OR_RETURN(int element_size, - ParseBufferTypeElementSize(str_parts.type_str)); - ASSIGN_OR_RETURN(Shape shape, ParseShape(str_parts.shape_str)); - - int buffer_size = element_size * shape.element_count(); - std::vector<uint8_t> contents(buffer_size); - - if (!str_parts.data_str.empty()) { - RETURN_IF_ERROR(ParseBufferDataAsType( - str_parts.data_str, str_parts.type_str, absl::MakeSpan(contents))); - } - return ShapedBuffer(element_size, shape, std::move(contents)); -} - -StatusOr<std::string> PrintShapedBufferToString(const ShapedBuffer& shaped_buf, - BufferDataPrintMode print_mode, - size_t max_entries) { - std::string result; - RETURN_IF_ERROR( - PrintShapedBufferToString(shaped_buf, print_mode, max_entries, &result)); - return result; -} - -Status PrintShapedBufferToString(const ShapedBuffer& shaped_buf, - BufferDataPrintMode print_mode, - size_t max_entries, std::string* out_result) { - std::ostringstream stream; - RETURN_IF_ERROR( - PrintShapedBufferToStream(shaped_buf, print_mode, max_entries, &stream)); - *out_result = stream.str(); - return OkStatus(); -} - -Status PrintShapedBufferToStream(const ShapedBuffer& shaped_buffer, - BufferDataPrintMode print_mode, - size_t max_entries, std::ostream* stream) { - if (shaped_buffer.contents().empty()) { - // No data means the shaped_buffer is empty. We use the empty string to - // denote this (as we have no useful information). - return OkStatus(); - } - - ASSIGN_OR_RETURN( - std::string type_str, - MakeBufferTypeString(shaped_buffer.element_size(), print_mode)); - - PrintShapedTypeToStream(shaped_buffer.shape(), type_str, stream); - *stream << "="; - - if (print_mode == BufferDataPrintMode::kBinary) { - return PrintBinaryDataToStream(shaped_buffer.element_size(), - shaped_buffer.contents(), max_entries, - stream); - } - return PrintNumericalDataToStream(shaped_buffer.shape(), type_str, - shaped_buffer.contents(), max_entries, - stream); -} - -} // namespace iree
diff --git a/iree/base/shaped_buffer_string_util.h b/iree/base/shaped_buffer_string_util.h deleted file mode 100644 index a351750..0000000 --- a/iree/base/shaped_buffer_string_util.h +++ /dev/null
@@ -1,55 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utilities for parsing and printing ShapedBuffer, mostly useful for -// testing. The format is as described in -// https://github.com/google/iree/tree/main/iree/base/buffer_string_util.h - -#ifndef IREE_BASE_SHAPED_BUFFER_STRING_UTIL_H_ -#define IREE_BASE_SHAPED_BUFFER_STRING_UTIL_H_ - -#include <stddef.h> - -#include <ostream> -#include <string> - -#include "absl/strings/string_view.h" -#include "iree/base/buffer_string_util.h" -#include "iree/base/shape.h" -#include "iree/base/shaped_buffer.h" -#include "iree/base/status.h" - -namespace iree { - -// Parses a ShapedBuffer encoded in a string. -// The format accepted matches that produced by PrintShapedBufferToString. -StatusOr<ShapedBuffer> ParseShapedBufferFromString( - absl::string_view shaped_buf_str); - -// Prints a ShapedBuffer to a string encoded in the canonical format. -StatusOr<std::string> PrintShapedBufferToString(const ShapedBuffer& shaped_buf, - BufferDataPrintMode print_mode, - size_t max_entries); -Status PrintShapedBufferToString(const ShapedBuffer& shaped_buf, - BufferDataPrintMode print_mode, - size_t max_entries, std::string* out_result); - -// Prints a ShapedBuffer to a stream encoded in the canonical format. -Status PrintShapedBufferToStream(const ShapedBuffer& shaped_buf, - BufferDataPrintMode print_mode, - size_t max_entries, std::ostream* stream); - -} // namespace iree - -#endif // IREE_BASE_SHAPED_BUFFER_STRING_UTIL_H_
diff --git a/iree/base/shaped_buffer_string_util_test.cc b/iree/base/shaped_buffer_string_util_test.cc deleted file mode 100644 index 42077f7..0000000 --- a/iree/base/shaped_buffer_string_util_test.cc +++ /dev/null
@@ -1,191 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/shaped_buffer_string_util.h" - -#include "absl/strings/string_view.h" -#include "iree/base/buffer_string_util.h" -#include "iree/base/memory.h" -#include "iree/base/shaped_buffer.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" -#include "iree/testing/gtest.h" - -namespace iree { -namespace { - -using ::iree::testing::status::StatusIs; -using ::testing::ElementsAre; - -template <typename T> -absl::Span<const T> ReadAs(absl::Span<const uint8_t> data) { - return ReinterpretSpan<T>(data); -} - -void RoundTripTest(absl::string_view buffer_string, - BufferDataPrintMode print_mode) { - ASSERT_OK_AND_ASSIGN(auto shaped_buf, - ParseShapedBufferFromString(buffer_string)); - ASSERT_OK_AND_ASSIGN(auto new_string, PrintShapedBufferToString( - shaped_buf, print_mode, SIZE_MAX)); - EXPECT_EQ(buffer_string, new_string); -} - -void RoundTripTest(ShapedBuffer shaped_buf, BufferDataPrintMode print_mode) { - ASSERT_OK_AND_ASSIGN(auto new_string, PrintShapedBufferToString( - shaped_buf, print_mode, SIZE_MAX)); - ASSERT_OK_AND_ASSIGN(auto new_shaped_buf, - ParseShapedBufferFromString(new_string)); - EXPECT_EQ(shaped_buf, new_shaped_buf); -} - -TEST(ShapedBufferStringUtilTest, ParseShapedBufferFromStringEmpty) { - // Empty string = empty buffer_view. - ASSERT_OK_AND_ASSIGN(auto m0, ParseShapedBufferFromString("")); - EXPECT_TRUE(m0.contents().empty()); - EXPECT_EQ(Shape{}, m0.shape()); - EXPECT_EQ(0, m0.element_size()); - - // No = means no data. - ASSERT_OK_AND_ASSIGN(auto m1, ParseShapedBufferFromString("4x2xf32")); - EXPECT_EQ(4 * 2 * 4, m1.contents().size()); - EXPECT_EQ(Shape({4, 2}), m1.shape()); - EXPECT_EQ(4, m1.element_size()); - EXPECT_THAT(ReadAs<float>(m1.contents()), - ElementsAre(0, 0, 0, 0, 0, 0, 0, 0)); - - // No data after = means no data. - ASSERT_OK_AND_ASSIGN(auto m2, ParseShapedBufferFromString("4x2xf32=")); - EXPECT_EQ(4 * 2 * 4, m2.contents().size()); - EXPECT_EQ(Shape({4, 2}), m2.shape()); - EXPECT_EQ(4, m2.element_size()); - EXPECT_THAT(ReadAs<float>(m2.contents()), - ElementsAre(0, 0, 0, 0, 0, 0, 0, 0)); -} - -TEST(ShapedBufferStringUtilTest, ParseShapedBufferFromStringBinary) { - ASSERT_OK_AND_ASSIGN(auto m0, ParseShapedBufferFromString("4x1=00 01 02 03")); - EXPECT_EQ(Shape({4}), m0.shape()); - EXPECT_EQ(1, m0.element_size()); - EXPECT_THAT(ReadAs<uint8_t>(m0.contents()), ElementsAre(0, 1, 2, 3)); - - // Whitespace shouldn't matter. - ASSERT_OK_AND_ASSIGN(auto m1, ParseShapedBufferFromString("4x1=00,010203")); - EXPECT_EQ(Shape({4}), m1.shape()); - EXPECT_EQ(1, m1.element_size()); - EXPECT_THAT(ReadAs<uint8_t>(m1.contents()), ElementsAre(0, 1, 2, 3)); - - // Should fail on malformed hex bytes. - EXPECT_THAT(ParseShapedBufferFromString("4x1=1"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShapedBufferFromString("4x1=00003"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShapedBufferFromString("4x1=%0123%\1"), - StatusIs(StatusCode::kInvalidArgument)); - EXPECT_THAT(ParseShapedBufferFromString("4x1=00010203040506"), - StatusIs(StatusCode::kInvalidArgument)); -} - -TEST(ShapedBufferStringUtilTest, ParseShapedBufferFromStringAllowBrackets) { - ASSERT_OK_AND_ASSIGN(auto m0, - ParseShapedBufferFromString("4xi16=[[0][ 1 ][2]][3]")); - EXPECT_EQ(Shape({4}), m0.shape()); - EXPECT_EQ(2, m0.element_size()); - EXPECT_THAT(ReadAs<int16_t>(m0.contents()), ElementsAre(0, 1, 2, 3)); -} - -TEST(ShapedBufferStringUtilTest, ParseShapedBufferFromStringInteger) { - // Signed int16. - ASSERT_OK_AND_ASSIGN(auto m0, - ParseShapedBufferFromString("4xi16=0 12345 65535 -2")); - EXPECT_EQ(Shape({4}), m0.shape()); - EXPECT_EQ(2, m0.element_size()); - EXPECT_THAT(ReadAs<int16_t>(m0.contents()), ElementsAre(0, 12345, -1, -2)); - - // Unsigned int16. - ASSERT_OK_AND_ASSIGN(auto m1, - ParseShapedBufferFromString("4xu16=0 12345 65535 -2")); - EXPECT_EQ(Shape({4}), m1.shape()); - EXPECT_EQ(2, m1.element_size()); - EXPECT_THAT(ReadAs<uint16_t>(m1.contents()), - ElementsAre(0, 12345, 65535, 65534)); - - // Mixing separator types is ok. - ASSERT_OK_AND_ASSIGN( - auto m2, ParseShapedBufferFromString("4xu16=0, 12345, 65535, -2")); - EXPECT_EQ(Shape({4}), m2.shape()); - EXPECT_EQ(2, m2.element_size()); - EXPECT_THAT(ReadAs<uint16_t>(m2.contents()), - ElementsAre(0, 12345, 65535, 65534)); - - // Should fail on malformed integers bytes and out of bounds values. - EXPECT_FALSE(ParseShapedBufferFromString("4xi32=asodfj").ok()); - EXPECT_FALSE(ParseShapedBufferFromString("4xi32=0 1 2 3 4").ok()); -} - -TEST(ShapedBufferStringUtilTest, ParseShapedBufferFromStringFloat) { - // Float. - ASSERT_OK_AND_ASSIGN(auto m0, - ParseShapedBufferFromString("4xf32=0 1.0 1234 -2.0e-5")); - EXPECT_EQ(Shape({4}), m0.shape()); - EXPECT_EQ(4, m0.element_size()); - EXPECT_THAT(ReadAs<float>(m0.contents()), - ElementsAre(0.0f, 1.0f, 1234.0f, -2.0e-5f)); - - // Double. - ASSERT_OK_AND_ASSIGN(auto m1, ParseShapedBufferFromString( - "4xf64=0 1.0 123456789012345 -2.0e-5")); - EXPECT_EQ(Shape({4}), m1.shape()); - EXPECT_EQ(8, m1.element_size()); - EXPECT_THAT(ReadAs<double>(m1.contents()), - ElementsAre(0.0, 1.0, 123456789012345.0, -2.0e-5)); - - // Splat (repeating single element value). - ASSERT_OK_AND_ASSIGN(auto m2, ParseShapedBufferFromString("4xf32=2.2")); - EXPECT_EQ(Shape({4}), m2.shape()); - EXPECT_EQ(4, m2.element_size()); - EXPECT_THAT(ReadAs<float>(m2.contents()), - ElementsAre(2.2f, 2.2f, 2.2f, 2.2f)); - - // Should fail on malformed floats and out of bounds values. - EXPECT_FALSE(ParseShapedBufferFromString("4xf32=asodfj").ok()); - EXPECT_FALSE(ParseShapedBufferFromString("4xf32=0 1 2 3 4").ok()); -} - -TEST(ShapedBufferStringUtilTest, RoundTripParsePrint) { - RoundTripTest("4xi8=0 -1 2 3", BufferDataPrintMode::kSignedInteger); - RoundTripTest("4xi16=0 -1 2 3", BufferDataPrintMode::kSignedInteger); - RoundTripTest("4xu16=0 1 2 3", BufferDataPrintMode::kUnsignedInteger); - RoundTripTest("4xf32=0 1.1 2 3", BufferDataPrintMode::kFloatingPoint); - RoundTripTest("1x2x3xi8=[[0 1 2][3 4 5]]", - BufferDataPrintMode::kSignedInteger); -} - -TEST(ShapedBufferStringUtilTest, RoundTripPrintParse) { - RoundTripTest(ShapedBuffer::Create<int8_t>({4}, {0, 1, 2, 3}), - BufferDataPrintMode::kSignedInteger); - RoundTripTest(ShapedBuffer::Create<int16_t>({4}, {0, 1, 2, 3}), - BufferDataPrintMode::kSignedInteger); - RoundTripTest(ShapedBuffer::Create<uint16_t>({4}, {0, 1, 2, 3}), - BufferDataPrintMode::kSignedInteger); - RoundTripTest(ShapedBuffer::Create<float>({4}, {0, 1.1, 2, 3}), - BufferDataPrintMode::kSignedInteger); - RoundTripTest(ShapedBuffer::Create<int8_t>({1, 2, 3}, {0, 1, 2, 3, 4, 5}), - BufferDataPrintMode::kSignedInteger); - RoundTripTest(ShapedBuffer(1, {4}, {0, 1, 2, 3}), - BufferDataPrintMode::kBinary); -} - -} // namespace -} // namespace iree
diff --git a/iree/build_defs.oss.bzl b/iree/build_defs.oss.bzl index 6b1184d..a8586b2 100644 --- a/iree/build_defs.oss.bzl +++ b/iree/build_defs.oss.bzl
@@ -54,8 +54,6 @@ # Driver modules that register themselves at link time. IREE_DRIVER_MODULES = [ - # TODO(b/142004903): enable when Dawn HAL implementation is functional - # "//iree/hal/dawn:dawn_driver_module", "//iree/hal/dylib:dylib_driver_module", "//iree/hal/vmla:vmla_driver_module", "//iree/hal/vulkan:vulkan_driver_module",
diff --git a/iree/hal/api.cc b/iree/hal/api.cc index 25681e0..088f055 100644 --- a/iree/hal/api.cc +++ b/iree/hal/api.cc
@@ -391,11 +391,11 @@ LOG(ERROR) << "Unimplemented parser for element format FLOAT_16"; return IREE_STATUS_UNIMPLEMENTED; case IREE_HAL_ELEMENT_TYPE_FLOAT_32: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%F", + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G", *reinterpret_cast<const float*>(data.data)); break; case IREE_HAL_ELEMENT_TYPE_FLOAT_64: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%E", + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G", *reinterpret_cast<const double*>(data.data)); break; default: {
diff --git a/iree/hal/api_string_util_test.cc b/iree/hal/api_string_util_test.cc index 21aa960..8e9efc1 100644 --- a/iree/hal/api_string_util_test.cc +++ b/iree/hal/api_string_util_test.cc
@@ -762,11 +762,10 @@ IsOkAndHolds(Eq("9223372036854775807"))); EXPECT_THAT(FormatElement<uint64_t>(UINT64_MAX), IsOkAndHolds(Eq("18446744073709551615"))); - EXPECT_THAT(FormatElement<float>(1.5f), IsOkAndHolds(Eq("1.500000"))); + EXPECT_THAT(FormatElement<float>(1.5f), IsOkAndHolds(Eq("1.5"))); EXPECT_THAT(FormatElement<double>(1123.56789456789), - IsOkAndHolds(Eq("1.123568E+03"))); - EXPECT_THAT(FormatElement<double>(-1.5e-10), - IsOkAndHolds(Eq("-1.500000E-10"))); + IsOkAndHolds(Eq("1123.57"))); + EXPECT_THAT(FormatElement<double>(-1.5e-10), IsOkAndHolds(Eq("-1.5E-10"))); } TEST(ElementStringUtilTest, FormatOpaqueElement) { @@ -1029,9 +1028,8 @@ expect_round_trip("4xi16=0 -1 2 3"); expect_round_trip("4xu16=0 1 2 3"); expect_round_trip("2x2xi32=[0 1][2 3]"); - expect_round_trip("4xf32=0.000000 1.100000 2.000000 3.000000"); - expect_round_trip( - "4xf64=0.000000E+00 1.100000E+00 2.000000E+00 3.000000E+00"); + expect_round_trip("4xf32=0 1.1 2 3"); + expect_round_trip("4xf64=0 1.1 2 3"); expect_round_trip("1x2x3xi8=[[0 1 2][3 4 5]]"); expect_round_trip("2x*16=AABB CCDD"); expect_round_trip(
diff --git a/iree/hal/cts/BUILD b/iree/hal/cts/BUILD index 9c14754..4fa423d 100644 --- a/iree/hal/cts/BUILD +++ b/iree/hal/cts/BUILD
@@ -40,7 +40,6 @@ "//iree/hal/llvmjit:llvmjit_driver_module", # build-cleaner: keep "//iree/hal/vmla:vmla_driver_module", # build-cleaner: keep "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep - # "//iree/hal/dawn:dawn_driver_module", # build-cleaner: keep ] + PLATFORM_VULKAN_TEST_DEPS, )
diff --git a/iree/hal/dawn/BUILD b/iree/hal/dawn/BUILD deleted file mode 100644 index 0b50c82..0000000 --- a/iree/hal/dawn/BUILD +++ /dev/null
@@ -1,89 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# HAL implementation using Dawn and SPIR-V executables. -# https://dawn.googlesource.com/dawn - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "dawn_device", - srcs = ["dawn_device.cc"], - hdrs = ["dawn_device.h"], - deps = [ - "//iree/base:memory", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_queue", - "//iree/hal:device", - "//iree/hal:executable_cache", - "//iree/hal:semaphore", - "//iree/hal/host:host_local_allocator", - "//third_party/dawn:dawn_native", - "//third_party/dawn:dawncpp_headers", - "//third_party/dawn:libdawn_proc", # build-cleaner: keep - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "dawn_driver", - srcs = ["dawn_driver.cc"], - hdrs = ["dawn_driver.h"], - deps = [ - ":dawn_device", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:device_info", - "//iree/hal:driver", - "//third_party/dawn:dawn_headers", - "//third_party/dawn:dawn_native", - "//third_party/dawn:dawncpp", - "//third_party/dawn:dawncpp_headers", - "//third_party/dawn:libdawn_proc", # build-cleaner: keep - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - -# TODO(scotttodd): Use SwiftShader to test Vulkan backend -cc_test( - name = "dawn_driver_test", - srcs = ["dawn_driver_test.cc"], - deps = [ - ":dawn_driver", - "//iree/base:status_matchers", - "//iree/testing:gtest_main", - ], -) - -cc_library( - name = "dawn_driver_module", - srcs = ["dawn_driver_module.cc"], - deps = [ - ":dawn_driver", - "//iree/base:init", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:driver_registry", - ], - alwayslink = 1, -)
diff --git a/iree/hal/dawn/dawn_device.cc b/iree/hal/dawn/dawn_device.cc deleted file mode 100644 index 4aa3b53..0000000 --- a/iree/hal/dawn/dawn_device.cc +++ /dev/null
@@ -1,153 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dawn/dawn_device.h" - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/executable_cache.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { -namespace dawn { - -namespace { - -// ExecutableCache implementation that compiles but does nothing. -// This will be replaced with something functional soon. -class NoopExecutableCache final : public ExecutableCache { - public: - explicit NoopExecutableCache() {} - ~NoopExecutableCache() override = default; - - bool CanPrepareFormat(ExecutableFormat format) const override { - return false; - } - - StatusOr<ref_ptr<Executable>> PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) override { - return UnimplementedErrorBuilder(IREE_LOC) << "PrepareExecutable NYI"; - } -}; - -} // namespace - -DawnDevice::DawnDevice(const DeviceInfo& device_info, - ::wgpu::Device backend_device) - : Device(device_info), backend_device_(backend_device) { - IREE_TRACE_SCOPE0("DawnDevice::ctor"); - - // TODO(scotttodd): construct command queues, perform other initialization - - // Log some basic device info. - std::string backend_type_str; - auto* adapter = - reinterpret_cast<dawn_native::Adapter*>(device_info.device_id()); - switch (adapter->GetBackendType()) { - case dawn_native::BackendType::D3D12: - backend_type_str = "D3D12"; - break; - case dawn_native::BackendType::Metal: - backend_type_str = "Metal"; - break; - case dawn_native::BackendType::Null: - backend_type_str = "Null"; - break; - case dawn_native::BackendType::OpenGL: - backend_type_str = "OpenGL"; - break; - case dawn_native::BackendType::Vulkan: - backend_type_str = "Vulkan"; - break; - } - LOG(INFO) << "Created DawnDevice '" << device_info.name() << "' (" - << backend_type_str << ")"; -} - -DawnDevice::~DawnDevice() = default; - -std::string DawnDevice::DebugString() const { - return absl::StrCat(Device::DebugString(), // - "\n[DawnDevice]", // - "\n Command Queues: ", command_queues_.size()); -} - -StatusOr<ref_ptr<DescriptorSetLayout>> DawnDevice::CreateDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span<const DescriptorSetLayout::Binding> bindings) { - IREE_TRACE_SCOPE0("DawnDevice::CreateDescriptorSetLayout"); - return UnimplementedErrorBuilder(IREE_LOC) << "CreateDescriptorSetLayout NYI"; -} - -StatusOr<ref_ptr<DescriptorSet>> DawnDevice::CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span<const DescriptorSet::Binding> bindings) { - IREE_TRACE_SCOPE0("DawnDevice::CreateDescriptorSet"); - return UnimplementedErrorBuilder(IREE_LOC) << "CreateDescriptorSet NYI"; -} - -ref_ptr<ExecutableCache> DawnDevice::CreateExecutableCache() { - IREE_TRACE_SCOPE0("DawnDevice::CreateExecutableCache"); - return make_ref<NoopExecutableCache>(); -} - -StatusOr<ref_ptr<ExecutableLayout>> DawnDevice::CreateExecutableLayout( - absl::Span<DescriptorSetLayout* const> set_layouts, size_t push_constants) { - IREE_TRACE_SCOPE0("DawnDevice::CreateExecutableLayout"); - return UnimplementedErrorBuilder(IREE_LOC) << "CreateExecutableLayout NYI"; -} - -StatusOr<ref_ptr<CommandBuffer>> DawnDevice::CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) { - return UnimplementedErrorBuilder(IREE_LOC) << "CreateCommandBuffer NYI"; -} - -StatusOr<ref_ptr<Event>> DawnDevice::CreateEvent() { - return UnimplementedErrorBuilder(IREE_LOC) << "CreateEvent NYI"; -} - -StatusOr<ref_ptr<Semaphore>> DawnDevice::CreateSemaphore( - uint64_t initial_value) { - IREE_TRACE_SCOPE0("DawnDevice::CreateSemaphore"); - - return UnimplementedErrorBuilder(IREE_LOC) << "CreateSemaphore NYI"; -} - -Status DawnDevice::WaitAllSemaphores( - absl::Span<const SemaphoreValue> semaphores, absl::Time deadline) { - IREE_TRACE_SCOPE0("DawnDevice::WaitAllSemaphores"); - - return UnimplementedErrorBuilder(IREE_LOC) << "WaitAllSemaphores NYI"; -} - -StatusOr<int> DawnDevice::WaitAnySemaphore( - absl::Span<const SemaphoreValue> semaphores, absl::Time deadline) { - IREE_TRACE_SCOPE0("DawnDevice::WaitAnySemaphore"); - - return UnimplementedErrorBuilder(IREE_LOC) << "WaitAnySemaphore NYI"; -} - -Status DawnDevice::WaitIdle(absl::Time deadline) { - return UnimplementedErrorBuilder(IREE_LOC) << "WaitIdle"; -} - -} // namespace dawn -} // namespace hal -} // namespace iree
diff --git a/iree/hal/dawn/dawn_device.h b/iree/hal/dawn/dawn_device.h deleted file mode 100644 index a20da84..0000000 --- a/iree/hal/dawn/dawn_device.h +++ /dev/null
@@ -1,88 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DAWN_DAWN_DEVICE_H_ -#define IREE_HAL_DAWN_DAWN_DEVICE_H_ - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/memory.h" -#include "iree/hal/device.h" -#include "iree/hal/host/host_local_allocator.h" -#include "third_party/dawn/src/include/dawn/webgpu_cpp.h" -#include "third_party/dawn/src/include/dawn_native/DawnNative.h" - -namespace iree { -namespace hal { -namespace dawn { - -class DawnDevice final : public Device { - public: - explicit DawnDevice(const DeviceInfo& device_info, - ::wgpu::Device backend_device); - ~DawnDevice() override; - - std::string DebugString() const override; - - Allocator* allocator() const override { return &allocator_; } - - absl::Span<CommandQueue*> dispatch_queues() const override { - return RawPtrSpan(absl::MakeSpan(command_queues_)); - } - - absl::Span<CommandQueue*> transfer_queues() const override { - return RawPtrSpan(absl::MakeSpan(command_queues_)); - } - - ref_ptr<ExecutableCache> CreateExecutableCache() override; - - StatusOr<ref_ptr<DescriptorSetLayout>> CreateDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span<const DescriptorSetLayout::Binding> bindings) override; - - StatusOr<ref_ptr<ExecutableLayout>> CreateExecutableLayout( - absl::Span<DescriptorSetLayout* const> set_layouts, - size_t push_constants) override; - - StatusOr<ref_ptr<DescriptorSet>> CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span<const DescriptorSet::Binding> bindings) override; - - StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) override; - - StatusOr<ref_ptr<Event>> CreateEvent() override; - - StatusOr<ref_ptr<Semaphore>> CreateSemaphore(uint64_t initial_value) override; - - Status WaitAllSemaphores(absl::Span<const SemaphoreValue> semaphores, - absl::Time deadline) override; - StatusOr<int> WaitAnySemaphore(absl::Span<const SemaphoreValue> semaphores, - absl::Time deadline) override; - - Status WaitIdle(absl::Time deadline) override; - - private: - mutable host::HostLocalAllocator allocator_; - mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 1> command_queues_; - - ::wgpu::Device backend_device_; -}; - -} // namespace dawn -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DAWN_DAWN_DEVICE_H_
diff --git a/iree/hal/dawn/dawn_driver.cc b/iree/hal/dawn/dawn_driver.cc deleted file mode 100644 index 34dd792..0000000 --- a/iree/hal/dawn/dawn_driver.cc +++ /dev/null
@@ -1,122 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dawn/dawn_driver.h" - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/dawn/dawn_device.h" -#include "iree/hal/device_info.h" -#include "third_party/dawn/src/include/dawn/dawn_proc.h" - -namespace iree { -namespace hal { -namespace dawn { - -namespace { - -// Populates device information from the given dawn_native::Adapter. -StatusOr<DeviceInfo> PopulateDeviceInfo(dawn_native::Adapter* adapter) { - // TODO(scotttodd): Query these for each backend or implement? - DeviceFeatureBitfield supported_features = DeviceFeature::kNone; - // supported_features |= DeviceFeature::kDebugging; - // supported_features |= DeviceFeature::kCoverage; - // supported_features |= DeviceFeature::kProfiling; - - // TODO(scotttodd): more clever/sanitized device naming. - std::string device_id = "dawn"; - std::string device_name = absl::StrCat("dawn-", adapter->GetPCIInfo().name); - - return DeviceInfo(device_id, device_name, supported_features, - reinterpret_cast<DriverDeviceID>(adapter)); -} - -} // namespace - -DawnDriver::DawnDriver() : Driver("dawn") { - dawn_instance_ = absl::make_unique<dawn_native::Instance>(); -} - -DawnDriver::~DawnDriver() = default; - -StatusOr<std::vector<DeviceInfo>> DawnDriver::EnumerateAvailableDevices() { - IREE_TRACE_SCOPE0("DawnDriver::EnumerateAvailableDevices"); - - if (dawn_backend_adapters_.empty()) { - // Discover adapters (i.e. devices and their associated backend APIs). - // Retain the list of adapters so pointers are valid for the lifetime of - // this object. - dawn_instance_->DiscoverDefaultAdapters(); - dawn_backend_adapters_ = dawn_instance_->GetAdapters(); - } else { - // Assume that the list of adapters does not change. This is not guaranteed - // to be true, but we also don't want to invalidate pointers by requesting - // a new list each time. If the list of available devices would change, - // tearing down and creating a new DawnDriver may be your best option. - } - - // Convert to our HAL structure. - std::vector<DeviceInfo> device_infos; - device_infos.reserve(dawn_backend_adapters_.size()); - for (auto& adapter : dawn_backend_adapters_) { - // TODO(scotttodd): if we fail should we just ignore the device in the list? - ASSIGN_OR_RETURN(auto device_info, PopulateDeviceInfo(&adapter)); - device_infos.push_back(std::move(device_info)); - } - return device_infos; -} - -StatusOr<ref_ptr<Device>> DawnDriver::CreateDefaultDevice() { - IREE_TRACE_SCOPE0("DawnDriver::CreateDefaultDevice"); - - // Query available devices. - ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices()); - if (available_devices.empty()) { - return NotFoundErrorBuilder(IREE_LOC) << "No devices are available"; - } - - // Create the first non-null device, if any. - for (const auto& device : available_devices) { - auto* adapter = reinterpret_cast<dawn_native::Adapter*>(device.device_id()); - if (adapter->GetBackendType() != dawn_native::BackendType::Null) { - return CreateDevice(device.device_id()); - } - } - - // Otherwise create the first null device. - return CreateDevice(available_devices.front().device_id()); -} - -StatusOr<ref_ptr<Device>> DawnDriver::CreateDevice(DriverDeviceID device_id) { - IREE_TRACE_SCOPE0("DawnDriver::CreateDevice"); - - auto* adapter = reinterpret_cast<dawn_native::Adapter*>(device_id); - ASSIGN_OR_RETURN(auto device_info, PopulateDeviceInfo(adapter)); - - ::WGPUDevice c_backend_device = adapter->CreateDevice(); - if (!c_backend_device) { - return InternalErrorBuilder(IREE_LOC) << "Failed to create a Dawn device"; - } - DawnProcTable backend_procs = dawn_native::GetProcs(); - dawnProcSetProcs(&backend_procs); - ::wgpu::Device backend_device = ::wgpu::Device::Acquire(c_backend_device); - - return make_ref<DawnDevice>(device_info, backend_device); -} - -} // namespace dawn -} // namespace hal -} // namespace iree
diff --git a/iree/hal/dawn/dawn_driver.h b/iree/hal/dawn/dawn_driver.h deleted file mode 100644 index 0d5ea07..0000000 --- a/iree/hal/dawn/dawn_driver.h +++ /dev/null
@@ -1,49 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DAWN_DAWN_DRIVER_H_ -#define IREE_HAL_DAWN_DAWN_DRIVER_H_ - -#include <memory> -#include <vector> - -#include "iree/hal/driver.h" -#include "third_party/dawn/src/include/dawn/webgpu_cpp.h" -#include "third_party/dawn/src/include/dawn_native/DawnNative.h" - -namespace iree { -namespace hal { -namespace dawn { - -class DawnDriver final : public Driver { - public: - DawnDriver(); - ~DawnDriver() override; - - StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override; - - StatusOr<ref_ptr<Device>> CreateDefaultDevice() override; - - StatusOr<ref_ptr<Device>> CreateDevice(DriverDeviceID device_id) override; - - private: - std::unique_ptr<dawn_native::Instance> dawn_instance_; - std::vector<dawn_native::Adapter> dawn_backend_adapters_; -}; - -} // namespace dawn -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DAWN_DAWN_DRIVER_H_
diff --git a/iree/hal/dawn/dawn_driver_module.cc b/iree/hal/dawn/dawn_driver_module.cc deleted file mode 100644 index 6f62bbf..0000000 --- a/iree/hal/dawn/dawn_driver_module.cc +++ /dev/null
@@ -1,39 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <memory> - -#include "iree/base/init.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/dawn/dawn_driver.h" -#include "iree/hal/driver_registry.h" - -namespace iree { -namespace hal { -namespace dawn { -namespace { - -StatusOr<ref_ptr<Driver>> CreateDawnDriver() { return make_ref<DawnDriver>(); } - -} // namespace -} // namespace dawn -} // namespace hal -} // namespace iree - -IREE_REGISTER_MODULE_INITIALIZER(iree_hal_dawn_driver, { - QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register( - "dawn", ::iree::hal::dawn::CreateDawnDriver)); -}); -IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_dawn_driver);
diff --git a/iree/hal/dawn/dawn_driver_test.cc b/iree/hal/dawn/dawn_driver_test.cc deleted file mode 100644 index ccecbe4..0000000 --- a/iree/hal/dawn/dawn_driver_test.cc +++ /dev/null
@@ -1,45 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dawn/dawn_driver.h" - -#include "iree/base/status_matchers.h" -#include "iree/testing/gtest.h" - -namespace iree { -namespace hal { -namespace dawn { -namespace { - -TEST(DawnDriverTest, CreateDefaultDevice) { - DawnDriver dawn_driver; - ASSERT_OK_AND_ASSIGN(auto default_device, dawn_driver.CreateDefaultDevice()); -} - -TEST(DawnDriverTest, EnumerateDevicesAndCreate) { - DawnDriver dawn_driver; - - ASSERT_OK_AND_ASSIGN(auto available_devices, - dawn_driver.EnumerateAvailableDevices()); - ASSERT_GT(available_devices.size(), 0); - - ASSERT_OK_AND_ASSIGN( - auto first_device, - dawn_driver.CreateDevice(available_devices[0].device_id())); -} - -} // namespace -} // namespace dawn -} // namespace hal -} // namespace iree
diff --git a/iree/modules/check/BUILD b/iree/modules/check/BUILD index 87f83a0..863b189 100644 --- a/iree/modules/check/BUILD +++ b/iree/modules/check/BUILD
@@ -80,7 +80,6 @@ deps = [ "//iree/base:api", "//iree/base:api_util", - "//iree/base:buffer_string_util", "//iree/base:status", "//iree/hal:api", "//iree/modules/hal",
diff --git a/iree/modules/check/CMakeLists.txt b/iree/modules/check/CMakeLists.txt index fb6d66e..804b844 100644 --- a/iree/modules/check/CMakeLists.txt +++ b/iree/modules/check/CMakeLists.txt
@@ -80,7 +80,6 @@ absl::strings iree::base::api iree::base::api_util - iree::base::buffer_string_util iree::base::status iree::hal::api iree::modules::hal
diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc index d90d521..4647643 100644 --- a/iree/modules/check/native_module.cc +++ b/iree/modules/check/native_module.cc
@@ -24,7 +24,6 @@ #include "absl/strings/str_cat.h" #include "iree/base/api.h" #include "iree/base/api_util.h" -#include "iree/base/buffer_string_util.h" #include "iree/base/status.h" #include "iree/hal/api.h" #include "iree/modules/hal/hal_module.h" @@ -47,6 +46,22 @@ bytes.data_length / sizeof(T)); } +StatusOr<std::string> BufferViewToString(iree_hal_buffer_view_t* buffer_view) { + std::string result_str(4096, '\0'); + iree_status_t status; + do { + iree_host_size_t actual_length = 0; + status = iree_hal_buffer_view_format( + buffer_view, /*max_element_count=*/1024, result_str.size() + 1, + &result_str[0], &actual_length); + result_str.resize(actual_length); + } while (iree_status_is_out_of_range(status)); + if (!iree_status_is_ok(status)) { + return FromApiStatus(status, IREE_LOC); + } + return std::move(result_str); +} + template <typename T> Status ExpectAllTrue(iree_byte_span_t bytes) { EXPECT_THAT(AbslSpan<T>(bytes), Each(Not(0))); @@ -225,6 +240,9 @@ bool shape_eq = lhs_shape == rhs_shape; bool contents_eq = EqByteSpan(lhs_mapped_memory.contents, rhs_mapped_memory.contents); + iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory); + iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory); + if (!element_types_eq || !shape_eq || !contents_eq) { std::ostringstream os; os << "Expected equality of these values."; @@ -241,43 +259,18 @@ os << "\n" " lhs:\n" " "; - char lhs_element_type_str[16]; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_format_element_type(iree_hal_buffer_view_element_type(lhs), - sizeof(lhs_element_type_str), - lhs_element_type_str, nullptr), - IREE_LOC)); - // TODO(b/146898896): Remove dependence on Shape. - PrintShapedTypeToStream(Shape{lhs_shape}, lhs_element_type_str, &os); - os << "="; - RETURN_IF_ERROR( - PrintNumericalDataToStream(Shape{lhs_shape}, lhs_element_type_str, - {lhs_mapped_memory.contents.data, - lhs_mapped_memory.contents.data_length}, - /*max_entries=*/1024, &os)); + ASSIGN_OR_RETURN(auto lhs_str, BufferViewToString(lhs)); + os << lhs_str; os << "\n" " rhs:\n" " "; - char rhs_element_type_str[16]; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_format_element_type(iree_hal_buffer_view_element_type(rhs), - sizeof(rhs_element_type_str), - rhs_element_type_str, nullptr), - IREE_LOC)); - PrintShapedTypeToStream(Shape{rhs_shape}, rhs_element_type_str, &os); - os << "="; - RETURN_IF_ERROR( - PrintNumericalDataToStream(Shape{rhs_shape}, rhs_element_type_str, - {rhs_mapped_memory.contents.data, - rhs_mapped_memory.contents.data_length}, - /*max_entries=*/1024, &os)); + ASSIGN_OR_RETURN(auto rhs_str, BufferViewToString(rhs)); + os << rhs_str; // TODO(b/146898896): Use ADD_FAILURE_AT to propagate source location. ADD_FAILURE() << os.str(); } - iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory); - iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory); return OkStatus(); } @@ -328,6 +321,9 @@ AlmostEqByteSpan(lhs_mapped_memory.contents, rhs_mapped_memory.contents, lhs_element_type)); } + iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory); + iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory); + if (!element_types_eq || !shape_eq || !contents_could_be_almost_eq) { std::ostringstream os; os << "Expected near equality of these values."; @@ -344,43 +340,18 @@ os << "\n" " lhs:\n" " "; - char lhs_element_type_str[16]; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_format_element_type(iree_hal_buffer_view_element_type(lhs), - sizeof(lhs_element_type_str), - lhs_element_type_str, nullptr), - IREE_LOC)); - // TODO(b/146898896): Remove dependence on Shape. - PrintShapedTypeToStream(Shape{lhs_shape}, lhs_element_type_str, &os); - os << "="; - RETURN_IF_ERROR( - PrintNumericalDataToStream(Shape{lhs_shape}, lhs_element_type_str, - {lhs_mapped_memory.contents.data, - lhs_mapped_memory.contents.data_length}, - /*max_entries=*/1024, &os)); + ASSIGN_OR_RETURN(auto lhs_str, BufferViewToString(lhs)); + os << lhs_str; os << "\n" " rhs:\n" " "; - char rhs_element_type_str[16]; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_format_element_type(iree_hal_buffer_view_element_type(rhs), - sizeof(rhs_element_type_str), - rhs_element_type_str, nullptr), - IREE_LOC)); - PrintShapedTypeToStream(Shape{rhs_shape}, rhs_element_type_str, &os); - os << "="; - RETURN_IF_ERROR( - PrintNumericalDataToStream(Shape{rhs_shape}, rhs_element_type_str, - {rhs_mapped_memory.contents.data, - rhs_mapped_memory.contents.data_length}, - /*max_entries=*/1024, &os)); + ASSIGN_OR_RETURN(auto rhs_str, BufferViewToString(rhs)); + os << rhs_str; // TODO(b/146898896): Use ADD_FAILURE_AT to propagate source location. ADD_FAILURE() << os.str(); } - iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory); - iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory); return OkStatus(); }
diff --git a/iree/samples/custom_modules/BUILD b/iree/samples/custom_modules/BUILD index 411e69c..7a65e51 100644 --- a/iree/samples/custom_modules/BUILD +++ b/iree/samples/custom_modules/BUILD
@@ -55,9 +55,6 @@ deps = [ "//iree/base:api", "//iree/base:api_util", - "//iree/base:buffer_string_util", - "//iree/base:shape", - "//iree/base:shaped_buffer_string_util", "//iree/hal:api", "//iree/modules/hal", "//iree/vm",
diff --git a/iree/samples/custom_modules/CMakeLists.txt b/iree/samples/custom_modules/CMakeLists.txt index 0998857..44224a4 100644 --- a/iree/samples/custom_modules/CMakeLists.txt +++ b/iree/samples/custom_modules/CMakeLists.txt
@@ -59,9 +59,6 @@ DEPS iree::base::api iree::base::api_util - iree::base::buffer_string_util - iree::base::shape - iree::base::shaped_buffer_string_util iree::hal::api iree::modules::hal iree::vm
diff --git a/iree/samples/custom_modules/native_module.cc b/iree/samples/custom_modules/native_module.cc index 93ffcf9..ade2f9c 100644 --- a/iree/samples/custom_modules/native_module.cc +++ b/iree/samples/custom_modules/native_module.cc
@@ -19,9 +19,6 @@ #include "iree/base/api.h" #include "iree/base/api_util.h" -#include "iree/base/buffer_string_util.h" -#include "iree/base/shape.h" -#include "iree/base/shaped_buffer_string_util.h" #include "iree/hal/api.h" #include "iree/modules/hal/hal_module.h" #include "iree/vm/module_abi_cc.h" @@ -133,56 +130,36 @@ Status Initialize(int32_t unique_id) { // Allocate a unique ID to demonstrate per-context state. auto str_buffer = "ctx_" + std::to_string(unique_id); - return FromApiStatus( + RETURN_IF_ERROR(FromApiStatus( iree_custom_message_create(iree_make_cstring_view(str_buffer.c_str()), allocator_, &unique_message_), - IREE_LOC); + IREE_LOC)); + + // Setup a host-local allocator we can use because this sample doesn't have + // a real device allocator. + RETURN_IF_ERROR(FromApiStatus(iree_hal_allocator_create_host_local( + allocator_, &host_local_allocator_), + IREE_LOC)); + + return OkStatus(); } // custom.buffer_to_message(%buffer_view) -> %result StatusOr<vm::ref<iree_custom_message_t>> BufferToMessage( vm::ref<iree_hal_buffer_view_t> buffer_view) { - iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view.get()); - - // Map the buffer memory so we can read it back. - iree_hal_mapped_memory_t mapped_memory; - RETURN_IF_ERROR( - FromApiStatus(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ, - 0, IREE_WHOLE_BUFFER, &mapped_memory), - IREE_LOC)); - - // NOTE: these string methods take the old Shape type and as such have a - // rank limit. That limit is just an artifact of those APIs, not the - // buffer view shape type. - absl::InlinedVector<int32_t, kMaxRank> shape(kMaxRank); - size_t rank = 0; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_buffer_view_shape(buffer_view.get(), shape.capacity(), - shape.data(), &rank), - IREE_LOC)); - shape.resize(rank); - char element_type_str[16]; - RETURN_IF_ERROR( - FromApiStatus(iree_hal_format_element_type( - iree_hal_buffer_view_element_type(buffer_view.get()), - sizeof(element_type_str), element_type_str, nullptr), - IREE_LOC)); - - // Print the buffer contents using our helpers. - std::string string_value; - RETURN_IF_ERROR(PrintNumericalDataToString( - Shape{shape}, element_type_str, - {mapped_memory.contents.data, mapped_memory.contents.data_length}, - /*max_entries=*/1024, &string_value)); - - // Prefix shape/type. - string_value = - absl::StrCat(PrintShapedTypeToString(Shape{shape}, element_type_str), - "=", string_value); - - // Unmap the buffer when we are done with it. - RETURN_IF_ERROR( - FromApiStatus(iree_hal_buffer_unmap(buffer, &mapped_memory), IREE_LOC)); + // Convert the buffer view to a [shape]x[type]=[contents] string. + std::string string_value(4096, '\0'); + iree_status_t status; + do { + iree_host_size_t actual_length = 0; + status = iree_hal_buffer_view_format( + buffer_view.get(), /*max_element_count=*/1024, + string_value.size() + 1, &string_value[0], &actual_length); + string_value.resize(actual_length); + } while (iree_status_is_out_of_range(status)); + if (!iree_status_is_ok(status)) { + return FromApiStatus(status, IREE_LOC); + } // Pack the string contents into a message. vm::ref<iree_custom_message_t> message; @@ -197,57 +174,17 @@ // custom.message_to_buffer(%message) -> %buffer_view StatusOr<vm::ref<iree_hal_buffer_view_t>> MessageToBuffer( vm::ref<iree_custom_message_t> message) { - // NOTE: these old-style parsing routines need to be updated for the new - // type system. They use different types, different shapes, etc. - auto str_parts = BufferStringParts::ExtractFrom( - absl::string_view(message->value.data, message->value.size)); - iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_parse_element_type( - {str_parts.type_str.data(), str_parts.type_str.size()}, - &element_type), - IREE_LOC)); - ASSIGN_OR_RETURN(auto shape, ParseShape(str_parts.shape_str)); - - // TODO(benvanik): plumb through an allocator we can use. - size_t allocation_size = - shape.element_count() * iree_hal_element_byte_count(element_type); - vm::ref<iree_hal_buffer_t> buffer; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_heap_buffer_allocate( - IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - static_cast<iree_hal_buffer_usage_t>( - IREE_HAL_BUFFER_USAGE_ALL | IREE_HAL_BUFFER_USAGE_CONSTANT), - allocation_size, IREE_ALLOCATOR_SYSTEM, IREE_ALLOCATOR_SYSTEM, - &buffer), - IREE_LOC)); - if (!str_parts.data_str.empty()) { - // Map the buffer memory so we can write it with the data contents. - iree_hal_mapped_memory_t mapped_memory; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_buffer_map(buffer.get(), - IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0, - IREE_WHOLE_BUFFER, &mapped_memory), - IREE_LOC)); - - // Parse the data from the string right into the buffer. - RETURN_IF_ERROR(ParseBufferDataAsType( - str_parts.data_str, str_parts.type_str, - absl::MakeSpan(mapped_memory.contents.data, - mapped_memory.contents.data_length))); - - // Unmap the buffer when we are done with it. - RETURN_IF_ERROR(FromApiStatus( - iree_hal_buffer_unmap(buffer.get(), &mapped_memory), IREE_LOC)); - } - - // Wrap in a buffer view to pass back into the VM. + // Convert the [shape]x[type]=[contents] string to a buffer view. + auto input_string = + absl::string_view(message->value.data, message->value.size); vm::ref<iree_hal_buffer_view_t> buffer_view; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_buffer_view_create(buffer.get(), shape.data().data(), - shape.size(), element_type, - IREE_ALLOCATOR_SYSTEM, &buffer_view), - IREE_LOC)); + iree_status_t status = iree_hal_buffer_view_parse( + iree_string_view_t{input_string.data(), input_string.size()}, + host_local_allocator_.get(), allocator_, &buffer_view); + if (!iree_status_is_ok(status)) { + return FromApiStatus(status, IREE_LOC) + << "Parsing value '" << input_string << "'"; + } return std::move(buffer_view); } @@ -289,6 +226,11 @@ // perform during operation. iree_allocator_t allocator_ = IREE_ALLOCATOR_SYSTEM; + // HAL buffer allocator that uses host-local memory. This is just for this + // test as we don't actually use a HAL device and don't have a real device + // allocator to use. + vm::ref<iree_hal_allocator_t> host_local_allocator_; + // A unique message owned by the state and returned to the VM. // This demonstrates any arbitrary per-context state one may want to store. vm::ref<iree_custom_message_t> unique_message_;
diff --git a/iree/schemas/BUILD b/iree/schemas/BUILD index b72d7a9..4d9d6d1 100644 --- a/iree/schemas/BUILD +++ b/iree/schemas/BUILD
@@ -33,12 +33,6 @@ "--gen-object-api", ] -iree_flatbuffer_cc_library( - name = "buffer_data_def_cc_fbs", - srcs = ["buffer_data_def.fbs"], - flatc_args = FLATC_ARGS, -) - # TODO(benvanik): also expose as C using flatcc. iree_flatbuffer_cc_library( name = "bytecode_module_def_cc_fbs", @@ -53,12 +47,6 @@ ) iree_flatbuffer_cc_library( - name = "interpreter_module_def_cc_fbs", - srcs = ["interpreter_module_def.fbs"], - flatc_args = FLATC_ARGS, -) - -iree_flatbuffer_cc_library( name = "llvmir_executable_def_cc_fbs", srcs = ["llvmir_executable_def.fbs"], flatc_args = FLATC_ARGS, @@ -79,10 +67,8 @@ iree_build_test( name = "schema_build_test", targets = [ - ":buffer_data_def_cc_fbs", ":bytecode_module_def_cc_fbs", ":dylib_executable_def_cc_fbs", - ":interpreter_module_def_cc_fbs", ":llvmir_executable_def_cc_fbs", ":spirv_executable_def_cc_fbs", ":vmla_executable_def_cc_fbs", @@ -90,10 +76,8 @@ ) REFLECTION_SRCS = [] if not FLATBUFFER_SUPPORTS_REFLECTIONS else [ - "buffer_data_def.bfbs", "bytecode_module_def.bfbs", "dylib_executable_def.bfbs", - "interpreter_module_def.bfbs", "llvmir_executable_def.bfbs", "spirv_executable_def.bfbs", "vmla_executable_def.bfbs",
diff --git a/iree/schemas/CMakeLists.txt b/iree/schemas/CMakeLists.txt index ebabfc0..751ed40 100644 --- a/iree/schemas/CMakeLists.txt +++ b/iree/schemas/CMakeLists.txt
@@ -16,19 +16,6 @@ flatbuffer_cc_library( NAME - buffer_data_def_cc_fbs - SRCS - "buffer_data_def.fbs" - FLATC_ARGS - "--keep-prefix" - "--scoped-enums" - "--reflect-names" - "--gen-object-api" - PUBLIC -) - -flatbuffer_cc_library( - NAME bytecode_module_def_cc_fbs SRCS "bytecode_module_def.fbs" @@ -55,19 +42,6 @@ flatbuffer_cc_library( NAME - interpreter_module_def_cc_fbs - SRCS - "interpreter_module_def.fbs" - FLATC_ARGS - "--keep-prefix" - "--scoped-enums" - "--reflect-names" - "--gen-object-api" - PUBLIC -) - -flatbuffer_cc_library( - NAME llvmir_executable_def_cc_fbs SRCS "llvmir_executable_def.fbs"
diff --git a/iree/schemas/buffer_data_def.fbs b/iree/schemas/buffer_data_def.fbs deleted file mode 100644 index 546e2bf..0000000 --- a/iree/schemas/buffer_data_def.fbs +++ /dev/null
@@ -1,27 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace iree; - -enum BufferConstantEncoding : uint8 { - DENSE = 0, - SPLAT, -} - -table BufferDataDef { - encoding:BufferConstantEncoding; - element_width:int; - shape:[int]; - contents:[byte]; -}
diff --git a/iree/schemas/interpreter_module_def.fbs b/iree/schemas/interpreter_module_def.fbs deleted file mode 100644 index 189e424..0000000 --- a/iree/schemas/interpreter_module_def.fbs +++ /dev/null
@@ -1,104 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace iree; - -// 'Executable MODule'. -file_identifier "EMOD"; -file_extension "emod"; - -table FloatTypeDef { - width:int; -} - -table IntegerTypeDef { - width:int; -} - -table UnknownTypeDef { - dialect:string; - type_data:string; -} - -union ElementTypeDefUnion { - FloatTypeDef, - IntegerTypeDef, - UnknownTypeDef, -} - -table ElementTypeDef { - type_union:ElementTypeDefUnion; -} - -table MemRefTypeDef { - element_type:ElementTypeDef; - shape:[int]; - memory_space:int; -} - -table DeviceTypeDef {} -table CommandBufferTypeDef {} -table EventTypeDef {} -table SemaphoreTypeDef {} -table FenceTypeDef {} - -union TypeDefUnion { - MemRefTypeDef, - DeviceTypeDef, - CommandBufferTypeDef, - EventTypeDef, - SemaphoreTypeDef, - FenceTypeDef, -} - -table TypeDef { - type_union:TypeDefUnion; -} - -table FunctionTypeDef { - inputs:[TypeDef]; - results:[TypeDef]; -} - -table BytecodeDef { - local_count:int; - contents:[byte]; -} - -table FunctionAttributeDef { - key:string; - value:string; -} - -table FunctionDef { - name:string; - type:FunctionTypeDef; - - attrs:[FunctionAttributeDef]; - - bytecode:BytecodeDef; -} - -table FunctionTableDef { - functions:[FunctionDef]; - imports:[int]; - exports:[int]; -} - -table ModuleDef { - name:string; - function_table:FunctionTableDef; -} - -root_type ModuleDef;
diff --git a/iree/tools/BUILD b/iree/tools/BUILD index 3ee8120..6f3cd56 100644 --- a/iree/tools/BUILD +++ b/iree/tools/BUILD
@@ -344,10 +344,6 @@ hdrs = ["vm_util.h"], deps = [ "//iree/base:api_util", - "//iree/base:buffer_string_util", - "//iree/base:shape", - "//iree/base:shaped_buffer", - "//iree/base:shaped_buffer_string_util", "//iree/base:signature_mangle", "//iree/base:status", "//iree/hal:api",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt index 4f1e8b4..9d57aff 100644 --- a/iree/tools/CMakeLists.txt +++ b/iree/tools/CMakeLists.txt
@@ -362,10 +362,6 @@ absl::span absl::strings iree::base::api_util - iree::base::buffer_string_util - iree::base::shape - iree::base::shaped_buffer - iree::base::shaped_buffer_string_util iree::base::signature_mangle iree::base::status iree::hal::api
diff --git a/iree/tools/vm_util.cc b/iree/tools/vm_util.cc index 30877f9..f911a34 100644 --- a/iree/tools/vm_util.cc +++ b/iree/tools/vm_util.cc
@@ -21,10 +21,6 @@ #include "absl/strings/strip.h" #include "absl/types/span.h" #include "iree/base/api_util.h" -#include "iree/base/buffer_string_util.h" -#include "iree/base/shape.h" -#include "iree/base/shaped_buffer.h" -#include "iree/base/shaped_buffer_string_util.h" #include "iree/base/signature_mangle.h" #include "iree/base/status.h" #include "iree/hal/api.h" @@ -130,51 +126,14 @@ break; } case RawSignatureParser::Type::kBuffer: { - ASSIGN_OR_RETURN(auto shaped_buffer, - ParseShapedBufferFromString(input_string), - _ << "Parsing value '" << input_string << "'"); - // Allocate the buffer. - iree_hal_buffer_t* buf = nullptr; - // TODO(benvanik): combined function for linear to optimal upload. - iree_device_size_t allocation_size = - shaped_buffer.shape().element_count() * - shaped_buffer.element_size(); - RETURN_IF_ERROR(FromApiStatus( - iree_hal_allocator_allocate_buffer( - allocator, - static_cast<iree_hal_memory_type_t>( - IREE_HAL_MEMORY_TYPE_HOST_LOCAL | - IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE), - static_cast<iree_hal_buffer_usage_t>( - IREE_HAL_BUFFER_USAGE_ALL | IREE_HAL_BUFFER_USAGE_CONSTANT), - allocation_size, &buf), - IREE_LOC)) - << "Allocating buffer"; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_buffer_write_data(buf, 0, shaped_buffer.contents().data(), - shaped_buffer.contents().size()), - IREE_LOC)) - << "Populating buffer contents "; - - absl::InlinedVector<iree_hal_dim_t, 5> dims( - shaped_buffer.shape().size()); - // TODO(laurenzo): The following should work but Shape iterators - // cause access violations. - // std::copy(shaped_buffer.shape().begin(), shaped_buffer.shape().end(), - // dims.begin()); - for (size_t i = 0; i < dims.size(); ++i) { - dims[i] = shaped_buffer.shape()[i]; - } - - // Wrap in buffer view. iree_hal_buffer_view_t* buffer_view = nullptr; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_buffer_view_create(buf, dims.data(), dims.size(), - hal_element_type, IREE_ALLOCATOR_SYSTEM, - &buffer_view), - IREE_LOC)) - << "Creating buffer view"; - iree_hal_buffer_release(buf); + iree_status_t status = iree_hal_buffer_view_parse( + iree_string_view_t{input_string.data(), input_string.size()}, + allocator, IREE_ALLOCATOR_SYSTEM, &buffer_view); + if (!iree_status_is_ok(status)) { + return FromApiStatus(status, IREE_LOC) + << "Parsing value '" << input_string << "'"; + } auto buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view); RETURN_IF_ERROR(FromApiStatus(iree_vm_variant_list_append_ref_move( variant_list, &buffer_view_ref), @@ -231,63 +190,20 @@ return InvalidArgumentErrorBuilder(IREE_LOC) << "failed dereferencing variant " << i; } - auto* buffer = iree_hal_buffer_view_buffer(buffer_view); - if (!buffer) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "failed to get buffer for variant " << i; + + std::string result_str(4096, '\0'); + iree_status_t status; + do { + iree_host_size_t actual_length = 0; + status = iree_hal_buffer_view_format( + buffer_view, /*max_element_count=*/1024, result_str.size() + 1, + &result_str[0], &actual_length); + result_str.resize(actual_length); + } while (iree_status_is_out_of_range(status)); + if (!iree_status_is_ok(status)) { + return FromApiStatus(status, IREE_LOC); } - auto print_mode = BufferDataPrintMode::kFloatingPoint; - int8_t element_size = 4; - - // Copy the dims out of the buffer view. - absl::InlinedVector<iree_hal_dim_t, 5> dims32( - iree_hal_buffer_view_shape_rank(buffer_view)); - iree_hal_buffer_view_shape(buffer_view, dims32.size(), dims32.data(), - nullptr); - absl::InlinedVector<int, 5> dims(dims32.size()); - std::copy(dims32.begin(), dims32.end(), dims.begin()); - Shape shape{dims}; - - switch (desc.buffer.scalar_type) { - case AbiConstants::ScalarType::kIeeeFloat16: - case AbiConstants::ScalarType::kIeeeFloat32: - case AbiConstants::ScalarType::kIeeeFloat64: - print_mode = BufferDataPrintMode::kFloatingPoint; - break; - case AbiConstants::ScalarType::kSint8: - case AbiConstants::ScalarType::kSint16: - case AbiConstants::ScalarType::kSint32: - case AbiConstants::ScalarType::kSint64: - print_mode = BufferDataPrintMode::kSignedInteger; - break; - case AbiConstants::ScalarType::kUint8: - case AbiConstants::ScalarType::kUint16: - case AbiConstants::ScalarType::kUint32: - case AbiConstants::ScalarType::kUint64: - print_mode = BufferDataPrintMode::kUnsignedInteger; - break; - default: - print_mode = BufferDataPrintMode::kBinary; - break; - } - element_size = AbiConstants::kScalarTypeSize[static_cast<unsigned>( - desc.buffer.scalar_type)]; - - iree_hal_mapped_memory_t mapped_memory; - RETURN_IF_ERROR(FromApiStatus( - iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ, 0, - IREE_WHOLE_BUFFER, &mapped_memory), - IREE_LOC)) - << "mapping hal buffer"; - auto contents = absl::MakeConstSpan(mapped_memory.contents.data, - mapped_memory.contents.data_length); - ShapedBuffer shaped_buffer( - element_size, shape, - std::vector<uint8_t>(contents.begin(), contents.end())); - ASSIGN_OR_RETURN(auto result_str, PrintShapedBufferToString( - shaped_buffer, print_mode, 1024)); - iree_hal_buffer_unmap(buffer, &mapped_memory); *os << result_str << "\n"; break; }
diff --git a/iree/tools/vm_util.h b/iree/tools/vm_util.h index c039a1d..1a793c3 100644 --- a/iree/tools/vm_util.h +++ b/iree/tools/vm_util.h
@@ -46,7 +46,7 @@ // Buffers should be in the IREE standard shaped buffer format: // [shape]xtype=[value] // described in -// https://github.com/google/iree/tree/main/iree/base/buffer_string_util.h +// https://github.com/google/iree/tree/main/iree/hal/api.h // Uses |allocator| to allocate the buffers. // Uses descriptors in |descs| for type information and validation. // The returned variant list must be freed by the caller. @@ -61,7 +61,7 @@ // Prints buffers in the IREE standard shaped buffer format: // [shape]xtype=[value] // described in -// https://github.com/google/iree/tree/main/iree/base/buffer_string_util.h +// https://github.com/google/iree/tree/main/iree/hal/api.h // Uses descriptors in |descs| for type information and validation. Status PrintVariantList(absl::Span<const RawSignatureParser::Description> descs, iree_vm_variant_list_t* variant_list,