Merge main -> google
* f92602b7 Always emitting framepointers in generated ELFs (#3987)
* 89ef6328 bump Tracy to get Android fixes (#3988)
* 7f989eb2 Disable MLIR crash reproducer on CI in python tests. (#3943)
* 4db3d08c Adding a demonstration of using the flow dialect post-partitioning. (#3701)
* 589dfa7b Remove no-longer-functional flag (#3961)
* 64543172 Fix MacOS builds after hack-and-slash PR (#3962)
* 8fb887e4 Update links to coverage tables (#3956)
* c93facb4 Adding iree_atomic_slist_t type. (#3917)
* bd082ca6 Merge pull request #3874 from google/benvanik-hack-and-slash
* f4f4ea26 Use UnitTestSpec in tf.keras.layers tests (#3935)
* 7b8e9f75 Reverting flatcc to use our own cmake file for cross-compilation.
* 323e1fde Simplify dylib driver switch.
* 55f3de0d Only register VMLA driver in bindings/java/.
* f429916f Fix warning flag on Windows and HAL CTS driver registration. (#3911)
* 7951e228 Drop IREE_DRIVER_MODULES and iree/base:initializer from ModelBuilder.
* 4e111e2f Disable layering_check in iree/hal/drivers/BUILD.
* 4773736d Add package to iree/base/testing/BUILD.
* 7ca321c1 Skipping dylib driver in simple_embedding_test as a hack.
* 692deb59 Overriding the default MLIR -> LLVM module name.
* 513f40e8 Speculative removing nowindows tags (#3615). If there's something that still d..
* cc47813a Removing the broken forward declarations entirely from some codegen code. http..
* fbcad44d Removing _GNU_SOURCE copt for wait_handle.
* c9b10a01 Fixing bad type in hal/api.h (been there for ages!).
* e0c532ec Changing iree::InitializeEnvironment to iree_flags_parse. Preparation for #3814.
* 8886ac07 Removing iree_api_init from the API.
* 132d747c Removing ALWAYSLINK support from cmake.
* 0ed81f6b Removing iree/base/initializer.h.
* 0135343c Changing to an explicit driver registration mechanism. This is required for ha..
* bf091d3c Removing ALWAYSLINK support from external_cc_library.
* d4bb871d Changing iree-tblgen to not require alwayslink.
* 6bc6f90c Removing IREE_COMMON_INCLUDE_DIRS and uses by LLVM/MLIR.
* 3c082aab Removing IREE_COMMON_INCLUDE_DIRS mhlo pollution.
* 2395cb99 Removing emitc usage of IREE_COMMON_INCLUDE_DIRS for now.
* 036bd966 TODOs on future library layout changes.
* c3a13e62 Rearranging iree/vm/ to reduce a public + cc target.
* 493b0e2b Rearranging iree/base build rules. By moving the dynamic_library_test out of t..
* 8fd38bf9 Replacing uses of some absl utilities now in C++14.
* 67863190 Removing unused absl/types/variant.h include.
* 99bd1af5 Replace absl::exchange with iree::exchange to reduce absl/utility dep.
* c1d0ee10 Removing unused PLATFORM_VULKAN_DEPS. It may be needed internally but it shoul..
* 15437f4b Simplifying iree/hal/dylib/ build config.
* e6984a5a Simplifying iree/hal/ build config.
* 10062814 Simplifying iree/hal/vulkan/ build config.
* 827e51b0 Simplifying iree/hal/llvmjit/ build config.
* 9a72f5d1 Simplifying iree/hal/metal/ build config.
* c7a7d726 Simplifying iree/hal/vmla/ build config.
* 90faf21f Adding IREE_TARGET_GUI_LINKOPTS to remove custom linkopts use.
* e5774c30 Remove unused args from flatbuffers_c_library macro.
* 22d16b4d Adding iree/base/flatcc.h to make flatcc easier to include.
* e44dee56 Switching from -DVK_NO_PROTOTYPES to iree/hal/vulkan/vulkan_headers.h.
* 48ca2fe6 Removing build-config setting of _GNU_SOURCE.
* eeb7dde0 Goodbye flatbuffers (well, the C++ ones anyway).
* 9c676a86 Removing all build config/utils related to flatbuffers.
* 49c61213 byte->ubyte in flatbuffer defs.
* c99000a8 Replacing compiler use of VM bytecode module def flatbuffers->flatcc.
* 1bf1e8d7 Replacing runtime use of metal executable flatbuffers->flatcc. Maybe it works?..
* 48aafb89 Replacing runtime use of spirv executable flatbuffers->flatcc.
* 011e9a2d Replacing runtime use of llvmjit executable flatbuffers->flatcc.
* a021062f Replacing runtime use of dylib executable flatbuffers->flatcc.
* 53a05d73 Replacing runtime use of VMLA executable flatbuffers->flatc.
* 99d30a99 Replacing compiler use of HAL executable flatbuffers->flatc.
* 6ebd1b0c Removing unused tag field in metal/spirv.
* bc685ed7 Adding flatcc json support and making iree-dump-module use it.
* 94b11c35 Adding include for flatcc to flat_c_libraries.
* 1172cf1f Removing unused iree::schemas::reflection_data.
* c86281af Removing unneeded flatbuffers copts.
* 7f3a7e3a Fixing various type warnings. We don't today have these warnings enabled in ou..
* c17659fc Refining MSVC warning set to the minimum required and documenting.
* b7c92bf4 Cleaning up MSVC warnings and syncing with bazel warnings.
* 94356d3b Removing legacy repo_utils.bzl.
* 0f0d9c82 Prevent bazel-to-cmake from running on iree/base/CMakeLists.txt for now.
* 36225a4d Centralizing -ldl/-lpthread linkopts (as they were in bazel already).
* 31c4dbb9 Documenting iree_copts with a nice big warning.
* e4740a57 Pass android libraries as actual linkopts.
* 85cdd868 Fixing cmake style issues - prefer `if(` not `if (` please.
* bf4069e3 Sorting copts/linkopts so we can override things.
* 28040cd8 Simplifying VMA build integration.
* 479ef30f Replacing use of PROJECT_SOURCE_DIR/PROJECT_BINARY_DIR. Those use the previous..
PiperOrigin-RevId: 344157809
diff --git a/.clang-format b/.clang-format
index 0547576..2899a77 100644
--- a/.clang-format
+++ b/.clang-format
@@ -16,3 +16,13 @@
# mostly LLVM style (for naming/etc) but the Google formatting (because internal
# tooling has... issues).
BasedOnStyle: Google
+
+# Some includes have specific inclusion order requirements. To prevent
+# clang-format from breaking things when bucketing headers into categories we
+# override the behavior here for those well-known headers.
+IncludeCategories:
+ # We override VK_NO_PROTOTYPES so we don't pull in auto-imported symbols and
+ # this must always sort above any header (ours or otherwise) that may
+ # transitively #include <vulkan/vulkan.h>.
+ - Regex: 'iree/hal/vulkan/vulkan_headers.h'
+ Priority: -1
diff --git a/.github/workflows/publish_docs.yml b/.github/workflows/publish_docs.yml
index 05b097d..c2d1819 100644
--- a/.github/workflows/publish_docs.yml
+++ b/.github/workflows/publish_docs.yml
@@ -37,7 +37,7 @@
- name: Initializing submodules
run: ./scripts/git/submodule_versions.py init
- name: Installing Ninja build
- uses: seanmiddleditch/gha-setup-ninja@v1
+ uses: seanmiddleditch/gha-setup-ninja@v3
- name: Building documentation
run: |
./build_tools/cmake/build_docs.sh
@@ -54,8 +54,8 @@
# top directory.
mv -f docs/index.md .
git add docs/ index.md
- echo "::set-env name=has_diff::false"
- git diff --cached --exit-code || echo "::set-env name=has_diff::true"
+ echo "has_diff=false" >> $GITHUB_ENV
+ git diff --cached --exit-code || echo "has_diff=true" >> $GITHUB_ENV
- name: Committing updates
if: env.has_diff == 'true'
run: |
diff --git a/.github/workflows/synchronize_submodules.yml b/.github/workflows/synchronize_submodules.yml
index aba1d6c..8e13878 100644
--- a/.github/workflows/synchronize_submodules.yml
+++ b/.github/workflows/synchronize_submodules.yml
@@ -33,8 +33,8 @@
run: ./scripts/git/submodule_versions.py init
- name: Checking submodule state
run: |
- echo "::set-env name=has_diff::false"
- git diff --cached --exit-code || echo "::set-env name=has_diff::true"
+ echo "has_diff=false" >> $GITHUB_ENV
+ git diff --cached --exit-code || echo "has_diff=true" >> $GITHUB_ENV
- name: Committing updates
if: env.has_diff == 'true'
run: |
diff --git a/.github/workflows/update_llvm_dependent_submodules.yml b/.github/workflows/update_llvm_dependent_submodules.yml
index aa9c3fa..bd0319a 100644
--- a/.github/workflows/update_llvm_dependent_submodules.yml
+++ b/.github/workflows/update_llvm_dependent_submodules.yml
@@ -38,9 +38,9 @@
run: ./scripts/git/update_to_llvm_syncpoint.py
- name: Calculating SHAs
run: |
- echo "::set-env name=LLVM_SHA::$(git submodule status third_party/llvm-project | awk '{print $1}' | cut -c -12)"
- echo "::set-env name=TF_SHA::$(git submodule status third_party/tensorflow | awk '{print $1}' | cut -c -12)"
- echo "::set-env name=LLVM_BAZEL_SHA::$(git submodule status third_party/llvm-bazel | awk '{print $1}' | cut -c -12)"
+ echo "LLVM_SHA=$(git submodule status third_party/llvm-project | awk '{print $1}' | cut -c -12)" >> $GITHUB_ENV
+ echo "TF_SHA=$(git submodule status third_party/tensorflow | awk '{print $1}' | cut -c -12)" >> $GITHUB_ENV
+ echo "LLVM_BAZEL_SHA=$(git submodule status third_party/llvm-bazel | awk '{print $1}' | cut -c -12)" >> $GITHUB_ENV
- name: Creating Pull Request
uses: peter-evans/create-pull-request@v2
with:
diff --git a/.gitmodules b/.gitmodules
index 45076ab..dcedcea 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -10,9 +10,6 @@
[submodule "third_party/tensorflow"]
path = third_party/tensorflow
url = https://github.com/tensorflow/tensorflow.git
-[submodule "third_party/flatbuffers"]
- path = third_party/flatbuffers
- url = https://github.com/google/flatbuffers.git
[submodule "third_party/vulkan_headers"]
path = third_party/vulkan_headers
url = https://github.com/KhronosGroup/Vulkan-Headers.git
diff --git a/BUILD.bazel b/BUILD.bazel
index f6d4abf..f3de1d2 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -22,10 +22,11 @@
# "@bazel_skylib//"
# "@com_google_benchmark//"
# "@com_github_pytorch_cpuinfo//"
-# "@com_github_google_flatbuffers//"
# "@com_github_dvidelabs_flatcc//"
+# "@half//"
# "@com_google_googletest//"
# "@llvm-project//"
+# "@pffft//"
# "@iree_pybind11//"
# "@renderdoc_api//"
# "@com_google_ruy//"
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e37de7a..3903d75 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -64,7 +64,7 @@
set(IREE_ENABLE_MLIR ON CACHE BOOL "Enable LLVM dependencies if the IREE compiler is build." FORCE)
endif()
-if (${IREE_ENABLE_MLIR})
+if(${IREE_ENABLE_MLIR})
set(IREE_MLIR_DEP_MODE "BUNDLED" CACHE STRING "One of BUNDLED (default), DISABLED, INSTALLED")
endif()
@@ -92,9 +92,9 @@
Vulkan
)
-if( IREE_HAL_DRIVERS_TO_BUILD STREQUAL "all" )
- set( IREE_HAL_DRIVERS_TO_BUILD ${IREE_ALL_HAL_DRIVERS} )
- # For cross compilation towords Android, we don't want LLVM JIT HAL driver.
+if(IREE_HAL_DRIVERS_TO_BUILD STREQUAL "all")
+ set(IREE_HAL_DRIVERS_TO_BUILD ${IREE_ALL_HAL_DRIVERS})
+ # For cross compilation towards Android, we don't want LLVM JIT HAL driver.
if(ANDROID)
list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM)
endif()
@@ -121,24 +121,6 @@
set(IREE_HAL_DRIVER_${uppercase_backend} ON CACHE BOOL "" FORCE)
endforeach()
-# Enable HAL driver modules based on options.
-set (IREE_HAL_DRIVER_MODULES "")
-if(${IREE_HAL_DRIVER_DYLIB})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::dylib_driver_module)
-endif()
-if(${IREE_HAL_DRIVER_LLVM})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::llvmjit::llvmjit_driver_module)
-endif()
-if(${IREE_HAL_DRIVER_METAL})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::metal::metal_driver_module)
-endif()
-if(${IREE_HAL_DRIVER_VMLA})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vmla::vmla_driver_module)
-endif()
-if(${IREE_HAL_DRIVER_VULKAN})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vulkan::vulkan_driver_module)
-endif()
-
# List of all target backends to be built by default:
set(IREE_ALL_TARGET_BACKENDS
# TODO(#2645): Add DYLIB-LLVM-AOT when it doesn't require an env var
@@ -231,7 +213,7 @@
option(IREE_ENABLE_TSAN "Enable thread sanitizer" OFF)
option(IREE_ENABLE_CCACHE "Use ccache if installed to speed up rebuilds." OFF)
-if (${IREE_ENABLE_CCACHE})
+if(${IREE_ENABLE_CCACHE})
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}")
@@ -245,7 +227,6 @@
include(iree_macros)
include(iree_copts)
include(sanitizers)
-include(iree_whole_archive_link)
include(iree_cc_binary)
include(iree_cc_library)
include(iree_cc_test)
@@ -259,7 +240,7 @@
include(iree_check_test)
set(DEFAULT_CMAKE_BUILD_TYPE "Release")
-if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
+if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
message(STATUS "No build type selected, default to ${DEFAULT_CMAKE_BUILD_TYPE}")
set(CMAKE_BUILD_TYPE "${DEFAULT_CMAKE_BUILD_TYPE}" CACHE STRING "Build type (default ${DEFAULT_CMAKE_BUILD_TYPE})" FORCE)
endif()
@@ -336,18 +317,29 @@
message(STATUS "Adding bundled LLVM source dependency")
add_iree_mlir_src_dep("third_party/llvm-project")
- # Extend module path to allow submodules to use LLVM and MLIR CMake modules
- list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
- list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/lib/cmake/llvm/")
+ # Extend module path to allow submodules to use LLVM and MLIR CMake modules.
+ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_BINARY_DIR}/lib/cmake/mlir")
+ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/lib/cmake/llvm/")
- # Set include directories
+ # Add the bundled include directories for cmake files looking for them.
list(APPEND LLVM_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/llvm-project/llvm/include
- ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/include
+ ${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/llvm/include
+ ${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/include
)
list(APPEND MLIR_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/llvm-project/mlir/include
- ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include
+ ${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/mlir/include
+ ${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include
+ )
+
+ # Avoid globally modifying paths by instead adding the include paths to the
+ # rules that really should have them in the first place.
+ target_include_directories(LLVMSupport PUBLIC
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/llvm/include>
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/include>
+ )
+ target_include_directories(MLIRSupport PUBLIC
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/mlir/include>
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include>
)
# Set build option to use MHLO alongside with bundled MLIR
@@ -396,23 +388,19 @@
find_package(PythonLibs 3)
endif()
-list(APPEND CMAKE_MODULE_PATH
- ${CMAKE_CURRENT_LIST_DIR}/third_party/flatbuffers/CMake/
-)
-
include(external_cc_library)
include(flatbuffer_c_library)
-include(flatbuffer_cc_library)
add_subdirectory(build_tools/third_party/flatcc EXCLUDE_FROM_ALL)
+add_subdirectory(build_tools/third_party/half EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/pffft EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/renderdoc_api EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/ruy EXCLUDE_FROM_ALL)
+add_subdirectory(build_tools/third_party/vulkan_memory_allocator EXCLUDE_FROM_ALL)
add_subdirectory(third_party/cpuinfo EXCLUDE_FROM_ALL)
add_subdirectory(third_party/googletest EXCLUDE_FROM_ALL)
add_subdirectory(third_party/abseil-cpp EXCLUDE_FROM_ALL)
-add_subdirectory(third_party/flatbuffers EXCLUDE_FROM_ALL)
add_subdirectory(third_party/flatcc EXCLUDE_FROM_ALL)
add_subdirectory(third_party/vulkan_headers EXCLUDE_FROM_ALL)
@@ -427,18 +415,10 @@
iree_get_executable_path(FLATBUFFERS_FLATC_EXECUTABLE flatc)
# Add a custom target to copy the flatc to the binary directory.
- add_custom_target(iree_host_flatc
- COMMAND
- "${CMAKE_COMMAND}" -E copy_if_different
- "${IREE_HOST_BINARY_ROOT}/third_party/flatbuffers/flatc${IREE_HOST_EXECUTABLE_SUFFIX}"
- "${IREE_HOST_BINARY_ROOT}/bin"
- DEPENDS iree_host_build_flatc
- COMMENT "Installing host flatc..."
- )
add_custom_target(iree_host_flatcc_cli
COMMAND
"${CMAKE_COMMAND}" -E copy_if_different
- "${PROJECT_SOURCE_DIR}/third_party/flatcc/bin/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
+ "${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatcc/bin/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
"${IREE_HOST_BINARY_ROOT}/bin/flatcc_cli${IREE_HOST_EXECUTABLE_SUFFIX}"
DEPENDS iree_host_build_flatcc_cli
COMMENT "Installing host flatcc..."
@@ -537,8 +517,6 @@
add_subdirectory(experimental)
endif()
-# Note: this must be called after all libraries have been declared.
-iree_complete_binary_link_options()
if(${IREE_BUILD_PYTHON_BINDINGS})
iree_complete_py_extension_link_options()
endif()
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index fc186cc..05f7426 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -2,11 +2,10 @@
daff5fead3fbe22c6fc58310ca3f49caf117f185 third_party/benchmark
63b254577ed77a8004a9be6ac707f3dccc4e1fd9 third_party/cpuinfo
4c13807b7d43ff0946b7ffea0ae3aee9e611d778 third_party/dear_imgui
-a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-2ad01e4b485d8753600766d967d1a7358b98ddd3 third_party/llvm-bazel
-c8d73d939fa4fda9c87b3979225d02d63062bd68 third_party/llvm-project
+fe17a7eff316d5846742cec8ced48bb6c49831db third_party/llvm-bazel
+5ce85e66358a69e786093756c77fae2e140947c1 third_party/llvm-project
55801f03f9cc69abfcf8b508a873f702c11b3b5f third_party/mlir-emitc
74d7261be17cf659d5930d4830609406bd7553e3 third_party/pffft
d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
@@ -14,7 +13,7 @@
a1390ed39ec77ecfb574bc6fcd5bfc5e3adbdea9 third_party/sdl2
685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
-3ea78d0fdd5402809e12d067b81dcd2a43cc8a45 third_party/tensorflow
-d7059eca6351546d1f51e248fc75e49dfeee709e third_party/tracy
+218c3a2712bc72d2239299a5eef4ff3c156004ed third_party/tensorflow
+d8cb536712e876ba956f27f23dbede1c2eccad28 third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator
diff --git a/WORKSPACE b/WORKSPACE
index 5c0ead1..75505a5 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -18,7 +18,7 @@
workspace(name = "iree_core")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
-load(":repo_utils.bzl", "maybe")
+load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
###############################################################################
# Bazel rules.
@@ -63,7 +63,7 @@
rbe_autoconfig(
name = "rbe_default",
base_container_digest = "sha256:1a8ed713f40267bb51fe17de012fa631a20c52df818ccb317aaed2ee068dfc61",
- digest = "sha256:8b7809d630286183a2119a83e03d52c93de552ada79534aa23c5b95283e66134",
+ digest = "sha256:d6d895294076b5289e81489f664656211c41656cffe7c448ecb5c6f54f045974",
registry = "gcr.io",
repository = "iree-oss/rbe-toolchain",
use_checked_in_confs = "Force",
@@ -196,14 +196,6 @@
path = "third_party/googletest",
)
-# Note that TensorFlow provides this as "flatbuffers" which is wrong.
-# It is only used for TFLite and may cause ODR issues if not fixed.
-maybe(
- local_repository,
- name = "com_github_google_flatbuffers",
- path = "third_party/flatbuffers",
-)
-
maybe(
new_local_repository,
name = "com_github_dvidelabs_flatcc",
@@ -293,6 +285,13 @@
path = "third_party/pffft",
)
+maybe(
+ new_local_repository,
+ name = "half",
+ build_file = "build_tools/third_party/half/BUILD.overlay",
+ path = "third_party/half",
+)
+
GOOGLE_RESEARCH_COMMIT = "a5213e2c92c3e87849fe417ba42786d0324e7c75"
http_archive(
diff --git a/bindings/java/com/google/iree/native/CMakeLists.txt b/bindings/java/com/google/iree/native/CMakeLists.txt
index e749258..11d1607 100644
--- a/bindings/java/com/google/iree/native/CMakeLists.txt
+++ b/bindings/java/com/google/iree/native/CMakeLists.txt
@@ -16,10 +16,10 @@
NAME
cc_wrappers
SRCS
- "context_wrapper.cc"
- "function_wrapper.cc"
- "instance_wrapper.cc"
- "module_wrapper.cc"
+ "context_wrapper.cc"
+ "function_wrapper.cc"
+ "instance_wrapper.cc"
+ "module_wrapper.cc"
HDRS
"context_wrapper.h"
"function_wrapper.h"
@@ -27,16 +27,15 @@
"module_wrapper.h"
DEPS
iree::base::api
- iree::base::init
+ iree::base::flags
iree::base::logging
iree::base::status
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::drivers
iree::modules::hal
iree::modules::strings::strings_module
iree::modules::tensorlist::native_module
+ iree::vm
+ iree::vm::cc
iree::vm::bytecode_module
- iree::vm::context
- iree::vm::instance
- iree::vm::ref_cc
)
diff --git a/bindings/java/com/google/iree/native/context_wrapper.h b/bindings/java/com/google/iree/native/context_wrapper.h
index 180a259..40efe28 100644
--- a/bindings/java/com/google/iree/native/context_wrapper.h
+++ b/bindings/java/com/google/iree/native/context_wrapper.h
@@ -23,7 +23,7 @@
#include "iree/base/status.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/hal_module.h"
-#include "iree/vm/context.h"
+#include "iree/vm/api.h"
namespace iree {
namespace java {
diff --git a/bindings/java/com/google/iree/native/function_wrapper.h b/bindings/java/com/google/iree/native/function_wrapper.h
index a8dd7dd..dbce860 100644
--- a/bindings/java/com/google/iree/native/function_wrapper.h
+++ b/bindings/java/com/google/iree/native/function_wrapper.h
@@ -17,7 +17,7 @@
#include <memory>
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace java {
diff --git a/bindings/java/com/google/iree/native/instance_wrapper.cc b/bindings/java/com/google/iree/native/instance_wrapper.cc
index cdd9fb0..e9679e6 100644
--- a/bindings/java/com/google/iree/native/instance_wrapper.cc
+++ b/bindings/java/com/google/iree/native/instance_wrapper.cc
@@ -16,7 +16,8 @@
#include <mutex>
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/modules/strings/strings_module.h"
#include "iree/modules/tensorlist/native_module.h"
@@ -27,13 +28,16 @@
namespace {
void SetupVm() {
- // TODO(jennik): Pass flags through from java.
+ // TODO(jennik): Pass flags through from java and us iree_flags_parse.
+ // This checked version will abort()/exit() and that's... not great.
char binname[] = "libiree.so";
char* argv[] = {binname};
char** aargv = argv;
int argc = 1;
- InitializeEnvironment(&argc, &aargv);
+ iree_flags_parse_checked(&argc, &aargv);
+ // TODO(jennik): register all available drivers
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
IREE_CHECK_OK(iree_vm_register_builtin_types());
IREE_CHECK_OK(iree_hal_module_register_types());
IREE_CHECK_OK(iree_tensorlist_module_register_types());
diff --git a/bindings/java/com/google/iree/native/instance_wrapper.h b/bindings/java/com/google/iree/native/instance_wrapper.h
index e97bfc5..66acc0f 100644
--- a/bindings/java/com/google/iree/native/instance_wrapper.h
+++ b/bindings/java/com/google/iree/native/instance_wrapper.h
@@ -16,7 +16,7 @@
#define IREE_BINDINGS_JAVA_COM_GOOGLE_IREE_NATIVE_INSTANCE_WRAPPER_H_
#include "iree/base/status.h"
-#include "iree/vm/instance.h"
+#include "iree/vm/api.h"
namespace iree {
namespace java {
diff --git a/bindings/javatests/com/google/iree/integration_test.cc b/bindings/javatests/com/google/iree/integration_test.cc
index 5c31c34..661878d 100644
--- a/bindings/javatests/com/google/iree/integration_test.cc
+++ b/bindings/javatests/com/google/iree/integration_test.cc
@@ -19,7 +19,7 @@
#include "bindings/java/com/google/iree/native/instance_wrapper.h"
#include "bindings/java/com/google/iree/native/module_wrapper.h"
#include "bindings/javatests/com/google/iree/simple_mul_bytecode_module.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
namespace iree {
namespace java {
diff --git a/bindings/python/build_defs.oss.bzl b/bindings/python/build_defs.oss.bzl
index a01a244..34d87b9 100644
--- a/bindings/python/build_defs.oss.bzl
+++ b/bindings/python/build_defs.oss.bzl
@@ -17,10 +17,8 @@
load("@iree_native_python//:build_defs.bzl", "py_extension")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test")
-load("//iree:build_defs.oss.bzl", _PLATFORM_VULKAN_DEPS = "PLATFORM_VULKAN_DEPS")
NUMPY_DEPS = []
-PLATFORM_VULKAN_DEPS = _PLATFORM_VULKAN_DEPS
PYTHON_HEADERS_DEPS = ["@iree_native_python//:python_headers"]
PYTHON_CPP_EXTRA_DEPS = []
diff --git a/bindings/python/pyiree/rt/BUILD b/bindings/python/pyiree/rt/BUILD
index 4c9d89c..6e64af5 100644
--- a/bindings/python/pyiree/rt/BUILD
+++ b/bindings/python/pyiree/rt/BUILD
@@ -15,7 +15,6 @@
load(
"//bindings/python:build_defs.oss.bzl",
"NUMPY_DEPS",
- "PLATFORM_VULKAN_DEPS",
"PYBIND_COPTS",
"PYBIND_EXTENSION_COPTS",
"PYBIND_FEATURES",
@@ -31,12 +30,6 @@
licenses = ["notice"], # Apache 2.0
)
-DRIVER_DEPS = PLATFORM_VULKAN_DEPS + [
- "//iree/hal/vulkan:vulkan_driver_module",
- "//iree/hal/llvmjit:llvmjit_driver_module",
- "//iree/hal/vmla:vmla_driver_module",
-]
-
iree_py_library(
name = "rt",
srcs = [
@@ -59,10 +52,10 @@
features = PYBIND_FEATURES,
linkstatic = 1,
win_def_file = "export.def",
- deps = DRIVER_DEPS + [
+ deps = [
":rt_library",
"//bindings/python/pyiree/common",
- "//iree/base:initializer",
+ "//iree/hal/drivers",
],
)
@@ -90,10 +83,6 @@
"//iree/modules/tensorlist:native_module",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:invocation",
- "//iree/vm:list",
- "//iree/vm:module",
- "//iree/vm:ref",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
diff --git a/bindings/python/pyiree/rt/CMakeLists.txt b/bindings/python/pyiree/rt/CMakeLists.txt
index 225dccd..7828073 100644
--- a/bindings/python/pyiree/rt/CMakeLists.txt
+++ b/bindings/python/pyiree/rt/CMakeLists.txt
@@ -35,8 +35,6 @@
iree::modules::tensorlist::native_module
iree::vm
iree::vm::bytecode_module
- iree::vm::invocation
- iree::vm::ref
absl::inlined_vector
absl::memory
absl::strings
@@ -56,10 +54,7 @@
::PyExtRtLib
bindings::python::pyiree::common::PyextCommonLib
DEPS
- iree::hal::vulkan::vulkan_driver_module
- iree::hal::llvmjit::llvmjit_driver_module
- iree::hal::vmla::vmla_driver_module
- iree::base::initializer
+ iree::hal::drivers
)
iree_py_library(
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index e92adf1..3cc046c 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -24,8 +24,7 @@
#include "iree/base/signature_mangle.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/hal_module.h"
-#include "iree/vm/list.h"
-#include "iree/vm/ref.h"
+#include "iree/vm/api.h"
namespace iree {
namespace python {
diff --git a/bindings/python/pyiree/rt/initialize_module.cc b/bindings/python/pyiree/rt/initialize_module.cc
index e4e3824..840278a 100644
--- a/bindings/python/pyiree/rt/initialize_module.cc
+++ b/bindings/python/pyiree/rt/initialize_module.cc
@@ -18,13 +18,13 @@
#include "bindings/python/pyiree/rt/hal.h"
#include "bindings/python/pyiree/rt/host_types.h"
#include "bindings/python/pyiree/rt/vm.h"
-#include "iree/base/initializer.h"
+#include "iree/hal/drivers/init.h"
namespace iree {
namespace python {
PYBIND11_MODULE(binding, m) {
- IREE_RUN_MODULE_INITIALIZERS();
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
m.doc() = "IREE Binding Backend Helpers";
SetupFunctionAbiBindings(m);
diff --git a/bindings/python/pyiree/rt/vm.cc b/bindings/python/pyiree/rt/vm.cc
index f7e0031..ce1a271 100644
--- a/bindings/python/pyiree/rt/vm.cc
+++ b/bindings/python/pyiree/rt/vm.cc
@@ -24,8 +24,7 @@
#include "iree/modules/hal/hal_module.h"
#include "iree/modules/strings/strings_module.h"
#include "iree/modules/tensorlist/native_module.h"
-#include "iree/vm/invocation.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace python {
diff --git a/bindings/python/pyiree/rt/vm.h b/bindings/python/pyiree/rt/vm.h
index e65560c..34302bc 100644
--- a/bindings/python/pyiree/rt/vm.h
+++ b/bindings/python/pyiree/rt/vm.h
@@ -21,7 +21,6 @@
#include "iree/base/api.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
-#include "iree/vm/list.h"
namespace iree {
namespace python {
diff --git a/build_tools/BUILD b/build_tools/BUILD
new file mode 100644
index 0000000..cdd4ae3
--- /dev/null
+++ b/build_tools/BUILD
@@ -0,0 +1,33 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# Helper for getting linkopts into libraries in a way we can easily spot and
+# rewrite from bazel-to-cmake and copybara.
+cc_library(
+ name = "default_linkopts",
+ linkopts = select({
+ "//iree:iree_is_msvc": [],
+ "//conditions:default": [
+ # Just include libraries that should be presumed in 2020.
+ "-ldl",
+ "-lpthread",
+ ],
+ }),
+)
diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc
index 7fc9945..58d2b91 100644
--- a/build_tools/bazel/iree.bazelrc
+++ b/build_tools/bazel/iree.bazelrc
@@ -145,7 +145,7 @@
build:generic_clang --copt=-Wthread-safety
build:generic_clang --copt=-Wthread-safety-beta
build:generic_clang --copt=-Wunused-comparison
-build:generic_clang --copt=-Wunused-variable
+build:generic_clang --copt=-Wno-unused-variable
build:generic_clang --copt=-Wvla
# LINT.ThenChange(https://github.com/google/iree/tree/main/build_tools/cmake/iree_copts.cmake:clang_diagnostics)
@@ -274,15 +274,29 @@
# absl forces /W3 in their copts, so we exclude them to avoid D9025
build:_msvc_base --per_file_copt=+external,-com_google_absl@/w
+# Find the source of truth for these in iree_copts.cmake.
+build:_msvc_base --copt=/DWIN32_LEAN_AND_MEAN
+build:_msvc_base --copt=/DNOMINMAX
+build:_msvc_base --copt=/D_USE_MATH_DEFINES
+build:_msvc_base --copt=/D_CRT_SECURE_NO_WARNINGS
+build:_msvc_base --copt=/D_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES
+build:_msvc_base --copt=/EHsc
+build:_msvc_base --copt=/bigobj
+
+build:_msvc_base --copt=/wd4200
+build:_msvc_base --copt=/wd4018
+build:_msvc_base --copt=/wd4146 # operator applied to unsigned type, result still unsigned
+build:_msvc_base --copt=/wd4244 # possible loss of data
+build:_msvc_base --copt=/wd4267 # initializing: possible loss of data
+build:_msvc_base --copt=/wd4005 # allow: macro redefinition
+build:_msvc_base --copt=/wd4065 # allow: switch statement contains 'default' but no 'case' labels
+build:_msvc_base --copt=/wd4141 # allow: inline used more than once
+build:_msvc_base --copt=/wd4624 # allow: destructor was implicitly defined as deleted
+build:_msvc_base --copt=/wd5105 # macro expansion producing 'defined' has undefined behavior
+
# And some more explicit disables. For some reason the `/w` on external doesn't
# work for these, maybe they come from headers?
-build:_msvc_base --copt=/wd4244 # possible loss of data
-build:_msvc_base --copt=/wd4624 # destructor was implicitly defined as deleted
-build:_msvc_base --copt=/wd4005 # macro redefinition
-build:_msvc_base --copt=/wd4267 # initializing: possible loss of data
-build:_msvc_base --copt=/wd4141 # inline used more than once
# new warning with the standards-compliant preprocessor. winbase itself is not standards-compliant
-build:_msvc_base --copt=/wd5105
build:_msvc_base --per_file_copt=mkl_dnn@/wd4551 # missing argument list
build:_msvc_base --per_file_copt=mkl_dnn@/wd4068 # unknown pragma
build:_msvc_base --per_file_copt=farmhash@/wd4319 # zero extending to T of greater size
@@ -295,16 +309,8 @@
build:_msvc_base --copt=/arch:AVX
# Host and target are the same in windows so don't waste time building both.
build:_msvc_base --distinct_host_configuration=false
-# Avoids incompatible versions of winsock and other badness.
-build:_msvc_base --copt=/DWIN32_LEAN_AND_MEAN
-# Why are min/max macros? No one knows.
-build:_msvc_base --copt=/DNOMINMAX
-# Yay for security warnings. Boo for non-standard.
-build:_msvc_base --copt=/D_CRT_SECURE_NO_WARNINGS
# TensorFlow requires the "monolithic" build mode for now on Windows.
build:_msvc_base --define framework_shared_object=false
-# Necessary for M_* math constants.
-build:_msvc_base --copt=/D_USE_MATH_DEFINES
# Workaround WinGDI.h defining `ERROR`, which conflicts with logging macros.
# Note that IREE and TensorFlow both `#undef ERROR` and define their own
diff --git a/build_tools/bazel/iree_flatcc.bzl b/build_tools/bazel/iree_flatcc.bzl
index e091af6..6d35299 100644
--- a/build_tools/bazel/iree_flatcc.bzl
+++ b/build_tools/bazel/iree_flatcc.bzl
@@ -37,6 +37,9 @@
outs += ["%s_builder.h" % (out_stem)]
if arg == "--verifier":
outs += ["%s_verifier.h" % (out_stem)]
+ if arg == "--json":
+ outs += ["%s_json_parser.h" % (out_stem)]
+ outs += ["%s_json_printer.h" % (out_stem)]
native.genrule(
name = name + "_gen",
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
index 6f477d1..f6dca0b 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
@@ -37,16 +37,6 @@
def __init__(self, converter):
self.converter = converter
- # TODO(gcmn): Do this in a less hard-coded way
- self.PLATFORM_VULKAN_DEPS = []
- self.PLATFORM_VULKAN_TEST_DEPS = ["//iree/testing:gtest_main"]
- self.FLATBUFFER_SUPPORTS_REFLECTIONS = False
- self.PLATFORM_VULKAN_LOADER_COPTS = []
- self.IREE_DRIVER_MODULES = [
- "//iree/hal/vmla:vmla_driver_module",
- "//iree/hal/vulkan:vulkan_driver_module",
- "//iree/hal/llvmjit:llvmjit_driver_module",
- ]
# ------------------------------------------------------------------------- #
# Conversion utilities, written to reduce boilerplate and allow for reuse #
@@ -109,9 +99,6 @@
else:
return ""
- def _convert_alwayslink_block(self, alwayslink):
- return self._convert_option_block("ALWAYSLINK", alwayslink)
-
def _convert_testonly_block(self, testonly):
return self._convert_option_block("TESTONLY", testonly)
@@ -348,7 +335,7 @@
srcs=None,
data=None,
deps=None,
- alwayslink=False,
+ defines=None,
testonly=False,
linkopts=None,
**kwargs):
@@ -360,7 +347,7 @@
srcs_block = self._convert_srcs_block(srcs)
data_block = self._convert_data_block(data)
deps_block = self._convert_deps_block(deps)
- alwayslink_block = self._convert_alwayslink_block(alwayslink)
+ defines_block = self._convert_string_list_block("DEFINES", defines)
testonly_block = self._convert_testonly_block(testonly)
self.converter.body += (f"iree_cc_library(\n"
@@ -370,7 +357,7 @@
f"{srcs_block}"
f"{data_block}"
f"{deps_block}"
- f"{alwayslink_block}"
+ f"{defines_block}"
f"{testonly_block}"
f" PUBLIC\n)\n\n")
@@ -499,17 +486,6 @@
f"{flatcc_args_block}"
f" PUBLIC\n)\n\n")
- def iree_flatbuffer_cc_library(self, name, srcs, flatc_args=None):
- name_block = self._convert_name_block(name)
- srcs_block = self._convert_srcs_block(srcs)
- flatc_args_block = self._convert_flatc_args_block(flatc_args)
-
- self.converter.body += (f"flatbuffer_cc_library(\n"
- f"{name_block}"
- f"{srcs_block}"
- f"{flatc_args_block}"
- f" PUBLIC\n)\n\n")
-
def gentbl(self,
name,
tblgen,
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 5a23bc1..f20ae3d 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -16,6 +16,9 @@
# Bazel to CMake target name conversions used by bazel_to_cmake.py.
EXPLICIT_TARGET_MAPPING = {
+ # Internal utilities to emulate various binary/library options.
+ "//build_tools:default_linkopts": [],
+
# absl
"@com_google_absl//absl/flags:flag": ["absl::flags"],
"@com_google_absl//absl/flags:parse": ["absl::flags_parse"],
@@ -57,22 +60,20 @@
"@llvm-project//mlir:MlirOptLib": ["MLIROptLib"],
"@llvm-project//mlir:VectorOps": ["MLIRVector"],
# Vulkan
- # TODO(scotttodd): Set -DVK_NO_PROTOTYPES to COPTS for _no_prototypes.
- # Maybe add a wrapper CMake lib within build_tools/third_party/?
"@iree_vulkan_headers//:vulkan_headers": ["Vulkan::Headers"],
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes": ["Vulkan::Headers"],
# The Bazel target maps to the IMPORTED target defined by FindVulkan().
"@vulkan_sdk//:sdk": ["Vulkan::Vulkan"],
# Misc single targets
"@com_google_benchmark//:benchmark": ["benchmark"],
- "@com_github_google_flatbuffers//:flatbuffers": ["flatbuffers"],
"@com_github_dvidelabs_flatcc//:flatcc": ["flatcc"],
"@com_github_dvidelabs_flatcc//:runtime": ["flatcc::runtime"],
"@com_google_googletest//:gtest": ["gmock", "gtest"],
"@renderdoc_api//:renderdoc_app": ["renderdoc_api::renderdoc_app"],
- "@pffft//:pffft": ["pffft"],
+ "@pffft": ["pffft"],
"@sdl2//:SDL2": ["SDL2-static"],
"@com_github_pytorch_cpuinfo//:cpuinfo": ["cpuinfo"],
+ "@half//:half": ["half"],
+ "@vulkan_memory_allocator//:impl_header_only": ["vulkan_memory_allocator"],
}
diff --git a/build_tools/cmake/external_cc_library.cmake b/build_tools/cmake/external_cc_library.cmake
index dd9799b..d9cd38e 100644
--- a/build_tools/cmake/external_cc_library.cmake
+++ b/build_tools/cmake/external_cc_library.cmake
@@ -32,7 +32,6 @@
# DEFINES: List of public defines
# INCLUDES: Include directories to add to dependencies
# LINKOPTS: List of link options
-# ALWAYSLINK: Always link the library into any binary with a direct dep.
# PUBLIC: Add this so that this library will be exported under ${PACKAGE}::
# Also in IDE, target will appear in ${PACKAGE} folder while non PUBLIC will be
# in ${PACKAGE}/internal.
@@ -79,7 +78,7 @@
# )
function(external_cc_library)
cmake_parse_arguments(_RULE
- "PUBLIC;ALWAYSLINK;TESTONLY"
+ "PUBLIC;TESTONLY"
"PACKAGE;NAME;ROOT"
"HDRS;SRCS;COPTS;DEFINES;LINKOPTS;DATA;DEPS;INCLUDES"
${ARGN}
@@ -122,7 +121,8 @@
)
target_include_directories(${_NAME} SYSTEM
PUBLIC
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
"$<BUILD_INTERFACE:${_RULE_INCLUDES}>"
)
target_compile_options(${_NAME}
@@ -134,8 +134,8 @@
PUBLIC
${_RULE_DEPS}
PRIVATE
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
)
target_compile_definitions(${_NAME}
PUBLIC
@@ -143,10 +143,6 @@
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
- if(DEFINED _RULE_ALWAYSLINK)
- set_property(TARGET ${_NAME} PROPERTY ALWAYSLINK 1)
- endif()
-
# Add all external targets to a a folder in the IDE for organization.
if(_RULE_PUBLIC)
set_property(TARGET ${_NAME} PROPERTY FOLDER third_party)
@@ -164,19 +160,20 @@
add_library(${_NAME} INTERFACE)
target_include_directories(${_NAME} SYSTEM
INTERFACE
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
"$<BUILD_INTERFACE:${_RULE_INCLUDES}>"
)
target_compile_options(${_NAME}
INTERFACE
- ${_RULE_COPTS}
${IREE_DEFAULT_COPTS}
+ ${_RULE_COPTS}
)
target_link_libraries(${_NAME}
INTERFACE
- ${_RULE_DEPS}
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
+ ${_RULE_DEPS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
target_compile_definitions(${_NAME}
diff --git a/build_tools/cmake/flatbuffer_c_library.cmake b/build_tools/cmake/flatbuffer_c_library.cmake
index 24ee470..23ade0c 100644
--- a/build_tools/cmake/flatbuffer_c_library.cmake
+++ b/build_tools/cmake/flatbuffer_c_library.cmake
@@ -21,10 +21,6 @@
# Parameters:
# NAME: name of target (see Note)
# SRCS: List of source files for the library
-# DEPS: List of other libraries to be linked in to the binary targets
-# COPTS: List of private compile options
-# DEFINES: List of public defines
-# LINKOPTS: List of link options
# FLATCC_ARGS: List of flattbuffers arguments. Default:
# "--common"
# "--reader"
@@ -39,32 +35,28 @@
#
# flatbuffer_c_library(
# NAME
-# base_schema
+# some_def
# SRCS
-# "a.fbs"
-# )
-# flatbuffer_c_library(
-# NAME
-# other_schemas
-# SRCS
-# "b.fbs"
-# DEPS
-# iree::schemas::base_schema
+# "some_def.fbs"
+# FLATCC_ARGS
+# "--reader"
+# "--builder"
+# "--verifier"
+# "--json"
# PUBLIC
# )
-#
# iree_cc_binary(
# NAME
# main_lib
# ...
# DEPS
-# iree::schemas::other_schemas
+# iree::schemas::some_def
# )
function(flatbuffer_c_library)
cmake_parse_arguments(_RULE
"PUBLIC;TESTONLY"
"NAME"
- "SRCS;COPTS;DEFINES;LINKOPTS;DEPS;FLATCC_ARGS"
+ "SRCS;FLATCC_ARGS"
${ARGN}
)
@@ -95,6 +87,9 @@
list(APPEND _OUTS "${_SRC_FILENAME}_builder.h")
elseif(_ARG STREQUAL "--verifier")
list(APPEND _OUTS "${_SRC_FILENAME}_verifier.h")
+ elseif(_ARG STREQUAL "--json")
+ list(APPEND _OUTS "${_SRC_FILENAME}_json_printer.h")
+ list(APPEND _OUTS "${_SRC_FILENAME}_json_parser.h")
endif()
endforeach()
endforeach()
@@ -125,28 +120,24 @@
${_GEN_TARGET}
DEPENDS
${_OUTS}
- ${_RULE_DEPS}
)
add_library(${_NAME} INTERFACE)
add_dependencies(${_NAME} ${_GEN_TARGET})
target_include_directories(${_NAME} SYSTEM
INTERFACE
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
${CMAKE_CURRENT_BINARY_DIR}
)
target_link_libraries(${_NAME}
INTERFACE
flatcc::runtime
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
)
- target_compile_definitions(${_NAME}
- INTERFACE
- ${_RULE_DEFINES}
- )
target_compile_options(${_NAME}
INTERFACE
+ "-I${IREE_ROOT_DIR}/third_party/flatcc/include/"
"-I${IREE_ROOT_DIR}/third_party/flatcc/include/flatcc/reflection/"
)
diff --git a/build_tools/cmake/flatbuffer_cc_library.cmake b/build_tools/cmake/flatbuffer_cc_library.cmake
deleted file mode 100644
index 52618de..0000000
--- a/build_tools/cmake/flatbuffer_cc_library.cmake
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-include(BuildFlatBuffers)
-include(CMakeParseArguments)
-
-# flatbuffer_cc_library()
-#
-# CMake function to imitate Bazel's flatbuffer_cc_library rule.
-#
-# Parameters:
-# NAME: name of target (see Note)
-# SRCS: List of source files for the library
-# DEPS: List of other libraries to be linked in to the binary targets
-# COPTS: List of private compile options
-# DEFINES: List of public defines
-# LINKOPTS: List of link options
-# FLATC_ARGS: List of flattbuffers arguments. Default:
-# "--keep-prefix"
-# "--scoped-enums"
-# "--reflect-names"
-# "--gen-object-api"
-# PUBLIC: Add this so that this library will be exported under iree::
-# Also in IDE, target will appear in IREE folder while non PUBLIC will be in IREE/internal.
-# TESTONLY: When added, this target will only be built if user passes -DIREE_BUILD_TESTS=ON to CMake.
-#
-# Note:
-# By default, flatbuffer_cc_library will always create a library named ${NAME},
-# and alias target iree::${NAME}. The iree:: form should always be used.
-# This is to reduce namespace pollution.
-#
-# flatbuffer_cc_library(
-# NAME
-# base_schema
-# SRCS
-# "a.fbs"
-# )
-# flatbuffer_cc_library(
-# NAME
-# other_schemas
-# SRCS
-# "b.fbs"
-# DEPS
-# iree::schemas::base_schema
-# PUBLIC
-# )
-#
-# iree_cc_binary(
-# NAME
-# main_lib
-# ...
-# DEPS
-# iree::schemas::other_schemas
-# )
-function(flatbuffer_cc_library)
- cmake_parse_arguments(_RULE
- "PUBLIC;TESTONLY"
- "NAME"
- "SRCS;COPTS;DEFINES;LINKOPTS;DEPS;FLATC_ARGS"
- ${ARGN}
- )
-
- if(_RULE_TESTONLY AND NOT IREE_BUILD_TESTS)
- return()
- endif()
-
- # Prefix the library with the package name, so we get: iree_package_name
- iree_package_name(_PACKAGE_NAME)
- set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
-
- if(NOT DEFINED _RULE_FLATC_ARGS)
- set(FLATBUFFERS_FLATC_SCHEMA_EXTRA_ARGS
- # Preserve root-relative include paths in generated code.
- "--keep-prefix"
- # Use C++11 'enum class' for enums.
- "--scoped-enums"
- # Include reflection tables used for dumping debug representations.
- "--reflect-names"
- # Generate FooT types for unpack/pack support. Note that this should only
- # be used in tooling as the code size/runtime overhead is non-trivial.
- "--gen-object-api"
- )
- else()
- set(FLATBUFFERS_FLATC_SCHEMA_EXTRA_ARGS ${_RULE_FLATC_ARGS})
- endif()
-
- set(_GEN_TARGET "${_NAME}_gen")
-
- build_flatbuffers(
- "${_RULE_SRCS}"
- "${IREE_ROOT_DIR}"
- "${_GEN_TARGET}" # custom_target_name
- "${_RULE_DEPS}" # additional_dependencies
- "${CMAKE_CURRENT_BINARY_DIR}" # generated_include_dir
- "${CMAKE_CURRENT_BINARY_DIR}" # binary_schemas_dir
- "" # copy_text_schemas_dir
- )
-
- # Add dependency on flatc explicitly. This is needed for cross-compiling
- # where flatc comes from another CMake invocation for host.
- iree_add_executable_dependencies(${_GEN_TARGET} flatc)
-
- add_library(${_NAME} INTERFACE)
- add_dependencies(${_NAME} ${_GEN_TARGET})
- target_include_directories(${_NAME} SYSTEM
- INTERFACE
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
- ${CMAKE_CURRENT_BINARY_DIR}
- )
- target_link_libraries(${_NAME}
- INTERFACE
- flatbuffers
- ${_RULE_LINKOPTS}
- ${IREE_DEFAULT_LINKOPTS}
- )
- target_compile_definitions(${_NAME}
- INTERFACE
- ${_RULE_DEFINES}
- )
-
- # Alias the iree_package_name library to iree::package::name.
- # This lets us more clearly map to Bazel and makes it possible to
- # disambiguate the underscores in paths vs. the separators.
- iree_package_ns(_PACKAGE_NS)
- add_library(${_PACKAGE_NS}::${_RULE_NAME} ALIAS ${_NAME})
-endfunction()
diff --git a/build_tools/cmake/iree_cc_binary.cmake b/build_tools/cmake/iree_cc_binary.cmake
index 5e1bd0a..168baa8 100644
--- a/build_tools/cmake/iree_cc_binary.cmake
+++ b/build_tools/cmake/iree_cc_binary.cmake
@@ -14,10 +14,6 @@
include(CMakeParseArguments)
-if (NOT DEFINED _IREE_CC_BINARY_NAMES)
- set(_IREE_CC_BINARY_NAMES "")
-endif()
-
# iree_cc_binary()
#
# CMake function to imitate Bazel's cc_binary rule.
@@ -112,7 +108,8 @@
endif()
target_include_directories(${_NAME} SYSTEM
PUBLIC
- ${IREE_COMMON_INCLUDE_DIRS}
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
)
target_compile_definitions(${_NAME}
PUBLIC
@@ -120,45 +117,33 @@
)
target_compile_options(${_NAME}
PRIVATE
+ ${IREE_DEFAULT_COPTS}
${_RULE_COPTS}
)
target_link_options(${_NAME}
PRIVATE
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
+ )
+
+ # Replace dependencies passed by ::name with iree::package::name
+ iree_package_ns(_PACKAGE_NS)
+ list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
+
+ target_link_libraries(${_NAME}
+ PUBLIC
+ ${_RULE_DEPS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
- iree_package_ns(_PACKAGE_NS)
- # Replace dependencies passed by ::name with ::iree::package::name
- list(TRANSFORM _RULE_DATA REPLACE "^::" "${_PACKAGE_NS}::")
- list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
-
# Add all IREE targets to a folder in the IDE for organization.
set_property(TARGET ${_NAME} PROPERTY FOLDER ${IREE_IDE_FOLDER}/binaries)
set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD ${IREE_CXX_STANDARD})
set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD_REQUIRED ON)
- # Defer computing transitive dependencies and calling target_link_libraries()
- # until all libraries have been declared.
- # Track target and deps, use in iree_complete_binary_link_options() later.
- set_property(GLOBAL APPEND PROPERTY _IREE_CC_BINARY_NAMES "${_NAME}")
- set_property(TARGET ${_NAME} PROPERTY DIRECT_DEPS ${_RULE_DEPS})
-
install(TARGETS ${_NAME}
RENAME ${_RULE_NAME}
COMPONENT ${_RULE_NAME}
RUNTIME DESTINATION bin)
endfunction()
-
-# Sets target_link_libraries() on all registered binaries.
-# This must be called after all libraries have been declared.
-function(iree_complete_binary_link_options)
- get_property(_NAMES GLOBAL PROPERTY _IREE_CC_BINARY_NAMES)
-
- foreach(_NAME ${_NAMES})
- get_target_property(_DIRECT_DEPS ${_NAME} DIRECT_DEPS)
- iree_whole_archive_link(${_NAME} ${_DIRECT_DEPS})
- endforeach(_NAME)
-endfunction()
diff --git a/build_tools/cmake/iree_cc_library.cmake b/build_tools/cmake/iree_cc_library.cmake
index e4d5b67..fc7cb5d 100644
--- a/build_tools/cmake/iree_cc_library.cmake
+++ b/build_tools/cmake/iree_cc_library.cmake
@@ -29,12 +29,10 @@
# DEFINES: List of public defines
# INCLUDES: Include directories to add to dependencies
# LINKOPTS: List of link options
-# ALWAYSLINK: Always link the library into any binary with a direct dep.
# PUBLIC: Add this so that this library will be exported under iree::
# Also in IDE, target will appear in IREE folder while non PUBLIC will be in IREE/internal.
# TESTONLY: When added, this target will only be built if user passes -DIREE_BUILD_TESTS=ON to CMake.
# SHARED: If set, will compile to a shared object.
-# WHOLEARCHIVE: If set, links all symbols from "ALWAYSLINK" libraries.
#
# Note:
# By default, iree_cc_library will always create a library named iree_${NAME},
@@ -69,7 +67,7 @@
function(iree_cc_library)
cmake_parse_arguments(
_RULE
- "PUBLIC;ALWAYSLINK;TESTONLY;SHARED;WHOLEARCHIVE"
+ "PUBLIC;TESTONLY;SHARED;WHOLEARCHIVE"
"NAME"
"HDRS;TEXTUAL_HDRS;SRCS;COPTS;DEFINES;LINKOPTS;DATA;DEPS;INCLUDES"
${ARGN}
@@ -79,10 +77,9 @@
return()
endif()
+ # Replace dependencies passed by ::name with iree::package::name
iree_package_ns(_PACKAGE_NS)
- # Replace dependencies passed by ::name with ::iree::package::name
list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
- list(TRANSFORM _RULE_DATA REPLACE "^::" "${_PACKAGE_NS}::")
# Prefix the library with the package name, so we get: iree_package_name.
iree_package_name(_PACKAGE_NAME)
@@ -105,13 +102,10 @@
endif()
if(NOT _RULE_IS_INTERFACE)
- if (_RULE_SHARED)
+ if(_RULE_SHARED)
add_library(${_NAME} SHARED "")
else()
add_library(${_NAME} STATIC "")
- if (_RULE_WHOLEARCHIVE)
- message(FATAL_ERROR "WHOLEARCHIVE must be set together with SHARED")
- endif()
endif()
target_sources(${_NAME}
@@ -122,7 +116,8 @@
)
target_include_directories(${_NAME} SYSTEM
PUBLIC
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
)
target_include_directories(${_NAME}
PUBLIC
@@ -130,19 +125,16 @@
)
target_compile_options(${_NAME}
PRIVATE
- ${_RULE_COPTS}
${IREE_DEFAULT_COPTS}
+ ${_RULE_COPTS}
)
- if(_RULE_WHOLEARCHIVE)
- iree_whole_archive_link(${_NAME} ${_RULE_DEPS})
- else()
- target_link_libraries(${_NAME} PUBLIC ${_RULE_DEPS})
- endif()
target_link_libraries(${_NAME}
+ PUBLIC
+ ${_RULE_DEPS}
PRIVATE
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
@@ -151,10 +143,6 @@
${_RULE_DEFINES}
)
- if(DEFINED _RULE_ALWAYSLINK)
- set_property(TARGET ${_NAME} PROPERTY ALWAYSLINK 1)
- endif()
-
# Add all IREE targets to a folder in the IDE for organization.
if(_RULE_PUBLIC)
set_property(TARGET ${_NAME} PROPERTY FOLDER ${IREE_IDE_FOLDER})
@@ -172,18 +160,19 @@
add_library(${_NAME} INTERFACE)
target_include_directories(${_NAME} SYSTEM
INTERFACE
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
)
target_compile_options(${_NAME}
INTERFACE
- ${_RULE_COPTS}
${IREE_DEFAULT_COPTS}
+ ${_RULE_COPTS}
)
target_link_libraries(${_NAME}
INTERFACE
+ ${IREE_DEFAULT_LINKOPTS}
${_RULE_DEPS}
${_RULE_LINKOPTS}
- ${IREE_DEFAULT_LINKOPTS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
target_compile_definitions(${_NAME}
diff --git a/build_tools/cmake/iree_cc_test.cmake b/build_tools/cmake/iree_cc_test.cmake
index 4ae27fb..0eac677 100644
--- a/build_tools/cmake/iree_cc_test.cmake
+++ b/build_tools/cmake/iree_cc_test.cmake
@@ -66,11 +66,6 @@
${ARGN}
)
- iree_package_ns(_PACKAGE_NS)
- # Replace dependencies passed by ::name with ::iree::package::name
- list(TRANSFORM _RULE_DATA REPLACE "^::" "${_PACKAGE_NS}::")
- list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
-
# Prefix the library with the package name, so we get: iree_package_name
iree_package_name(_PACKAGE_NAME)
set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
@@ -82,7 +77,8 @@
)
target_include_directories(${_NAME} SYSTEM
PUBLIC
- ${IREE_COMMON_INCLUDE_DIRS}
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
)
target_compile_definitions(${_NAME}
PUBLIC
@@ -90,26 +86,32 @@
)
target_compile_options(${_NAME}
PRIVATE
+ ${IREE_DEFAULT_COPTS}
${_RULE_COPTS}
)
target_link_options(${_NAME}
PRIVATE
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
+ )
+
+ # Replace dependencies passed by ::name with iree::package::name
+ iree_package_ns(_PACKAGE_NS)
+ list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
+
+ target_link_libraries(${_NAME}
+ PUBLIC
+ ${_RULE_DEPS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
+
# Add all IREE targets to a folder in the IDE for organization.
set_property(TARGET ${_NAME} PROPERTY FOLDER ${IREE_IDE_FOLDER}/test)
set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD ${IREE_CXX_STANDARD})
set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD_REQUIRED ON)
- # Defer computing transitive dependencies and calling target_link_libraries()
- # until all libraries have been declared.
- # Track target and deps, use in iree_complete_binary_link_options() later.
list(APPEND _RULE_DEPS "gmock")
- set_property(GLOBAL APPEND PROPERTY _IREE_CC_BINARY_NAMES "${_NAME}")
- set_property(TARGET ${_NAME} PROPERTY DIRECT_DEPS ${_RULE_DEPS})
string(REPLACE "::" "/" _PACKAGE_PATH ${_PACKAGE_NS})
set(_TEST_NAME "${_PACKAGE_PATH}/${_RULE_NAME}")
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 968380c..903b576 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -21,22 +21,74 @@
# By default Abseil strips string literals on mobile platforms, which means
# we cannot run IREE binaries via command-line with proper options. Turn off
# the stripping.
-# TODO: we might still want to strip when compiling IREE into Android Java apps.
+# TODO(#3814): remove ABSL flags.
if(ANDROID)
add_definitions(-DABSL_FLAGS_STRIP_NAMES=0)
endif()
#-------------------------------------------------------------------------------
-# C++ used within IREE
+# C/C++ options as used within IREE
#-------------------------------------------------------------------------------
+#
+# ██ ██ █████ ██████ ███ ██ ██ ███ ██ ██████
+# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
+# ██ █ ██ ███████ ██████ ██ ██ ██ ██ ██ ██ ██ ██ ███
+# ██ ███ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
+# ███ ███ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
+#
+# Everything here is added to *every* iree_cc_library/iree_cc_binary/etc.
+# That includes both runtime and compiler components, and these may propagate
+# out to user code interacting with either (such as custom modules).
+#
+# Be extremely judicious in the use of these flags.
+#
+# - Need to disable a warning?
+# Usually these are encountered in compiler-specific code and can be disabled
+# in a compiler-specific way. Only add global warning disables when it's clear
+# that we never want them or that they'll show up in a lot of places.
+#
+# See: https://stackoverflow.com/questions/3378560/how-to-disable-gcc-warnings-for-a-few-lines-of-code
+#
+# - Need to add a linker dependency?
+# First figure out if you *really* need it. If it's only required on specific
+# platforms and in very specific files clang or msvc are used prefer
+# autolinking. GCC is stubborn and doesn't have autolinking so additional
+# flags may be required there.
+#
+# See: https://en.wikipedia.org/wiki/Auto-linking
+#
+# - Need to tweak a compilation mode setting (debug/asserts/etc)?
+# Don't do that here, and in general *don't do that at all* unless it's behind
+# a very specific IREE-prefixed cmake flag (like IREE_SIZE_OPTIMIZED).
+# There's no one-size solution when we are dealing with cross-project and
+# cross-compiled binaries - there's no safe way to set global options that
+# won't cause someone to break, and you probably don't really need to do
+# change that setting anyway. Follow the rule of least surprise: if the user
+# has CMake's Debug configuration active then don't force things into release
+# mode, etc.
+#
+# - Need to add an include directory?
+# Don't do that here. Always prefer to fully-specify the path from the IREE
+# workspace root when it's known that the compilation will be occuring using
+# the files within the IREE checkout; for example, instead of adding a global
+# include path to third_party/foo/ and #include <foo.h>'ing, just
+# #include "third_party/foo/foo.h". This reduces build configuration, makes it
+# easier for readers to find the files, etc.
+#
+# - Still think you need to add an include directory? (system includes, etc)
+# Don't do that here, either. It's highly doubtful that every single target in
+# all of IREE (both compiler and runtime) on all platforms (both host and
+# cross-compilation targets) needs your special include directory. Add it on
+# the COPTS of the target you are using it in and, ideally, private to that
+# target (used in .c/cc files, not in a .h that leaks the include path
+# requirements to all consumers of the API).
set(IREE_CXX_STANDARD ${CMAKE_CXX_STANDARD})
-set(IREE_ROOT_DIR ${PROJECT_SOURCE_DIR})
-list(APPEND IREE_COMMON_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}
- ${PROJECT_BINARY_DIR}
-)
+# TODO(benvanik): fix these names (or remove entirely).
+set(IREE_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR})
+set(IREE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
+set(IREE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
iree_select_compiler_opts(IREE_DEFAULT_COPTS
CLANG
@@ -97,7 +149,6 @@
"-Wthread-safety"
"-Wthread-safety-beta"
"-Wunused-comparison"
- "-Wunused-variable"
"-Wvla"
# LINT.ThenChange(https://github.com/google/iree/tree/main/build_tools/bazel/iree.bazelrc:clang_diagnostics)
@@ -111,27 +162,145 @@
"-Wno-gnu-label-as-value"
CLANG_OR_GCC
"-Wno-unused-parameter"
+ "-Wno-unused-variable"
"-Wno-undef"
"-fvisibility=hidden"
MSVC_OR_CLANG_CL
+ # Exclude a bunch of rarely-used APIs, such as crypto/DDE/shell.
+ # https://docs.microsoft.com/en-us/windows/win32/winprog/using-the-windows-headers
+ # NOTE: this is not really required anymore for build performance but does
+ # work around some issues that crop up with header version compatibility
+ # (abseil has issues with winsock versions).
"/DWIN32_LEAN_AND_MEAN"
+
+ # Don't allow windows.h to define MIN and MAX and conflict with the STL.
+ # There's no legit use for these macros as any code we are writing ourselves
+ # that we want a MIN/MAX in should be using an IREE-prefixed version
+ # instead: iree_min iree_max
+ # https://stackoverflow.com/a/4914108
+ "/DNOMINMAX"
+
+ # Adds M_PI and other constants to <math.h>/<cmath> (to match non-windows).
+ # https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
"/D_USE_MATH_DEFINES"
- "/wd4624"
- # 'inline': used more than once
- "/wd4141"
- # 'WIN32_LEAN_AND_MEAN': macro redefinition
- "/wd4005"
- "/wd4267"
- "/wd4141"
- "/wd4244"
- "/wd4146"
- "/wd4018"
- "/wd4065"
- # TODO(benvanik): figure out if really required or accidentally enabled.
+
+ # Disable the "deprecation" warnings about CRT functions like strcpy.
+ # Though the secure versions *are* better, they aren't portable and as such
+ # just make cross-platform code annoying. One solution is to reimplement
+ # them in a portable fashion and use those - and that's what we try to do
+ # in certain places where we can get away with it. Other uses, like getenv,
+ # are fine as these are not intended for use in core runtime code that needs
+ # to be secure (friends don't let friends ship entire compiler stacks
+ # embedded inside security sensitive applications anyway :).
+ # https://docs.microsoft.com/en-us/cpp/c-runtime-library/security-features-in-the-crt
+ "/D_CRT_SECURE_NO_WARNINGS"
+
+ # With the above said about the "deprecated" functions; this useful flag
+ # will at least try to use them when possible without any change to user
+ # code. Note however because the new versions use templates they won't be
+ # activated in C code; that's fine.
+ # https://docs.microsoft.com/en-us/cpp/c-runtime-library/secure-template-overloads
+ "/D_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES"
+
+ # Configure exception handling for standard C++ behavior.
+ # - /EHs enables C++ catch-style exceptions
+ # - /EHc breaks unwinding across extern C boundaries, dramatically reducing
+ # unwind table size and associated exception handling overhead as the
+ # compiler can assume no exception will ever be thrown within any function
+ # annotated with extern "C".
+ # https://docs.microsoft.com/en-us/cpp/build/reference/eh-exception-handling-model
+ #
+ # TODO(benvanik): figure out if we need /EHs - we don't use exceptions in
+ # the runtime and I'm pretty sure LLVM doesn't use them either.
"/EHsc"
+
+ # Default max section count is 64k, which is woefully inadequate for some of
+ # the insanely bloated tablegen outputs LLVM/MLIR produces. This cranks it
+ # up to 2^32. It's not great that we have to generate/link files like that
+ # but it's better to not get spurious failures during LTCG.
+ # https://docs.microsoft.com/en-us/cpp/build/reference/bigobj-increase-number-of-sections-in-dot-obj-file
"/bigobj"
+
+ # "nonstandard extension used : zero-sized array in struct/union"
+ # This happens with unsized or zero-length arrays at the end of structs,
+ # which is completely valid in C where we do it and get this warning. Shut
+ # it up and rely on the better warnings from clang to catch if we try to
+ # use it where it really matters (on a class that has copy/move ctors, etc).
+ # https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-levels-2-and-4-c4200
+ "/wd4200"
+
+ # "signed/unsigned mismatch in comparison"
+ # This is along the lines of a generic implicit conversion warning but tends
+ # to crop up in code that implicitly treats unsigned size_t values as if
+ # they were signed values instead of properly using ssize_t. In certain
+ # cases where the comparison being performed may be guarding access to
+ # memory this can cause unexpected behavior ("-1ull < 512ull, great let's
+ # dereference buffer[-1ull]!").
+ # https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-3-c4018
+ #
+ # TODO(#3844): remove this (or make it per-file to iree/compiler, as LLVM
+ # tends to not care about these kind of things and it crops up there a lot).
+ "/wd4018"
+
+ # Also common in LLVM is mismatching signed/unsigned math. That's even more
+ # dangerous than C4018: almost always these crop up in doing something with
+ # a size_t and a non-size_t value (usually int or something like it) and do
+ # you want out-of-bounds access exploits? Because that's how you get
+ # out-of-bounds access exploits. Before fuzzers took over finding code and
+ # trying to compile it with this warning forced to be an error was a way to
+ # narrow down the places to look for attack vectors. I lived through the
+ # Microsoft SAL/safe-int code red, and once you get used to using the safe
+ # buffer offset/size manipulation functions it eliminates all kinds of
+ # annoying bugs - as well as potential security issues.
+ #
+ # TODO(#3844): work to remove this class of errors from our code. It's
+ # almost entirely in LLVM related stuff so per-file iree/compiler/... would
+ # be fine.
+ "/wd4146" # operator applied to unsigned type, result still unsigned
+ "/wd4244" # possible loss of data
+ "/wd4267" # initializing: possible loss of data
+
+ # Misc tweaks to better match reasonable clang/gcc behavior:
+ "/wd4005" # allow: macro redefinition
+ "/wd4065" # allow: switch statement contains 'default' but no 'case' labels
+ "/wd4141" # allow: inline used more than once
+ "/wd4624" # allow: destructor was implicitly defined as deleted
+
+ # TODO(benvanik): confirm these are all still required and document:
+ "/wd4146" # operator applied to unsigned type, result still unsigned
+ "/wd4244" # possible loss of data
+ "/wd4267" # initializing: possible loss of data
+ "/wd5105" # allow: macro expansion producing 'defined' has undefined behavior
)
-set(IREE_DEFAULT_LINKOPTS "${ABSL_DEFAULT_LINKOPTS}")
+
+if(NOT ANDROID)
+ iree_select_compiler_opts(_IREE_PTHREADS_LINKOPTS
+ CLANG_OR_GCC
+ "-lpthread"
+ )
+else()
+ # Android provides its own pthreads support with no linking required.
+endif()
+
+iree_select_compiler_opts(IREE_DEFAULT_LINKOPTS
+ ALL
+ # TODO(benvanik): remove the ABSL usage here; we aren't abseil.
+ "${ABSL_DEFAULT_LINKOPTS}"
+ CLANG_OR_GCC
+ # Required by all modern software, effectively:
+ "-ldl"
+ ${_IREE_PTHREADS_LINKOPTS}
+)
+
+# Add to LINKOPTS on a binary to configure it for X/Wayland/Windows/etc
+# depending on the target cross-compilation platform.
+if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
+ set(IREE_TARGET_GUI_LINKOPTS "-SUBSYSTEM:WINDOWS")
+else()
+ set(IREE_TARGET_GUI_LINKOPTS "")
+endif()
+
+# TODO(benvanik): remove the ABSL usage here; we aren't abseil.
set(IREE_TEST_COPTS "${ABSL_TEST_COPTS}")
#-------------------------------------------------------------------------------
@@ -207,7 +376,7 @@
set(FLATBUFFERS_BUILD_GRPCTEST OFF CACHE BOOL "" FORCE)
set(FLATBUFFERS_INSTALL OFF CACHE BOOL "" FORCE)
set(FLATBUFFERS_INCLUDE_DIRS
- "${PROJECT_SOURCE_DIR}/third_party/flatbuffers/include/"
+ "${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include/"
)
if(CMAKE_CROSSCOMPILING)
@@ -216,14 +385,6 @@
set(FLATBUFFERS_BUILD_FLATC ON CACHE BOOL "" FORCE)
endif()
-iree_select_compiler_opts(FLATBUFFERS_COPTS
- CLANG
- # Flatbuffers has a bunch of incorrect documentation annotations.
- "-Wno-documentation"
- "-Wno-documentation-unknown-command"
-)
-list(APPEND IREE_DEFAULT_COPTS ${FLATBUFFERS_COPTS})
-
#-------------------------------------------------------------------------------
# Third party: flatcc
#-------------------------------------------------------------------------------
@@ -267,35 +428,12 @@
set(LLVM_USE_LINKER ${IREE_USE_LINKER} CACHE STRING "" FORCE)
endif()
-# TODO: This should go in add_iree_mlir_src_dep at the top level.
-if(IREE_MLIR_DEP_MODE STREQUAL "BUNDLED")
- list(APPEND IREE_COMMON_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/llvm-project/llvm/include
- ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/include
- ${PROJECT_SOURCE_DIR}/third_party/llvm-project/mlir/include
- ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include
- )
-endif()
-
set(MLIR_TABLEGEN_EXE mlir-tblgen)
# iree-tblgen is not defined using the add_tablegen mechanism as other TableGen
# tools in LLVM.
iree_get_executable_path(IREE_TABLEGEN_EXE iree-tblgen)
#-------------------------------------------------------------------------------
-# Third party: tensorflow
-#-------------------------------------------------------------------------------
-
-list(APPEND IREE_COMMON_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/tensorflow
- ${PROJECT_SOURCE_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/include/
- ${PROJECT_BINARY_DIR}/third_party/tensorflow
- ${PROJECT_BINARY_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/include/
- ${PROJECT_BINARY_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/
- ${PROJECT_BINARY_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms
-)
-
-#-------------------------------------------------------------------------------
# Third party: mlir-emitc
#-------------------------------------------------------------------------------
@@ -304,10 +442,6 @@
set(EMITC_ENABLE_HLO OFF)
set(EMITC_INCLUDE_TESTS OFF)
- list(APPEND IREE_COMMON_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/mlir-emitc/include
- ${PROJECT_BINARY_DIR}/third_party/mlir-emitc/include
- )
add_definitions(-DIREE_HAVE_EMITC_DIALECT)
endif()
diff --git a/build_tools/cmake/iree_cross_compile.cmake b/build_tools/cmake/iree_cross_compile.cmake
index 3cfa1e9..f749f90 100644
--- a/build_tools/cmake/iree_cross_compile.cmake
+++ b/build_tools/cmake/iree_cross_compile.cmake
@@ -101,7 +101,8 @@
message(STATUS "C++ compiler for ${CONFIG_NAME} build: ${_CONFIG_CXX_COMPILER}")
add_custom_command(OUTPUT ${IREE_${CONFIG_NAME}_BINARY_ROOT}/CMakeCache.txt
- COMMAND "${CMAKE_COMMAND}" "${PROJECT_SOURCE_DIR}" -G "${CMAKE_GENERATOR}"
+ COMMAND "${CMAKE_COMMAND}" "${CMAKE_CURRENT_SOURCE_DIR}"
+ -G "${CMAKE_GENERATOR}"
-DCMAKE_MAKE_PROGRAM="${CMAKE_MAKE_PROGRAM}"
-DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}"
-DCMAKE_C_COMPILER="${_CONFIG_C_COMPILER}"
@@ -150,7 +151,7 @@
if(NOT _RULE_CONFIG)
set(_RULE_CONFIG "$<CONFIG>")
endif()
- if (CMAKE_GENERATOR MATCHES "Make")
+ if(CMAKE_GENERATOR MATCHES "Make")
# Use special command for Makefiles to support parallelism.
set(${_RULE_CMDVAR}
"$(MAKE)" "-C" "${_RULE_BINDIR}" "${EXECUTABLE_TARGET}" PARENT_SCOPE)
diff --git a/build_tools/cmake/iree_multipy.cmake b/build_tools/cmake/iree_multipy.cmake
index a974e95..f773925 100644
--- a/build_tools/cmake/iree_multipy.cmake
+++ b/build_tools/cmake/iree_multipy.cmake
@@ -344,7 +344,7 @@
# Performs a depth-first search through the dependency graph, appending all
# dependencies of TARGET to the TRANSITIVE_DEPS list.
function(_iree_transitive_dependencies_helper TARGET TRANSITIVE_DEPS)
- if (NOT TARGET "${TARGET}")
+ if(NOT TARGET "${TARGET}")
# Excluded from the project, or invalid name? Just ignore.
return()
endif()
@@ -358,7 +358,7 @@
endif()
set(_RESULT "${${TRANSITIVE_DEPS}}")
- if (${_TARGET_NAME} IN_LIST _RESULT)
+ if(${_TARGET_NAME} IN_LIST _RESULT)
# Already visited, ignore.
return()
endif()
@@ -368,7 +368,7 @@
list(APPEND _RESULT ${_TARGET_NAME})
# Check for non-target identifiers again after resolving the alias.
- if (NOT TARGET ${_TARGET_NAME})
+ if(NOT TARGET ${_TARGET_NAME})
return()
endif()
@@ -408,7 +408,7 @@
foreach(_DEP ${_TRANSITIVE_DEPS})
# Check if _DEP is a library with the ALWAYSLINK property set.
set(_DEP_IS_ALWAYSLINK OFF)
- if (TARGET ${_DEP})
+ if(TARGET ${_DEP})
get_target_property(_DEP_TYPE ${_DEP} TYPE)
if(${_DEP_TYPE} STREQUAL "INTERFACE_LIBRARY")
# Can't be ALWAYSLINK since it's an INTERFACE library.
@@ -428,14 +428,14 @@
# For macOS, also add a `-Wl,-force_load` version of the dep.
if(MSVC)
get_target_property(_ALIASED_TARGET ${_DEP} ALIASED_TARGET)
- if (_ALIASED_TARGET)
+ if(_ALIASED_TARGET)
list(APPEND _ALWAYS_LINK_DEPS "-WHOLEARCHIVE:${_ALIASED_TARGET}")
else()
list(APPEND _ALWAYS_LINK_DEPS "-WHOLEARCHIVE:${_DEP}")
endif()
elseif(APPLE)
get_target_property(_ALIASED_TARGET ${_DEP} ALIASED_TARGET)
- if (_ALIASED_TARGET)
+ if(_ALIASED_TARGET)
list(APPEND _ALWAYS_LINK_DEPS "-Wl,-force_load $<TARGET_FILE:${_ALIASED_TARGET}>")
else()
list(APPEND _ALWAYS_LINK_DEPS "-Wl,-force_load $<TARGET_FILE:${_DEP}>")
diff --git a/build_tools/cmake/iree_setup_toolchain.cmake b/build_tools/cmake/iree_setup_toolchain.cmake
index 48e2463..c3789df 100644
--- a/build_tools/cmake/iree_setup_toolchain.cmake
+++ b/build_tools/cmake/iree_setup_toolchain.cmake
@@ -19,7 +19,7 @@
endfunction()
if(IREE_ENABLE_LLD)
- if (IREE_USE_LINKER)
+ if(IREE_USE_LINKER)
message(FATAL_ERROR "IREE_ENABLE_LLD and IREE_USE_LINKER can't be set at the same time")
endif()
set(IREE_USE_LINKER "lld")
diff --git a/build_tools/cmake/iree_tablegen_doc.cmake b/build_tools/cmake/iree_tablegen_doc.cmake
index cc62fe4..5e93390 100644
--- a/build_tools/cmake/iree_tablegen_doc.cmake
+++ b/build_tools/cmake/iree_tablegen_doc.cmake
@@ -53,7 +53,10 @@
endif()
- set(_INCLUDE_DIRS ${IREE_COMMON_INCLUDE_DIRS})
+ set(_INCLUDE_DIRS
+ "${MLIR_INCLUDE_DIRS}"
+ "${IREE_SOURCE_DIR}"
+ )
list(APPEND _INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR})
list(TRANSFORM _INCLUDE_DIRS PREPEND "-I")
@@ -73,7 +76,7 @@
endwhile()
# Put all dialect docs at one place.
- set(_DOC_DIR ${PROJECT_BINARY_DIR}/doc/Dialects/)
+ set(_DOC_DIR ${CMAKE_CURRENT_BINARY_DIR}/doc/Dialects/)
# Set a target to drive copy.
add_custom_target(${_NAME}_target
${CMAKE_COMMAND} -E make_directory ${_DOC_DIR}
diff --git a/build_tools/cmake/iree_tablegen_library.cmake b/build_tools/cmake/iree_tablegen_library.cmake
index e185e59..6226b16 100644
--- a/build_tools/cmake/iree_tablegen_library.cmake
+++ b/build_tools/cmake/iree_tablegen_library.cmake
@@ -41,7 +41,10 @@
endif()
set(LLVM_TARGET_DEFINITIONS ${_RULE_TD_FILE})
- set(_INCLUDE_DIRS ${IREE_COMMON_INCLUDE_DIRS})
+ set(_INCLUDE_DIRS
+ "${MLIR_INCLUDE_DIRS}"
+ "${IREE_SOURCE_DIR}"
+ )
list(APPEND _INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR})
list(TRANSFORM _INCLUDE_DIRS PREPEND "-I")
set(_OUTPUTS)
diff --git a/build_tools/cmake/iree_whole_archive_link.cmake b/build_tools/cmake/iree_whole_archive_link.cmake
deleted file mode 100644
index e34f900..0000000
--- a/build_tools/cmake/iree_whole_archive_link.cmake
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Lists all transitive dependencies of DIRECT_DEPS in TRANSITIVE_DEPS.
-function(_iree_transitive_dependencies DIRECT_DEPS TRANSITIVE_DEPS)
- set(_TRANSITIVE "")
-
- foreach(_DEP ${DIRECT_DEPS})
- _iree_transitive_dependencies_helper(${_DEP} _TRANSITIVE)
- endforeach(_DEP)
-
- set(${TRANSITIVE_DEPS} "${_TRANSITIVE}" PARENT_SCOPE)
-endfunction()
-
-# Recursive helper function for _iree_transitive_dependencies.
-# Performs a depth-first search through the dependency graph, appending all
-# dependencies of TARGET to the TRANSITIVE_DEPS list.
-function(_iree_transitive_dependencies_helper TARGET TRANSITIVE_DEPS)
- if (NOT TARGET "${TARGET}")
- # Excluded from the project, or invalid name? Just ignore.
- return()
- endif()
-
- # Resolve aliases, canonicalize name formatting.
- get_target_property(_ALIASED_TARGET ${TARGET} ALIASED_TARGET)
- if(_ALIASED_TARGET)
- set(_TARGET_NAME ${_ALIASED_TARGET})
- else()
- string(REPLACE "::" "_" _TARGET_NAME ${TARGET})
- endif()
-
- set(_RESULT "${${TRANSITIVE_DEPS}}")
- if (${_TARGET_NAME} IN_LIST _RESULT)
- # Already visited, ignore.
- return()
- endif()
-
- # Append this target to the list. Dependencies of this target will be added
- # (if valid and not already visited) in recursive function calls.
- list(APPEND _RESULT ${_TARGET_NAME})
-
- # Check for non-target identifiers again after resolving the alias.
- if (NOT TARGET ${_TARGET_NAME})
- return()
- endif()
-
- # Get the list of direct dependencies for this target.
- get_target_property(_TARGET_TYPE ${_TARGET_NAME} TYPE)
- if(NOT ${_TARGET_TYPE} STREQUAL "INTERFACE_LIBRARY")
- get_target_property(_TARGET_DEPS ${_TARGET_NAME} LINK_LIBRARIES)
- else()
- get_target_property(_TARGET_DEPS ${_TARGET_NAME} INTERFACE_LINK_LIBRARIES)
- endif()
-
- if(_TARGET_DEPS)
- # Recurse on each dependency.
- foreach(_TARGET_DEP ${_TARGET_DEPS})
- _iree_transitive_dependencies_helper(${_TARGET_DEP} _RESULT)
- endforeach(_TARGET_DEP)
- endif()
-
- # Propagate the augmented list up to the parent scope.
- set(${TRANSITIVE_DEPS} "${_RESULT}" PARENT_SCOPE)
-endfunction()
-
-# Given the ${TARGET} and the libaries it directly depends on in ${ARGN},
-# properly establish the linking relationship by considering ALWAYSLINK
-# in a recursive manner.
-#
-# All symbols from ALWAYSLINK libraries will be included in ${TARGET},
-# regardless of whether they are directly referenced or not.
-function(iree_whole_archive_link TARGET)
- # List all dependencies, including transitive dependencies, then split the
- # dependency list into one for whole archive (ALWAYSLINK) and one for
- # standard linking (which only links in symbols that are directly used).
- _iree_transitive_dependencies("${ARGN}" _TRANSITIVE_DEPS)
- set(_ALWAYS_LINK_DEPS "")
- set(_STANDARD_DEPS "")
- foreach(_DEP ${_TRANSITIVE_DEPS})
- # Check if _DEP is a library with the ALWAYSLINK property set.
- set(_DEP_IS_ALWAYSLINK OFF)
- if (TARGET ${_DEP})
- get_target_property(_DEP_TYPE ${_DEP} TYPE)
- if(${_DEP_TYPE} STREQUAL "INTERFACE_LIBRARY")
- # Can't be ALWAYSLINK since it's an INTERFACE library.
- # We also can't even query for the property, since it isn't allowlisted.
- else()
- get_target_property(_DEP_IS_ALWAYSLINK ${_DEP} ALWAYSLINK)
- endif()
- endif()
-
- # Append to the corresponding list of deps.
- if(_DEP_IS_ALWAYSLINK)
- list(APPEND _ALWAYS_LINK_DEPS ${_DEP})
-
- # For MSVC, also add a `-WHOLEARCHIVE:` version of the dep.
- # CMake treats -WHOLEARCHIVE[:lib] as a link flag and will not actually
- # try to link the library in, so we need the flag *and* the dependency.
- # For macOS, also add a `-Wl,-force_load` version of the dep.
- if(MSVC)
- get_target_property(_ALIASED_TARGET ${_DEP} ALIASED_TARGET)
- if (_ALIASED_TARGET)
- list(APPEND _ALWAYS_LINK_DEPS "-WHOLEARCHIVE:${_ALIASED_TARGET}")
- else()
- list(APPEND _ALWAYS_LINK_DEPS "-WHOLEARCHIVE:${_DEP}")
- endif()
- elseif(APPLE)
- get_target_property(_ALIASED_TARGET ${_DEP} ALIASED_TARGET)
- if (_ALIASED_TARGET)
- list(APPEND _ALWAYS_LINK_DEPS "-Wl,-force_load $<TARGET_FILE:${_ALIASED_TARGET}>")
- else()
- list(APPEND _ALWAYS_LINK_DEPS "-Wl,-force_load $<TARGET_FILE:${_DEP}>")
- endif()
- endif()
- else()
- list(APPEND _STANDARD_DEPS ${_DEP})
- endif()
- endforeach(_DEP)
-
- # Call into target_link_libraries with the lists of deps.
- if(MSVC OR APPLE)
- target_link_libraries(${TARGET}
- PUBLIC
- ${_ALWAYS_LINK_DEPS}
- ${_STANDARD_DEPS}
- )
- else()
- target_link_libraries(${TARGET}
- PUBLIC
- "-Wl,--whole-archive"
- ${_ALWAYS_LINK_DEPS}
- "-Wl,--no-whole-archive"
- ${_STANDARD_DEPS}
- )
- endif()
-endfunction()
-
diff --git a/build_tools/docker/README.md b/build_tools/docker/README.md
index 19144bc..3172c86 100644
--- a/build_tools/docker/README.md
+++ b/build_tools/docker/README.md
@@ -69,38 +69,97 @@
This requires that the tagged image have a repository digest, which means it was
pushed to or pulled from GCR.
-## Deploying New Images
+## Adding or Updating an Image
-1. Modify the Dockerfiles as desired.
-2. Update `manage_images.py` to include the new image and its dependencies.
-3. Build and push the new image to GCR and update references to it:
+If you have worked with the `docker` images before, it is prudent to follow the
+steps in the "Debugging" section below before continuing.
- ```shell
- python3 build_tools/docker/manage_images.py --image "${IMAGE?}" --build --push --update_references
+### Part 1. Local Changes
+
+1. Update the `Dockerfile` for the image that you want to modify or add. If
+ you're adding a new image, or updating the dependencies between images, be
+ sure to update `IMAGES_TO_DEPENDENCIES` in `manage_images.py` as well.
+2. Build the image, push the image to GCR and update all references to the image
+ with the new GCR digest:
+
+ ```shell
+ python3 build_tools/docker/manage_images.py \
+ --image "${IMAGE?}" --build \
+ --tag latest \
+ --push \
+ --update_references
```
-4. Commit changes and send a PR for review.
+3. Test that the changes behave as expected locally and iterate on the steps
+ above.
-5. Merge your PR after is approved and all builds pass.
+### Part 2. Submitting to GitHub
-6. Kokoro builds preload images tagged with `prod` on VM creation, so after
- changing the images used, you should also update the images tagged as `prod`
- in GCR. Update your local reference to the `prod` tag to point at the new
- image:
+4. Commit the changes and send a PR for review. The CI will use the updated
+ digest references to test the new images.
- ```shell
- python3 build_tools/docker/manage_images.py --image "${IMAGE?}" --tag prod --build --update_references
+5. Merge your PR after is approved and all CI tests pass. **Please remember to
+ complete the rest of the steps below**.
+
+### Part 3. Updating the `:prod` tag
+
+Kokoro builds preload images tagged with `prod` on VM creation, so after
+changing the images used, you should also update the images tagged as `prod`
+in GCR. This also makes development significantly easier for others who need to
+modify the `docker` images.
+
+6. On the `main` branch, build (but don't push) the images and locally tag them
+ with the `:prod` tag:
+
+ ```shell
+ python3 build_tools/docker/manage_images.py \
+ --image "${IMAGE?}" --build \
+ --tag prod \
+ --update_references
```
- The build steps here should all be cache hits and no references should
- actually be changed. If they are, that indicates the images you've just
- built are different from the ones that are being referenced. Stop and fix
- this before proceeding. This relies on you keeping your local copy of the
- Docker images. If you didn't, you'll have to manually pull the missing
- images by their digest.
+ This build should be entirely cache hits.
+7. We include `--update_references` in the command above so that we can check
+ that none of the images or references to them have been changed. Check that
+ the following command produces no output before continuing:
-7. Push the new images with the `prod` tag to GCR.
+ ```shell
+ git status --porcelain
+ ```
+
+ If the output is not empty then you'll need to find the source of the
+ discrepancy (e.g. a locally modified `Dockerfile`) and remove it, and repeat
+ steps 5 and 6 before continuing. (This relies on you keeping your local copy
+ of the Docker images. If you didn't, you'll have to manually pull the missing
+ images by their digest).
+8. Now that we've confirmed that none of the images were changed, we can push
+ them to GCR with the `:prod` tag.
```shell
- python3 build_tools/docker/manage_images.py --image "${IMAGE?}" --tag prod --push
+ python3 build_tools/docker/manage_images.py \
+ --image "${IMAGE?}" \
+ --tag prod \
+ --push
```
+
+## Debugging
+
+Sometimes old versions of the `:latest` images can be stored locally and produce
+unexpected behaviors. The following commands will download all of the prod
+images and then update the images tagged with `:latest` on your machine (and on
+GCR).
+
+```shell
+# Pull all :prod images
+python3 build_tools/docker/manage_images.py --images all --pull --tag prod
+# Update the :latest images to match the :prod images.
+# If you have a clean workspace this _shouldn't_ require building anything as
+# everything should be cache hits from the :prod images downloaded above, but if
+# the :prod images are behind then that will not be the case and this may take
+# several hours (depending on your machine).
+python3 build_tools/docker/manage_images.py \
+ --images all --build \
+ --tag latest \
+ --push \
+ --update_references
+```
diff --git a/build_tools/docker/bazel-python/Dockerfile b/build_tools/docker/bazel-python/Dockerfile
index 11c35ab..98b221b 100644
--- a/build_tools/docker/bazel-python/Dockerfile
+++ b/build_tools/docker/bazel-python/Dockerfile
@@ -24,6 +24,6 @@
python3-pip \
python3-setuptools \
&& python3 -m pip install --upgrade pip \
- && python3 -m pip install numpy
+ && python3 -m pip install numpy==1.19.4
ENV PYTHON_BIN /usr/bin/python3
diff --git a/build_tools/docker/bazel-tensorflow-nvidia/Dockerfile b/build_tools/docker/bazel-tensorflow-nvidia/Dockerfile
index a9b4cf4..adb2e60 100644
--- a/build_tools/docker/bazel-tensorflow-nvidia/Dockerfile
+++ b/build_tools/docker/bazel-tensorflow-nvidia/Dockerfile
@@ -18,4 +18,4 @@
FROM gcr.io/iree-oss/bazel-tensorflow-vulkan AS final
RUN apt-get update \
- && DEBIAN_FRONTEND=noninteractive apt-get install -y vulkan-sdk nvidia-driver-440
+ && DEBIAN_FRONTEND=noninteractive apt-get install -y libnvidia-gl-440=440.100-0ubuntu0.18.04.1
diff --git a/build_tools/docker/bazel-tensorflow-vulkan/Dockerfile b/build_tools/docker/bazel-tensorflow-vulkan/Dockerfile
index 849a32c..05b7209 100644
--- a/build_tools/docker/bazel-tensorflow-vulkan/Dockerfile
+++ b/build_tools/docker/bazel-tensorflow-vulkan/Dockerfile
@@ -13,23 +13,21 @@
# limitations under the License.
# A base image for building IREE with TensorFlow integrations using Bazel and
-# running Vulkan tests. Requires a child image to provide a Vulkan ICD.
+# running Vulkan tests. This image provides the Vulkan SDK. Requires a child
+# image to provide a Vulkan ICD.
FROM gcr.io/iree-oss/bazel-tensorflow AS final
-RUN apt-get update && apt-get install -y wget
+ARG VULKAN_SDK_VERSION=1.2.154.0
-ARG VULKAN_SDK_VERSION=1.2.141
+COPY --from=gcr.io/iree-oss/vulkan /opt/vulkan-sdk/ /opt/vulkan-sdk/
-# Disable apt-key parse waring. If someone knows how to do whatever the "proper"
-# thing is then feel free. The warning complains about parsing apt-key output,
-# which we're not even doing.
-ARG APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=1
+ENV VULKAN_SDK="/opt/vulkan-sdk/${VULKAN_SDK_VERSION}/x86_64"
-RUN wget -qO - http://packages.lunarg.com/lunarg-signing-key-pub.asc \
- | apt-key add - \
- && wget -qO \
- "/etc/apt/sources.list.d/lunarg-vulkan-${VULKAN_SDK_VERSION?}-bionic.list" \
- "http://packages.lunarg.com/vulkan/${VULKAN_SDK_VERSION?}/lunarg-vulkan-${VULKAN_SDK_VERSION?}-bionic.list" \
- && apt-get update \
- && apt-get install -y vulkan-sdk
+ENV PATH="${VULKAN_SDK}/bin:$PATH"
+
+# Symlink the Vulkan loader to a system library directory. This is needed to
+# allow Vulkan applications to find the Vulkan loader. It also avoids using
+# LD_LIBRARY_PATH, which is not supported well by Docker.
+RUN ln -s "${VULKAN_SDK}/lib/libvulkan.so" /usr/lib/x86_64-linux-gnu/ \
+ && ln -s "${VULKAN_SDK}/lib/libvulkan.so.1" /usr/lib/x86_64-linux-gnu/
diff --git a/build_tools/docker/bazel-tensorflow/Dockerfile b/build_tools/docker/bazel-tensorflow/Dockerfile
index a9dd7fb..1ce0944 100644
--- a/build_tools/docker/bazel-tensorflow/Dockerfile
+++ b/build_tools/docker/bazel-tensorflow/Dockerfile
@@ -17,4 +17,4 @@
FROM gcr.io/iree-oss/bazel-python AS final
# Install tensorflow.
-RUN python3 -m pip install tf-nightly
+RUN python3 -m pip install tf-nightly==2.5.0.dev20201116
diff --git a/build_tools/docker/cmake-python-nvidia/Dockerfile b/build_tools/docker/cmake-python-nvidia/Dockerfile
index cce8d2f..e9559d4 100644
--- a/build_tools/docker/cmake-python-nvidia/Dockerfile
+++ b/build_tools/docker/cmake-python-nvidia/Dockerfile
@@ -28,4 +28,4 @@
FROM gcr.io/iree-oss/cmake-python-vulkan AS final
RUN apt-get update \
- && apt-get install -y nvidia-driver-440
+ && DEBIAN_FRONTEND=noninteractive apt-get install -y libnvidia-gl-440=440.100-0ubuntu0.18.04.1
diff --git a/build_tools/docker/cmake-python-vulkan/Dockerfile b/build_tools/docker/cmake-python-vulkan/Dockerfile
index a20221b..af95afc 100644
--- a/build_tools/docker/cmake-python-vulkan/Dockerfile
+++ b/build_tools/docker/cmake-python-vulkan/Dockerfile
@@ -12,27 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# A base image for building IREE using CMake and running Vulkan tests. Requires
-# a child image to provide a Vulkan ICD.
+# A base image for building IREE using CMake and running Vulkan tests.
+# This image provides the Vulkan SDK. Requires a child image to provide
+# a Vulkan ICD.
FROM gcr.io/iree-oss/cmake-python AS final
-# It would be nice to have a separate install vulkan image and copy from that,
-# but I don't know all the files that installing the vulkan-sdk adds.
-RUN apt-get update && apt-get install -y wget
+ARG VULKAN_SDK_VERSION=1.2.154.0
-ARG VULKAN_SDK_VERSION=1.2.141
-# Disable apt-key parse waring. If someone knows how to do whatever the "proper"
-# thing is then feel free. The warning complains about parsing apt-key output,
-# which we're not even doing.
-ARG APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=1
+COPY --from=gcr.io/iree-oss/vulkan /opt/vulkan-sdk/ /opt/vulkan-sdk/
-RUN wget -qO - http://packages.lunarg.com/lunarg-signing-key-pub.asc \
- | apt-key add - \
- && wget -qO \
- "/etc/apt/sources.list.d/lunarg-vulkan-${VULKAN_SDK_VERSION?}-bionic.list" \
- "http://packages.lunarg.com/vulkan/${VULKAN_SDK_VERSION?}/lunarg-vulkan-${VULKAN_SDK_VERSION?}-bionic.list" \
- && apt-get update \
- && apt-get install -y vulkan-sdk
+ENV VULKAN_SDK="/opt/vulkan-sdk/${VULKAN_SDK_VERSION}/x86_64"
-RUN rm /usr/bin/wget
+ENV PATH="${VULKAN_SDK}/bin:$PATH"
+
+# Symlink the Vulkan loader to a system library directory. This is needed to
+# allow Vulkan applications to find the Vulkan loader. It also avoids using
+# LD_LIBRARY_PATH, which is not supported well by Docker.
+RUN ln -s "${VULKAN_SDK}/lib/libvulkan.so" /usr/lib/x86_64-linux-gnu/ \
+ && ln -s "${VULKAN_SDK}/lib/libvulkan.so.1" /usr/lib/x86_64-linux-gnu/
diff --git a/build_tools/docker/cmake-python/Dockerfile b/build_tools/docker/cmake-python/Dockerfile
index d206097..0f1e618 100644
--- a/build_tools/docker/cmake-python/Dockerfile
+++ b/build_tools/docker/cmake-python/Dockerfile
@@ -14,10 +14,6 @@
# An image for building IREE and its Python bindings using CMake.
-ARG CMAKE_MAJOR_VERSION=3
-ARG CMAKE_MINOR_VERSION=13
-ARG CMAKE_PATCH_VERSION=5
-
FROM gcr.io/iree-oss/cmake AS final
# Dependencies for the python bindings tests.
RUN apt-get update \
@@ -26,6 +22,6 @@
python3-pip \
python3-setuptools \
&& python3 -m pip install --upgrade pip \
- && python3 -m pip install numpy absl-py
+ && python3 -m pip install numpy==1.19.4 absl-py
ENV PYTHON_BIN /usr/bin/python3
diff --git a/build_tools/docker/manage_images.py b/build_tools/docker/manage_images.py
index e8c7c8b..ad5c3a5 100755
--- a/build_tools/docker/manage_images.py
+++ b/build_tools/docker/manage_images.py
@@ -17,6 +17,8 @@
Includes information on their dependency graph and GCR URL.
+See the README for information on how to add and update images.
+
Example usage:
Rebuild the cmake image and all images that transitively on depend on it,
@@ -54,16 +56,17 @@
'bazel-tensorflow': ['bazel-python'],
'bazel-tensorflow-nvidia': ['bazel-tensorflow-vulkan'],
'bazel-tensorflow-swiftshader': ['bazel-tensorflow-vulkan', 'swiftshader'],
- 'bazel-tensorflow-vulkan': ['bazel-tensorflow'],
+ 'bazel-tensorflow-vulkan': ['bazel-tensorflow', 'vulkan'],
'cmake': ['base', 'util'],
'cmake-android': ['cmake', 'util'],
'cmake-python': ['cmake'],
'cmake-python-nvidia': ['cmake-python-vulkan'],
'cmake-python-swiftshader': ['cmake-python-vulkan', 'swiftshader'],
- 'cmake-python-vulkan': ['cmake-python'],
- 'rbe-toolchain': [],
+ 'cmake-python-vulkan': ['cmake-python', 'vulkan'],
+ 'rbe-toolchain': ['vulkan'],
'swiftshader': ['cmake'],
'util': [],
+ 'vulkan': ['util'],
}
IMAGES_TO_DEPENDENT_IMAGES = {k: [] for k in IMAGES_TO_DEPENDENCIES}
diff --git a/build_tools/docker/rbe-toolchain/Dockerfile b/build_tools/docker/rbe-toolchain/Dockerfile
index ffb0f0f..f3fef55 100644
--- a/build_tools/docker/rbe-toolchain/Dockerfile
+++ b/build_tools/docker/rbe-toolchain/Dockerfile
@@ -79,25 +79,20 @@
&& python3 -m pip install numpy
######################## Vulkan SDK ############################################
-ARG VULKAN_SDK_VERSION=1.2.141
-# Disable apt-key parse waring. If someone knows how to do whatever the "proper"
-# thing is then feel free. The warning complains about parsing apt-key output,
-# which we're not even doing.
-ARG APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=1
+ARG VULKAN_SDK_VERSION=1.2.154.0
-RUN apt-get update \
- && apt-get install apt-transport-https
+COPY --from=gcr.io/iree-oss/vulkan /opt/vulkan-sdk/ /opt/vulkan-sdk/
-# Note that this image is based on Ubuntu 16.04 (xenial) as opposed to
-# Ubuntu 18.04 (bionic), which we use for our other images.
-RUN wget -qO - http://packages.lunarg.com/lunarg-signing-key-pub.asc \
- | apt-key add - \
- && wget -qO \
- "/etc/apt/sources.list.d/lunarg-vulkan-${VULKAN_SDK_VERSION?}-xenial.list" \
- "http://packages.lunarg.com/vulkan/${VULKAN_SDK_VERSION?}/lunarg-vulkan-${VULKAN_SDK_VERSION?}-xenial.list" \
- && apt-get update \
- && apt-get install -y vulkan-sdk
+ENV VULKAN_SDK="/opt/vulkan-sdk/${VULKAN_SDK_VERSION}/x86_64"
+
+ENV PATH="${VULKAN_SDK}/bin:$PATH"
+
+# Symlink the Vulkan loader to a system library directory. This is needed to
+# allow Vulkan applications to find the Vulkan loader. It also avoids using
+# LD_LIBRARY_PATH, which is not supported well by Docker.
+RUN ln -s "${VULKAN_SDK}/lib/libvulkan.so" /usr/lib/x86_64-linux-gnu/ \
+ && ln -s "${VULKAN_SDK}/lib/libvulkan.so.1" /usr/lib/x86_64-linux-gnu/
######################## Swiftshader ###########################################
COPY --from=install-swiftshader /swiftshader /swiftshader
diff --git a/build_tools/docker/vulkan/Dockerfile b/build_tools/docker/vulkan/Dockerfile
new file mode 100644
index 0000000..24a8914
--- /dev/null
+++ b/build_tools/docker/vulkan/Dockerfile
@@ -0,0 +1,25 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM gcr.io/iree-oss/util AS base
+
+ARG VULKAN_SDK_VERSION=1.2.154.0
+
+RUN wget -q \
+ "https://sdk.lunarg.com/sdk/download/${VULKAN_SDK_VERSION?}/linux/vulkansdk-linux-${VULKAN_SDK_VERSION?}.tar.gz" \
+ -O "/tmp/vulkansdk-linux-x86_64-${VULKAN_SDK_VERSION}.tar.gz"
+
+RUN mkdir -p /opt/vulkan-sdk
+
+RUN tar -xf /tmp/vulkansdk-linux-x86_64-$VULKAN_SDK_VERSION.tar.gz -C /opt/vulkan-sdk
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/bindings/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/bindings/build_kokoro.sh
index 199a9a2..9e18be2 100755
--- a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/bindings/build_kokoro.sh
+++ b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/bindings/build_kokoro.sh
@@ -32,7 +32,7 @@
docker_setup
docker run "${DOCKER_RUN_ARGS[@]?}" \
- gcr.io/iree-oss/bazel-python@sha256:3dc3d72717edec56a7ce8a910a9333aa10bd7c9814858880325968047b3a28d4 \
+ gcr.io/iree-oss/bazel-python@sha256:473b7e294136bc38abc1941042f0c0404199de5827f141520f0b6757305b7a95 \
build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/bindings/build.sh
# Kokoro will rsync this entire directory back to the executor orchestrating the
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/core/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/core/build_kokoro.sh
index e78cd07..72f53b1 100755
--- a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/core/build_kokoro.sh
+++ b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/core/build_kokoro.sh
@@ -32,7 +32,7 @@
docker_setup
docker run "${DOCKER_RUN_ARGS[@]?}" \
- gcr.io/iree-oss/bazel@sha256:6b83817206384c7e8ad4f522162ee31b927ac9912095dc3ef1d3c8f580feba92 \
+ gcr.io/iree-oss/bazel@sha256:59da17e5cc8176890a6e1bda369b1f3d398e27af3d47e02e1ffd5b76729c215b \
build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/core/build.sh
# Kokoro will rsync this entire directory back to the executor orchestrating the
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build_kokoro.sh
index 9e9d18c..3b9f412 100755
--- a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build_kokoro.sh
+++ b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build_kokoro.sh
@@ -32,7 +32,7 @@
docker_setup
docker run "${DOCKER_RUN_ARGS[@]?}" \
- gcr.io/iree-oss/bazel-tensorflow-swiftshader@sha256:684d38c5fbb4a362476305138d4a78fab8710daa0794afe0f2e41f6b61627fe1 \
+ gcr.io/iree-oss/bazel-tensorflow-swiftshader@sha256:39c0e43c503bddfacd69758a50f02450ad2322d35324e2f56997aebb33a1b20a \
build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-swiftshader/integrations/build.sh
# Kokoro will rsync this entire directory back to the executor orchestrating the
diff --git a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build_kokoro.sh
index 5ba3605..dc43124 100755
--- a/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build_kokoro.sh
+++ b/build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build_kokoro.sh
@@ -36,7 +36,7 @@
# TODO(#3550): Allow this to follow the checked-in Docker hierarchy.
docker run "${DOCKER_RUN_ARGS[@]?}" \
--gpus all \
- gcr.io/iree-oss/bazel-tensorflow-nvidia@sha256:754dc09c558157f82e9d53451486951fc096e8d2a2b9a1306a29ebfe9e0772df \
+ gcr.io/iree-oss/bazel-tensorflow-nvidia@sha256:e5e96ec1709e83355ee2264c97c26fa5c3d40f749a62734f4787b17a83f2c3b8 \
build_tools/kokoro/gcp_ubuntu/bazel/linux/x86-turing/integrations/build.sh
# Kokoro will rsync this entire directory back to the executor orchestrating the
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/android/arm64-v8a/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/cmake/android/arm64-v8a/build_kokoro.sh
index 18aa00d..43f6e0a 100755
--- a/build_tools/kokoro/gcp_ubuntu/cmake/android/arm64-v8a/build_kokoro.sh
+++ b/build_tools/kokoro/gcp_ubuntu/cmake/android/arm64-v8a/build_kokoro.sh
@@ -32,7 +32,7 @@
docker_setup
docker run "${DOCKER_RUN_ARGS[@]?}" \
- gcr.io/iree-oss/cmake-android@sha256:dbee219f04bff26ea04c41e5e4232ab5486abb04c2abf17ab60ff33f7b279227 \
+ gcr.io/iree-oss/cmake-android@sha256:7accda0b84e2ae337740f2ee71801ee30f2155900abf1cf7b73ea47c15dc694f \
build_tools/kokoro/gcp_ubuntu/cmake/android/build.sh arm64-v8a
# Kokoro will rsync this entire directory back to the executor orchestrating the
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build_kokoro.sh
index dea51d5..eb80e5a 100755
--- a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build_kokoro.sh
+++ b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build_kokoro.sh
@@ -32,7 +32,7 @@
docker_setup
docker run "${DOCKER_RUN_ARGS[@]?}" \
- gcr.io/iree-oss/cmake-python-swiftshader@sha256:e84cac3152543a6f300701bfce98e44b25d4e52dc0ffec715b6998058d91d583 \
+ gcr.io/iree-oss/cmake-python-swiftshader@sha256:3e3d3427f3a58b32fa3ed578b610e411e0b81fd0e1984ac9b0fceae8bf8343dc \
build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-swiftshader/build.sh
# Kokoro will rsync this entire directory back to the executor orchestrating the
diff --git a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh
index 8965dd6..c77abcd 100755
--- a/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh
+++ b/build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build_kokoro.sh
@@ -33,7 +33,7 @@
docker run "${DOCKER_RUN_ARGS[@]?}" \
--gpus all \
- gcr.io/iree-oss/cmake-python-nvidia@sha256:fb0babff91402d999a2532816d2ee3a9df29ead152e3af7df0215bab9ce85682 \
+ gcr.io/iree-oss/cmake-python-nvidia@sha256:310e3b399717905bb2b485f3ebed32222915c7dc4dc075aa4e1b8551101fe607 \
build_tools/kokoro/gcp_ubuntu/cmake/linux/x86-turing/build.sh
# Kokoro will rsync this entire directory back to the executor orchestrating the
diff --git a/build_tools/third_party/flatcc/BUILD.overlay b/build_tools/third_party/flatcc/BUILD.overlay
index 3481995..729d57c 100644
--- a/build_tools/third_party/flatcc/BUILD.overlay
+++ b/build_tools/third_party/flatcc/BUILD.overlay
@@ -14,13 +14,14 @@
package(default_visibility = ["//visibility:public"])
-# NOTE: we exclude JSON parsing/printing to avoid additional dependencies.
cc_library(
name = "runtime",
srcs = [
"config/config.h",
"src/runtime/builder.c",
"src/runtime/emitter.c",
+ "src/runtime/json_parser.c",
+ "src/runtime/json_printer.c",
"src/runtime/refmap.c",
"src/runtime/verifier.c",
],
@@ -45,10 +46,19 @@
"include/flatcc/reflection/flatbuffers_common_builder.h",
"include/flatcc/reflection/flatbuffers_common_reader.h",
] + glob(["include/flatcc/portable/**/*.h"]),
- copts = [
- "-Iexternal/com_github_dvidelabs_flatcc/config/",
- "-Iexternal/com_github_dvidelabs_flatcc/include/",
+ textual_hdrs = [
+ "include/flatcc/flatcc_json_parser.h",
+ "include/flatcc/flatcc_json_printer.h",
],
+ copts = [
+ "-Iexternal/com_github_dvidelabs_flatcc/config/",
+ "-Iexternal/com_github_dvidelabs_flatcc/include/",
+ ] + select({
+ "@bazel_tools//src/conditions:windows": [],
+ "//conditions:default": [
+ "-Wno-implicit-fallthrough",
+ ],
+ }),
includes = [
"include/",
],
diff --git a/build_tools/third_party/flatcc/CMakeLists.txt b/build_tools/third_party/flatcc/CMakeLists.txt
index 82b1cb2..f1ad43e 100644
--- a/build_tools/third_party/flatcc/CMakeLists.txt
+++ b/build_tools/third_party/flatcc/CMakeLists.txt
@@ -27,6 +27,8 @@
SRCS
"src/runtime/builder.c"
"src/runtime/emitter.c"
+ "src/runtime/json_parser.c"
+ "src/runtime/json_printer.c"
"src/runtime/refmap.c"
"src/runtime/verifier.c"
HDRS
@@ -40,6 +42,8 @@
"include/flatcc/flatcc_flatbuffers.h"
"include/flatcc/flatcc_identifier.h"
"include/flatcc/flatcc_iov.h"
+ "include/flatcc/flatcc_json_parser.h"
+ "include/flatcc/flatcc_json_printer.h"
"include/flatcc/flatcc_portable.h"
"include/flatcc/flatcc_prologue.h"
"include/flatcc/flatcc_refmap.h"
diff --git a/build_tools/third_party/half/BUILD.overlay b/build_tools/third_party/half/BUILD.overlay
new file mode 100644
index 0000000..b27851b
--- /dev/null
+++ b/build_tools/third_party/half/BUILD.overlay
@@ -0,0 +1,22 @@
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "includes",
+ hdrs = ["half.hpp"],
+ include_prefix = "third_party/half",
+)
diff --git a/build_tools/third_party/half/CMakeLists.txt b/build_tools/third_party/half/CMakeLists.txt
new file mode 100644
index 0000000..4ea1374
--- /dev/null
+++ b/build_tools/third_party/half/CMakeLists.txt
@@ -0,0 +1,26 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(HALF_API_ROOT "${IREE_ROOT_DIR}/third_party/half/")
+
+external_cc_library(
+ PACKAGE
+ half
+ NAME
+ includes
+ ROOT
+ ${HALF_API_ROOT}
+ HDRS
+ "half.hpp"
+)
diff --git a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
index 4b9df5e..4d270b9 100644
--- a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
+++ b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set(TF_MLIR_HLO_SRC_ROOT
- "${IREE_ROOT_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/"
+set(TF_MLIR_HLO_SOURCE_DIR
+ "${IREE_SOURCE_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/"
+)
+set(TF_MLIR_HLO_BINARY_DIR
+ "${IREE_BINARY_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/"
)
external_cc_library(
@@ -22,7 +25,7 @@
NAME
mlir_hlo
ROOT
- ${TF_MLIR_HLO_SRC_ROOT}
+ ${TF_MLIR_HLO_SOURCE_DIR}
DEPS
MhloDialect
MhloInferFusibilityOpInterface
@@ -32,5 +35,10 @@
MhloPasses
MhloLhloToLinalg
MLIRMhloUtils
+ INCLUDES
+ "${TF_MLIR_HLO_SOURCE_DIR}/"
+ "${TF_MLIR_HLO_SOURCE_DIR}/include/"
+ "${TF_MLIR_HLO_BINARY_DIR}/"
+ "${TF_MLIR_HLO_BINARY_DIR}/include/"
PUBLIC
)
diff --git a/build_tools/third_party/vulkan_headers/BUILD.overlay b/build_tools/third_party/vulkan_headers/BUILD.overlay
index 3667e68..06aa1ea 100644
--- a/build_tools/third_party/vulkan_headers/BUILD.overlay
+++ b/build_tools/third_party/vulkan_headers/BUILD.overlay
@@ -14,18 +14,6 @@
package(default_visibility = ["//visibility:public"])
-# Exports all headers but defining VK_NO_PROTOTYPES to disable the
-# inclusion of C function prototypes. Useful if dynamically loading
-# all symbols via dlopen/etc.
-# Not all headers are hermetic, so they are just included as textual
-# headers to disable additional validation.
-cc_library(
- name = "vulkan_headers_no_prototypes",
- defines = ["VK_NO_PROTOTYPES"],
- includes = ["include"],
- textual_hdrs = glob(["include/vulkan/*.h"]),
-)
-
# Exports all headers, including C function prototypes. Useful if statically
# linking against the Vulkan SDK.
# Not all headers are hermetic, so they are just included as textual
diff --git a/build_tools/third_party/vulkan_memory_allocator/CMakeLists.txt b/build_tools/third_party/vulkan_memory_allocator/CMakeLists.txt
new file mode 100644
index 0000000..5e8a120
--- /dev/null
+++ b/build_tools/third_party/vulkan_memory_allocator/CMakeLists.txt
@@ -0,0 +1,28 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(VMA_ROOT "${IREE_ROOT_DIR}/third_party/vulkan_memory_allocator/")
+
+external_cc_library(
+ PACKAGE
+ vulkan_memory_allocator
+ NAME
+ vulkan_memory_allocator
+ ROOT
+ ${VMA_ROOT}
+ HDRS
+ "src/vk_mem_alloc.h"
+ INCLUDES
+ "${VMA_ROOT}/src/"
+)
diff --git a/colab/edge_detection.ipynb b/colab/edge_detection.ipynb
index 18528e4..4501f53 100644
--- a/colab/edge_detection.ipynb
+++ b/colab/edge_detection.ipynb
@@ -90,7 +90,7 @@
"import numpy as np\n",
"import tensorflow as tf\n",
"from pyiree.tf import compiler as ireec\n",
- "from pyiree.tf.support import tf_utils\n",
+ "from pyiree.tf.support import module_utils\n",
"from pyiree import rt as ireert"
]
},
@@ -271,7 +271,7 @@
"\n",
"backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
"backend_choice = backend_choice.split(\" \")[0]\n",
- "backend = tf_utils.BackendInfo(backend_choice)"
+ "backend = module_utils.BackendInfo(backend_choice)"
]
},
{
@@ -435,7 +435,7 @@
"\n",
"backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
"backend_choice = backend_choice.split(\" \")[0]\n",
- "backend = tf_utils.BackendInfo(backend_choice)"
+ "backend = module_utils.BackendInfo(backend_choice)"
]
},
{
diff --git a/colab/mnist_tensorflow.ipynb b/colab/mnist_tensorflow.ipynb
index 733f65d..4a1f1aa 100644
--- a/colab/mnist_tensorflow.ipynb
+++ b/colab/mnist_tensorflow.ipynb
@@ -86,7 +86,7 @@
"\n",
"from pyiree import rt as ireert\n",
"from pyiree.tf import compiler as ireec\n",
- "from pyiree.tf.support import tf_utils\n",
+ "from pyiree.tf.support import module_utils\n",
"\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
@@ -335,7 +335,7 @@
"\n",
"backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
"backend_choice = backend_choice.split(\" \")[0]\n",
- "backend = tf_utils.BackendInfo(backend_choice)"
+ "backend = module_utils.BackendInfo(backend_choice)"
],
"execution_count": 8,
"outputs": []
@@ -353,7 +353,7 @@
"source": [
"#@title Compile the mhlo MLIR to an IREE backend and prepare a context to execute it\n",
"\n",
- "iree_module = tf_utils.IreeCompiledModule.create_from_instance(\n",
+ "iree_module = module_utils.IreeCompiledModule.create_from_instance(\n",
" inference_module, backend, exported_names, ARTIFACTS_DIR)\n",
"\n",
"print(\"* Module compiled! See intermediate .mlir files in\", ARTIFACTS_DIR, \"*\")"
@@ -462,4 +462,4 @@
]
}
]
-}
\ No newline at end of file
+}
diff --git a/colab/resnet.ipynb b/colab/resnet.ipynb
index 8595da9..af7e8e6 100644
--- a/colab/resnet.ipynb
+++ b/colab/resnet.ipynb
@@ -82,7 +82,7 @@
"\n",
"from pyiree import rt as ireert\n",
"from pyiree.tf import compiler as ireec\n",
- "from pyiree.tf.support import tf_utils\n",
+ "from pyiree.tf.support import module_utils\n",
"\n",
"import tensorflow as tf\n",
"from matplotlib import pyplot as plt\n",
@@ -139,7 +139,7 @@
"\n",
"backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
"backend_choice = backend_choice.split(\" \")[0]\n",
- "backend = tf_utils.BackendInfo(backend_choice)"
+ "backend = module_utils.BackendInfo(backend_choice)"
],
"execution_count": 3,
"outputs": []
@@ -314,4 +314,4 @@
]
}
]
-}
\ No newline at end of file
+}
diff --git a/docs/get_started/getting_started_android_cmake.md b/docs/get_started/getting_started_android_cmake.md
index 05cbd71..874b183 100644
--- a/docs/get_started/getting_started_android_cmake.md
+++ b/docs/get_started/getting_started_android_cmake.md
@@ -276,8 +276,8 @@
```shell
$ ../iree-build-android/host/bin/iree-translate \
-iree-mlir-to-vm-bytecode-module \
- -iree-llvm-target-triple=aarch64-linux-android \
-iree-hal-target-backends=dylib-llvm-aot \
+ -iree-llvm-target-triple=aarch64-linux-android \
$PWD/iree/tools/test/simple.mlir \
-o /tmp/simple-llvm_aot.vmfb
```
diff --git a/docs/get_started/getting_started_windows_cmake.md b/docs/get_started/getting_started_windows_cmake.md
index a6025b2..04a48e0 100644
--- a/docs/get_started/getting_started_windows_cmake.md
+++ b/docs/get_started/getting_started_windows_cmake.md
@@ -89,6 +89,51 @@
> cmake --build ..\iree-build\
```
+## Target Configuration
+
+### LLVM AOT Backend
+
+`-iree-hal-target-backends=dylib-llvm-aot` can be used to generate modules with
+ahead-of-time compiled kernels stored in DLLs. Run the iree-opt/iree-translate
+tools from a command prompt with `lld-link.exe` or `link.exe` tools on the
+`PATH` and the MSVC/Windows SDK environment variables; the easiest way to get
+this configured is to use the `vsvarsall.bat` or `vcvars64.bat` files to set
+your environment. See
+[the Microsoft documentation](https://docs.microsoft.com/en-us/cpp/build/building-on-the-command-line?view=vs-2019)
+for details on configuring the toolchain.
+
+If you want to manually specify the linker used, set the
+`IREE_LLVMAOT_LINKER_PATH` environment variable to the path of the linker:
+
+```powershell
+> set IREE_LLVMAOT_LINKER_PATH="C:\Tools\LLVM\bin\lld-link.exe"
+```
+
+Translate a source MLIR file into an IREE module:
+
+```powershell
+> ..\iree-build\iree\tools\iree-translate.exe \
+ -iree-mlir-to-vm-bytecode-module \
+ -iree-hal-target-backends=dylib-llvm-aot \
+ iree/tools/test/simple.mlir \
+ -o %TMP%/simple-llvm_aot.vmfb
+```
+
+Note that this will use the host machine as the target by default, and the
+exact target triple and architecture can be specified with flags when
+cross-compiling:
+
+```powershell
+> ..\iree-build\iree\tools\iree-translate.exe \
+ -iree-mlir-to-vm-bytecode-module \
+ -iree-hal-target-backends=dylib-llvm-aot \
+ -iree-llvm-target-triple=x86_64-pc-windows-msvc \
+ -iree-llvm-target-cpu=host \
+ -iree-llvm-target-cpu-features=host \
+ iree/tools/test/simple.mlir \
+ -o %TMP%/simple-llvm_aot.vmfb
+```
+
## What's next?
### Take a Look Around
diff --git a/experimental/ModelBuilder/ModelBuilder.cpp b/experimental/ModelBuilder/ModelBuilder.cpp
index d4151b2..aff020f 100644
--- a/experimental/ModelBuilder/ModelBuilder.cpp
+++ b/experimental/ModelBuilder/ModelBuilder.cpp
@@ -239,19 +239,21 @@
ArrayRef<Type> returnTypes,
ValueRange values) {
auto &builder = ScopedContext::getBuilderRef();
- auto funcOp =
+ auto callerFunc =
builder.getInsertionBlock()->getParent()->getParentOfType<FuncOp>();
- Operation *func = SymbolTable::lookupNearestSymbolFrom(funcOp, functionName);
- if (!func) {
+ FuncOp calleeFunc =
+ SymbolTable::lookupNearestSymbolFrom<FuncOp>(callerFunc, functionName);
+ if (!calleeFunc) {
OpBuilder::InsertionGuard insertGuard(builder);
- auto module = funcOp.getParentOfType<ModuleOp>();
+ auto module = callerFunc.getParentOfType<ModuleOp>();
builder.setInsertionPointToStart(module.getBody());
- func = builder.create<FuncOp>(
+ calleeFunc = builder.create<FuncOp>(
module.getLoc(), functionName,
FunctionType::get(SmallVector<Type, 4>(values.getTypes()), returnTypes,
builder.getContext()));
+ calleeFunc.setPrivate();
}
- return std_call(builder.getSymbolRefAttr(func), returnTypes, values);
+ return std_call(calleeFunc, values);
}
MLIRFuncOpConfig &MLIRFuncOpConfig::setNoInline(bool v) {
@@ -292,7 +294,10 @@
if (emitCInterface)
f.setAttr("llvm.emit_c_interface", mlir::UnitAttr::get(ctx));
- if (!declOnly) f.addEntryBlock();
+ if (!declOnly)
+ f.addEntryBlock();
+ else
+ f.setPrivate();
}
// -----------------------------------------------------------------------------
diff --git a/experimental/ModelBuilder/VulkanWrapperPass.cpp b/experimental/ModelBuilder/VulkanWrapperPass.cpp
index c9ba950..f29d798 100644
--- a/experimental/ModelBuilder/VulkanWrapperPass.cpp
+++ b/experimental/ModelBuilder/VulkanWrapperPass.cpp
@@ -95,9 +95,9 @@
vulkanLaunchTypes.insert(vulkanLaunchTypes.end(), args.begin(), args.end());
// Declare vulkan launch function.
- builder.create<FuncOp>(loc, kVulkanLaunch,
- FunctionType::get(vulkanLaunchTypes, ArrayRef<Type>{},
- loc->getContext()));
+ auto type = FunctionType::get(vulkanLaunchTypes, {}, loc->getContext());
+ FuncOp vkLaunch = builder.create<FuncOp>(loc, kVulkanLaunch, type);
+ vkLaunch.setPrivate();
return success();
}
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index f378bea..00cccb2 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Tests for end-to-end IREE support starting from the XLA HLO dialect.
+# Tests for ModelBuilder.
load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//iree:build_defs.oss.bzl", "IREE_DRIVER_MODULES", "PLATFORM_VULKAN_DEPS")
package(
default_visibility = ["//visibility:public"],
@@ -109,13 +108,12 @@
deps = [
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
- "//iree/base:initializer",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:SPIRVDialect",
"@llvm-project//mlir:mlir_runner_utils",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
cc_binary(
@@ -125,7 +123,6 @@
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
"//experimental/ModelBuilder:VulkanLaunchWrapper",
- "//iree/base:initializer",
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
@@ -149,7 +146,7 @@
"@llvm-project//mlir:VectorToLLVM",
# mlir_runner_utils with iostream needed for printMemRef atm
"@llvm-project//mlir:mlir_runner_utils",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
cc_binary(
@@ -162,7 +159,6 @@
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
"//experimental/ModelBuilder:VulkanLaunchWrapper",
- "//iree/base:initializer",
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
@@ -182,7 +178,7 @@
"@llvm-project//mlir:mlir_c_runner_utils",
# mlir_runner_utils with iostream needed for printMemRef atm
"@llvm-project//mlir:mlir_runner_utils",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
cc_binary(
@@ -195,7 +191,6 @@
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
"//experimental/ModelBuilder:VulkanLaunchWrapper",
- "//iree/base:initializer",
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
@@ -215,7 +210,7 @@
"@llvm-project//mlir:mlir_c_runner_utils",
# mlir_runner_utils with iostream needed for printMemRef atm
"@llvm-project//mlir:mlir_runner_utils",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
cc_binary(
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
index 68e7910..18493fa 100644
--- a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
@@ -16,7 +16,6 @@
#include "experimental/ModelBuilder/ModelBuilder.h"
#include "experimental/ModelBuilder/ModelRunner.h"
#include "experimental/ModelBuilder/VulkanWrapperPass.h"
-#include "iree/base/initializer.h"
#include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
@@ -541,7 +540,6 @@
}
int main(int argc, char **argv) {
- iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
diff --git a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
index a8590ac..c05c296 100644
--- a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
@@ -34,7 +34,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser.h"
-#include "iree/base/initializer.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Pass/PassManager.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
@@ -164,7 +163,6 @@
}
int main(int argc, char **argv) {
- iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
diff --git a/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp b/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
index 0e3d388..be71ae5 100644
--- a/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
@@ -33,7 +33,6 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser.h"
#include "mlir/ExecutionEngine/RunnerUtils.h"
-#include "iree/base/initializer.h"
static llvm::cl::opt<std::string> vulkanWrapper(
"vulkan-wrapper", llvm::cl::desc("Vulkan wrapper library"),
@@ -118,7 +117,6 @@
}
int main(int argc, char **argv) {
- iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
diff --git a/experimental/ModelBuilder/test/TestVectorToGPU.cpp b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
index a269ca9..e8d778e 100644
--- a/experimental/ModelBuilder/test/TestVectorToGPU.cpp
+++ b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
@@ -35,7 +35,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser.h"
-#include "iree/base/initializer.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Pass/PassManager.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
@@ -149,7 +148,6 @@
}
int main(int argc, char **argv) {
- iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
index fc67c96..09fe138 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
@@ -29,9 +29,11 @@
name = "support",
srcs = [
"__init__.py",
+ "module_utils.py",
"tf_test_driver.py",
"tf_test_utils.py",
"tf_utils.py",
+ "trace_utils.py",
],
deps = INTREE_TENSORFLOW_PY_DEPS + [
"//integrations/tensorflow/bindings/python:pathsetup", # build_cleaner: keep
@@ -41,6 +43,22 @@
)
iree_py_test(
+ name = "module_utils_test",
+ srcs = [
+ "module_utils.py",
+ "module_utils_test.py",
+ ],
+ python_version = "PY3",
+ tags = [
+ "driver=llvm",
+ "driver=vmla",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_py_test(
name = "tf_test_utils_test",
srcs = [
"tf_test_utils.py",
@@ -59,8 +77,19 @@
"tf_utils_test.py",
],
python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_py_test(
+ name = "trace_utils_test",
+ srcs = [
+ "trace_utils.py",
+ "trace_utils_test.py",
+ ],
+ python_version = "PY3",
tags = [
- "driver=llvm",
"driver=vmla",
],
deps = INTREE_TENSORFLOW_PY_DEPS + [
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
new file mode 100644
index 0000000..fb55501
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
@@ -0,0 +1,982 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for compiling 'tf.Module's"""
+
+import collections
+import os
+import tempfile
+from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+
+from absl import flags
+from absl import logging
+import numpy as np
+from pyiree import rt
+from pyiree.tf import compiler
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+flags.DEFINE_bool(
+ "capture_crash_reproducer", True,
+ "Captures MLIR crash reproducers in the artifacts directory for crashes "
+ "and suppresses C++ stack traces.")
+
+FLAGS = flags.FLAGS
+
+
+def _running_bazel_test() -> bool:
+ # Bazel guarantees that TEST_TMPDIR is set when `bazel test` is running.
+ return "TEST_TMPDIR" in os.environ
+
+
+def _setup_mlir_crash_reproducer(
+ function: Any, # pytype doesn't support arbitrary Callable[*args, **kwargs]
+ artifacts_dir: str,
+ backend_id: str,
+) -> Any: # Callable[Any, Any]
+ """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+
+ Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
+
+ Args:
+ function: The callable to decorate.
+ artifacts_dir: The directory to write the reproducer to.
+ backend_id: The unique backend name to use when writting the reproducer.
+
+ Returns:
+ A function with the same API as the passed function.
+ """
+
+ def decorator(*args, **kwargs):
+ # Set up a crash reproducer for debugging.
+ if artifacts_dir is not None:
+ compiler.Context.default_crash_reproducer_path = os.path.join(
+ artifacts_dir, f"reproducer__{backend_id}.mlir")
+ try:
+ results = function(*args, **kwargs)
+ except Exception: # pylint: disable=broad-except
+ # Disable the crash reproducer (to avoid inadvertently overwriting it).
+ if artifacts_dir is not None:
+ compiler.Context.default_crash_reproducer_path = None
+ raise
+ return results
+
+ return decorator
+
+
+def _incrementally_lower_compiler_module(
+ compiler_module: compiler.Module,
+ backend_info: "BackendInfo",
+ artifacts_dir: str,
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+ """Lowers a MLIR compiler module incrementally and saves its outputs.
+
+ If artifacts_dir is provided then the following artifacts will be saved:
+ tf_input.mlir:
+ MLIR for the module in TF's input dialect.
+ iree_input.mlir:
+ The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
+ backend_id/compiled.vmfb:
+ A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
+
+ Args:
+ compiler_module: A compiler.Module to lower.
+ backend_info: BackendInfo with the details for lowering compiler_module to
+ IREE.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ if artifacts_dir is not None:
+ os.makedirs(artifacts_dir, exist_ok=True)
+ tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
+ logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+ with open(tf_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ # Manually run the passes that tf_module_to_compiler_module usually would.
+ compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+
+ if artifacts_dir is not None:
+ iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
+ logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+ with open(iree_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ compiled_module = compiler_module.compile(
+ target_backends=backend_info.compiler_targets)
+
+ compiled_path = None
+ if artifacts_dir is not None:
+ backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
+ os.makedirs(backend_dir, exist_ok=True)
+ compiled_path = os.path.join(backend_dir, "compiled.vmfb")
+ logging.info("Saving compiled IREE module to: %s", compiled_path)
+ with open(compiled_path, "wb") as f:
+ f.write(compiled_module)
+ return compiled_module, compiled_path
+
+
+def _incrementally_compile_tf_module(
+ module: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None,
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+ """Compile a TensorFlow tf.Module and optionally save compilation artifacts.
+
+ The module blob this creates is not callable. See IreeCompiledModule for an
+ API that returns a module that can be called without any further steps.
+
+ See _incrementally_lower_compiler_module's docstring for details about which
+ artifacts will be saved.
+
+ Args:
+ module: A tf.Module.
+ backend_info: BackendInfo with the details for compiling this module.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+
+ Returns:
+ A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+ artifacts_dir is provided.
+ """
+
+ def _compile_module(module, backend_info, exported_names, artifacts_dir):
+ compiler_module = compiler.tf_module_to_compiler_module(module,
+ exported_names,
+ pass_pipeline=())
+ return _incrementally_lower_compiler_module(compiler_module, backend_info,
+ artifacts_dir)
+
+ # Avoid the crash reproducer under tests or if the flag is false.
+ # Developers can run tests outside of the test runner (e.g. `bazel run`) to
+ # use the crash reproducer.
+ if (FLAGS.capture_crash_reproducer and not _running_bazel_test()):
+ _compile_module = _setup_mlir_crash_reproducer(_compile_module,
+ artifacts_dir,
+ backend_info.backend_id)
+ return _compile_module(module, backend_info, exported_names, artifacts_dir)
+
+
+def _incrementally_compile_tf_signature_def_saved_model(
+ saved_model_dir: str, saved_model_tags: Set[str],
+ backend_info: "BackendInfo", exported_name: str, artifacts_dir: str):
+ """Compile a SignatureDef SavedModel and optionally save compilation artifacts.
+
+ The module blob this creates is not callable. See IreeCompiledModule for an
+ API that returns a module that can be called without any further steps.
+
+ See _incrementally_lower_compiler_module's docstring for details about which
+ artifacts will be saved.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+
+ Returns:
+ A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+ artifacts_dir is provided.
+ """
+
+ def _compile_module(saved_model_dir, saved_model_tags, backend_info,
+ exported_name, artifacts_dir):
+ # Convert the tf_module into raw TF input MLIR.
+ compiler_module = compiler.tf_signature_def_saved_model_to_compiler_module(
+ saved_model_dir, saved_model_tags, [exported_name], pass_pipeline=())
+ return _incrementally_lower_compiler_module(compiler_module, backend_info,
+ artifacts_dir)
+
+ # Avoid the crash reproducer under tests or if the flag is false.
+ # Developers can run tests outside of the test runner (e.g. `bazel run`) to
+ # use the crash reproducer.
+ if (FLAGS.capture_crash_reproducer and not _running_bazel_test()):
+ _compile_module = _setup_mlir_crash_reproducer(_compile_module,
+ artifacts_dir,
+ backend_info.backend_id)
+ return _compile_module(saved_model_dir, saved_model_tags, backend_info,
+ exported_name, artifacts_dir)
+
+
+class _FunctionWrapper(object):
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+ """Dummy function to match _IreeFunctionWrapper's API."""
+ return ("",), ("",)
+
+
+class CompiledModule(object):
+ """Base class for the TF and IREE compiled modules."""
+
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Union[Dict[str, str], None],
+ ):
+ """Shared base constructor – not useful on its own.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ """
+ self.module_name = module_name
+ self.backend_info = backend_info
+ self.compiled_paths = compiled_paths
+
+ def reinitialize(self):
+ """Reinitializes all stateful variables."""
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling this module.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_instance(cls,
+ module_instance: tf.Module,
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module instance to the target backend in backend_info.
+
+ This is only implemented for IreeCompiledModule.
+
+ Args:
+ module_instance: The tf.Module instance to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ raise NotImplementedError()
+
+ def __getattr__(self, attr: str) -> _FunctionWrapper:
+ raise NotImplementedError()
+
+ def iree_serializable(self):
+ return False
+
+ def tflite_serializable(self):
+ return False
+
+
+class _IreeFunctionWrapper(_FunctionWrapper):
+ """Wraps an IREE function, making it callable."""
+
+ def __init__(self, context: rt.SystemContext, f: rt.system_api.BoundFunction):
+ self._context = context
+ self._f = f
+
+ def __call__(self, *args, **kwargs):
+ return self._f(*args, **kwargs)
+
+ def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+ """Get cxx serialized inputs and outputs for this function."""
+ return self._f.get_serialized_values()
+
+
+class IreeCompiledModule(CompiledModule):
+ """Iree compiled module."""
+
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ vm_module: rt.VmModule,
+ config: rt.Config,
+ ):
+ """Base constructor – Use one of the named constructors instead.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ vm_module: A rt.VmModule containing compilation info to wrap.
+ config: A rt.Config containing compilation info to wrap.
+ """
+ super().__init__(module_name, backend_info, compiled_paths)
+ self._vm_module = vm_module
+ self._config = config
+ self.reinitialize()
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ tf_utils.set_random_seed()
+ module_instance = module_class()
+ return cls.create_from_instance(module_instance, backend_info,
+ exported_names, artifacts_dir)
+
+ @classmethod
+ def create_from_instance(cls,
+ module_instance: tf.Module,
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module instance to the target backend in backend_info.
+
+ Args:
+ module_instance: The tf.Module instance to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ module_blob, compiled_path = _incrementally_compile_tf_module(
+ module=module_instance,
+ backend_info=backend_info,
+ exported_names=exported_names,
+ artifacts_dir=artifacts_dir)
+ vm_module = rt.VmModule.from_flatbuffer(module_blob)
+ config = rt.Config(driver_name=backend_info.driver)
+
+ compiled_paths = None
+ if compiled_path is not None:
+ # IREE bundles every compiled method into the same compiled module.
+ compiled_paths = collections.defaultdict(lambda: compiled_path)
+
+ module_name = type(module_instance).__name__
+
+ return cls(module_name, backend_info, compiled_paths, vm_module, config)
+
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ del input_names # Unused.
+ del output_names # Unused.
+ module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
+ saved_model_dir, saved_model_tags, backend_info, exported_name,
+ artifacts_dir)
+ vm_module = rt.VmModule.from_flatbuffer(module_blob)
+ config = rt.Config(driver_name=backend_info.driver)
+
+ compiled_paths = None
+ if compiled_path is not None:
+ # IREE bundles every compiled method into the same compiled module :)
+ compiled_paths = collections.defaultdict(lambda: compiled_path)
+
+ return cls(module_name, backend_info, compiled_paths, vm_module, config)
+
+ def reinitialize(self):
+ """Reinitializes all stateful variables."""
+ # set_random_seed is not needed here because the model_class.__init__ is not
+ # called.
+ self._context = rt.SystemContext(modules=[self._vm_module],
+ config=self._config)
+
+ def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
+ # Try to resolve it as a function.
+ m = self._context.modules[self._vm_module.name]
+ f = m[attr]
+ return _IreeFunctionWrapper(self._context, f)
+
+ def iree_serializable(self) -> bool:
+ return self.compiled_paths is not None
+
+
+class _TfFunctionWrapper(_FunctionWrapper):
+ """Wraps a TF function, normalizing it to numpy."""
+
+ def __init__(self, f: Callable[..., Any]):
+ self._f = f
+
+ def __call__(self, *args, **kwargs):
+ # TensorFlow will auto-convert all inbound args.
+ results = self._f(*args, **kwargs)
+ return tf_utils.convert_to_numpy(results)
+
+
+def _convert_inputs_to_tensors(function):
+
+ def decorator(*args, **kwargs):
+ args = [tf.convert_to_tensor(arg) for arg in args]
+ kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()}
+ return function(*args, **kwargs)
+
+ return decorator
+
+
+class SignatureDefSavedModelWrapper(object):
+ """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'."""
+
+ def __init__(self, saved_model_dir: str, saved_model_tags: Set[str],
+ exported_name: str):
+ self.saved_model = tf.saved_model.load(saved_model_dir,
+ tags=saved_model_tags)
+ inference_func = self.saved_model.signatures[exported_name]
+ inference_func = _convert_inputs_to_tensors(inference_func)
+ self.__setattr__(exported_name, inference_func)
+
+
+class TfCompiledModule(CompiledModule):
+ """TensorFlow 'compiled' module.
+
+ This facade exists to provide a complimentary API to IreeCompiledModule and
+ normalize TensorFlow's output to Numpy.
+ """
+
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ constructor: Callable[[], tf.Module],
+ exported_names: Sequence[str],
+ ):
+ """Base constructor – Use one of the named constructors instead.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ constructor: A callable (class or function) which returns the tf.Module
+ subclass instance to wrap.
+ exported_names: an optional iterable of strings representing which of the
+ tf.Module subclass instance's functions should be callable. If
+ exported_names is empty then all functions will be callable.
+ """
+ super().__init__(module_name, backend_info, compiled_paths=None)
+ self._constructor = constructor
+ self._exported_names = exported_names
+ self.reinitialize()
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling this module.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ module_name = module_class.__name__
+ constructor = module_class
+ return cls(module_name, backend_info, constructor, exported_names)
+
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ constructor = lambda: SignatureDefSavedModelWrapper(
+ saved_model_dir, saved_model_tags, exported_name)
+ return cls(module_name, backend_info, constructor, [exported_name])
+
+ def reinitialize(self):
+ """Reinitializes all stateful variables."""
+ tf_utils.set_random_seed()
+ self._tf_module = self._constructor()
+
+ def __getattr__(self, attr: str) -> _TfFunctionWrapper:
+ # Try to resolve it as a function.
+ exported = not self._exported_names or attr in self._exported_names
+ if not hasattr(self._tf_module, attr) or not exported:
+ raise AttributeError(f"The TensorFlow module does not have attr '{attr}'")
+ f = getattr(self._tf_module, attr)
+ if not f or not hasattr(f, "__call__"):
+ raise AttributeError(
+ f"The TensorFlow module does not have a callable attr '{attr}'")
+ return _TfFunctionWrapper(f)
+
+
+def _get_non_inhereted_function_names(cls):
+ """Gets all methods that cls has that its parents don't have."""
+ names = set(dir(cls))
+ for parent in cls.__bases__:
+ names -= set(dir(parent))
+ return list(names)
+
+
+def _get_concrete_functions(module_class: Type[tf.Module],
+ exported_names: Sequence[str] = ()):
+ """Get concrete functions from non-inherited methods or exported_names."""
+ if not len(exported_names):
+ # Get all method names on 'module_class' that aren't on 'tf.Module'.
+ exported_names = _get_non_inhereted_function_names(module_class)
+ instance = module_class()
+ functions = []
+ for name in exported_names:
+ functions.append(getattr(instance, name).get_concrete_function())
+ return functions, exported_names, instance
+
+
+def tf_module_to_tflite_module_bytes(
+ module_class: Type[tf.Module], exported_names: Sequence[str] = ()
+) -> Dict[str, bytes]:
+ """Compiles a tf.Module's methods with TFLite.
+
+ Args:
+ module_class: A tf.Module subclass to compile with TFLite.
+ exported_names: an optional iterable of strings representing which of the
+ module_class's functions should be compiled. If exported_names is empty
+ then all functions will be compiled.
+
+ Returns:
+ A dict mapping method names to compiled TFLite module bytes.
+ """
+ tflite_modules = []
+ methods, method_names, instance = _get_concrete_functions(
+ module_class, exported_names)
+ failed_methods = []
+ for method, method_name in zip(methods, method_names):
+ logging.info("Attempting to convert '%s' to tflite...", method_name)
+ try:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
+ logging.info("...converted '%s' to tflite.", method_name)
+ tflite_modules.append(converter.convert())
+ except Exception as e:
+ logging.error("Failed to convert '%s' to tflite.", method_name)
+ logging.error("TFLite excpetion: %s", e)
+ failed_methods.append(method_name)
+
+ if failed_methods:
+ raise RuntimeError(
+ f"Failed to convert the following methods to tflite: {failed_methods}")
+
+ # Keep variables alive until TFLite has done the conversion; ConcreteFunctions
+ # themselves only keep weak references to variables.
+ del instance
+ return dict(zip(method_names, tflite_modules))
+
+
+def tf_signature_def_saved_model_to_tflite_module_bytes(
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+) -> Dict[str, bytes]:
+ """Compiles a SignatureDef SavedModel signature with TFLite.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+
+ Returns:
+ A dict mapping the signature name to the compiled TFLite module bytes.
+ """
+ converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
+ saved_model_dir,
+ tag_set=saved_model_tags,
+ signature_key=exported_name,
+ input_arrays=input_names,
+ output_arrays=output_names)
+ tflite_module = converter.convert()
+ return dict([[exported_name, tflite_module]])
+
+
+def tflite_module_bytes_to_tflite_interpreters(
+ tflite_module_bytes: Dict[str, bytes],
+ artifacts_dir: str = None
+) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str], None]]:
+ """Compile a dict of TFLite compiled bytes to TFLite interpreters.
+
+ Args:
+ tflite_module_bytes: A dict mapping method names to compiled TFLite byte
+ strings.
+ artifacts_dir: an optional path to save compilation artifacts to.
+
+ Returns:
+ A dictionary mapping method names to TFLite interpreters and a dictionary
+ mapping method names to compiled tflite graph paths (or None if
+ artifacts_dir is None).
+ """
+ interpreters = dict()
+ compiled_paths = None
+ if artifacts_dir is not None:
+ compiled_paths = dict()
+
+ def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str):
+ """Save compiled TFLite module bytes and convert into an interpreter."""
+ tflite_dir = os.path.join(base_dir, "tflite")
+ os.makedirs(tflite_dir, exist_ok=True)
+ tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite")
+ with open(tflite_path, "wb") as f:
+ f.write(tflite_module)
+
+ interpreters[method_name] = tf.lite.Interpreter(tflite_path)
+ if artifacts_dir is not None:
+ compiled_paths[method_name] = tflite_path
+
+ # Load each of the converted methods above into tf.lite.Interpreters.
+ for method_name, tflite_module in tflite_module_bytes.items():
+ if artifacts_dir is None:
+ with tempfile.TemporaryDirectory() as base_dir:
+ _interpret_bytes(method_name, tflite_module, base_dir)
+ else:
+ _interpret_bytes(method_name, tflite_module, artifacts_dir)
+
+ return interpreters, compiled_paths
+
+
+class _TfLiteFunctionWrapper(_FunctionWrapper):
+ """Wraps a TFLite interpreter and makes it behave like a python function."""
+
+ def __init__(self, interpreter: tf.lite.Interpreter,
+ output_names: Sequence[str]):
+ self._interpreter = interpreter
+ self._output_names = output_names
+
+ def __call__(self, *args,
+ **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]:
+ if len(args) and len(kwargs):
+ raise ValueError("Passing both args and kwargs is not supported by "
+ "_TfLiteFunctionWrapper")
+
+ # Set up and run the function.
+ self._interpreter.allocate_tensors()
+
+ if len(args):
+ # Specifically to get TFLite to work with keras models that take a list of
+ # inputs instead of a sequence of args as their inputs, because it decides
+ # to change the input signature but it still technically works if you
+ # ignore that it does that.
+ if len(args) == 1 and isinstance(args[0], list):
+ args = args[0]
+
+ for arg, detail in zip(args, self._interpreter.get_input_details()):
+ self._interpreter.set_tensor(detail["index"], arg)
+ else:
+ for detail in self._interpreter.get_input_details():
+ self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]])
+
+ self._interpreter.invoke()
+
+ # Extract the outputs from the TFLite interpreter.
+ outputs = []
+ for detail in self._interpreter.get_output_details():
+ value = tf_utils.normalize_numpy(
+ self._interpreter.get_tensor(detail["index"]))
+ if self._output_names is not None:
+ name = detail["name"]
+ if name not in self._output_names:
+ raise ValueError(f"Expected '{name}' to be in {self._output_names}")
+ outputs.append([detail["name"], value])
+ else:
+ outputs.append(value)
+
+ # Process them to match the output of the tf.Module.
+ if self._output_names is not None:
+ return dict(outputs)
+ else:
+ if len(outputs) == 1:
+ return outputs[0]
+ return tuple(outputs)
+
+
+class TfLiteCompiledModule(CompiledModule):
+ """Compiles a tf.Module with TFLite and allows it to be called."""
+
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ interpreters: Dict[str, tf.lite.Interpreter],
+ output_names: Sequence[str] = None,
+ ):
+ """Base constructor – Use one of the named constructors instead.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ interpreters: A dict of tf.lite.Interpreters to make callable.
+ """
+ super().__init__(module_name, backend_info, compiled_paths)
+ self._interpreters = interpreters
+ self._output_names = output_names
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling this module.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ tf_utils.set_random_seed()
+ tflite_module_bytes = tf_module_to_tflite_module_bytes(
+ module_class, exported_names)
+ interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+ tflite_module_bytes, artifacts_dir)
+ module_name = module_class.__name__
+ return cls(module_name, backend_info, compiled_paths, interpreters)
+
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes(
+ saved_model_dir, saved_model_tags, exported_name, input_names,
+ output_names)
+ interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+ tflite_module_bytes, artifacts_dir)
+ return cls(module_name, backend_info, compiled_paths, interpreters,
+ output_names)
+
+ def reinitialize(self):
+ """Reinitializes all stateful variables."""
+ # This is a noop because TFLite (mostly) doesn't support stateful modules.
+ pass
+
+ def __getattr__(self, attr: str) -> _TfLiteFunctionWrapper:
+ # Try to resolve it as an interpreter.
+ if not attr in self._interpreters:
+ raise AttributeError(
+ f"The TFLite module does not have an interpreter for '{attr}'")
+ return _TfLiteFunctionWrapper(self._interpreters[attr], self._output_names)
+
+ def tflite_serializable(self) -> bool:
+ return self.compiled_paths is not None
+
+
+class BackendInfo:
+ """Contains information for compiling the specified backend."""
+
+ _name_to_info = {
+ "tf": {
+ "compiled_module_class": TfCompiledModule,
+ "driver": None,
+ "compiler_targets": None,
+ },
+ "tflite": {
+ "compiled_module_class": TfLiteCompiledModule,
+ "driver": None,
+ "compiler_targets": None,
+ },
+ "iree_vmla": {
+ "compiled_module_class": IreeCompiledModule,
+ "driver": "vmla",
+ "compiler_targets": ["vmla"]
+ },
+ "iree_vulkan": {
+ "compiled_module_class": IreeCompiledModule,
+ "driver": "vulkan",
+ "compiler_targets": ["vulkan-*"]
+ },
+ }
+
+ def __init__(self, backend_name: str, backend_id: str = None):
+ """Creates a BackendInfo with the compilation details for backend_name.
+
+ Args:
+ backend_name: a str specifying which backend to use. Should be one of
+ 'tf', 'tflite', 'iree_vmla', 'iree_vulkan'.
+ backend_id: an optional str specifying what name to use when saving
+ compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
+
+ Raises:
+ KeyError: if backend_name is not one of ['tf', 'tflite', 'iree_vmla',
+ 'iree_vulkan'].
+ ValueError: if backend_id doesn't start with backend_name.
+ """
+ if backend_name not in self._name_to_info:
+ raise KeyError(
+ "Expected backend_name to be one of "
+ f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
+ if backend_id is not None and not backend_id.startswith(backend_name):
+ raise ValueError(f"Expected backend_id to start with '{backend_name}' "
+ f"but got '{backend_id}'.")
+
+ self.backend_name = backend_name
+ self.backend_id = backend_name if backend_id is None else backend_id
+
+ info = self._name_to_info[backend_name]
+ self._compiled_module_class = info["compiled_module_class"]
+ self.driver = info["driver"]
+ self.compiler_targets = info["compiler_targets"]
+
+ def compile_from_class(self,
+ module_class: Type[tf.Module],
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None) -> CompiledModule:
+ """Creates a 'CompiledModule' for this backend."""
+ return self._compiled_module_class.create_from_class(
+ module_class, self, exported_names, artifacts_dir)
+
+ def compile_signature_def_saved_model(
+ self,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None) -> CompiledModule:
+ return self._compiled_module_class.create_from_signature_def_saved_model(
+ saved_model_dir, saved_model_tags, module_name, self, exported_name,
+ input_names, output_names, artifacts_dir)
+
+ @classmethod
+ def get_all_backends(cls) -> Sequence["BackendInfo"]:
+ """Returns a list of all BackendInfo configurations."""
+ return [BackendInfo(backend_name) for backend_name in cls._name_to_info]
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py
new file mode 100644
index 0000000..8baa6bc
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py
@@ -0,0 +1,114 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for pyiree.tf.support.module_utils."""
+
+import os
+import tempfile
+
+from absl import logging
+from absl.testing import parameterized
+from pyiree.tf.support import module_utils
+import tensorflow as tf
+
+
+class ConstantModule(tf.Module):
+
+ @tf.function(input_signature=[])
+ def meaning(self):
+ return tf.constant([42.])
+
+
+class StatefulCountingModule(tf.Module):
+
+ def __init__(self):
+ self.count = tf.Variable([0.])
+
+ @tf.function(input_signature=[])
+ def increment(self):
+ self.count.assign_add(tf.constant([1.]))
+
+ @tf.function(input_signature=[])
+ def get_count(self):
+ return self.count
+
+
+class RandomInitModule(tf.Module):
+
+ def __init__(self):
+ self.value = tf.Variable(tf.random.uniform([1]))
+
+ @tf.function(input_signature=[])
+ def get(self):
+ return self.value
+
+
+class UtilsTests(tf.test.TestCase, parameterized.TestCase):
+
+ def test_artifact_saving(self):
+ backend_info = module_utils.BackendInfo('iree_vmla')
+ with tempfile.TemporaryDirectory() as artifacts_dir:
+ tf_module = ConstantModule()
+ iree_module_utils, compiled_path = (
+ module_utils._incrementally_compile_tf_module(
+ tf_module, backend_info=backend_info,
+ artifacts_dir=artifacts_dir))
+
+ artifacts_to_check = [
+ 'tf_input.mlir',
+ 'iree_input.mlir',
+ compiled_path,
+ ]
+ for artifact in artifacts_to_check:
+ artifact_path = os.path.join(artifacts_dir, artifact)
+ logging.info('Checking path: %s', artifact_path)
+ self.assertTrue(os.path.exists(artifact_path))
+
+ @parameterized.named_parameters([
+ ('tensorflow', 'tf'),
+ ('vmla', 'iree_vmla'),
+ ])
+ def test_unaltered_state(self, backend_name):
+ backend_info = module_utils.BackendInfo(backend_name)
+ module = backend_info.compile_from_class(StatefulCountingModule)
+
+ # Test that incrementing works properly.
+ self.assertEqual([0.], module.get_count())
+ module.increment()
+ self.assertEqual([1.], module.get_count())
+
+ module.reinitialize()
+ # Test reinitialization.
+ self.assertEqual([0.], module.get_count())
+
+ @parameterized.named_parameters([
+ ('tensorflow', 'tf'),
+ ('vmla', 'iree_vmla'),
+ ])
+ def test_random_initialization(self, backend_name):
+ backend_info = module_utils.BackendInfo(backend_name)
+
+ # Test compilation is the same.
+ module_1 = backend_info.compile_from_class(RandomInitModule)
+ module_2 = backend_info.compile_from_class(RandomInitModule)
+ self.assertAllEqual(module_1.get(), module_2.get())
+
+ # Test reinitialization is the same.
+ old_value = module_1.get()
+ module_1.reinitialize()
+ self.assertAllEqual(old_value, module_1.get())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 49eb68a..313be88 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -24,24 +24,23 @@
import collections
import copy
-import glob
-import inspect
+import itertools
import os
-import pickle
-import sys
+import re
import tempfile
-from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+from typing import Any, Callable, Dict, List, Sequence, Set, Tuple, Type, Union
from absl import flags
from absl import logging
-import numpy as np
+from pyiree.tf.support import module_utils
from pyiree.tf.support import tf_utils
+from pyiree.tf.support import trace_utils
import tensorflow.compat.v2 as tf
flags.DEFINE_string("reference_backend", "tf",
"The backend to treat as a source of truth.")
-flags.DEFINE_string("target_backends", None,
- "Explicit comma-delimited list of target backends.")
+flags.DEFINE_list("target_backends", None,
+ "Explicit comma-delimited list of target backends.")
flags.DEFINE_string(
"artifacts_dir", None,
"Specifies a directory to dump compilation artifacts and traces to. "
@@ -51,11 +50,9 @@
"Summarize the inputs and outputs of each module trace logged to disk.")
flags.DEFINE_bool("log_all_traces", False,
"Log all traces to logging.info, even if comparison passes.")
-flags.DEFINE_bool(
- "get_saved_model", False,
- "Creates and stores a SavedModel for the tf.Module class to be tested.")
+
FLAGS = flags.FLAGS
-NUMPY_LINEWIDTH = 120
+DEFAULT_INPUT_GENERATOR = tf_utils.uniform
def _setup_artifacts_dir(module_name: str) -> str:
@@ -76,7 +73,7 @@
def _parse_target_backends() -> Tuple[Sequence[str], Sequence[str]]:
"""Decodes --target_backends and creates unique ids for them."""
- backend_names = FLAGS.target_backends.split(",")
+ backend_names = FLAGS.target_backends
backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1}
backend_ids = []
@@ -93,7 +90,7 @@
return backend_names, backend_ids
-def get_target_backends() -> Sequence[tf_utils.BackendInfo]:
+def get_target_backends() -> Sequence[module_utils.BackendInfo]:
"""Gets the BackendInfo instances to compare with the reference backend.
By default all backends in BackendInfo will be used. Specific backends to
@@ -106,503 +103,21 @@
logging.info("Using backends from command line: %s", FLAGS.target_backends)
backend_names, backend_ids = _parse_target_backends()
backends = [
- tf_utils.BackendInfo(backend_name, backend_id)
+ module_utils.BackendInfo(backend_name, backend_id)
for backend_name, backend_id in zip(backend_names, backend_ids)
]
else:
# If no backends are specified, use them all.
- backends = tf_utils.BackendInfo.get_all_backends()
+ backends = module_utils.BackendInfo.get_all_backends()
return backends
-def _indent(input_str: str, indentation: int = 2) -> str:
- """Indents a string by the specified number of spaces, defaulting to 2."""
- spaces = " " * indentation
- lines = input_str.split("\n")
- # Prepend spaces to each non-empty line.
- lines = [f"{spaces}{line}" if len(line) else line for line in lines]
- return "\n".join(lines)
-
-
-def _zfill_width(length: int) -> Union[int, None]:
- return int(np.ceil(np.log10(length))) if length else None
-
-
-class ModuleCall:
-
- def __init__(self,
- method: str,
- inputs: Tuple[Any],
- outputs: Tuple[Any],
- serialized_inputs: Tuple[str],
- serialized_outputs: Tuple[str],
- rtol: float = 1e-6,
- atol: float = 1e-6):
- """Records the details of a call to a CompiledModule."""
- self.method = method
-
- for value in inputs:
- if isinstance(value, tf.Tensor):
- raise TypeError("Expected inputs to be native python types or numpy "
- f"arrays, but got {type(value)}")
-
- # Deepcopy to safegard against mutation.
- self.inputs = copy.deepcopy(inputs)
- if outputs is not None:
- outputs = copy.deepcopy(outputs)
- else:
- outputs = tuple()
- self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
-
- self.serialized_inputs = serialized_inputs
- self.serialized_outputs = serialized_outputs
-
- self.rtol = rtol
- self.atol = atol
-
- def get_tolerances(self) -> Tuple[float, float]:
- """Gets the floating point tolerances associated with this call."""
- return self.rtol, self.atol
-
- def _get_shape_and_dtype(self, value: Any) -> str:
- if isinstance(value, np.ndarray):
- return tf_utils.get_shape_and_dtype(value, allow_non_mlir_dtype=True)
- else:
- return str(type(value))
-
- def __str__(self):
- prior_printoptions = np.get_printoptions()
- np.set_printoptions(linewidth=NUMPY_LINEWIDTH)
-
- header = f"Method: {self.method}"
- inputs = "\n".join(_indent(str(value)) for value in self.inputs)
- input_shapes = ", ".join(
- self._get_shape_and_dtype(value) for value in self.inputs)
-
- outputs = "\n".join(_indent(str(value)) for value in self.outputs)
- output_shapes = ", ".join(
- self._get_shape_and_dtype(value) for value in self.outputs)
-
- tolerances = _indent(f"rtol={self.rtol}, atol={self.atol}")
- body = (f"Inputs: {input_shapes}\n{inputs}\n"
- f"Outputs: {output_shapes}\n{outputs}"
- f"\nTolerances:\n{tolerances}")
- result = f"{header}\n{_indent(body)}"
-
- np.set_printoptions(**prior_printoptions)
- return result
-
- def serialize(self, call_dir: str) -> None:
- """Stores a serialized copy of this call.
-
- Can be loaded via ModuleCall.load(call_dir)
-
- Args:
- call_dir: str, the path to the directory to serialize this call to.
- """
- os.makedirs(call_dir, exist_ok=True)
-
- metadata = {
- "method": self.method,
- "serialized_inputs": self.serialized_inputs,
- "serialized_outputs": self.serialized_outputs,
- "rtol": self.rtol,
- "atol": self.atol
- }
- with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f:
- pickle.dump(metadata, f)
-
- width = _zfill_width(len(self.inputs))
- for i, value in enumerate(self.inputs):
- path = os.path.join(call_dir, f"input_{str(i).zfill(width)}.pkl")
- with open(path, "wb") as f:
- pickle.dump(value, f)
-
- width = _zfill_width(len(self.outputs))
- for i, value in enumerate(self.outputs):
- path = os.path.join(call_dir, f"output_{str(i).zfill(width)}.pkl")
- with open(path, "wb") as f:
- pickle.dump(value, f)
-
- @staticmethod
- def load(call_dir: str) -> "ModuleCall":
- """Loads and returns a trace serialized with ModuleCall.serialize."""
- with open(os.path.join(call_dir, "metadata.pkl"), "rb") as f:
- kwargs = pickle.load(f)
-
- for result_type in ["input", "output"]:
- key = f"{result_type}s" # inputs or outputs
- kwargs[key] = []
-
- files = glob.glob(os.path.join(call_dir, f"{result_type}_*.pkl"))
- for filename in sorted(files):
- with open(filename, "rb") as f:
- kwargs[key].append(pickle.load(f))
-
- # Convert to tuple to match python's return type for multiple results.
- kwargs[key] = tuple(kwargs[key])
-
- return ModuleCall(**kwargs)
-
-
-class Trace:
- """Stores the inputs and outputs of a series of calls to a module."""
-
- def __init__(self,
- module: Union[tf_utils.CompiledModule, None],
- function: Union[Callable[["TracedModule"], None], None],
- _load_dict: Dict[str, Any] = None):
- """Extracts metadata from module and function and initializes.
-
- Example usage:
- def forward_pass(...):
- ...
- module = IreeCompiledModule(...)
- trace = Trace(module, forward_pass)
- forward_pass(TracedModule(module, trace))
-
- Args:
- module: the module who's outputs this trace will record.
- function: the function that module will be traced on.
- _load_dict: used internally
- """
- if _load_dict is None:
- # Extract metadata from module and function.
- self.module_name = module.module_name
- self.compiled_paths = module.compiled_paths
- self.backend_name = module.backend_info.backend_name
- self.backend_id = module.backend_info.backend_id
- self.backend_driver = module.backend_info.driver
- self.iree_serializable = module.iree_serializable()
- self.tflite_serializable = module.tflite_serializable()
- self.function_name = function.__name__
- self.function_sourcefile = inspect.getsourcefile(function)
- source, start_line = inspect.getsourcelines(function)
- self.function_line_numbers = (start_line, start_line + len(source))
- self.function_source = "".join(source)
-
- self.calls = []
- else:
- self.module_name = _load_dict["module_name"]
- self.compiled_paths = _load_dict["compiled_paths"]
- self.backend_name = _load_dict["backend_name"]
- self.backend_id = _load_dict["backend_id"]
- self.backend_driver = _load_dict["backend_driver"]
- self.iree_serializable = _load_dict["iree_serializable"]
- self.tflite_serializable = _load_dict["tflite_serializable"]
- self.function_name = _load_dict["function_name"]
- self.function_sourcefile = _load_dict["function_sourcefile"]
- self.function_line_numbers = _load_dict["function_line_numbers"]
- self.function_source = _load_dict["function_source"]
- self.calls = _load_dict["calls"]
-
- def __str__(self):
- header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
- f"on function '{self.function_name}':")
- # Give each call a number so it's easier to compare between multiple traces.
- calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
- calls = _indent("\n".join(calls))
- return f"{header}\n{calls}"
-
- def __iter__(self):
- for call in self.calls:
- yield call
-
- @staticmethod
- def compare_traces(ref_trace: "Trace",
- tar_trace: "Trace") -> Tuple[bool, Sequence[str]]:
- traces_match = True
- error_messages = []
-
- # Check that all method invocations match.
- ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
- tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace]
- if ref_methods != tar_methods:
- # Raise a ValueError instead of returning False since this is an
- # unexpected error.
- raise ValueError(
- "The reference and target traces have different call structures:\n"
- f"Reference: {ref_methods}\nTarget: {tar_methods}")
-
- for ref_call, tar_call in zip(ref_trace, tar_trace):
- logging.info("Comparing calls to '%s'", ref_call.method)
- rtol, atol = ref_call.get_tolerances()
-
- inputs_match, error_message = Trace._check_same(ref_call.inputs,
- tar_call.inputs, rtol,
- atol)
- if not inputs_match:
- error_messages.append(error_message)
- logging.error("Inputs did not match.")
- outputs_match, error_message = Trace._check_same(ref_call.outputs,
- tar_call.outputs, rtol,
- atol)
- if not outputs_match:
- error_messages.append(error_message)
- logging.error("Outputs did not match.")
- calls_match = inputs_match and outputs_match
-
- if not calls_match:
- logging.error("Comparision between '%s' and '%s' failed on method '%s'",
- ref_trace.backend_id, tar_trace.backend_id,
- ref_call.method)
- logging.error("Reference call '%s':\n%s", ref_trace.backend_id,
- ref_call)
- logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
-
- traces_match = traces_match and calls_match
- return traces_match, error_messages
-
- @staticmethod
- def _check_same(ref: Any, tar: Any, rtol: float,
- atol: float) -> Tuple[bool, Union[str, None]]:
- """Checks that ref and tar have identical datastructures and values."""
- # Check for matching types.
- if not isinstance(tar, type(ref)):
- error = ("Expected ref and tar to have the same type but got "
- f"'{type(ref)}' and '{type(tar)}'")
- logging.error(error)
- return False, error
-
- if ref is None:
- # Nothing to compare (e.g. the called method had no outputs).
- return True, None
-
- # Recursive check for dicts.
- if isinstance(ref, dict):
- if ref.keys() != tar.keys():
- error = ("Expected ref and tar to have the same keys, but got "
- f"'{ref.keys()}' and '{tar.keys()}'")
- logging.error(error)
- return False, error
- # Check that all of the dictionaries' values are the same.
- for key in ref:
- same, error = Trace._check_same(ref[key], tar[key], rtol, atol)
- if not same:
- return same, error
-
- # Recursive check for iterables.
- elif isinstance(ref, list) or isinstance(ref, tuple):
- if len(ref) != len(tar):
- error = ("Expected ref and tar to have the same length, but got "
- f"{len(ref)} and {len(tar)}")
- logging.error(error)
- return False, error
- # Check that all of the iterables' values are the same.
- for i in range(len(ref)):
- same, error = Trace._check_same(ref[i], tar[i], rtol, atol)
- if not same:
- return same, error
-
- # Base check for numpy arrays.
- elif isinstance(ref, np.ndarray):
- if ref.dtype != tar.dtype:
- error = ("Expected ref and tar to have the same dtype, but got "
- f"'{ref.dtype}' and '{tar.dtype}'")
- logging.error(error)
- return False, error
- if ref.size == tar.size == 0:
- return True, None
-
- if np.issubdtype(ref.dtype, np.floating):
- same = np.allclose(ref, tar, rtol=rtol, atol=atol, equal_nan=True)
- abs_diff = np.max(np.abs(ref - tar))
- rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
- diff_string = (f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, "
- f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}")
- if not same:
- error = ("Floating point difference between ref and tar was too "
- f"large. {diff_string}")
- logging.error(error)
- else:
- error = None
- logging.info(
- "Floating point difference between ref and tar was within "
- "tolerance. %s", diff_string)
- return same, error
- elif np.issubdtype(ref.dtype, np.integer):
- same = np.array_equal(ref, tar)
- if not same:
- abs_diff = np.max(np.abs(ref - tar))
- error = ("Expected array equality between ref and tar, but got "
- f"a max elementwise difference of {abs_diff}")
- logging.error(error)
- else:
- error = None
- return same, error
- else:
- return np.array_equal(ref, tar), None
-
- # Base check for native number types.
- elif isinstance(ref, (int, float)):
- return ref == tar, None
-
- # If outputs end up here then an extra branch for that type should be added.
- else:
- raise TypeError(f"Encountered results with unexpected type {type(ref)}")
- return True, None
-
- def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None:
- """Saves a human-readable string representation of this trace to disk.
-
- Args:
- trace_dir: str, path to the directory to save the trace in.
- summarize: a bool controlling whether numpy should summarize the inputs
- and outputs if they're large. Setting this to False is very slow for
- large outputs.
- """
- prior_printoptions = np.get_printoptions()
- np.set_printoptions(
- linewidth=NUMPY_LINEWIDTH,
- threshold=None if summarize else sys.maxsize,
- edgeitems=10) # Can show more items since they won't clutter the logs.
-
- path = os.path.join(trace_dir, "log.txt")
- with open(path, "w") as f:
- f.write(str(self))
- f.write("\n")
-
- np.set_printoptions(**prior_printoptions)
-
- def serialize(self, trace_dir: str) -> None:
- """Stores a serialized copy of this trace in trace_dir.
-
- It can be loaded via `Trace.load(trace_dir)`.
-
- Args:
- trace_dir: str, path to the directory to serialize the trace to.
- """
-
- compiled_paths = None
- if self.compiled_paths is not None:
- # Convert to a dict to avoid the issues with serializing defaultdicts.
- compiled_paths = dict(self.compiled_paths)
-
- # Python serialization.
- metadata = {
- "module_name": self.module_name,
- "compiled_paths": compiled_paths,
- "backend_name": self.backend_name,
- "backend_id": self.backend_id,
- "backend_driver": self.backend_driver,
- "iree_serializable": self.iree_serializable,
- "tflite_serializable": self.tflite_serializable,
- "function_name": self.function_name,
- "function_sourcefile": self.function_sourcefile,
- "function_line_numbers": self.function_line_numbers,
- "function_source": self.function_source
- }
- with open(os.path.join(trace_dir, "metadata.pkl"), "wb") as f:
- pickle.dump(metadata, f)
-
- width = _zfill_width(len(self.calls))
- for i, call in enumerate(self.calls):
- call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}")
- call.serialize(call_dir)
-
- # C++ benchmark serialization.
- if self.iree_serializable or self.tflite_serializable:
- entry_function = self.calls[0].method
- compiled_path = self.compiled_paths[entry_function]
-
- if self.iree_serializable:
- serialized_inputs = ", ".join(self.calls[0].serialized_inputs)
- flagfile = [
- f"--module_file={compiled_path}",
- f"--driver={self.backend_driver}",
- f"--function_inputs={serialized_inputs}",
- f"--entry_function={entry_function}",
- ]
- with open(os.path.join(trace_dir, "flagfile"), "w") as f:
- f.writelines(line + "\n" for line in flagfile)
- else:
- with open(os.path.join(trace_dir, "graph_path"), "w") as f:
- f.writelines(compiled_path + "\n")
-
- @staticmethod
- def load(trace_dir: str) -> "Trace":
- """Loads and returns a trace serialized with Trace.serialize.
-
- Args:
- trace_dir: str, path to the directory of the serialized trace.
-
- Returns:
- A Trace deserialized from trace_dir.
- """
- with open(os.path.join(trace_dir, "metadata.pkl"), "rb") as f:
- load_dict = pickle.load(f)
- call_dirs = sorted(glob.glob(os.path.join(trace_dir, "call_*")))
- calls = [ModuleCall.load(call_dir) for call_dir in call_dirs]
- load_dict["calls"] = calls
- return Trace(module=None, function=None, _load_dict=load_dict)
-
-
-def _get_trace_dir(artifacts_dir: str, trace: Trace) -> str:
- trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
- trace.function_name)
- os.makedirs(trace_dir, exist_ok=True)
- return trace_dir
-
-
-class TracedModule:
-
- def __init__(self, module: tf_utils.CompiledModule, trace: Trace):
- """Wraps a CompiledModule so that all inputs and outputs are traced.
-
- The TracedModule returned will have an API almost identical to that of the
- passed CompiledModule. The only changes is that if the keywords `rtol` or
- `atol` are passed to one of the CompiledModule's methods, then they will be
- used to set the tolerance for comparing that call to the same call in
- another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)`
- would be the same as calling `module.add(a, b)`.
-
- Args:
- module: the CompiledModule to trace.
- trace: the Trace to record calls to this module with.
- """
- self._module = module
- self._trace = trace
-
- def _trace_call(self, method: tf_utils._FunctionWrapper, method_name: str):
- """Decorates a CompiledModule method to capture its inputs and outputs."""
-
- def call(*args, **kwargs):
- # Pop manually specified tolerances from the kwargs (if any).
- tolerances = {}
- tolerances["rtol"] = kwargs.pop("rtol", None)
- tolerances["atol"] = kwargs.pop("atol", None)
- # Only pass these to ModuleCall if they were specified by the user.
- tolerances = {k: v for k, v in tolerances.items() if v is not None}
-
- # Run the method and record the details of the call.
- outputs = method(*args, **kwargs)
- serialized_inputs, serialized_outputs = method.get_serialized_values()
- self._trace.calls.append(
- ModuleCall(method_name, args, outputs, serialized_inputs,
- serialized_outputs, **tolerances))
- return outputs
-
- return call
-
- def __getattr__(self, attr):
- # Try to resolve it as an attr on self._module.
- if not hasattr(self._module, attr):
- raise AttributeError(f"The compiled module does not have attr '{attr}'")
- module_attr = getattr(self._module, attr)
- if not hasattr(module_attr, "__call__"):
- # e.g. traced_module.backend
- return module_attr
- else:
- # e.g. traced_module.simple_mul(a, b)
- return self._trace_call(module_attr, method_name=attr)
-
-
Modules = collections.namedtuple("Modules",
["ref_module", "tar_modules", "artifacts_dir"])
# We have to use a global variable to store the compiled modules so that we can
# avoid recompilation. This is because the TestCase class resets it's entire
-# state and calls __init__ before each unittest. It also calls __init__ one
+# state and calls __init__ before each unit_test. It also calls __init__ one
# additional time before that for good measure, which means without storing the
# modules somewhere else we would have to compile each of them at least twice.
# We can't store the modules on the class itself via setUpClass because of #2900
@@ -633,8 +148,8 @@
artifacts_dir = _setup_artifacts_dir(module_class.__name__)
# Get the backend information for this test.
- ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
- f"{FLAGS.reference_backend}_ref")
+ ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
+ f"{FLAGS.reference_backend}_ref")
tar_backend_infos = get_target_backends()
compile_backend = lambda backend_info: backend_info.compile_from_class(
@@ -676,8 +191,8 @@
artifacts_dir = _setup_artifacts_dir(module_name)
# Get the backend information for this test.
- ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
- f"{FLAGS.reference_backend}_ref")
+ ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
+ f"{FLAGS.reference_backend}_ref")
tar_backend_infos = get_target_backends()
compile_backend = (
@@ -693,13 +208,230 @@
return _global_modules
-def tf_function_unittest(input_generator: tf_utils.InputGeneratorType = None,
- input_args: Sequence[Any] = None,
- atol: float = None,
- rtol: float = None,
- name: str = None,
- **tf_function_kwargs):
- """Creates a tf.function that can be used to generate unittests.
+# We use global variables to store the configuration information for
+# tf_function_unit_tests because tensorflow.python.eager.def_function.Function
+# is not an API that we can subclass, and storing the information directly
+# that class results in it being deleted at tf.Module initialization.
+# _global_unit_test_configs is a dict mapping exported_names to dicts containing
+# a get-function for input args and the tolerance kwargs for the trace.
+global _global_unit_test_configs
+_global_unit_test_configs = dict()
+
+
+class UnitTestSpec:
+
+ def __init__(self,
+ unit_test_name: str,
+ input_signature: Sequence[tf.TensorSpec],
+ input_generator: tf_utils.InputGeneratorType = None,
+ input_args: Union[Sequence[Any], None] = None,
+ kwargs: Dict[str, Any] = None):
+ self.unit_test_name = tf_utils.remove_special_characters(unit_test_name)
+ self.input_signature = input_signature
+ self.input_args = input_args
+ self.kwargs = dict() if kwargs is None else kwargs
+ self.input_generator = input_generator
+
+ def with_name(self, new_name: str) -> "UnitTestSpec":
+ return UnitTestSpec(new_name, self.input_signature, self.input_generator,
+ self.input_args, self.kwargs)
+
+ def __str__(self):
+ return self.unit_test_name
+
+
+def _dictionary_product(dictionary: Dict[Any, Any]) -> List[Dict[Any, Any]]:
+ """Returns a named cartesian product of dictionary's values.
+
+ Converts {'a': [1, 2], 'b': [3, 4]} into
+ [{'a': 1, 'b': 3}, {'a': 1, 'b': 4}, {'a': 2, 'b': 3}, {'a': 2, 'b': 4}]
+ """
+ product = [[]]
+ for values in dictionary.values():
+ # Iteratively grow the elements of the product.
+ product = [element + [value] for element in product for value in values]
+ dicts = [{k: v for k, v in zip(dictionary, element)} for element in product]
+ return dicts
+
+
+def _named_kwargs_product(
+ kwargs_to_values: Dict[str, Sequence[Any]]) -> Dict[str, Dict[str, Any]]:
+ """Splits kwargs_to_values into a Cartesian product of its elements."""
+ # Validate 'kwargs_to_values'
+ if kwargs_to_values is None:
+ kwargs_to_values = dict() # Use only default kwargs.
+ for kwarg_key, kwarg_values in kwargs_to_values.items():
+ if not isinstance(kwarg_values, Sequence):
+ raise TypeError(f"Expected kwargs_to_values[{repr(kwarg_key)}] to be a "
+ f"sequence, but got '{type(kwarg_values)}'")
+
+ # Expand across a Cartesian product.
+ kwargs_product = _dictionary_product(kwargs_to_values)
+ # {'a': 1, 'b': 3} -> "a_1__b_3"
+ dict_to_str = lambda d: "__".join([f"{k}_{v}" for k, v in d.items()])
+ return {dict_to_str(kwargs): kwargs for kwargs in kwargs_product}
+
+
+def unit_test_specs_from_signatures(
+ signature_shapes: Sequence[Sequence[Sequence[int]]],
+ signature_dtypes: Sequence[tf.DType] = [tf.float32],
+ input_generators: Union[Sequence[tf_utils.InputGeneratorType],
+ Dict[str, tf_utils.InputGeneratorType]] = [
+ DEFAULT_INPUT_GENERATOR
+ ],
+ kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]:
+ """Generates a Cartesian product of UnitTestSpecs from the given arguments.
+
+ Args:
+ signature_shapes:
+ A sequence (representing multiple signatures to test) of sequences
+ (representing the shapes of the args in those signatures) of ints
+ (representing the individual sizes of those shapes).
+ signature_dtypes:
+ A sequence of dtypes to test each signature with.
+ input_generators:
+ Either:
+ 1. a sequence of input generators to test each of the signature-dtype
+ pairs with
+ 2. a dictionary mapping input generator names to input generators to
+ test each of the signature-dtype pairs with. This format must be used
+ if any of the generators are lambda functions.
+ kwargs_to_values:
+ A dict mapping kwarg names to sequences of values that they can take.
+
+ Returns:
+ A list of 'UnitTestSpec's generated from the provided arguments.
+ """
+ # Validate 'signature_shapes'
+ for i, shapes in enumerate(signature_shapes):
+ if not isinstance(shapes, Sequence):
+ raise TypeError(f"Expected signature_shapes[{i}] to be a sequence, but "
+ f"got '{type(shapes)}'")
+ for j, shape in enumerate(shapes):
+ if not isinstance(shape, Sequence):
+ raise TypeError(f"Expected signature_shapes[{i}][{j}] to be a "
+ f"sequence, but got '{type(shape)}'")
+ for k, size in enumerate(shape):
+ if not isinstance(size, int):
+ raise TypeError(f"Expected signature_shapes[{i}][{j}][{k}] to be an "
+ f"int but got '{type(size)}")
+
+ # Parse 'signature_shapes'
+ names_to_shapes = dict()
+ for signature in signature_shapes:
+ # Converts [[1, 2, 3], [4, 5]] into 1x2x3_4x5.
+ signature_key = "_".join(
+ ["x".join(str(size) for size in shape) for shape in signature])
+ names_to_shapes[signature_key] = signature
+
+ # Validate 'signature_dtypes'
+ for i, dtype in enumerate(signature_dtypes):
+ if not isinstance(dtype, tf.DType):
+ raise TypeError(
+ f"Expected dtypes[{i}] to be a tf.DType, but got '{type(dtype)}'")
+
+ # Parse 'signature_dtypes'
+ # 'complex64' -> 'c64'
+ abbreviate = lambda dtype: re.sub(r"([a-z])[a-z]*([0-9]+)", r"\1\2", dtype)
+ names_to_dtypes = {
+ abbreviate(dtype.name): dtype for dtype in signature_dtypes
+ }
+
+ # Validate 'input_generators'
+ if not isinstance(input_generators, (Sequence, Dict)):
+ raise TypeError("Expected 'input_generators' to be a sequence or "
+ f"dictionary, but got '{type(input_generators)}'")
+ if isinstance(input_generators, Sequence):
+ for i, generator in enumerate(input_generators):
+ if generator.__name__ == "<lambda>":
+ raise TypeError(
+ f"'input_generators' was a sequence but input_generators[{i}] was "
+ "lambda function. 'input_generators' must be a dictionary if "
+ "lambda functions are used.")
+
+ # Parse 'input_generators'
+ if isinstance(input_generators, Sequence):
+ names_to_generators = {gen.__name__: gen for gen in input_generators}
+ else:
+ names_to_generators = input_generators
+
+ # Validate and parse 'kwargs_to_values'
+ names_to_kwargs = _named_kwargs_product(kwargs_to_values)
+
+ # Create a Cartesian product through all specifications and their names.
+ specs = [
+ names_to_shapes, names_to_dtypes, names_to_generators, names_to_kwargs
+ ]
+ # pytype: disable=attribute-error
+ key_product = itertools.product(*[list(spec.keys()) for spec in specs])
+ value_product = itertools.product(*[list(spec.values()) for spec in specs])
+ # pytype: enable=attribute-error
+
+ # Generate a UnitTestSpec for each element in the above product.
+ unit_tests = []
+ for keys, (shapes, dtype, generator, kwargs) in zip(key_product,
+ value_product):
+ unit_test_name = "__".join(key for key in keys if key)
+ input_signature = [tf.TensorSpec(shape, dtype) for shape in shapes]
+ unit_tests.append(
+ UnitTestSpec(
+ unit_test_name=unit_test_name,
+ input_signature=input_signature,
+ input_generator=generator,
+ input_args=None,
+ kwargs=kwargs,
+ ))
+ return unit_tests
+
+
+def unit_test_specs_from_args(
+ names_to_input_args: Dict[str, Sequence[Any]],
+ kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]:
+ """Generates a Cartesian product of UnitTestSpecs from the given arguments.
+
+ Args:
+ signature_shapes:
+ A dict mapping names for input arguments to the arguments themselves.
+ kwargs_to_values:
+ A dict mapping kwarg names to sequences of values that they can take.
+
+ Returns:
+ A list of 'UnitTestSpec's generated from the provided arguments.
+ """
+ # Validate and parse 'kwargs_to_values'
+ names_to_kwargs = _named_kwargs_product(kwargs_to_values)
+
+ # Create a Cartesian product through all specifications and their names.
+ specs = [names_to_input_args, names_to_kwargs]
+ key_product = itertools.product(*[list(spec.keys()) for spec in specs])
+ value_product = itertools.product(*[list(spec.values()) for spec in specs])
+
+ # Generate a UnitTestSpec for each element in the above product.
+ unit_tests = []
+ for keys, (input_args, kwargs) in zip(key_product, value_product):
+ unit_test_name = "__".join(key for key in keys if key)
+ input_signature = tf_utils.apply_function(
+ input_args,
+ lambda x: tf.TensorSpec.from_tensor(tf.convert_to_tensor(x)))
+ unit_tests.append(
+ UnitTestSpec(
+ unit_test_name=unit_test_name,
+ input_signature=input_signature,
+ input_generator=None,
+ input_args=input_args,
+ kwargs=kwargs,
+ ))
+ return unit_tests
+
+
+def tf_function_unit_test(input_generator: tf_utils.InputGeneratorType = None,
+ input_args: Sequence[Any] = None,
+ atol: float = None,
+ rtol: float = None,
+ name: str = None,
+ static_signature: Sequence[tf.TensorSpec] = None,
+ **tf_function_kwargs):
+ """Creates a tf.function that can be used to generate unit_tests.
If 'input_generator' and 'input_args' are unspecified then the function will
be tested using random uniform data.
@@ -707,7 +439,7 @@
Args:
input_generator:
an optional callable taking a shape and dtype that returns input data for
- the unittest.
+ the unit_test.
input_args:
an optional sequence of values to pass as positional args to the function.
atol:
@@ -719,6 +451,10 @@
name:
optional, the name to reference this function with. Must be used if
decorating a lambda.
+ static_signature:
+ optional, a signature with the same structure as 'input_signature'. Used
+ to specify the correct shape for data generation when dynamic dims are
+ provided.
Raises:
ValueError: if 'input_generator' and 'input_args' are both specified.
@@ -729,7 +465,7 @@
__name__ attribute if 'name' was specified.
"""
- def _store_unittest_info(function):
+ def _store_unit_test_info(function):
# Validate arguments.
if input_generator is not None and input_args is not None:
raise ValueError(
@@ -737,58 +473,59 @@
function = tf.function(**tf_function_kwargs)(function)
- # Used to identify that the tf.function was created by this decorator.
- function.is_tf_function_unittest = True
-
- # Set function.get_trace_args.
- if input_generator is not None:
- # Use the user-specificed input_generator.
- function.get_trace_args = lambda: tf_utils.generate_inputs(
- function.input_signature, input_generator)
- elif input_args is not None:
- # Use the user-specified input_args.
- function.get_trace_args = lambda: copy.deepcopy(input_args)
- else:
- # No user data specification – default to using random uniform data.
- function.get_trace_args = lambda: tf_utils.generate_inputs(
- function.input_signature, tf_utils.uniform)
-
- # Set function.trace_kwargs.
- function.trace_kwargs = dict(atol=atol, rtol=rtol)
-
- # Set function.__name__.
+ # Set function.__name__
if name is not None:
function.__name__ = name
elif function.__name__ == "<lambda>":
raise ValueError("The 'name' kwarg must be provided when decorating a "
"lambda function.")
+ global _global_unit_test_configs
+ if function.__name__ not in _global_unit_test_configs:
+
+ if static_signature is not None:
+ signature = static_signature
+ else:
+ signature = function.input_signature
+
+ if input_generator is not None:
+ # Use the user-specificed input_generator.
+ get_trace_args = lambda: tf_utils.generate_inputs(
+ signature, input_generator)
+ elif input_args is not None:
+ # Use the user-specified input_args.
+ get_trace_args = lambda: copy.deepcopy(input_args)
+ else:
+ # No user data specification – default to using random uniform data.
+ get_trace_args = lambda: tf_utils.generate_inputs(
+ signature, DEFAULT_INPUT_GENERATOR)
+
+ _global_unit_test_configs[function.__name__] = dict(
+ get_trace_args=get_trace_args,
+ trace_kwargs=dict(atol=atol, rtol=rtol))
+
return function
- return _store_unittest_info
+ return _store_unit_test_info
class TestModule(tf.Module):
- """Thin wrapper of tf.Module with helper methods for tf_function_unittests."""
+ """Thin tf.Module wrapper with helper methods for tf_function_unit_tests."""
@classmethod
- def get_tf_function_unittests(cls):
- """Get all tf_function_unittest-created tf.functions on the class."""
- tf_function_unittests = []
- for name in dir(cls):
- value = getattr(cls, name)
- if hasattr(value, 'is_tf_function_unittest'):
- tf_function_unittests.append(value)
+ def get_tf_function_unit_tests(cls):
+ """Get all tf_function_unit_test-created tf.functions on the class."""
+ # Initialize the module to ensure that _global_unit_test_configs has the
+ # info for all of the unit_tests. (Only doing this if
+ # _global_unit_test_configs is empty wouldn't address the case where some
+ # unit_tests are defined on the class and some are generated by __init__).
+ cls()
- if not len(tf_function_unittests):
+ tf_function_unit_tests = list(_global_unit_test_configs.keys())
+ if not len(tf_function_unit_tests):
raise ValueError(
- "'get_tf_function_unittests' was called but no unittests were found.")
- return tf_function_unittests
-
- @classmethod
- def get_exported_names(cls):
- """Get the names of all tf_function_unittest-created tf.functions"""
- return [function.__name__ for function in cls.get_tf_function_unittests()]
+ "'get_tf_function_unit_tests' was called but no tests were found.")
+ return tf_function_unit_tests
class TracedModuleTestCase(tf.test.TestCase):
@@ -802,37 +539,42 @@
module.reinitialize()
@classmethod
- def generate_unittests(cls, module_class: Type[TestModule]):
- """Generates unittests for each 'tf_function_unittest' on 'module_class'."""
- for function in module_class.get_tf_function_unittests():
- # We have to pass the closure argument 'funcion' to 'trace' via a kwarg
- # instead of using it directly in the body because 'function' is
- # overwritten in each iteration of this loop, and python will only use
- # the most recent version of 'function'. If we didn't do this, then we
- # would only test the last function in this loop. The same is true for
- # passing 'trace' to 'unittest'.
+ def generate_unit_tests(cls, module_class: Type[TestModule]):
+ """Generates tests for each 'tf_function_unit_test' on 'module_class'."""
+ for function_name in module_class.get_tf_function_unit_tests():
+ # We have to pass the closure arguments 'function_name', 'get_args' and
+ # 'kwargs' to 'trace' via a kwarg instead of using it directly in the body
+ # because 'function_name' and 'unit_test_config' are overwritten in each
+ # iteration of this loop, and python will only use the most recent version
+ # of each. If we didn't do this, then we would only test the last function
+ # in this loop. The same is true for passing 'trace' to 'unit_test'.
+ unit_test_config = _global_unit_test_configs[function_name]
# Runs the inputs through a (traced) module.
- def trace(module, function=function):
- getattr(module, function.__name__)(*function.get_trace_args(),
- **function.trace_kwargs)
+ def trace(module,
+ function_name=function_name,
+ get_args=unit_test_config["get_trace_args"],
+ kwargs=unit_test_config["trace_kwargs"]):
+ getattr(module, function_name)(*get_args(), **kwargs)
# Give the trace the name of the tf.function that it is testing.
- trace.__name__ = function.__name__
+ trace.__name__ = function_name
# Runs 'trace' on modules compiled to each backend and compares them.
- def unittest(self, trace=trace):
+ def unit_test(self, trace=trace):
self.compare_backends(trace, self._modules)
- # Make 'unittest' a function on the TracedModuleTestCase, which tells
+ # Make 'unit_test' a function on the TracedModuleTestCase, which tells
# the test runner to run it.
- unittest.__name__ = f"test_{function.__name__}"
- if hasattr(cls, unittest.__name__):
- raise ValueError("Tried to generate multiple instances of the unittest "
- f"'{unittest.__name__}'.")
- setattr(cls, unittest.__name__, unittest)
+ unit_test.__name__ = f"test_{function_name}"
+ if hasattr(cls, unit_test.__name__):
+ raise ValueError("Tried to generate multiple instances of the "
+ f"unit_test '{unit_test.__name__}'.")
+ setattr(cls, unit_test.__name__, unit_test)
- def compare_backends(self, trace_function: Callable[[TracedModule], None],
+ def compare_backends(self,
+ trace_function: Callable[[trace_utils.TracedModule],
+ None],
modules: Modules) -> None:
"""Run the reference and target backends on trace_function and compare them.
@@ -843,19 +585,20 @@
trace_function: a function accepting a TracedModule as its argument.
"""
# Create Traces for each backend.
- ref_trace = Trace(modules.ref_module, trace_function)
+ ref_trace = trace_utils.Trace(modules.ref_module, trace_function)
tar_traces = [
- Trace(module, trace_function) for module in modules.tar_modules
+ trace_utils.Trace(module, trace_function)
+ for module in modules.tar_modules
]
# Run the traces through trace_function with their associated modules.
tf_utils.set_random_seed()
- trace_function(TracedModule(modules.ref_module, ref_trace))
+ trace_function(trace_utils.TracedModule(modules.ref_module, ref_trace))
if FLAGS.log_all_traces:
logging.info(ref_trace)
for module, trace in zip(modules.tar_modules, tar_traces):
tf_utils.set_random_seed()
- trace_function(TracedModule(module, trace))
+ trace_function(trace_utils.TracedModule(module, trace))
if FLAGS.log_all_traces:
logging.info(trace)
@@ -865,17 +608,18 @@
for i, tar_trace in enumerate(tar_traces):
logging.info("Comparing the reference backend '%s' with '%s'",
ref_trace.backend_id, tar_trace.backend_id)
- traces_match, errors = Trace.compare_traces(ref_trace, tar_trace)
+ traces_match, errors = trace_utils.compare_traces(ref_trace, tar_trace)
if not traces_match:
failed_backend_indices.append(i)
error_messages.extend(errors)
# Save the results to disk before validating.
- ref_trace_dir = _get_trace_dir(modules.artifacts_dir, ref_trace)
+ ref_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, ref_trace)
ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize)
ref_trace.serialize(ref_trace_dir)
for tar_trace in tar_traces:
- tar_trace_dir = _get_trace_dir(modules.artifacts_dir, tar_trace)
+ tar_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir,
+ tar_trace)
tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize)
tar_trace.serialize(tar_trace_dir)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 63cfa2f..84c1c36 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -14,58 +14,26 @@
# limitations under the License.
"""Tests for pyiree.tf.support.tf_test_utils."""
-import os
-import tempfile
-
-from absl.testing import parameterized
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow as tf
-class StatefulCountingModule(tf.Module):
+class TfFunctionUnitTestModule(tf_test_utils.TestModule):
- def __init__(self):
- self.count = tf.Variable([0.])
-
- @tf.function(input_signature=[])
- def increment(self):
- self.count.assign_add(tf.constant([1.]))
-
- @tf.function(input_signature=[])
- def get_count(self):
- return self.count
-
- @tf.function(input_signature=[tf.TensorSpec([1])])
- def increment_by(self, value):
- self.count.assign_add(value)
-
- @tf.function(input_signature=[tf.TensorSpec([1]), tf.TensorSpec([1])])
- def increment_by_max(self, a, b):
- result = tf.maximum(a, b)
- self.count.assign_add(result)
- return result
-
- @tf.function(input_signature=[])
- def decrement(self):
- self.count.assign_sub(tf.constant([1.]))
-
-
-class TfFunctionUnittestModule(tf_test_utils.TestModule):
-
- @tf_test_utils.tf_function_unittest(input_signature=[])
+ @tf_test_utils.tf_function_unit_test(input_signature=[])
def no_args(self):
return np.array([True], dtype=np.bool)
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([4]),
tf.TensorSpec([4]),
])
def default_uniform_inputs(self, a, b):
return a + b
- @tf_test_utils.tf_function_unittest(
+ @tf_test_utils.tf_function_unit_test(
input_signature=[
tf.TensorSpec([4]),
tf.TensorSpec([4]),
@@ -75,7 +43,7 @@
def custom_input_generator(self, a, b):
return a + b
- @tf_test_utils.tf_function_unittest(
+ @tf_test_utils.tf_function_unit_test(
input_signature=[
tf.TensorSpec([4]),
tf.TensorSpec([4]),
@@ -89,7 +57,7 @@
return a + b
# This test will fail if atol is not successfully set.
- @tf_test_utils.tf_function_unittest(
+ @tf_test_utils.tf_function_unit_test(
input_signature=[
tf.TensorSpec([128, 3072], tf.float32),
tf.TensorSpec([3072, 256], tf.float32),
@@ -100,158 +68,7 @@
return tf.matmul(a, b)
-class TestUtilsTests(tf.test.TestCase, parameterized.TestCase):
-
- @parameterized.named_parameters([
- {
- 'testcase_name': 'all the same',
- 'array_c': np.array([0, 1, 2]),
- 'array_d': np.array(['0', '1', '2']),
- 'array_e': np.array([0.0, 0.1, 0.2]),
- 'tar_same': True,
- },
- {
- 'testcase_name': 'wrong int',
- 'array_c': np.array([1, 1, 2]),
- 'array_d': np.array(['0', '1', '2']),
- 'array_e': np.array([0.0, 0.1, 0.2]),
- 'tar_same': False,
- },
- {
- 'testcase_name': 'wrong string',
- 'array_c': np.array([0, 1, 2]),
- 'array_d': np.array(['a', '1', '2']),
- 'array_e': np.array([0.0, 0.1, 0.2]),
- 'tar_same': False,
- },
- {
- 'testcase_name': 'wrong float',
- 'array_c': np.array([0, 1, 2]),
- 'array_d': np.array(['0', '1', '2']),
- 'array_e': np.array([1.0, 0.1, 0.2]),
- 'tar_same': False,
- },
- ])
- def test_recursive_check_same(self, array_c, array_d, array_e, tar_same):
-
- # yapf: disable
- ref = {
- 'a': 1,
- 'b': [
- {'c': np.array([0, 1, 2])},
- {'d': np.array(['0', '1', '2'])},
- {'e': np.array([0.0, 0.1, 0.2])}
- ],
- }
- tar = {
- 'a': 1,
- 'b': [
- {'c': array_c},
- {'d': array_d},
- {'e': array_e}
- ],
- }
- # yapf: enable
- same, _ = tf_test_utils.Trace._check_same(ref, tar, rtol=1e-6, atol=1e-6)
- self.assertEqual(tar_same, same)
-
- def test_trace_inputs_and_outputs(self):
-
- def trace_function(module):
- # No inputs or outputs
- module.increment()
- # Only inputs
- module.increment_by(np.array([81.], dtype=np.float32))
- # Only outputs
- module.get_count()
-
- module = tf_utils.TfCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('tf'))
- trace = tf_test_utils.Trace(module, trace_function)
- trace_function(tf_test_utils.TracedModule(module, trace))
-
- self.assertIsInstance(trace.calls[0].inputs, tuple)
- self.assertEmpty(trace.calls[0].inputs)
- self.assertIsInstance(trace.calls[0].outputs, tuple)
- self.assertEmpty(trace.calls[0].outputs)
-
- self.assertAllClose(trace.calls[1].inputs[0], [81.])
- self.assertAllClose(trace.calls[2].outputs[0], [82.])
-
- def test_nonmatching_methods(self):
-
- def tf_function(module):
- module.increment()
- module.increment()
-
- def vmla_function(module):
- module.increment()
- module.decrement()
-
- tf_module = tf_utils.TfCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('tf'))
- tf_trace = tf_test_utils.Trace(tf_module, tf_function)
- tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
-
- vmla_module = tf_utils.IreeCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
- vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
- vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
-
- with self.assertRaises(ValueError):
- tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)
-
- def test_nonmatching_inputs(self):
-
- def tf_function(module):
- module.increment_by(np.array([42.], dtype=np.float32))
-
- def vmla_function(module):
- module.increment_by(np.array([22.], dtype=np.float32))
-
- tf_module = tf_utils.TfCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('tf'))
- tf_trace = tf_test_utils.Trace(tf_module, tf_function)
- tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
-
- vmla_module = tf_utils.IreeCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
- vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
- vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
-
- same, error_messages = tf_test_utils.Trace.compare_traces(
- tf_trace, vmla_trace)
- self.assertFalse(same)
-
- def test_trace_serialize_and_load(self):
-
- def trace_function(module):
- module.increment()
- module.increment_by(np.array([81.], dtype=np.float32))
- module.increment_by_max(np.array([81], dtype=np.float32),
- np.array([92], dtype=np.float32))
- module.get_count()
-
- module = tf_utils.IreeCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
- trace = tf_test_utils.Trace(module, trace_function)
- trace_function(tf_test_utils.TracedModule(module, trace))
-
- with tempfile.TemporaryDirectory() as artifacts_dir:
- trace_function_dir = tf_test_utils._get_trace_dir(artifacts_dir, trace)
- trace.serialize(trace_function_dir)
- self.assertTrue(
- os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl')))
- loaded_trace = tf_test_utils.Trace.load(trace_function_dir)
-
- # Check all calls match.
- self.assertTrue(tf_test_utils.Trace.compare_traces(trace, loaded_trace))
-
- # Check all other metadata match.
- self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys())
- for key in trace.__dict__.keys():
- if key != 'calls':
- self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
+class TestUtilsTests(tf.test.TestCase):
def test_tf_function_unittet(self):
@@ -260,9 +77,9 @@
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
- TfFunctionUnittestModule)
+ TfFunctionUnitTestModule)
- TfFunctionUnittestTest.generate_unittests(TfFunctionUnittestModule)
+ TfFunctionUnittestTest.generate_unit_tests(TfFunctionUnitTestModule)
test_case = TfFunctionUnittestTest()
self.assertTrue(hasattr(test_case, 'test_no_args'))
self.assertTrue(hasattr(test_case, 'test_default_uniform_inputs'))
@@ -270,7 +87,7 @@
self.assertTrue(hasattr(test_case, 'test_custom_input_args'))
self.assertTrue(hasattr(test_case, 'test_high_tolerance'))
- # Will throw an error if 'atol' and 'rtol' are not set.
+ # Will throw an error if 'atol' is not set.
test_case = TfFunctionUnittestTest()
test_case.test_high_tolerance()
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 687d78b..1af3cca 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -14,21 +14,18 @@
# limitations under the License.
"""Utilities interop with TensorFlow."""
-# pylint: disable=protected-access
-
-import collections
import os
import random
import re
-import tempfile
-from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+from typing import Any, Callable, Sequence, Set, Tuple, Union
from absl import logging
import numpy as np
-from pyiree import rt
-from pyiree.tf import compiler
import tensorflow.compat.v2 as tf
+InputGeneratorType = Callable[[Sequence[int], Union[tf.DType, np.dtype]],
+ np.ndarray]
+
def set_random_seed(seed: int = 0) -> None:
"""Set random seed for tf, np and random."""
@@ -37,17 +34,19 @@
np.random.seed(seed)
-InputGeneratorType = Callable[[Sequence[int], Union[tf.DType, np.dtype]],
- np.ndarray]
-
-
def uniform(shape: Sequence[int],
dtype: Union[tf.DType, np.dtype] = np.float32,
low: float = -1.0,
high: float = 1.0) -> np.ndarray:
- """np.random.uniform with simplified API and dtype control."""
+ """np.random.uniform with simplified API and dtype and bool support."""
dtype = dtype.as_numpy_dtype if isinstance(dtype, tf.DType) else dtype
- return np.random.uniform(size=shape, low=low, high=high).astype(dtype)
+ if dtype == np.bool:
+ return np.random.choice(2, shape).astype(np.bool)
+ else:
+ values = np.random.uniform(size=shape, low=low, high=high)
+ if np.issubdtype(dtype, np.integer):
+ values = np.round(values)
+ return values.astype(dtype)
def ndarange(shape: Sequence[int],
@@ -57,26 +56,63 @@
return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
+def random_permutation(
+ shape: Sequence[int],
+ dtype: Union[tf.DType, np.dtype] = np.float32) -> np.ndarray:
+ """Returns a random permutation of [0, np.prod(shape))."""
+ values = ndarange(shape, dtype)
+ np.random.shuffle(values)
+ return values
+
+
+def apply_function(values, function):
+ """Applies 'function' recursively to the inputted values."""
+ if isinstance(values, list):
+ return [apply_function(v, function) for v in values]
+ elif isinstance(values, tuple):
+ return tuple(apply_function(v, function) for v in values)
+ elif isinstance(values, dict):
+ return {k: apply_function(v, function) for k, v in values.items()}
+ else:
+ return function(values)
+
+
def generate_inputs(
spec, # Union[Sequence[tf.TensorSpec], tf.TensorSpec]
input_generator: InputGeneratorType,
) -> Sequence[np.ndarray]:
"""Generates inputs for a given input signature using 'input_generator'."""
- if isinstance(spec, Sequence):
- # 'spec' is a sequence of 'tf.TensorSpec'.
- # Recursively generate inputs.
- return [generate_inputs(s, input_generator) for s in spec]
- elif isinstance(spec, tf.TensorSpec):
- # Handle dynamic shapes (e.g. batches) by substituting an int for None.
- shape = [size if size is not None else 2 for size in spec.shape]
- return input_generator(shape, spec.dtype)
- else:
- raise TypeError("Expected 'spec' to be a sequence of 'tf.TensorSpec' or "
- f"'tf.TensorSpec', but got '{type(spec)}'")
+ make_static = lambda shape: [dim if dim is not None else 2 for dim in shape]
+ generate = lambda spec: input_generator(make_static(spec.shape), spec.dtype)
+ return apply_function(spec, generate)
+
+
+def normalize_numpy(result: np.ndarray):
+ """Normalizes TF and TFLite's outputs to match IREE's"""
+ if np.isscalar(result):
+ result = np.array(result)
+ if result.dtype == np.bool:
+ # IREE interprets bools as int8s, so we modify this for comparison.
+ result = result.astype(dtype=np.int8)
+ return result
+
+
+def convert_to_numpy(values: Any) -> Any:
+ """Converts any tf.Tensor in values to numpy."""
+
+ def _convert_to_numpy(tensor: Any) -> Any:
+ if not isinstance(tensor, tf.Tensor):
+ return tensor
+ return normalize_numpy(tensor.numpy())
+
+ return apply_function(values, _convert_to_numpy)
def to_mlir_type(dtype: np.dtype) -> str:
"""Returns a string that denotes the type 'dtype' in MLIR style."""
+ if not isinstance(dtype, np.dtype):
+ # Handle np.int8 _not_ being a dtype.
+ dtype = np.dtype(dtype)
bits = dtype.itemsize * 8
if np.issubdtype(dtype, np.integer):
return f"i{bits}"
@@ -116,942 +152,147 @@
return result
-def _setup_mlir_crash_reproducer(
- function: Any, # pytype doesn't support arbitrary Callable[*args, **kwargs]
- artifacts_dir: str,
- backend_id: str,
-) -> Any: # Callable[Any, Any]
- """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+def remove_special_characters(value: str) -> str:
+ """Replaces special characters with '_' while keeping instances of '__'."""
+ normalized_parts = []
+ for part in value.split("__"):
+ part = re.sub(r"[^a-zA-Z0-9_]", "_", part) # Remove special characters.
+ part = re.sub(r"_+", "_", part) # Remove duplicate "_".
+ part = part.strip("_") # Don't end or start in "_".
+ normalized_parts.append(part)
+ return "__".join(normalized_parts)
- Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
- Args:
- function: The callable to decorate.
- artifacts_dir: The directory to write the reproducer to.
- backend_id: The unique backend name to use when writting the reproducer.
+def is_complex(tensors: Union[Sequence[tf.TensorSpec], tf.TensorSpec]) -> bool:
+ if isinstance(tensors, Sequence):
+ for tensor in tensors:
+ if is_complex(tensor):
+ return True
+ return False
+ else:
+ return tensors.dtype.is_complex # pytype: disable=attribute-error
- Returns:
- A function with the same API as the passed function.
- """
+
+def _complex_wrapper(function):
+ """Wraps a tf.function to allow compiling functions of complex numbers."""
def decorator(*args, **kwargs):
- # Set up a crash reproducer for debugging.
- if artifacts_dir is not None:
- compiler.Context.default_crash_reproducer_path = os.path.join(
- artifacts_dir, f"reproducer__{backend_id}.mlir")
- try:
- results = function(*args, **kwargs)
- except Exception: # pylint: disable=broad-except
- # Disable the crash reproducer (to avoid inadvertently overwriting it).
- if artifacts_dir is not None:
- compiler.Context.default_crash_reproducer_path = None
- raise
- return results
+ inputs = []
+ for real, imag in zip(args[::2], args[1::2]):
+ inputs.append(tf.complex(real, imag))
+ result = function(*inputs, **kwargs)
+ # TODO(meadowlark): Support returning complex numbers.
+ return tf.math.real(result) + tf.math.imag(result)
return decorator
-def _incrementally_lower_compiler_module(
- compiler_module: compiler.Module,
- backend_info: "BackendInfo",
- artifacts_dir: str,
-) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
- """Lowers a MLIR compiler module incrementally and saves its outputs.
+def rewrite_complex_signature(function, signature: Sequence[tf.TensorSpec]):
+ """Compatibility layer for testing complex numbers."""
+ if not all([spec.dtype.is_complex for spec in signature]):
+ raise NotImplementedError("Signatures with mixed complex and non-complex "
+ "tensor specs are not supported.")
- If artifacts_dir is provided then the following artifacts will be saved:
- tf_input.mlir:
- MLIR for the module in TF's input dialect.
- iree_input.mlir:
- The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
- backend_id/compiled.vmfb:
- A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
+ # Rewrite the signature, replacing all complex tensors with pairs of real
+ # and imaginary tensors.
+ real_imag_signature = []
+ for spec in signature:
+ new_dtype = tf.float32 if spec.dtype.size == 8 else tf.float64
+ real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype))
+ real_imag_signature.append(tf.TensorSpec(spec.shape, new_dtype))
- Args:
- compiler_module: A compiler.Module to lower.
- backend_info: BackendInfo with the details for lowering compiler_module to
- IREE.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- if artifacts_dir is not None:
- os.makedirs(artifacts_dir, exist_ok=True)
- tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
- logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
- with open(tf_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
+ return _complex_wrapper(function), real_imag_signature
- # Manually run the passes that tf_module_to_compiler_module usually would.
- compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
- if artifacts_dir is not None:
- iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
- logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
- with open(iree_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
+def make_dims_dynamic(spec: tf.TensorSpec) -> tf.TensorSpec:
+ """Gives a tf.TensorSpec dynamic dims."""
+ return tf.TensorSpec([None] * len(spec.shape), spec.dtype)
- compiled_module = compiler_module.compile(
- target_backends=backend_info.compiler_targets)
- compiled_path = None
- if artifacts_dir is not None:
- backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
- os.makedirs(backend_dir, exist_ok=True)
- compiled_path = os.path.join(backend_dir, "compiled.vmfb")
- logging.info("Saving compiled IREE module to: %s", compiled_path)
- with open(compiled_path, "wb") as f:
- f.write(compiled_module)
- return compiled_module, compiled_path
+def check_same(ref: Any, tar: Any, rtol: float,
+ atol: float) -> Tuple[bool, Union[str, None]]:
+ """Checks that ref and tar have identical datastructures and values."""
+ # Check for matching types.
+ if not isinstance(tar, type(ref)):
+ error = ("Expected ref and tar to have the same type but got "
+ f"'{type(ref)}' and '{type(tar)}'")
+ logging.error(error)
+ return False, error
+ if ref is None:
+ # Nothing to compare (e.g. the called method had no outputs).
+ return True, None
-def _incrementally_compile_tf_module(
- module: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None,
-) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
- """Compile a TensorFlow tf.Module and optionally save compilation artifacts.
+ # Recursive check for dicts.
+ if isinstance(ref, dict):
+ if ref.keys() != tar.keys():
+ error = ("Expected ref and tar to have the same keys, but got "
+ f"'{ref.keys()}' and '{tar.keys()}'")
+ logging.error(error)
+ return False, error
+ # Check that all of the dictionaries' values are the same.
+ for key in ref:
+ same, error = check_same(ref[key], tar[key], rtol, atol)
+ if not same:
+ return same, error
- The module blob this creates is not callable. See IreeCompiledModule for an
- API that returns a module that can be called without any further steps.
+ # Recursive check for iterables.
+ elif isinstance(ref, list) or isinstance(ref, tuple):
+ if len(ref) != len(tar):
+ error = ("Expected ref and tar to have the same length, but got "
+ f"{len(ref)} and {len(tar)}")
+ logging.error(error)
+ return False, error
+ # Check that all of the iterables' values are the same.
+ for i in range(len(ref)):
+ same, error = check_same(ref[i], tar[i], rtol, atol)
+ if not same:
+ return same, error
- See _incrementally_lower_compiler_module's docstring for details about which
- artifacts will be saved.
+ # Base check for numpy arrays.
+ elif isinstance(ref, np.ndarray):
+ if ref.dtype != tar.dtype:
+ error = ("Expected ref and tar to have the same dtype, but got "
+ f"'{ref.dtype}' and '{tar.dtype}'")
+ logging.error(error)
+ return False, error
+ if ref.size == tar.size == 0:
+ return True, None
- Args:
- module: A tf.Module.
- backend_info: BackendInfo with the details for compiling this module.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
-
- Returns:
- A compiled IREE module blob and the path to the compiled VM FlatBuffer if
- artifacts_dir is provided.
- """
-
- def _compile_module(module, backend_info, exported_names, artifacts_dir):
- compiler_module = compiler.tf_module_to_compiler_module(module,
- exported_names,
- pass_pipeline=())
- return _incrementally_lower_compiler_module(compiler_module, backend_info,
- artifacts_dir)
-
- _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
- backend_info.backend_id)
- return _compile_module(module, backend_info, exported_names, artifacts_dir)
-
-
-def _incrementally_compile_tf_signature_def_saved_model(
- saved_model_dir: str, saved_model_tags: Set[str],
- backend_info: "BackendInfo", exported_name: str, artifacts_dir: str):
- """Compile a SignatureDef SavedModel and optionally save compilation artifacts.
-
- The module blob this creates is not callable. See IreeCompiledModule for an
- API that returns a module that can be called without any further steps.
-
- See _incrementally_lower_compiler_module's docstring for details about which
- artifacts will be saved.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
-
- Returns:
- A compiled IREE module blob and the path to the compiled VM FlatBuffer if
- artifacts_dir is provided.
- """
-
- def _compile_module(saved_model_dir, saved_model_tags, backend_info,
- exported_name, artifacts_dir):
- # Convert the tf_module into raw TF input MLIR.
- compiler_module = compiler.tf_signature_def_saved_model_to_compiler_module(
- saved_model_dir, saved_model_tags, [exported_name], pass_pipeline=())
- return _incrementally_lower_compiler_module(compiler_module, backend_info,
- artifacts_dir)
-
- _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
- backend_info.backend_id)
- return _compile_module(saved_model_dir, saved_model_tags, backend_info,
- exported_name, artifacts_dir)
-
-
-class _FunctionWrapper(object):
-
- def __call__(self, *args, **kwargs):
- raise NotImplementedError()
-
- def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
- """Dummy function to match _IreeFunctionWrapper's API."""
- return ("",), ("",)
-
-
-class CompiledModule(object):
- """Base class for the TF and IREE compiled modules."""
-
- def __init__(
- self,
- module_name: str,
- backend_info: "BackendInfo",
- compiled_paths: Union[Dict[str, str], None],
- ):
- """Shared base constructor – not useful on its own.
-
- Args:
- module_name: A name for this compiled module. In most cases this will be
- the name of the tf.Module subclass or instance that is compiled.
- backend_info: BackendInfo with the details about compiling this module.
- compiled_paths: A dictionary mapping compiled method names to file paths
- corresponding to their serialized representations.
- """
- self.module_name = module_name
- self.backend_info = backend_info
- self.compiled_paths = compiled_paths
-
- def reinitialize(self):
- """Reinitializes all stateful variables."""
- raise NotImplementedError()
-
- @classmethod
- def create_from_class(cls,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module subclass to the target backend in backend_info.
-
- Args:
- module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling this module.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- raise NotImplementedError()
-
- @classmethod
- def create_from_instance(cls,
- module_instance: tf.Module,
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module instance to the target backend in backend_info.
-
- This is only implemented for IreeCompiledModule.
-
- Args:
- module_instance: The tf.Module instance to compile.
- backend_info: BackendInfo with the details for compiling module to IREE.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- raise NotImplementedError()
-
- @classmethod
- def create_from_signature_def_saved_model(cls,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- backend_info: "BackendInfo",
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None):
- """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- module_name: A name for this compiled module.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- raise NotImplementedError()
-
- def __getattr__(self, attr: str) -> _FunctionWrapper:
- raise NotImplementedError()
-
- def iree_serializable(self):
- return False
-
- def tflite_serializable(self):
- return False
-
-
-class _IreeFunctionWrapper(_FunctionWrapper):
- """Wraps an IREE function, making it callable."""
-
- def __init__(self, context: rt.SystemContext, f: rt.system_api.BoundFunction):
- self._context = context
- self._f = f
-
- def __call__(self, *args, **kwargs):
- return self._f(*args, **kwargs)
-
- def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
- """Get cxx serialized inputs and outputs for this function."""
- return self._f.get_serialized_values()
-
-
-class IreeCompiledModule(CompiledModule):
- """Iree compiled module."""
-
- def __init__(
- self,
- module_name: str,
- backend_info: "BackendInfo",
- compiled_paths: Dict[str, str],
- vm_module: rt.VmModule,
- config: rt.Config,
- ):
- """Base constructor – Use one of the named constructors instead.
-
- Args:
- module_name: A name for this compiled module. In most cases this will be
- the name of the tf.Module subclass or instance that is compiled.
- backend_info: BackendInfo with the details about compiling this module.
- compiled_paths: A dictionary mapping compiled method names to file paths
- corresponding to their serialized representations.
- vm_module: A rt.VmModule containing compilation info to wrap.
- config: A rt.Config containing compilation info to wrap.
- """
- super().__init__(module_name, backend_info, compiled_paths)
- self._vm_module = vm_module
- self._config = config
- self.reinitialize()
-
- @classmethod
- def create_from_class(cls,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module subclass to the target backend in backend_info.
-
- Args:
- module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling module to IREE.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- set_random_seed()
- module_instance = module_class()
- return cls.create_from_instance(module_instance, backend_info,
- exported_names, artifacts_dir)
-
- @classmethod
- def create_from_instance(cls,
- module_instance: tf.Module,
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module instance to the target backend in backend_info.
-
- Args:
- module_instance: The tf.Module instance to compile.
- backend_info: BackendInfo with the details for compiling module to IREE.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- module_blob, compiled_path = _incrementally_compile_tf_module(
- module=module_instance,
- backend_info=backend_info,
- exported_names=exported_names,
- artifacts_dir=artifacts_dir)
- vm_module = rt.VmModule.from_flatbuffer(module_blob)
- config = rt.Config(driver_name=backend_info.driver)
-
- compiled_paths = None
- if compiled_path is not None:
- # IREE bundles every compiled method into the same compiled module.
- compiled_paths = collections.defaultdict(lambda: compiled_path)
-
- module_name = type(module_instance).__name__
-
- return cls(module_name, backend_info, compiled_paths, vm_module, config)
-
- @classmethod
- def create_from_signature_def_saved_model(cls,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- backend_info: "BackendInfo",
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None):
- """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- module_name: A name for this compiled module.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- del input_names # Unused.
- del output_names # Unused.
- module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
- saved_model_dir, saved_model_tags, backend_info, exported_name,
- artifacts_dir)
- vm_module = rt.VmModule.from_flatbuffer(module_blob)
- config = rt.Config(driver_name=backend_info.driver)
-
- compiled_paths = None
- if compiled_path is not None:
- # IREE bundles every compiled method into the same compiled module :)
- compiled_paths = collections.defaultdict(lambda: compiled_path)
-
- return cls(module_name, backend_info, compiled_paths, vm_module, config)
-
- def reinitialize(self):
- """Reinitializes all stateful variables."""
- # set_random_seed is not needed here because the model_class.__init__ is not
- # called.
- self._context = rt.SystemContext(modules=[self._vm_module],
- config=self._config)
-
- def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
- # Try to resolve it as a function.
- m = self._context.modules[self._vm_module.name]
- f = m[attr]
- return _IreeFunctionWrapper(self._context, f)
-
- def iree_serializable(self) -> bool:
- return self.compiled_paths is not None
-
-
-def _normalize_numpy(result: np.ndarray):
- """Normalizes TF and TFLite's outputs to match IREE's"""
- if np.isscalar(result):
- result = np.array(result)
- if result.dtype == np.bool:
- # IREE interprets bools as int8s, so we modify this for comparison.
- result = result.astype(dtype=np.int8)
- return result
-
-
-class _TfFunctionWrapper(_FunctionWrapper):
- """Wraps a TF function, normalizing it to numpy."""
-
- def __init__(self, f: Callable[..., Any]):
- self._f = f
-
- def _convert_to_numpy(self, tensor: Any) -> Any:
- if not isinstance(tensor, tf.Tensor):
- return tensor
- return _normalize_numpy(tensor.numpy())
-
- def __call__(self, *args, **kwargs):
- # TensorFlow will auto-convert all inbound args.
- results = self._f(*args, **kwargs)
- # Then unmarshal them to numpy in the same way that the other backends do.
- # Handle single result (technically ambiguous with return of a tuple,
- # which is sad).
- if not isinstance(results, tuple):
- results = (results,)
- return tf.nest.map_structure(self._convert_to_numpy,
- *results,
- check_types=False)
-
-
-def _convert_inputs_to_tensors(function):
-
- def decorator(*args, **kwargs):
- args = [tf.convert_to_tensor(arg) for arg in args]
- kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()}
- return function(*args, **kwargs)
-
- return decorator
-
-
-class SignatureDefSavedModelWrapper(object):
- """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'."""
-
- def __init__(self, saved_model_dir: str, saved_model_tags: Set[str],
- exported_name: str):
- self.saved_model = tf.saved_model.load(saved_model_dir,
- tags=saved_model_tags)
- inference_func = self.saved_model.signatures[exported_name]
- inference_func = _convert_inputs_to_tensors(inference_func)
- self.__setattr__(exported_name, inference_func)
-
-
-class TfCompiledModule(CompiledModule):
- """TensorFlow 'compiled' module.
-
- This facade exists to provide a complimentary API to IreeCompiledModule and
- normalize TensorFlow's output to Numpy.
- """
-
- def __init__(
- self,
- module_name: str,
- backend_info: "BackendInfo",
- constructor: Callable[[], tf.Module],
- exported_names: Sequence[str],
- ):
- """Base constructor – Use one of the named constructors instead.
-
- Args:
- module_name: A name for this compiled module. In most cases this will be
- the name of the tf.Module subclass or instance that is compiled.
- backend_info: BackendInfo with the details about compiling this module.
- constructor: A callable (class or function) which returns the tf.Module
- subclass instance to wrap.
- exported_names: an optional iterable of strings representing which of the
- tf.Module subclass instance's functions should be callable. If
- exported_names is empty then all functions will be callable.
- """
- super().__init__(module_name, backend_info, compiled_paths=None)
- self._constructor = constructor
- self._exported_names = exported_names
- self.reinitialize()
-
- @classmethod
- def create_from_class(cls,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module subclass to the target backend in backend_info.
-
- Args:
- module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling this module.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- module_name = module_class.__name__
- constructor = module_class
- return cls(module_name, backend_info, constructor, exported_names)
-
- @classmethod
- def create_from_signature_def_saved_model(cls,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- backend_info: "BackendInfo",
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None):
- """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- module_name: A name for this compiled module.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- constructor = lambda: SignatureDefSavedModelWrapper(
- saved_model_dir, saved_model_tags, exported_name)
- return cls(module_name, backend_info, constructor, [exported_name])
-
- def reinitialize(self):
- """Reinitializes all stateful variables."""
- set_random_seed()
- self._tf_module = self._constructor()
-
- def __getattr__(self, attr: str) -> _TfFunctionWrapper:
- # Try to resolve it as a function.
- exported = not self._exported_names or attr in self._exported_names
- if not hasattr(self._tf_module, attr) or not exported:
- raise AttributeError(f"The TensorFlow module does not have attr '{attr}'")
- f = getattr(self._tf_module, attr)
- if not f or not hasattr(f, "__call__"):
- raise AttributeError(
- f"The TensorFlow module does not have a callable attr '{attr}'")
- return _TfFunctionWrapper(f)
-
-
-def _get_non_inhereted_function_names(cls):
- """Gets all methods that cls has that its parents don't have."""
- names = set(dir(cls))
- for parent in cls.__bases__:
- names -= set(dir(parent))
- return list(names)
-
-
-def _get_concrete_functions(module_class: Type[tf.Module],
- exported_names: Sequence[str] = ()):
- """Get concrete functions from non-inherited methods or exported_names."""
- if not len(exported_names):
- # Get all method names on 'module_class' that aren't on 'tf.Module'.
- exported_names = _get_non_inhereted_function_names(module_class)
- instance = module_class()
- functions = []
- for name in exported_names:
- functions.append(getattr(instance, name).get_concrete_function())
- return functions, exported_names
-
-
-def tf_module_to_tflite_module_bytes(
- module_class: Type[tf.Module], exported_names: Sequence[str] = ()
-) -> Dict[str, bytes]:
- """Compiles a tf.Module's methods with TFLite.
-
- Args:
- module_class: A tf.Module subclass to compile with TFLite.
- exported_names: an optional iterable of strings representing which of the
- module_class's functions should be compiled. If exported_names is empty
- then all functions will be compiled.
-
- Returns:
- A dict mapping method names to compiled TFLite module bytes.
- """
- tflite_modules = []
- methods, method_names = _get_concrete_functions(module_class, exported_names)
- for method in methods:
- converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
- tflite_modules.append(converter.convert())
- return dict(zip(method_names, tflite_modules))
-
-
-def tf_signature_def_saved_model_to_tflite_module_bytes(
- saved_model_dir: str,
- saved_model_tags: Set[str],
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
-) -> Dict[str, bytes]:
- """Compiles a SignatureDef SavedModel signature with TFLite.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
-
- Returns:
- A dict mapping the signature name to the compiled TFLite module bytes.
- """
- converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
- saved_model_dir,
- tag_set=saved_model_tags,
- signature_key=exported_name,
- input_arrays=input_names,
- output_arrays=output_names)
- tflite_module = converter.convert()
- return dict([[exported_name, tflite_module]])
-
-
-def tflite_module_bytes_to_tflite_interpreters(
- tflite_module_bytes: Dict[str, bytes],
- artifacts_dir: str = None
-) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str], None]]:
- """Compile a dict of TFLite compiled bytes to TFLite interpreters.
-
- Args:
- tflite_module_bytes: A dict mapping method names to compiled TFLite byte
- strings.
- artifacts_dir: an optional path to save compilation artifacts to.
-
- Returns:
- A dictionary mapping method names to TFLite interpreters and a dictionary
- mapping method names to compiled tflite graph paths (or None if
- artifacts_dir is None).
- """
- interpreters = dict()
- compiled_paths = None
- if artifacts_dir is not None:
- compiled_paths = dict()
-
- def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str):
- """Save compiled TFLite module bytes and convert into an interpreter."""
- tflite_dir = os.path.join(base_dir, "tflite")
- os.makedirs(tflite_dir, exist_ok=True)
- tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite")
- with open(tflite_path, "wb") as f:
- f.write(tflite_module)
-
- interpreters[method_name] = tf.lite.Interpreter(tflite_path)
- if artifacts_dir is not None:
- compiled_paths[method_name] = tflite_path
-
- # Load each of the converted methods above into tf.lite.Interpreters.
- for method_name, tflite_module in tflite_module_bytes.items():
- if artifacts_dir is None:
- with tempfile.TemporaryDirectory() as base_dir:
- _interpret_bytes(method_name, tflite_module, base_dir)
- else:
- _interpret_bytes(method_name, tflite_module, artifacts_dir)
-
- return interpreters, compiled_paths
-
-
-class _TfLiteFunctionWrapper(_FunctionWrapper):
- """Wraps a TFLite interpreter and makes it behave like a python function."""
-
- def __init__(self, interpreter: tf.lite.Interpreter,
- output_names: Sequence[str]):
- self._interpreter = interpreter
- self._output_names = output_names
-
- def __call__(self, *args,
- **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]:
- if len(args) and len(kwargs):
- raise ValueError("Passing both args and kwargs is not supported by "
- "_TfLiteFunctionWrapper")
-
- # Set up and run the function.
- self._interpreter.allocate_tensors()
-
- if len(args):
- # Specifically to get TFLite to work with keras models that take a list of
- # inputs instead of a sequence of args as their inputs, because it decides
- # to change the input signature but it still technically works if you
- # ignore that it does that.
- if len(args) == 1 and isinstance(args[0], list):
- args = args[0]
-
- for arg, detail in zip(args, self._interpreter.get_input_details()):
- self._interpreter.set_tensor(detail["index"], arg)
- else:
- for detail in self._interpreter.get_input_details():
- self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]])
-
- self._interpreter.invoke()
-
- # Extract the outputs from the TFLite interpreter.
- outputs = []
- for detail in self._interpreter.get_output_details():
- value = _normalize_numpy(self._interpreter.get_tensor(detail["index"]))
- if self._output_names is not None:
- name = detail["name"]
- if name not in self._output_names:
- raise ValueError(f"Expected '{name}' to be in {self._output_names}")
- outputs.append([detail["name"], value])
+ if np.issubdtype(ref.dtype, np.floating):
+ same = np.allclose(ref, tar, rtol=rtol, atol=atol, equal_nan=True)
+ abs_diff = np.max(np.abs(ref - tar))
+ rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
+ diff_string = (f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, "
+ f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}")
+ if not same:
+ error = ("Floating point difference between ref and tar was too "
+ f"large. {diff_string}")
+ logging.error(error)
else:
- outputs.append(value)
-
- # Process them to match the output of the tf.Module.
- if self._output_names is not None:
- return dict(outputs)
+ error = None
+ logging.info(
+ "Floating point difference between ref and tar was within "
+ "tolerance. %s", diff_string)
+ return same, error
+ elif np.issubdtype(ref.dtype, np.integer):
+ same = np.array_equal(ref, tar)
+ if not same:
+ abs_diff = np.max(np.abs(ref - tar))
+ error = ("Expected array equality between ref and tar, but got "
+ f"a max elementwise difference of {abs_diff}")
+ logging.error(error)
+ else:
+ error = None
+ return same, error
else:
- if len(outputs) == 1:
- return outputs[0]
- return tuple(outputs)
+ return np.array_equal(ref, tar), None
+ # Base check for native number types.
+ elif isinstance(ref, (int, float)):
+ return ref == tar, None
-class TfLiteCompiledModule(CompiledModule):
- """Compiles a tf.Module with TFLite and allows it to be called."""
-
- def __init__(
- self,
- module_name: str,
- backend_info: "BackendInfo",
- compiled_paths: Dict[str, str],
- interpreters: Dict[str, tf.lite.Interpreter],
- output_names: Sequence[str] = None,
- ):
- """Base constructor – Use one of the named constructors instead.
-
- Args:
- module_name: A name for this compiled module. In most cases this will be
- the name of the tf.Module subclass or instance that is compiled.
- backend_info: BackendInfo with the details about compiling this module.
- compiled_paths: A dictionary mapping compiled method names to file paths
- corresponding to their serialized representations.
- interpreters: A dict of tf.lite.Interpreters to make callable.
- """
- super().__init__(module_name, backend_info, compiled_paths)
- self._interpreters = interpreters
- self._output_names = output_names
-
- @classmethod
- def create_from_class(cls,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module subclass to the target backend in backend_info.
-
- Args:
- module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling this module.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- set_random_seed()
- tflite_module_bytes = tf_module_to_tflite_module_bytes(
- module_class, exported_names)
- interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
- tflite_module_bytes, artifacts_dir)
- module_name = module_class.__name__
- return cls(module_name, backend_info, compiled_paths, interpreters)
-
- @classmethod
- def create_from_signature_def_saved_model(cls,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- backend_info: "BackendInfo",
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None):
- """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- module_name: A name for this compiled module.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes(
- saved_model_dir, saved_model_tags, exported_name, input_names,
- output_names)
- interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
- tflite_module_bytes, artifacts_dir)
- return cls(module_name, backend_info, compiled_paths, interpreters,
- output_names)
-
- def reinitialize(self):
- """Reinitializes all stateful variables."""
- # This is a noop because TFLite (mostly) doesn't support stateful modules.
- pass
-
- def __getattr__(self, attr: str) -> _TfLiteFunctionWrapper:
- # Try to resolve it as an interpreter.
- if not attr in self._interpreters:
- raise AttributeError(
- f"The TFLite module does not have an interpreter for '{attr}'")
- return _TfLiteFunctionWrapper(self._interpreters[attr], self._output_names)
-
- def tflite_serializable(self) -> bool:
- return self.compiled_paths is not None
-
-
-class BackendInfo:
- """Contains information for compiling the specified backend."""
-
- _name_to_info = {
- "tf": {
- "compiled_module_class": TfCompiledModule,
- "driver": None,
- "compiler_targets": None,
- },
- "tflite": {
- "compiled_module_class": TfLiteCompiledModule,
- "driver": None,
- "compiler_targets": None,
- },
- "iree_vmla": {
- "compiled_module_class": IreeCompiledModule,
- "driver": "vmla",
- "compiler_targets": ["vmla"]
- },
- "iree_llvmjit": {
- "compiled_module_class": IreeCompiledModule,
- "driver": "llvm",
- "compiler_targets": ["llvm-ir"]
- },
- "iree_vulkan": {
- "compiled_module_class": IreeCompiledModule,
- "driver": "vulkan",
- "compiler_targets": ["vulkan-*"]
- },
- }
-
- def __init__(self, backend_name: str, backend_id: str = None):
- """Creates a BackendInfo with the compilation details for backend_name.
-
- Args:
- backend_name: a str specifying which backend to use. Should be one of
- 'tf', 'iree_vmla', 'iree_llvmjit', 'iree_vulkan'.
- backend_id: an optional str specifying what name to use when saving
- compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
-
- Raises:
- KeyError: if backend_name is not one of ['tf', 'iree_vmla',
- 'iree_llvmjit', 'iree_vulkan'].
- ValueError: if backend_id doesn't start with backend_name.
- """
- if backend_name not in self._name_to_info:
- raise KeyError(
- "Expected backend_name to be one of "
- f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
- if backend_id is not None and not backend_id.startswith(backend_name):
- raise ValueError(f"Expected backend_id to start with '{backend_name}' "
- f"but got '{backend_id}'.")
-
- self.backend_name = backend_name
- self.backend_id = backend_name if backend_id is None else backend_id
-
- info = self._name_to_info[backend_name]
- self._compiled_module_class = info["compiled_module_class"]
- self.driver = info["driver"]
- self.compiler_targets = info["compiler_targets"]
-
- def compile_from_class(self,
- module_class: Type[tf.Module],
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None) -> CompiledModule:
- """Creates a 'CompiledModule' for this backend."""
- return self._compiled_module_class.create_from_class(
- module_class, self, exported_names, artifacts_dir)
-
- def compile_signature_def_saved_model(
- self,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None) -> CompiledModule:
- return self._compiled_module_class.create_from_signature_def_saved_model(
- saved_model_dir, saved_model_tags, module_name, self, exported_name,
- input_names, output_names, artifacts_dir)
-
- @classmethod
- def get_all_backends(cls) -> Sequence["BackendInfo"]:
- """Returns a list of all BackendInfo configurations."""
- return [BackendInfo(backend_name) for backend_name in cls._name_to_info]
+ # If outputs end up here then an extra branch for that type should be added.
+ else:
+ raise TypeError(f"Encountered results with unexpected type {type(ref)}")
+ return True, None
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 9934460..deef8d9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -14,125 +14,87 @@
# limitations under the License.
"""Tests for pyiree.tf.support.tf_utils."""
-import os
-import tempfile
-
-from absl import logging
from absl.testing import parameterized
import numpy as np
from pyiree.tf.support import tf_utils
import tensorflow as tf
-class ConstantModule(tf.Module):
-
- @tf.function(input_signature=[])
- def meaning(self):
- return tf.constant([42.])
-
-
-class StatefulCountingModule(tf.Module):
-
- def __init__(self):
- self.count = tf.Variable([0.])
-
- @tf.function(input_signature=[])
- def increment(self):
- self.count.assign_add(tf.constant([1.]))
-
- @tf.function(input_signature=[])
- def get_count(self):
- return self.count
-
-
-class RandomInitModule(tf.Module):
-
- def __init__(self):
- self.value = tf.Variable(tf.random.uniform([1]))
-
- @tf.function(input_signature=[])
- def get(self):
- return self.value
-
-
class UtilsTests(tf.test.TestCase, parameterized.TestCase):
- def test_artifact_saving(self):
- backend_info = tf_utils.BackendInfo('iree_vmla')
- with tempfile.TemporaryDirectory() as artifacts_dir:
- tf_module = ConstantModule()
- iree_compiled_module, compiled_path = (
- tf_utils._incrementally_compile_tf_module(
- tf_module, backend_info=backend_info,
- artifacts_dir=artifacts_dir))
+ @parameterized.named_parameters([('int8_to_i8', np.int8, 'i8'),
+ ('int32_to_i32', np.int32, 'i32'),
+ ('float32_to_f32', np.float32, 'f32'),
+ ('float64_to_f64', np.float64, 'f64')])
+ def test_to_mlir_type(self, numpy_type, mlir_type):
+ self.assertEqual(tf_utils.to_mlir_type(numpy_type), mlir_type)
- artifacts_to_check = [
- 'tf_input.mlir',
- 'iree_input.mlir',
- compiled_path,
- ]
- for artifact in artifacts_to_check:
- artifact_path = os.path.join(artifacts_dir, artifact)
- logging.info('Checking path: %s', artifact_path)
- self.assertTrue(os.path.exists(artifact_path))
+ @parameterized.named_parameters([
+ ('single_i32', [np.array([1, 2], dtype=np.int32)], '2xi32=1 2'),
+ ('single_f32', [np.array([1, 2], dtype=np.float32)], '2xf32=1.0 2.0'),
+ ])
+ def test_save_input_values(self, inputs, inputs_str):
+ self.assertEqual(tf_utils.save_input_values(inputs), inputs_str)
+
+ def test_apply_function(self):
+ inputs = [1, [2, 3], (4, 5), {'6': 6, '78': [7, 8]}]
+ expected = [0, [1, 2], (3, 4), {'6': 5, '78': [6, 7]}]
+ result = tf_utils.apply_function(inputs, lambda x: x - 1)
+ self.assertEqual(result, expected)
+ self.assertNotEqual(inputs, expected)
@parameterized.named_parameters([
{
- 'testcase_name': 'tensorflow',
- 'backend_name': 'tf',
+ 'testcase_name': 'all the same',
+ 'array_c': np.array([0, 1, 2]),
+ 'array_d': np.array(['0', '1', '2']),
+ 'array_e': np.array([0.0, 0.1, 0.2]),
+ 'tar_same': True,
},
{
- 'testcase_name': 'vmla',
- 'backend_name': 'iree_vmla',
+ 'testcase_name': 'wrong int',
+ 'array_c': np.array([1, 1, 2]),
+ 'array_d': np.array(['0', '1', '2']),
+ 'array_e': np.array([0.0, 0.1, 0.2]),
+ 'tar_same': False,
+ },
+ {
+ 'testcase_name': 'wrong string',
+ 'array_c': np.array([0, 1, 2]),
+ 'array_d': np.array(['a', '1', '2']),
+ 'array_e': np.array([0.0, 0.1, 0.2]),
+ 'tar_same': False,
+ },
+ {
+ 'testcase_name': 'wrong float',
+ 'array_c': np.array([0, 1, 2]),
+ 'array_d': np.array(['0', '1', '2']),
+ 'array_e': np.array([1.0, 0.1, 0.2]),
+ 'tar_same': False,
},
])
- def test_unaltered_state(self, backend_name):
- backend_info = tf_utils.BackendInfo(backend_name)
- module = backend_info.compile_from_class(StatefulCountingModule)
+ def test_recursive_check_same(self, array_c, array_d, array_e, tar_same):
- # Test that incrementing works properly.
- self.assertEqual([0.], module.get_count())
- module.increment()
- self.assertEqual([1.], module.get_count())
-
- module.reinitialize()
- # Test reinitialization.
- self.assertEqual([0.], module.get_count())
-
- def test_to_mlir_type(self):
- self.assertEqual('i8', tf_utils.to_mlir_type(np.dtype('int8')))
- self.assertEqual('i32', tf_utils.to_mlir_type(np.dtype('int32')))
- self.assertEqual('f32', tf_utils.to_mlir_type(np.dtype('float32')))
- self.assertEqual('f64', tf_utils.to_mlir_type(np.dtype('float64')))
-
- def test_save_input_values(self):
- inputs = [np.array([1, 2], dtype=np.int32)]
- self.assertEqual('2xi32=1 2', tf_utils.save_input_values(inputs))
- inputs = [np.array([1, 2], dtype=np.float32)]
- self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))
-
- @parameterized.named_parameters([
- {
- 'testcase_name': 'tensorflow',
- 'backend_name': 'tf',
- },
- {
- 'testcase_name': 'vmla',
- 'backend_name': 'iree_vmla',
- },
- ])
- def test_random_initialization(self, backend_name):
- backend_info = tf_utils.BackendInfo(backend_name)
-
- # Test compilation is the same.
- module_1 = backend_info.compile_from_class(RandomInitModule)
- module_2 = backend_info.compile_from_class(RandomInitModule)
- self.assertAllEqual(module_1.get(), module_2.get())
-
- # Test reinitialization is the same.
- old_value = module_1.get()
- module_1.reinitialize()
- self.assertAllEqual(old_value, module_1.get())
+ # yapf: disable
+ ref = {
+ 'a': 1,
+ 'b': [
+ {'c': np.array([0, 1, 2])},
+ {'d': np.array(['0', '1', '2'])},
+ {'e': np.array([0.0, 0.1, 0.2])}
+ ],
+ }
+ tar = {
+ 'a': 1,
+ 'b': [
+ {'c': array_c},
+ {'d': array_d},
+ {'e': array_e}
+ ],
+ }
+ # yapf: enable
+ same, _ = tf_utils.check_same(ref, tar, rtol=1e-6, atol=1e-6)
+ self.assertEqual(tar_same, same)
if __name__ == '__main__':
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils.py
new file mode 100644
index 0000000..1a0c789
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils.py
@@ -0,0 +1,421 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for tracing tf.function inputs and outputs."""
+
+# This file uses the following abbreviations:
+# ref: reference – for the reference CompiledModule
+# tar: target - for one of the target CompiledModules
+
+import copy
+import glob
+import inspect
+import os
+import pickle
+import sys
+import textwrap
+from typing import Any, Callable, Dict, Sequence, Tuple, Union
+
+from absl import logging
+import numpy as np
+from pyiree.tf.support import module_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+NUMPY_LINEWIDTH = 120
+INDENT = " " * 2
+
+
+def _zfill_width(length: int) -> Union[int, None]:
+ return int(np.ceil(np.log10(length))) if length else None
+
+
+def get_trace_dir(artifacts_dir: str, trace: "Trace") -> str:
+ trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
+ trace.function_name)
+ os.makedirs(trace_dir, exist_ok=True)
+ return trace_dir
+
+
+class ModuleCall:
+
+ def __init__(self,
+ method: str,
+ inputs: Tuple[Any],
+ outputs: Tuple[Any],
+ serialized_inputs: Tuple[str],
+ serialized_outputs: Tuple[str],
+ rtol: float = 1e-6,
+ atol: float = 1e-6):
+ """Records the details of a call to a CompiledModule."""
+ self.method = method
+
+ # Deepcopy to safegard against mutation.
+ self.inputs = copy.deepcopy(inputs)
+ if outputs is not None:
+ outputs = copy.deepcopy(outputs)
+ else:
+ outputs = tuple()
+ self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
+
+ self.serialized_inputs = serialized_inputs
+ self.serialized_outputs = serialized_outputs
+
+ self.rtol = rtol
+ self.atol = atol
+
+ def get_tolerances(self) -> Tuple[float, float]:
+ """Gets the floating point tolerances associated with this call."""
+ return self.rtol, self.atol
+
+ def _get_shape_and_dtype(self, value: Any) -> str:
+ if isinstance(value, np.ndarray):
+ return tf_utils.get_shape_and_dtype(value, allow_non_mlir_dtype=True)
+ else:
+ return str(type(value))
+
+ def __str__(self):
+ prior_printoptions = np.get_printoptions()
+ np.set_printoptions(linewidth=NUMPY_LINEWIDTH)
+
+ header = f"Method: {self.method}"
+ inputs = "\n".join(
+ [textwrap.indent(str(value), INDENT) for value in self.inputs])
+ input_shapes = ", ".join(
+ self._get_shape_and_dtype(value) for value in self.inputs)
+
+ outputs = "\n".join(
+ [textwrap.indent(str(value), INDENT) for value in self.outputs])
+ output_shapes = ", ".join(
+ self._get_shape_and_dtype(value) for value in self.outputs)
+
+ tolerances = textwrap.indent(f"rtol={self.rtol}, atol={self.atol}", INDENT)
+ body = (f"Inputs: {input_shapes}\n{inputs}\n"
+ f"Outputs: {output_shapes}\n{outputs}"
+ f"\nTolerances:\n{tolerances}")
+ result = f"{header}\n{textwrap.indent(body, INDENT)}"
+
+ np.set_printoptions(**prior_printoptions)
+ return result
+
+ def serialize(self, call_dir: str) -> None:
+ """Stores a serialized copy of this call.
+
+ Can be loaded via ModuleCall.load(call_dir)
+
+ Args:
+ call_dir: str, the path to the directory to serialize this call to.
+ """
+ os.makedirs(call_dir, exist_ok=True)
+
+ metadata = {
+ "method": self.method,
+ "serialized_inputs": self.serialized_inputs,
+ "serialized_outputs": self.serialized_outputs,
+ "rtol": self.rtol,
+ "atol": self.atol
+ }
+ with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f:
+ pickle.dump(metadata, f)
+
+ width = _zfill_width(len(self.inputs))
+ for i, value in enumerate(self.inputs):
+ path = os.path.join(call_dir, f"input_{str(i).zfill(width)}.pkl")
+ with open(path, "wb") as f:
+ pickle.dump(value, f)
+
+ width = _zfill_width(len(self.outputs))
+ for i, value in enumerate(self.outputs):
+ path = os.path.join(call_dir, f"output_{str(i).zfill(width)}.pkl")
+ with open(path, "wb") as f:
+ pickle.dump(value, f)
+
+ @staticmethod
+ def load(call_dir: str) -> "ModuleCall":
+ """Loads and returns a trace serialized with ModuleCall.serialize."""
+ with open(os.path.join(call_dir, "metadata.pkl"), "rb") as f:
+ kwargs = pickle.load(f)
+
+ for result_type in ["input", "output"]:
+ key = f"{result_type}s" # inputs or outputs
+ kwargs[key] = []
+
+ files = glob.glob(os.path.join(call_dir, f"{result_type}_*.pkl"))
+ for filename in sorted(files):
+ with open(filename, "rb") as f:
+ kwargs[key].append(pickle.load(f))
+
+ # Convert to tuple to match python's return type for multiple results.
+ kwargs[key] = tuple(kwargs[key])
+
+ return ModuleCall(**kwargs)
+
+
+class Trace:
+ """Stores the inputs and outputs of a series of calls to a module."""
+
+ def __init__(self,
+ module: Union[module_utils.CompiledModule, None],
+ function: Union[Callable[["TracedModule"], None], None],
+ _load_dict: Dict[str, Any] = None):
+ """Extracts metadata from module and function and initializes.
+
+ Example usage:
+ def forward_pass(...):
+ ...
+ module = IreeCompiledModule(...)
+ trace = Trace(module, forward_pass)
+ forward_pass(TracedModule(module, trace))
+
+ Args:
+ module: the module who's outputs this trace will record.
+ function: the function that module will be traced on.
+ _load_dict: used internally
+ """
+ if _load_dict is None:
+ # Extract metadata from module and function.
+ self.module_name = module.module_name
+ self.compiled_paths = module.compiled_paths
+ self.backend_name = module.backend_info.backend_name
+ self.backend_id = module.backend_info.backend_id
+ self.backend_driver = module.backend_info.driver
+ self.iree_serializable = module.iree_serializable()
+ self.tflite_serializable = module.tflite_serializable()
+ self.function_name = function.__name__
+ self.function_sourcefile = inspect.getsourcefile(function)
+ source, start_line = inspect.getsourcelines(function)
+ self.function_line_numbers = (start_line, start_line + len(source))
+ self.function_source = "".join(source)
+
+ self.calls = []
+ else:
+ self.module_name = _load_dict["module_name"]
+ self.compiled_paths = _load_dict["compiled_paths"]
+ self.backend_name = _load_dict["backend_name"]
+ self.backend_id = _load_dict["backend_id"]
+ self.backend_driver = _load_dict["backend_driver"]
+ self.iree_serializable = _load_dict["iree_serializable"]
+ self.tflite_serializable = _load_dict["tflite_serializable"]
+ self.function_name = _load_dict["function_name"]
+ self.function_sourcefile = _load_dict["function_sourcefile"]
+ self.function_line_numbers = _load_dict["function_line_numbers"]
+ self.function_source = _load_dict["function_source"]
+ self.calls = _load_dict["calls"]
+
+ def __str__(self):
+ header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
+ f"on function '{self.function_name}':")
+ # Give each call a number so it's easier to compare between multiple traces.
+ calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
+ calls = textwrap.indent("\n".join(calls), prefix=INDENT)
+ return f"{header}\n{calls}"
+
+ def __iter__(self):
+ for call in self.calls:
+ yield call
+
+ def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None:
+ """Saves a human-readable string representation of this trace to disk.
+
+ Args:
+ trace_dir: str, path to the directory to save the trace in.
+ summarize: a bool controlling whether numpy should summarize the inputs
+ and outputs if they're large. Setting this to False is very slow for
+ large outputs.
+ """
+ prior_printoptions = np.get_printoptions()
+ np.set_printoptions(
+ linewidth=NUMPY_LINEWIDTH,
+ threshold=None if summarize else sys.maxsize,
+ edgeitems=10) # Can show more items since they won't clutter the logs.
+
+ path = os.path.join(trace_dir, "log.txt")
+ with open(path, "w") as f:
+ f.write(str(self))
+ f.write("\n")
+
+ np.set_printoptions(**prior_printoptions)
+
+ def serialize(self, trace_dir: str) -> None:
+ """Stores a serialized copy of this trace in trace_dir.
+
+ It can be loaded via `Trace.load(trace_dir)`.
+
+ Args:
+ trace_dir: str, path to the directory to serialize the trace to.
+ """
+
+ compiled_paths = None
+ if self.compiled_paths is not None:
+ # Convert to a dict to avoid the issues with serializing defaultdicts.
+ compiled_paths = dict(self.compiled_paths)
+
+ # Python serialization.
+ metadata = {
+ "module_name": self.module_name,
+ "compiled_paths": compiled_paths,
+ "backend_name": self.backend_name,
+ "backend_id": self.backend_id,
+ "backend_driver": self.backend_driver,
+ "iree_serializable": self.iree_serializable,
+ "tflite_serializable": self.tflite_serializable,
+ "function_name": self.function_name,
+ "function_sourcefile": self.function_sourcefile,
+ "function_line_numbers": self.function_line_numbers,
+ "function_source": self.function_source
+ }
+ with open(os.path.join(trace_dir, "metadata.pkl"), "wb") as f:
+ pickle.dump(metadata, f)
+
+ width = _zfill_width(len(self.calls))
+ for i, call in enumerate(self.calls):
+ call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}")
+ call.serialize(call_dir)
+
+ # C++ benchmark serialization.
+ if self.iree_serializable or self.tflite_serializable:
+ entry_function = self.calls[0].method
+ compiled_path = self.compiled_paths[entry_function]
+
+ if self.iree_serializable:
+ serialized_inputs = ", ".join(self.calls[0].serialized_inputs)
+ flagfile = [
+ f"--module_file={compiled_path}",
+ f"--driver={self.backend_driver}",
+ f"--function_inputs={serialized_inputs}",
+ f"--entry_function={entry_function}",
+ ]
+ with open(os.path.join(trace_dir, "flagfile"), "w") as f:
+ f.writelines(line + "\n" for line in flagfile)
+ else:
+ with open(os.path.join(trace_dir, "graph_path"), "w") as f:
+ f.writelines(compiled_path + "\n")
+
+ @staticmethod
+ def load(trace_dir: str) -> "Trace":
+ """Loads and returns a trace serialized with Trace.serialize.
+
+ Args:
+ trace_dir: str, path to the directory of the serialized trace.
+
+ Returns:
+ A Trace deserialized from trace_dir.
+ """
+ with open(os.path.join(trace_dir, "metadata.pkl"), "rb") as f:
+ load_dict = pickle.load(f)
+ call_dirs = sorted(glob.glob(os.path.join(trace_dir, "call_*")))
+ calls = [ModuleCall.load(call_dir) for call_dir in call_dirs]
+ load_dict["calls"] = calls
+ return Trace(module=None, function=None, _load_dict=load_dict)
+
+
+class TracedModule:
+
+ def __init__(self, module: module_utils.CompiledModule, trace: Trace):
+ """Wraps a CompiledModule so that all inputs and outputs are traced.
+
+ The TracedModule returned will have an API almost identical to that of the
+ passed CompiledModule. The only changes is that if the keywords `rtol` or
+ `atol` are passed to one of the CompiledModule's methods, then they will be
+ used to set the tolerance for comparing that call to the same call in
+ another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)`
+ would be the same as calling `module.add(a, b)`.
+
+ Args:
+ module: the CompiledModule to trace.
+ trace: the Trace to record calls to this module with.
+ """
+ self._module = module
+ self._trace = trace
+
+ def _trace_call(self, method: module_utils._FunctionWrapper,
+ method_name: str):
+ """Decorates a CompiledModule method to capture its inputs and outputs."""
+
+ def call(*args, **kwargs):
+ # Pop manually specified tolerances from the kwargs (if any).
+ tolerances = {}
+ tolerances["rtol"] = kwargs.pop("rtol", None)
+ tolerances["atol"] = kwargs.pop("atol", None)
+ # Only pass these to ModuleCall if they were specified by the user.
+ tolerances = {k: v for k, v in tolerances.items() if v is not None}
+
+ # Ensure the inputs are numpy inputs.
+ args = tf_utils.convert_to_numpy(args)
+ kwargs = tf_utils.convert_to_numpy(kwargs)
+
+ # Run the method and record the details of the call.
+ outputs = method(*args, **kwargs)
+ serialized_inputs, serialized_outputs = method.get_serialized_values()
+ self._trace.calls.append(
+ ModuleCall(method_name, args, outputs, serialized_inputs,
+ serialized_outputs, **tolerances))
+ return outputs
+
+ return call
+
+ def __getattr__(self, attr):
+ # Try to resolve it as an attr on self._module.
+ if not hasattr(self._module, attr):
+ raise AttributeError(f"The compiled module does not have attr '{attr}'")
+ module_attr = getattr(self._module, attr)
+ if not hasattr(module_attr, "__call__"):
+ # e.g. traced_module.backend
+ return module_attr
+ else:
+ # e.g. traced_module.simple_mul(a, b)
+ return self._trace_call(module_attr, method_name=attr)
+
+
+def compare_traces(ref_trace: Trace,
+ tar_trace: Trace) -> Tuple[bool, Sequence[str]]:
+ traces_match = True
+ error_messages = []
+
+ # Check that all method invocations match.
+ ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
+ tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace]
+ if ref_methods != tar_methods:
+ # Raise a ValueError instead of returning False since this is an
+ # unexpected error.
+ raise ValueError(
+ "The reference and target traces have different call structures:\n"
+ f"Reference: {ref_methods}\nTarget: {tar_methods}")
+
+ for ref_call, tar_call in zip(ref_trace, tar_trace):
+ logging.info("Comparing calls to '%s'", ref_call.method)
+ rtol, atol = ref_call.get_tolerances()
+
+ inputs_match, error_message = tf_utils.check_same(ref_call.inputs,
+ tar_call.inputs, rtol,
+ atol)
+ if not inputs_match:
+ error_messages.append(error_message)
+ logging.error("Inputs did not match.")
+ outputs_match, error_message = tf_utils.check_same(ref_call.outputs,
+ tar_call.outputs, rtol,
+ atol)
+ if not outputs_match:
+ error_messages.append(error_message)
+ logging.error("Outputs did not match.")
+ calls_match = inputs_match and outputs_match
+
+ if not calls_match:
+ logging.error("Comparision between '%s' and '%s' failed on method '%s'",
+ ref_trace.backend_id, tar_trace.backend_id, ref_call.method)
+ logging.error("Reference call '%s':\n%s", ref_trace.backend_id, ref_call)
+ logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
+
+ traces_match = traces_match and calls_match
+ return traces_match, error_messages
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils_test.py
new file mode 100644
index 0000000..58315c8
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils_test.py
@@ -0,0 +1,156 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for pyiree.tf.support.trace_utils."""
+
+import os
+import tempfile
+
+from absl.testing import parameterized
+import numpy as np
+from pyiree.tf.support import module_utils
+from pyiree.tf.support import trace_utils
+import tensorflow as tf
+
+
+class StatefulCountingModule(tf.Module):
+
+ def __init__(self):
+ self.count = tf.Variable([0.])
+
+ @tf.function(input_signature=[])
+ def increment(self):
+ self.count.assign_add(tf.constant([1.]))
+
+ @tf.function(input_signature=[])
+ def get_count(self):
+ return self.count
+
+ @tf.function(input_signature=[tf.TensorSpec([1])])
+ def increment_by(self, value):
+ self.count.assign_add(value)
+
+ @tf.function(input_signature=[tf.TensorSpec([1]), tf.TensorSpec([1])])
+ def increment_by_max(self, a, b):
+ result = tf.maximum(a, b)
+ self.count.assign_add(result)
+ return result
+
+ @tf.function(input_signature=[])
+ def decrement(self):
+ self.count.assign_sub(tf.constant([1.]))
+
+
+class TestUtilsTests(tf.test.TestCase, parameterized.TestCase):
+
+ def test_trace_inputs_and_outputs(self):
+
+ def trace_function(module):
+ # No inputs or outputs
+ module.increment()
+ # Only inputs
+ module.increment_by(np.array([81.], dtype=np.float32))
+ # Only outputs
+ module.get_count()
+
+ module = module_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('tf'))
+ trace = trace_utils.Trace(module, trace_function)
+ trace_function(trace_utils.TracedModule(module, trace))
+
+ self.assertIsInstance(trace.calls[0].inputs, tuple)
+ self.assertEmpty(trace.calls[0].inputs)
+ self.assertIsInstance(trace.calls[0].outputs, tuple)
+ self.assertEmpty(trace.calls[0].outputs)
+
+ self.assertAllClose(trace.calls[1].inputs[0], [81.])
+ self.assertAllClose(trace.calls[2].outputs[0], [82.])
+
+ def test_nonmatching_methods(self):
+
+ def tf_function(module):
+ module.increment()
+ module.increment()
+
+ def vmla_function(module):
+ module.increment()
+ module.decrement()
+
+ tf_module = module_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('tf'))
+ tf_trace = trace_utils.Trace(tf_module, tf_function)
+ tf_function(trace_utils.TracedModule(tf_module, tf_trace))
+
+ vmla_module = module_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
+ vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
+ vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))
+
+ with self.assertRaises(ValueError):
+ trace_utils.compare_traces(tf_trace, vmla_trace)
+
+ def test_nonmatching_inputs(self):
+
+ def tf_function(module):
+ module.increment_by(np.array([42.], dtype=np.float32))
+
+ def vmla_function(module):
+ module.increment_by(np.array([22.], dtype=np.float32))
+
+ tf_module = module_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('tf'))
+ tf_trace = trace_utils.Trace(tf_module, tf_function)
+ tf_function(trace_utils.TracedModule(tf_module, tf_trace))
+
+ vmla_module = module_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
+ vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
+ vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))
+
+ same, error_messages = trace_utils.compare_traces(tf_trace, vmla_trace)
+ self.assertFalse(same)
+
+ def test_trace_serialize_and_load(self):
+
+ def trace_function(module):
+ module.increment()
+ module.increment_by(np.array([81.], dtype=np.float32))
+ module.increment_by_max(np.array([81], dtype=np.float32),
+ np.array([92], dtype=np.float32))
+ module.get_count()
+
+ module = module_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
+ trace = trace_utils.Trace(module, trace_function)
+ trace_function(trace_utils.TracedModule(module, trace))
+
+ with tempfile.TemporaryDirectory() as artifacts_dir:
+ trace_function_dir = trace_utils.get_trace_dir(artifacts_dir, trace)
+ trace.serialize(trace_function_dir)
+ self.assertTrue(
+ os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl')))
+ loaded_trace = trace_utils.Trace.load(trace_function_dir)
+
+ # Check all calls match.
+ self.assertTrue(trace_utils.compare_traces(trace, loaded_trace))
+
+ # Check all other metadata match.
+ self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys())
+ for key in trace.__dict__.keys():
+ if key != 'calls':
+ self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/integrations/tensorflow/compiler/CheckNoTF.cpp b/integrations/tensorflow/compiler/CheckNoTF.cpp
index 246f28a..b69a1ac 100644
--- a/integrations/tensorflow/compiler/CheckNoTF.cpp
+++ b/integrations/tensorflow/compiler/CheckNoTF.cpp
@@ -13,13 +13,13 @@
// limitations under the License.
#include "llvm/Support/FormatVariadic.h"
+#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
diff --git a/integrations/tensorflow/compiler/LegalizeTF.cpp b/integrations/tensorflow/compiler/LegalizeTF.cpp
index e453436..edff0d2 100644
--- a/integrations/tensorflow/compiler/LegalizeTF.cpp
+++ b/integrations/tensorflow/compiler/LegalizeTF.cpp
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
diff --git a/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp b/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp
index 40aad21..42fee39 100644
--- a/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp
+++ b/integrations/tensorflow/compiler/TFSavedModelLowerGlobalTensors.cpp
@@ -82,8 +82,7 @@
auto variableOp = globalBuilder.create<IREE::Flow::VariableOp>(
globalTensor.getLoc(), flowSymName, globalTensor.is_mutable(),
globalTensor.type(), globalTensor.value());
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Private);
+ variableOp.setPrivate();
}
// TODO(silvasean): Make this conversion interprocedural.
diff --git a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir
index 38f1aeb..1f08ff4 100644
--- a/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir
+++ b/integrations/tensorflow/compiler/dialect/tf_tensorlist/conversion/test/convert_tf_to_tf_tensorlist.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-tf-opt %s -pass-pipeline=convert-tf-to-tf_tensorlist -split-input-file -allow-unregistered-dialect -verify-diagnostics | IreeFileCheck %s
+// RUN: iree-tf-opt %s -pass-pipeline='func(convert-tf-to-tf_tensorlist)' -split-input-file -allow-unregistered-dialect -verify-diagnostics | IreeFileCheck %s
// TODO(silvasean): Handle interprocedural conversion.
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 57f73df..7fa4ff7 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -13,8 +13,8 @@
# limitations under the License.
# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. A coverage table generated from this file can be viewed here:
-# https://google.github.io/iree/tf-e2e-coverage
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage
# Updates made to test suite names should also be reflected here:
# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
@@ -59,7 +59,6 @@
# keep sorted
TFLITE_FAILING = [
"broadcasting_test.py",
- "complex_test.py",
"concat_test.py",
"dynamic_mlp_relu_test.py",
"dynamic_mlp_test.py",
@@ -67,10 +66,12 @@
"einsum_static_test.py",
"einsum_vector_test.py",
"fft_test.py",
- "finite_test.py",
"gather_test.py",
+ "image_resize_test.py",
"mandelbrot_test.py",
"matrix_ops_dynamic_test.py",
+ "quantization_dyn_test.py",
+ "reduce_test.py",
"resource_ops_test.py",
"ring_buffer_test.py",
"scatter_update_test.py",
@@ -103,10 +104,11 @@
"fft_test.py", # TODO(natashaknk): Get this working after kernel is in.
"fill_test.py", # TODO(jennik): Get this test working on IREE.
"linspace_test.py", # TODO(https://github.com/google/iree/issues/1521)
- "logical_ops_test.py",
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
"matrix_ops_dynamic_test.py",
+ "quantization_dyn_test.py",
"range_test.py",
+ "reduce_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
"strings_test.py",
@@ -114,7 +116,6 @@
# keep sorted
VULKAN_FAILING = [
- "bool_test.py",
"broadcast_to_test.py",
"broadcasting_test.py",
"conv_transpose_test.py",
@@ -126,10 +127,11 @@
"fft_test.py", # TODO(natashaknk): Get this working after kernel is in.
"fill_test.py", # TODO(jennik): Get this test working on IREE.
"linspace_test.py", # TODO(https://github.com/google/iree/issues/1521)
- "logical_ops_test.py",
"mandelbrot_test.py", # TODO(silvasean): Get this working on IREE.
"matrix_ops_dynamic_test.py",
+ "quantization_dyn_test.py",
"range_test.py",
+ "reduce_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
"strings_test.py",
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 0ac1605..ccc2f63 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -18,7 +18,7 @@
If you do not have your environment setup to use IREE with Vulkan (see
[this doc](https://google.github.io/iree/get-started/generic-vulkan-env-setup)),
then you can run the manual test targets with
-`--target_backends=tf,iree_vmla,iree_llvmjit` (that is, by omitting
+`--target_backends=tf,iree_vmla` (that is, by omitting
`iree_vulkan` from the list of backends to run the tests on).
The test suites can be run excluding Vulkan by specifying
@@ -32,16 +32,16 @@
specified directory. These artifacts include MLIR across various lowerings and
the compiled VM FlatBuffer. A basic example of creating and calling an
`IreeCompiledModule` can be found in
-[`tf_utils_test.py`](https://github.com/google/iree/blob/main/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py)
+[`module_utils_test.py`](https://github.com/google/iree/blob/main/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py)
When using Keras models or tf.Modules with functions that IREE can't compile,
`exported_names` should be specified. For example:
```python
-from pyiree.tf.support import tf_utils
-vmla_module = tf_utils.IreeCompiledModule(
+from pyiree.tf.support import module_utils
+vmla_module = module_utils.IreeCompiledModule(
module_class=KerasTFModuleClass,
- backend_info=tf_utils.BackendInfo('iree_vmla'),
+ backend_info=module_utils.BackendInfo('iree_vmla'),
exported_names=['predict'])
vmla_module.predict(...)
```
@@ -52,17 +52,17 @@
preferred.
```shell
-# Run math_test on all backends.
-bazel run //integrations/tensorflow/e2e:math_test_manual
+# Run conv_test on all backends.
+bazel run //integrations/tensorflow/e2e:conv_test_manual
-# Run math_test comparing TensorFlow to itself (e.g. to debug randomization).
-bazel run //integrations/tensorflow/e2e:math_test_manual -- --target_backends=tf
+# Run conv_test comparing TensorFlow to itself (e.g. to debug randomization).
+bazel run //integrations/tensorflow/e2e:conv_test_manual -- --target_backends=tf
-# Run math_test comparing the VMLA backend and TensorFlow.
-bazel run //integrations/tensorflow/e2e:math_test_manual -- --target_backends=iree_vmla
+# Run conv_test comparing the VMLA backend and TensorFlow.
+bazel run //integrations/tensorflow/e2e:conv_test_manual -- --target_backends=iree_vmla
-# Run math_test comparing the VMLA backend to itself multiple times.
-bazel run //integrations/tensorflow/e2e:math_test_manual -- \
+# Run conv_test comparing the VMLA backend to itself multiple times.
+bazel run //integrations/tensorflow/e2e:conv_test_manual -- \
--reference_backend=iree_vmla --target_backends=iree_vmla,iree_vmla
```
@@ -72,10 +72,10 @@
## Writing Tests
-There are two ways to write tests – via `tf_test_utils.tf_function_unittest` and
+There are two ways to write tests – via `tf_test_utils.tf_function_unit_test` and
via test methods on a child of `tf_test_utils.TracedModuleTestCase`.
-### Via `tf_test_utils.tf_function_unittest`
+### Via `tf_test_utils.tf_function_unit_test`
This is preferred in the cases where
@@ -86,7 +86,7 @@
Tests are specified by writing modules that inherit from
`tf_test_utils.TestModule` (which is a thin wrapper around `tf.Module`) with
-methods decorated with `@tf_test_utils.tf_function_unittest` (with is a thin
+methods decorated with `@tf_test_utils.tf_function_unit_test` (with is a thin
wrapper around `tf.function`).
#### Basic example
@@ -101,14 +101,14 @@
# function. The 'input_signature' is required. If no other arguments are
# specified then uniform random data is generated from the input signature
# to numerically test the function.
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 1], tf.float32),
])
def conv2d_1451x1111_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([2, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 1], tf.float32),
])
@@ -130,7 +130,7 @@
```
Finally, in the `main` function, you need to call
-`.generate_unittests(module_class)` on your `TestCase` to actually generate
+`.generate_unit_tests(module_class)` on your `TestCase` to actually generate
the unittests that we specified:
```python
@@ -138,12 +138,12 @@
del argv # Unused
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
- # Generates unittests for all @tf_test_utils.tf_function_unittest decorated
+ # Generates unittests for all @tf_test_utils.tf_function_unit_test decorated
# functions on the module class.
# Note: if you are automatically generating functions to test they need to be
# specified via a `classmethod` prior to this call _as well_ as via `__init__`
# to properly handle stateful `tf.function`s.
- ConvTest.generate_unittests(Conv2dModule)
+ ConvTest.generate_unit_tests(Conv2dModule)
tf.test.main()
@@ -154,9 +154,9 @@
This generates two unittests: `test_conv2d_1451x1111_valid` and
`test_conv2d_2451x1111_valid`.
-#### Configuring `@tf_test_utils.tf_function_unittest`
+#### Configuring `@tf_test_utils.tf_function_unit_test`
-By default `@tf_test_utils.tf_function_unittest` uses uniform random input data
+By default `@tf_test_utils.tf_function_unit_test` uses uniform random input data
to numerically test the function, but you can specify an `input_generator` or
`input_args` to test data-specific behaviors:
@@ -234,8 +234,8 @@
to check numerical correctness against TensorFlow. Tests targets that pass are
placed into the `e2e_tests` test suite. Tests that fail on particular backends
are recorded in lists in the `BUILD` files. For example, if
-`experimental_new_test.py` fails on the `iree_llvmjit` and `iree_vulkan`
-backends then the following lines should be added to the `BUILD` file:
+`experimental_new_test.py` fails on the `iree_vulkan` backend then the following
+lines should be added to the `BUILD` file:
```build
LLVM_FAILING = [
diff --git a/integrations/tensorflow/e2e/bool_test.py b/integrations/tensorflow/e2e/bool_test.py
deleted file mode 100644
index df09ecb..0000000
--- a/integrations/tensorflow/e2e/bool_test.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class BooleanModule(tf_test_utils.TestModule):
-
- @tf_test_utils.tf_function_unittest(input_signature=[])
- def constant(self):
- return np.array([True, False, True], dtype=np.bool)
-
- @tf_test_utils.tf_function_unittest(
- input_signature=[tf.TensorSpec([4], tf.float32)],
- input_args=[np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32)])
- def greater_than(self, x):
- return x > 1.0
-
- @tf_test_utils.tf_function_unittest(
- input_signature=[
- tf.TensorSpec([4], tf.bool),
- tf.TensorSpec([4], tf.bool)
- ],
- input_args=[
- np.array([True, True, False, False], dtype=np.bool),
- np.array([True, False, False, True], dtype=np.bool)
- ],
- )
- def logical_and(self, x, y):
- return tf.math.logical_and(x, y)
-
-
-class BooleanTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(BooleanModule)
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- BooleanTest.generate_unittests(BooleanModule)
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/complex_test.py b/integrations/tensorflow/e2e/complex_test.py
deleted file mode 100644
index e60f532..0000000
--- a/integrations/tensorflow/e2e/complex_test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class ComplexModule(tf.Module):
-
- def __init__(self):
- pass
-
- @tf.function(input_signature=[
- tf.TensorSpec([2], tf.float32),
- tf.TensorSpec([2], tf.float32)
- ])
- def complex_exp(self, real, imag):
- tensor = tf.complex(real, imag)
- exp = tf.exp(tensor)
- return tf.math.real(exp)
-
-
-class ComplexTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(ComplexModule)
-
- def test_complex(self):
-
- def complex_exp(module):
- real = np.array([2., 3.], dtype=np.float32)
- imag = np.array([-1., 0.4], dtype=np.float32)
- module.complex_exp(real, imag)
-
- self.compare_backends(complex_exp, self._modules)
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index 11ca8d6..ca0076e 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -22,14 +22,14 @@
class Conv2dModule(tf_test_utils.TestModule):
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 1], tf.float32),
])
def conv2d_1451x1111_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([2, 2, 1, 1], tf.float32),
])
@@ -40,7 +40,7 @@
dilations=[1, 2, 1, 1],
name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
@@ -51,70 +51,70 @@
dilations=[1, 2, 1, 1],
name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([2, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 1], tf.float32),
])
def conv2d_2451x1111_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([2, 3, 1, 1], tf.float32),
])
def conv2d_1451x2311_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([2, 3, 1, 1], tf.float32),
])
def conv2d_1451x2311_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([2, 4, 5, 1], tf.float32),
tf.TensorSpec([2, 3, 1, 1], tf.float32),
])
def conv2d_2451x2311_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([3, 2, 2, 1], tf.float32),
])
def conv2d_1452x3221_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 1], tf.float32),
tf.TensorSpec([1, 1, 1, 2], tf.float32),
])
def conv2d_1451x1112_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([1, 1, 2, 2], tf.float32),
])
def conv2d_1452x1122_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
def conv2d_1452x2223_same(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "SAME", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([1, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
def conv2d_1452x2223_valid(self, img, kernel):
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
- @tf_test_utils.tf_function_unittest(input_signature=[
+ @tf_test_utils.tf_function_unit_test(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
@@ -133,7 +133,7 @@
del argv # Unused
if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
- ConvTest.generate_unittests(Conv2dModule)
+ ConvTest.generate_unit_tests(Conv2dModule)
tf.test.main()
diff --git a/integrations/tensorflow/e2e/finite_test.py b/integrations/tensorflow/e2e/finite_test.py
deleted file mode 100644
index 761cae4..0000000
--- a/integrations/tensorflow/e2e/finite_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class FiniteModule(tf.Module):
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def finite(self, x):
- return tf.math.is_finite(x)
-
-
-class FiniteTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(FiniteModule)
-
- def test_finite(self):
-
- def finite(module):
- module.finite(np.array([0.0, 1.2, -5.0, np.inf], dtype=np.float32))
-
- self.compare_backends(finite, self._modules)
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/image_resize_test.py b/integrations/tensorflow/e2e/image_resize_test.py
new file mode 100644
index 0000000..51d6c7d
--- /dev/null
+++ b/integrations/tensorflow/e2e/image_resize_test.py
@@ -0,0 +1,69 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from absl import app
+import numpy as np
+from pyiree.tf.support import tf_utils
+from pyiree.tf.support import tf_test_utils
+import tensorflow.compat.v1 as tf
+
+
+class ResizeImageModule(tf.Module):
+
+ def __init__(self):
+ pass
+
+ @tf.function(input_signature=[tf.TensorSpec([1, 52, 37, 1], tf.int32)])
+ def downsample_nearest_neighbor(self, image):
+ size = np.asarray([8, 7], dtype=np.int32)
+ return tf.image.resize_nearest_neighbor(image, size)
+
+ @tf.function(input_signature=[tf.TensorSpec([1, 8, 7, 1], tf.int32)])
+ def upsample_nearest_neighbor(self, image):
+ size = np.asarray([52, 37], dtype=np.int32)
+ return tf.image.resize_nearest_neighbor(image, size)
+
+
+class ResizeImageTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(ResizeImageModule)
+
+ def test_downsample_nearest_neighbor(self):
+
+ def downsample_nearest_neighbor(module):
+ img = tf_utils.ndarange([1, 52, 37, 1], dtype=np.int32)
+ module.downsample_nearest_neighbor(img)
+
+ self.compare_backends(downsample_nearest_neighbor, self._modules)
+
+ def test_upsample_nearest_neighbor(self):
+
+ def upsample_nearest_neighbor(module):
+ img = tf_utils.ndarange([1, 8, 7, 1], dtype=np.int32)
+ module.upsample_nearest_neighbor(img)
+
+ self.compare_backends(upsample_nearest_neighbor, self._modules)
+
+
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
+ tf.enable_v2_behavior()
+ tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
index ee6d4f1..7f49f83 100644
--- a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
@@ -164,6 +164,13 @@
tests = []
for flags in all_flag_configurations:
+ if len(flags["target_backends"].split(",")) > 1:
+ fail("Multiple target backends cannot be specified at once, but " +
+ "got `{}`".format(flags["target_backends"]))
+ driver = get_driver(flags["target_backends"])
+ if not driver:
+ continue
+
# Check if this is a failing configuration.
failing = flags in failing_flag_configurations
@@ -180,12 +187,6 @@
tests.append(test_name)
args = ["--{}={}".format(k, v) for k, v in flags.items()]
-
- if len(flags["target_backends"].split(",")) > 1:
- fail("Multiple target backends cannot be specified at once, but " +
- "got `{}`".format(flags["target_backends"]))
-
- driver = get_driver(flags["target_backends"])
py_test_tags = ["driver={}".format(driver)]
if tags != None: # `is` is not supported.
py_test_tags += tags
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index 8fb35a4..661cc30 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -19,10 +19,19 @@
def get_driver(backend):
# TODO(#2175): Simplify this after backend names are standardized.
driver = backend.replace("iree_", "") # "iree_<driver>" --> "<driver>"
+
+ # TODO(#2673): enable LLVM AOT for these tests. JIT is deprecated.
if driver == "llvmjit":
- driver = "llvm"
+ driver = ""
return driver
+def set_difference(include, exclude):
+ return [
+ value
+ for value in include
+ if value not in exclude
+ ]
+
def iree_e2e_test_suite(
name,
backends_to_srcs,
@@ -70,6 +79,8 @@
]
driver = get_driver(backend)
+ if not driver:
+ continue
py_test_tags = ["driver={}".format(driver)]
if tags != None: # `is` is not supported.
py_test_tags += tags
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index d6e61ad..6ecd6b5 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -13,8 +13,8 @@
# limitations under the License.
# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. A coverage table generated from this file can be viewed here:
-# https://google.github.io/iree/tf-e2e-coverage
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage
# Updates made to test suite names should also be reflected here:
# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
@@ -44,13 +44,13 @@
vision_model_test_manual is for manual testing of all keras vision models.
Test will run only manually with all parameters specified manually, for example:
bazel run -c opt integrations/tensorflow/e2e/keras:vision_model_test_manual -- \
---target_backends=tf,iree_vmla,iree_llvmjit \
+--target_backends=tf,iree_vmla \
--data=imagenet \
--url=https://storage.googleapis.com/iree_models/ \
--model=ResNet50
Command arguments description:
---target_backends: can be combination of these: tf,iree_vmla,iree_llvmjit
+--target_backends: can be combination of these: tf,iree_vmla
--data: can be 'imagenet' or 'cifar10'.
imagenet - input image size (1, 224, 224, 3)
cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
diff --git a/integrations/tensorflow/e2e/keras/layers/BUILD b/integrations/tensorflow/e2e/keras/layers/BUILD
index d13e515..759741e 100644
--- a/integrations/tensorflow/e2e/keras/layers/BUILD
+++ b/integrations/tensorflow/e2e/keras/layers/BUILD
@@ -13,8 +13,8 @@
# limitations under the License.
# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. A coverage table generated from this file can be viewed here:
-# https://google.github.io/iree/tf-e2e-coverage
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage/tf-keras-coverage
# Updates made to test suite names should also be reflected here:
# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
@@ -87,7 +87,6 @@
"Embedding",
"Flatten",
"GRU",
- "GRUCell",
"GaussianDropout",
"GaussianNoise",
"GlobalAveragePooling1D",
@@ -98,7 +97,6 @@
"GlobalMaxPool3D",
"InputLayer",
"LSTM",
- "LSTMCell",
"Lambda",
"LayerNormalization",
"LeakyReLU",
@@ -119,8 +117,7 @@
"Reshape",
"SeparableConv1D",
"SeparableConv2D",
- "SimpleRNN",
- "SimpleRNNCell",
+ # "SimpleRNN", # TODO(meadowlark): Debug flakiness.
"Softmax",
"SpatialDropout1D",
"SpatialDropout2D",
@@ -137,15 +134,6 @@
FAILING_STATIC = [
{
- # Wrapping these in a tf.function appears to cause a keras bug.
- "layer": [
- "ConvLSTM2D",
- "GRUCell",
- "LSTMCell",
- "SimpleRNNCell",
- ],
- },
- {
# Failing on TFLite
"layer": [
"AveragePooling3D",
@@ -153,6 +141,7 @@
"Conv3D",
"ConvLSTM2D",
"LayerNormalization",
+ "Softmax",
"MaxPool3D",
"ZeroPadding3D",
],
@@ -162,12 +151,12 @@
# Failing on IREE
"layer": [
"ConvLSTM2D",
+ "GRU",
+ "LSTM", # Failing unless 'return_sequences = True'
"LayerNormalization",
"LeakyReLU",
"LocallyConnected2D",
- "Masking",
"MultiHeadAttention",
- "SimpleRNN",
"UpSampling2D",
],
"target_backends": [
@@ -188,6 +177,7 @@
# Failing on LLVM and Vulkan
"layer": [
"Lambda",
+ "Masking",
"MaxPool1D",
"MaxPool2D",
"MaxPool3D",
@@ -200,11 +190,11 @@
{
# Failing on Vulkan
"layer": [
+ "Attention",
+ "AdditiveAttention",
"AveragePooling1D",
"AveragePooling2D",
"AveragePooling3D",
- "GRU",
- "LSTM", # TODO(silvasean): Get this test working on Vulkan.
"ThresholdedReLU",
],
"target_backends": "iree_vulkan",
@@ -218,9 +208,9 @@
flags_to_values = {
"reference_backend": "tf",
"layer": LAYERS,
- "dynamic_batch": False,
+ "dynamic_dims": False,
"training": False,
- "test_full_api": False,
+ "test_default_kwargs_only": True,
"target_backends": [
"tf",
"tflite",
@@ -239,12 +229,14 @@
# bazel run integrations/tensorflow/e2e/keras/layers:layers_test_manual -- \
# --list_layers_with_full_api_tests
LAYERS_WITH_FULL_API_TESTS = [
+ "ActivityRegularization",
"AdditiveAttention",
"Attention",
"AveragePooling1D",
"AveragePooling2D",
"AveragePooling3D",
"BatchNormalization",
+ "Concatenate",
"Conv1D",
"Conv1DTranspose",
"Conv2D",
@@ -252,23 +244,21 @@
"Conv3D",
"Conv3DTranspose",
# "ConvLSTM2D", # TODO(meadowlark): Debug flakiness.
+ "Cropping1D",
+ "Cropping2D",
+ "Cropping3D",
"DepthwiseConv2D",
- "GlobalAveragePooling1D",
- "GlobalAveragePooling2D",
- "GlobalAveragePooling3D",
- "GlobalMaxPool1D",
- "GlobalMaxPool2D",
- "GlobalMaxPool3D",
"GRU",
+ "LSTM",
"LocallyConnected1D",
"LocallyConnected2D",
- "LSTM",
"MaxPool1D",
"MaxPool2D",
"MaxPool3D",
"SeparableConv1D",
"SeparableConv2D",
"SimpleRNN",
+ # "SimpleRNN", # TODO(meadowlark): Debug flakiness.
]
FAILING_FULL_API = [
@@ -276,8 +266,6 @@
# Failing on TFLite
"layer": [
"AveragePooling3D",
- "Conv1D",
- "Conv2D",
"Conv2DTranspose",
"Conv3D",
"Conv3DTranspose",
@@ -289,6 +277,8 @@
"LSTM",
"MaxPool1D",
"MaxPool3D",
+ "SeparableConv1D", # Failing on Kokoro.
+ "SeparableConv2D",
"SimpleRNN",
],
"target_backends": "tflite",
@@ -296,19 +286,13 @@
{
# Failing on IREE
"layer": [
- "Conv1D",
- "Conv2D",
"Conv2DTranspose",
"Conv3DTranspose",
- "Conv3D",
"ConvLSTM2D",
- "DepthwiseConv2D",
"GRU",
"LocallyConnected1D",
"LocallyConnected2D",
"LSTM",
- "SeparableConv1D",
- "SeparableConv2D",
"SimpleRNN",
],
"target_backends": [
@@ -318,6 +302,10 @@
],
},
{
+ "layer": "Conv3D",
+ "target_backends": "iree_vmla",
+ },
+ {
# Failing on LLVM and Vulakn
"layer": [
"AdditiveAttention",
@@ -344,9 +332,9 @@
flags_to_values = {
"reference_backend": "tf",
"layer": LAYERS_WITH_FULL_API_TESTS,
- "dynamic_batch": False,
+ "dynamic_dims": False,
"training": False,
- "test_full_api": True,
+ "test_default_kwargs_only": False,
"target_backends": [
"tf",
"tflite",
@@ -363,14 +351,6 @@
FAILING_DYNAMIC = [
{
- # Wrapping these in a tf.function appears to cause a keras bug.
- "layer": [
- "GRUCell",
- "LSTMCell",
- "SimpleRNNCell",
- ],
- },
- {
# TFLite does not support dynamic shapes.
"target_backends": "tflite",
},
@@ -381,6 +361,7 @@
"AveragePooling1D",
"AveragePooling2D",
"AveragePooling3D",
+ "BatchNormalization",
"Concatenate",
"Conv1D",
"Conv1DTranspose",
@@ -392,7 +373,9 @@
"Cropping1D",
"Cropping2D",
"Cropping3D",
+ "Dense",
"DepthwiseConv2D",
+ "Dot",
"ELU",
"Flatten",
"GRU",
@@ -432,9 +415,6 @@
"Add",
"Attention",
"Average",
- "BatchNormalization",
- "Dense",
- "Dot",
"GlobalAveragePooling1D",
"GlobalAveragePooling2D",
"GlobalAveragePooling3D",
@@ -460,15 +440,15 @@
]
iree_e2e_cartesian_product_test_suite(
- name = "layers_dynamic_batch_tests",
+ name = "layers_dynamic_dims_tests",
srcs = ["layers_test.py"],
failing_configurations = FAILING_DYNAMIC,
flags_to_values = {
"reference_backend": "tf",
"layer": LAYERS,
- "dynamic_batch": True,
+ "dynamic_dims": True,
"training": False,
- "test_full_api": False,
+ "test_default_kwargs_only": True,
"target_backends": [
"tf",
"tflite",
@@ -492,14 +472,11 @@
# "ConvLSTM2D", # TODO(meadowlark): Debug flakiness.
"Dropout",
"GRU",
- "GRUCell",
"GaussianDropout",
"GaussianNoise",
"LSTM",
- "LSTMCell",
"MultiHeadAttention",
- "SimpleRNN",
- "SimpleRNNCell",
+ # "SimpleRNN", # TODO(meadowlark): Debug flakiness.
"SpatialDropout1D",
"SpatialDropout2D",
"SpatialDropout3D",
@@ -507,14 +484,6 @@
FAILING_TRAINING = [
{
- # Wrapping these in a tf.function appears to cause a keras bug.
- "layer": [
- "GRUCell",
- "LSTMCell",
- "SimpleRNNCell",
- ],
- },
- {
# Failing on TFLite:
"layer": [
"AlphaDropout",
@@ -534,6 +503,7 @@
"AdditiveAttention",
"AlphaDropout",
"Attention",
+ "BatchNormalization",
"ConvLSTM2D",
"Dropout",
"GaussianDropout",
@@ -561,9 +531,9 @@
flags_to_values = {
"reference_backend": "tf",
"layer": LAYERS_WITH_TRAINING_BEHAVIOR,
- "dynamic_batch": False,
+ "dynamic_dims": False,
"training": True,
- "test_full_api": False,
+ "test_default_kwargs_only": True,
"target_backends": [
"tf",
"tflite",
diff --git a/integrations/tensorflow/e2e/keras/layers/layers_test.py b/integrations/tensorflow/e2e/keras/layers/layers_test.py
index 0ab5f32..45a7d95 100644
--- a/integrations/tensorflow/e2e/keras/layers/layers_test.py
+++ b/integrations/tensorflow/e2e/keras/layers/layers_test.py
@@ -16,11 +16,13 @@
import collections
import copy
+import inspect
import os
-from typing import Any, Dict, Sequence, Union
+from typing import Any, Dict, List, Sequence, Tuple, Union
from absl import app
from absl import flags
+from absl import logging
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -28,516 +30,448 @@
FLAGS = flags.FLAGS
DROPOUT = 0.5
-DIM = 4
-RANK_2_INPUT = [DIM] * 2
-RANK_3_INPUT = [DIM] * 3
-RANK_4_INPUT = [DIM] * 4
+CONV_FILTERS = 2
+CONV_KERNEL_SIZE = 2
+DIM = 3
-CONV_1D_INPUT = [2, 8, 3]
-CONV_2D_INPUT = [2, 8, 8, 3]
-CONV_3D_INPUT = [2, 8, 8, 8, 3]
+# Used for attention layers and recurrent layers.
+RANK_3_SHAPE = [DIM] * 3
+# Highest rank that tf.keras will allow for all layers.
+RANK_5_SHAPE = [DIM] * 5
-# Configs are namedtuples storing keyword arguments and shapes to test a
-# tf.keras.layers.Layer with. They are used in two ways:
-# 1. To directly specify the kwargs and shapes for a layers test.
-# 2. In 'generate_configs', to specify how to change a default config to
-# specify a non-default test. In this case, the overriding Config will
-# exclusively specify the shape of the test if its shape is not None, and
-# the overriding Config will extend/update the kwargs of the default
-# Config.
-Config = collections.namedtuple('Config', ['kwargs', 'shapes'])
-# Use old default API for compatibility with Python 3.6.
-Config.__new__.__defaults__ = (dict(), None)
+UNARY_SIGNATURE_SHAPES = [[RANK_5_SHAPE]]
+BINARY_SIGNATURE_SHAPES = [[RANK_5_SHAPE] * 2]
+TERNARY_SIGNATURE_SHAPES = [[RANK_5_SHAPE] * 3]
+CONV_1D_SIGNATURE_SHAPES = [[[2, 8, 3]]]
+CONV_2D_SIGNATURE_SHAPES = [[[2, 8, 8, 3]]]
+CONV_3D_SIGNATURE_SHAPES = [[[2, 8, 8, 8, 3]]]
-def generate_configs(default_config: Config,
- override_configs: Dict[str, Config]) -> Dict[str, Config]:
- """Generates a dict of 'Config's based off changes to a default Config."""
- configs = {'default': default_config}
- for exported_name, config in override_configs.items():
- shapes = default_config.shapes if config.shapes is None else config.shapes
+RNN_SIGNATURE_SHAPES = [[RANK_3_SHAPE]]
+RNN_KWARGS_TO_VALUES = dict(units=[4],
+ return_sequences=[False, True],
+ stateful=[False, True])
- # Deep copy to avoid inplace mutation of the default.
- kwargs = copy.deepcopy(default_config.kwargs)
- kwargs.update(config.kwargs) # Adds new and overwrites old kwargs.
+POOLING_KWARGS_TO_VALUES = dict(strides=[None, 2],
+ padding=["valid", "same"],
+ data_format=[None, "channels_first"])
+CONV_KWARGS_TO_VALUES = dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 2],
+ padding=["valid", "same"],
+ data_format=[None, "channels_first"],
+ dilation_rate=[1, 2])
+# Address pooling and conv layers having different default values for
+# 'data_format' for 1D layers.
+POOLING_1D_KWARGS_TO_VALUES = copy.deepcopy(POOLING_KWARGS_TO_VALUES)
+POOLING_1D_KWARGS_TO_VALUES.update(
+ {"data_format": ["channels_last", "channels_first"]})
+CONV_1D_KWARGS_TO_VALUES = copy.deepcopy(CONV_KWARGS_TO_VALUES)
+CONV_1D_KWARGS_TO_VALUES.update(
+ {"data_format": ["channels_last", "channels_first"]})
- configs[exported_name] = Config(kwargs, shapes)
- return configs
-
-
-# A dict mapping tf.keras.layers names to either a single Config (representing
-# the kwargs and shapes to use to test a Layer) or a dict mapping exported_names
-# to Configs. The latter case is usually automatically generated via
-# 'generate_configs', with the 'Config's in 'override_configs' specifying how
-# to modify the 'default_config's kwargs and shapes.
-#
-# Each entry will be normalized to be a dict mapping exported_names to Configs,
-# with a default exported_name of 'default'.
-LAYER_TO_UNITTEST_CONFIGURATIONS = {
- 'Activation':
- Config(dict(activation='relu'), [RANK_2_INPUT]),
- 'ActivityRegularization':
- Config(dict(l1=0.1, l2=0.1), shapes=[RANK_2_INPUT]),
- 'Add':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'AdditiveAttention':
- generate_configs(
- default_config=Config(
- shapes=[RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],),
- override_configs={
- 'causal': Config(dict(causal=True)),
- },
- ),
- 'AlphaDropout':
- Config(dict(rate=DROPOUT), [RANK_2_INPUT]),
- 'Attention':
- generate_configs(
- default_config=Config(
- shapes=[RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],),
- override_configs={
- 'causal': Config(dict(causal=True)),
- },
- ),
- 'Average':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'AveragePooling1D':
- generate_configs(
- default_config=Config(shapes=[CONV_1D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'AveragePooling2D':
- generate_configs(
- default_config=Config(shapes=[CONV_2D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Default AvgPoolingOp only supports NHWC on device type CPU
- # 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'AveragePooling3D':
- generate_configs(
- default_config=Config(shapes=[CONV_3D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'BatchNormalization':
- generate_configs(
- default_config=Config(shapes=[RANK_2_INPUT]),
- override_configs={'renorm': Config(dict(renorm=True))},
- ),
- 'Concatenate':
- Config(shapes=[RANK_4_INPUT, RANK_4_INPUT]),
- 'Conv1D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_1D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: The Conv2D op currently only supports the NHWC tensor
- # format on the CPU.
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv1DTranspose':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_1D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Conv2DCustomBackpropInputOp only supports NHWC
- # 'channels_first': Config(dict(data_format='channels_first')),
- # TF: Current libxsmm and customized CPU implementations do not
- # yet support dilation rates larger than 1.
- # 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: The Conv2D op currently only supports the NHWC tensor
- # format on the CPU.
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv2DTranspose':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv3D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_3D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: The Conv3D op currently only supports the NHWC tensor
- # format on the CPU.
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv3DTranspose':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_3D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'ConvLSTM2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3, return_state=True),
- shapes=[CONV_3D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- 'go_backwards': Config(dict(go_backwards=True)),
- 'stateful': Config(dict(stateful=True)),
- },
- ),
- 'Cropping1D':
- Config(dict(cropping=2), [CONV_1D_INPUT]),
- 'Cropping2D':
- Config(dict(cropping=2), [CONV_2D_INPUT]),
- 'Cropping3D':
- Config(dict(cropping=2), [CONV_3D_INPUT]),
- 'Dense':
- Config(dict(units=4), [RANK_2_INPUT]),
- 'DepthwiseConv2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'depth_multiplier': Config(dict(depth_multiplier=2)),
- 'dilation_rate': Config(dict(dilation_rate=2)),
- },
- ),
- 'Dot':
- Config(dict(axes=(1, 2)), [RANK_3_INPUT, RANK_3_INPUT]),
- 'Dropout':
- Config(dict(rate=DROPOUT), [RANK_3_INPUT]),
- 'ELU':
- Config(shapes=[RANK_2_INPUT]),
- 'Embedding':
- Config(dict(input_dim=4, output_dim=2), [RANK_2_INPUT]),
- 'Flatten':
- Config(shapes=[RANK_2_INPUT]),
- 'GRU':
- generate_configs(
- default_config=Config(
- kwargs=dict(units=4, return_sequences=True),
- shapes=[RANK_3_INPUT],
- ),
- override_configs={
- 'implementation_1': Config(dict(implementation=1)),
- 'go_backwards': Config(dict(go_backwards=True)),
- 'time_major': Config(dict(time_major=True)),
- 'stateful': Config(dict(stateful=True)),
- },
- ),
- 'GRUCell':
- Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
- 'GaussianDropout':
- Config(dict(rate=DROPOUT), [RANK_2_INPUT]),
- 'GaussianNoise':
- Config(dict(stddev=1.0), [RANK_2_INPUT]),
- 'GlobalAveragePooling1D':
- generate_configs(
- default_config=Config(shapes=[CONV_1D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalAveragePooling2D':
- generate_configs(
- default_config=Config(shapes=[CONV_2D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalAveragePooling3D':
- generate_configs(
- default_config=Config(shapes=[CONV_3D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalMaxPool1D':
- generate_configs(
- default_config=Config(shapes=[CONV_1D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalMaxPool2D':
- generate_configs(
- default_config=Config(shapes=[CONV_2D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalMaxPool3D':
- generate_configs(
- default_config=Config(shapes=[CONV_3D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'InputLayer':
- Config(shapes=[RANK_2_INPUT]),
- 'LSTM':
- generate_configs(
- default_config=Config(
- kwargs=dict(units=4, return_sequences=True),
- shapes=[RANK_3_INPUT],
- ),
- override_configs={
- 'implementation_1': Config(dict(implementation=1)),
- 'go_backwards': Config(dict(go_backwards=True)),
- 'time_major': Config(dict(time_major=True)),
- 'stateful': Config(dict(stateful=True)),
- },
- ),
- 'LSTMCell':
- Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
- 'Lambda':
- Config(dict(function=lambda x: x**2), [RANK_2_INPUT]),
- 'LayerNormalization':
- Config(shapes=[RANK_2_INPUT]),
- 'LeakyReLU':
- Config(shapes=[RANK_2_INPUT]),
- 'LocallyConnected1D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_1D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same', implementation=2)),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'sparse_implementation': Config(dict(implementation=3)),
- },
- ),
- 'LocallyConnected2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same', implementation=2)),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'sparse_implementation': Config(dict(implementation=3)),
- },
- ),
- 'Masking':
- Config(shapes=[RANK_2_INPUT]),
- 'MaxPool1D':
- generate_configs(
- default_config=Config(shapes=[CONV_1D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'MaxPool2D':
- generate_configs(
- default_config=Config(shapes=[CONV_2D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Default MaxPoolingOp only supports NHWC on device type CPU
- # 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'MaxPool3D':
- generate_configs(
- default_config=Config(shapes=[CONV_3D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'Maximum':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'Minimum':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'MultiHeadAttention':
- Config(dict(num_heads=2, key_dim=3), [RANK_3_INPUT, RANK_3_INPUT]),
- 'Multiply':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'PReLU':
- Config(shapes=[RANK_2_INPUT]),
- 'Permute':
- Config(dict(dims=(3, 1, 2)), [RANK_4_INPUT]),
- 'ReLU':
- Config(shapes=[RANK_2_INPUT]),
- 'RepeatVector':
- Config(dict(n=3), [RANK_2_INPUT]),
- 'Reshape':
- Config(dict(target_shape=[1, 1, 1] + RANK_3_INPUT[1:]), [RANK_3_INPUT]),
- 'SeparableConv1D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_1D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Depthwise convolution on CPU is only supported for NHWC
- # format
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'depth_multiplier': Config(dict(depth_multiplier=2)),
- 'dilation_rate': Config(dict(dilation_rate=2)),
- },
- ),
- 'SeparableConv2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Depthwise convolution on CPU is only supported for NHWC
- # format
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'depth_multiplier': Config(dict(depth_multiplier=2)),
- 'dilation_rate': Config(dict(dilation_rate=2)),
- },
- ),
- 'SimpleRNN':
- generate_configs(
- default_config=Config(
- kwargs=dict(units=4, return_sequences=True),
- shapes=[RANK_3_INPUT],
- ),
- override_configs={
- 'go_backwards': Config(dict(go_backwards=True)),
- 'stateful': Config(dict(stateful=True)),
- },
- ),
- 'SimpleRNNCell':
- Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
- 'Softmax':
- Config(shapes=[RANK_2_INPUT]),
- 'SpatialDropout1D':
- Config(dict(rate=DROPOUT), [CONV_1D_INPUT]),
- 'SpatialDropout2D':
- Config(dict(rate=DROPOUT), [CONV_2D_INPUT]),
- 'SpatialDropout3D':
- Config(dict(rate=DROPOUT), [CONV_3D_INPUT]),
- 'Subtract':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'ThresholdedReLU':
- Config(shapes=[RANK_2_INPUT]),
- 'UpSampling1D':
- Config(shapes=[CONV_1D_INPUT]),
- 'UpSampling2D':
- Config(shapes=[CONV_2D_INPUT]),
- 'UpSampling3D':
- Config(shapes=[CONV_3D_INPUT]),
- 'ZeroPadding1D':
- Config(shapes=[CONV_1D_INPUT]),
- 'ZeroPadding2D':
- Config(shapes=[CONV_2D_INPUT]),
- 'ZeroPadding3D':
- Config(shapes=[CONV_3D_INPUT]),
+# Unsupported by TensorFlow (at least on CPU).
+LAYERS_TO_TF_UNSUPPORTED_NON_DEFAULT_KWARGS = {
+ "AveragePooling2D": ["data_format"],
+ "Conv1D": ["data_format"],
+ "Conv1DTranspose": ["data_format", "dilation_rate"],
+ "Conv2D": ["data_format"],
+ "Conv3D": ["data_format"],
+ "Conv3DTranspose": ["data_format"],
+ "LocallyConnected1D": ["padding"],
+ "LocallyConnected2D": ["padding"],
+ "MaxPool2D": ["data_format"],
}
-# Normalize LAYER_TO_UNITTEST_CONFIGURATIONS
-for key, value in LAYER_TO_UNITTEST_CONFIGURATIONS.items():
- if isinstance(value, Config):
- LAYER_TO_UNITTEST_CONFIGURATIONS[key] = {'default': value}
+# Some layers have kwargs which cannot both have non-default values.
+LAYERS_TO_MUTUALLY_EXCLUSIVE_KWARGS = {
+ "Conv1D": ["strides", "dilation_rate"],
+ "Conv2D": ["strides", "dilation_rate"],
+ "Conv2DTranspose": ["strides", "dilation_rate"],
+ "Conv3D": ["strides", "dilation_rate"],
+ "ConvLSTM2D": ["strides", "dilation_rate"],
+}
+
+
+def get_default_kwargs_values(layer: str) -> Dict[str, Any]:
+ """Gets the default kwargs for a tf.keras.layers layer."""
+ layer_class = getattr(tf.keras.layers, layer)
+ layer_parameters = inspect.signature(layer_class.__init__).parameters
+ kwargs_to_default_values = {
+ kwarg: value.default
+ for kwarg, value in layer_parameters.items()
+ if value.default is not inspect.Parameter.empty
+ }
+ return kwargs_to_default_values
+
+
+def _equal_or_splat_equal(value: Any, sequence: Any) -> bool:
+ """Returns True if value==sequence or value==(every element in sequence)."""
+ if value == sequence:
+ return True
+ elif isinstance(sequence, (list, tuple)):
+ for element in sequence:
+ if not _equal_or_splat_equal(value, element):
+ return False
+ return True
+ return False
+
+
+def get_non_default_kwargs(
+ layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> List[str]:
+ """Returns all non-default optional kwargs in unit_test_spec."""
+ kwargs_to_defaults = get_default_kwargs_values(layer)
+ non_default_kwargs = []
+ for kwarg, value in unit_test_spec.kwargs.items():
+ if (kwarg in kwargs_to_defaults and
+ not _equal_or_splat_equal(value, kwargs_to_defaults[kwarg])):
+ non_default_kwargs.append(kwarg)
+ return non_default_kwargs
+
+
+def unsupported_by_tf(layer: str,
+ unit_test_spec: tf_test_utils.UnitTestSpec) -> bool:
+ """True if unit_test_spec specifies tf-unsupported non-default kwargs."""
+ if layer in LAYERS_TO_TF_UNSUPPORTED_NON_DEFAULT_KWARGS:
+ unsupported_kwargs = LAYERS_TO_TF_UNSUPPORTED_NON_DEFAULT_KWARGS[layer]
+ non_default_kwargs = get_non_default_kwargs(layer, unit_test_spec)
+ return any(kwarg in unsupported_kwargs for kwarg in non_default_kwargs)
+ return False
+
+
+def has_mutually_exclusive_kwargs(
+ layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> bool:
+ """True if unit_test_spec specifies mutually exclusive non-default kwargs."""
+ if layer in LAYERS_TO_MUTUALLY_EXCLUSIVE_KWARGS:
+ mutually_exclusive_kwargs = LAYERS_TO_MUTUALLY_EXCLUSIVE_KWARGS[layer]
+ non_default_kwargs = get_non_default_kwargs(layer, unit_test_spec)
+ return set(mutually_exclusive_kwargs).issubset(set(non_default_kwargs))
+ return False
+
+
+# A dictionary mapping tf.keras.layers names to lists of UnitTestSpecs.
+# Each unit_test_name will have the tf.keras.layer name prepended to it.
+#
+# Each layer is required to have a UnitTestSpec with all-default values for
+# unrequired kwargs. This allows us to seperately test the basic api and the
+# full api.
+LAYERS_TO_UNIT_TEST_SPECS = {
+ "Activation":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(activation=["relu"])),
+ "ActivityRegularization":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(l1=[0.0, 0.1], l2=[0.0, 0.1])),
+ "Add":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "AdditiveAttention":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[(RANK_3_SHAPE, RANK_3_SHAPE, RANK_3_SHAPE)],
+ kwargs_to_values=dict(causal=[False, True])),
+ "AlphaDropout":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "Attention":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[(RANK_3_SHAPE, RANK_3_SHAPE, RANK_3_SHAPE)],
+ kwargs_to_values=dict(causal=[False, True])),
+ "Average":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "AveragePooling1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_1D_KWARGS_TO_VALUES),
+ "AveragePooling2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_KWARGS_TO_VALUES),
+ "AveragePooling3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_KWARGS_TO_VALUES),
+ "BatchNormalization":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(renorm=[False, True])),
+ "Concatenate":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(axis=[-1, 0])),
+ "Conv1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_1D_KWARGS_TO_VALUES),
+ "Conv1DTranspose":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "Conv2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "Conv2DTranspose":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "Conv3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "Conv3DTranspose":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "ConvLSTM2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ return_state=[False, True],
+ strides=[1, 2],
+ dilation_rate=[1, 2],
+ stateful=[False, True])),
+ "Cropping1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(cropping=[1, (1, 2)])),
+ "Cropping2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(cropping=[0, ((1, 2), (2, 1))])),
+ "Cropping3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(cropping=[1, ((1, 2), (2, 1), (1, 0))])),
+ "Dense":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(units=[8])),
+ "DepthwiseConv2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 2],
+ padding=["valid", "same"],
+ dilation_rate=[1, 2],
+ depth_multiplier=[1, 2])),
+ "Dot":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(axes=[(1, 2)])),
+ "Dropout":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "ELU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "Embedding":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(input_dim=[4], output_dim=[2])),
+ "Flatten":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "GRU":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=RNN_KWARGS_TO_VALUES),
+ "GaussianDropout":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "GaussianNoise":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(stddev=[1.0])),
+ "GlobalAveragePooling1D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
+ "GlobalAveragePooling2D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
+ "GlobalAveragePooling3D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
+ "GlobalMaxPool1D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
+ "GlobalMaxPool2D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
+ "GlobalMaxPool3D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
+ "InputLayer":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "LSTM":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=RNN_KWARGS_TO_VALUES),
+ "Lambda":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(function=[lambda x: x**2])),
+ "LayerNormalization":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "LeakyReLU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "LocallyConnected1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 3],
+ padding=["valid", "same"],
+ implementation=[1, 3])),
+ "LocallyConnected2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 3],
+ padding=["valid", "same"],
+ implementation=[1, 3])),
+ "Masking":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "MaxPool1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_1D_KWARGS_TO_VALUES),
+ "MaxPool2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_KWARGS_TO_VALUES),
+ "MaxPool3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_KWARGS_TO_VALUES),
+ "Maximum":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "Minimum":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "MultiHeadAttention":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[(RANK_3_SHAPE, RANK_3_SHAPE)],
+ kwargs_to_values=dict(num_heads=[2], key_dim=[3])),
+ "Multiply":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "PReLU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "Permute":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(dims=[(3, 1, 4, 2)])),
+ "ReLU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "RepeatVector":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[((2, 2),)], kwargs_to_values=dict(n=[3])),
+ "Reshape":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[((3, 2, 2, 2),)],
+ kwargs_to_values=dict(target_shape=[(2, 1, 4, 1)])),
+ "SeparableConv1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 2],
+ padding=["valid", "same"],
+ dilation_rate=[1, 2],
+ depth_multiplier=[1, 2])),
+ "SeparableConv2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 2],
+ padding=["valid", "same"],
+ dilation_rate=[1, 2],
+ depth_multiplier=[1, 2])),
+ "SimpleRNN":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=RNN_KWARGS_TO_VALUES),
+ "Softmax":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "SpatialDropout1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "SpatialDropout2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "SpatialDropout3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "Subtract":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "ThresholdedReLU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "UpSampling1D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
+ "UpSampling2D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
+ "UpSampling3D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
+ "ZeroPadding1D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
+ "ZeroPadding2D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
+ "ZeroPadding3D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
+}
+
+for layer, specs in LAYERS_TO_UNIT_TEST_SPECS.items():
+ # Update using 'with_name' to avoid updating shared UnitTestSpecs.
+ specs = [spec.with_name(f"{layer}__{spec.unit_test_name}") for spec in specs]
+ LAYERS_TO_UNIT_TEST_SPECS[layer] = specs
+
+ # Validate that there are not multiple UnitTestSpecs with the same name.
+ seen_unit_test_names = set()
+ for spec in specs:
+ if spec.unit_test_name in seen_unit_test_names:
+ raise ValueError(
+ f"Found multiple UnitTestSpecs with the name '{spec.unit_test_name}'")
+ seen_unit_test_names.add(spec.unit_test_name)
+
+ # Validate that there is one spec that has default values for all unrequired
+ # kwargs.
+ has_default_unrequired_kwargs = False
+ for spec in specs:
+ if not get_non_default_kwargs(layer, spec):
+ has_default_unrequired_kwargs = True
+
+ if not has_default_unrequired_kwargs:
+ raise ValueError(
+ f"The configuration for '{layer}' did not have a UnitTestSpec with all "
+ "default kwargs.")
# Layers that allow specifying the 'dropout' kwarg.
DROPOUT_LAYERS = [
- 'AdditiveAttention', 'Attention', 'ConvLSTM2D', 'GRU', 'GRUCell', 'LSTM',
- 'LSTMCell', 'MultiHeadAttention', 'SimpleRNN', 'SimpleRNNCell'
+ "AdditiveAttention", "Attention", "ConvLSTM2D", "GRU", "LSTM",
+ "MultiHeadAttention", "SimpleRNN"
]
-flags.DEFINE_string('layer', 'Dense',
- f'One of {list(LAYER_TO_UNITTEST_CONFIGURATIONS.keys())}.')
+flags.DEFINE_string("layer", None,
+ f"One of {list(LAYERS_TO_UNIT_TEST_SPECS.keys())}.")
flags.DEFINE_bool(
- 'dynamic_batch', False,
- 'Whether or not to compile the layer with a dynamic batch size.')
-flags.DEFINE_bool('training', False,
- 'Whether or not to compile the layer in training mode.')
+ "dynamic_dims", False,
+ "Whether or not to compile the layer with a dynamic dimension sizes.")
+flags.DEFINE_bool("training", False,
+ "Whether or not to compile the layer in training mode.")
flags.DEFINE_bool(
- 'test_full_api', False,
- 'Whether or not to test multiple layer configurations using non-required '
- 'kwargs.')
+ "test_default_kwargs_only", True,
+ "Whether or not to test multiple layer configurations using non-required "
+ "kwargs.")
flags.DEFINE_bool(
- 'list_layers_with_full_api_tests', False,
- 'Whether or not to print out all layers with non-default configurations '
- '(and skip running the tests).')
-
-
-def get_configs() -> Dict[str, Config]:
- """Gets the configs that we want to test for FLAGS.layer."""
- configs = LAYER_TO_UNITTEST_CONFIGURATIONS[FLAGS.layer]
- if not FLAGS.test_full_api:
- return {'default': configs['default']}
- return configs # pytype: disable=bad-return-type
+ "list_layers_with_full_api_tests", False,
+ "Whether or not to print out all layers with non-default configurations "
+ "(and skip running the tests).")
def get_input(shape: Sequence[int]) -> tf.keras.layers.Input:
- """Gets the input shape(s) that we want to test."""
- batch_size = None if FLAGS.dynamic_batch else shape[0]
+ """Converts a shape into a tf.keras.Input."""
+ # Most keras layers are only compatible with dynamic batch sizes.
+ batch_size = None if FLAGS.dynamic_dims else shape[0]
return tf.keras.layers.Input(batch_size=batch_size, shape=shape[1:])
@@ -551,56 +485,76 @@
return list(args) if isinstance(args, tuple) else args
-def create_wrapped_keras_layer(config: Config) -> tf.keras.Model:
+def create_wrapped_keras_layer(
+ layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.keras.Model:
"""Wraps a keras layer in a model for compilation."""
- layer_class = getattr(tf.keras.layers, FLAGS.layer)
+ layer_class = getattr(tf.keras.layers, layer)
- if FLAGS.training and FLAGS.layer in DROPOUT_LAYERS:
- config.kwargs['dropout'] = DROPOUT
+ kwargs = copy.deepcopy(unit_test_spec.kwargs)
+ if FLAGS.training and layer in DROPOUT_LAYERS:
+ kwargs["dropout"] = DROPOUT
- inputs = keras_input_normalizer([get_input(shape) for shape in config.shapes])
- if FLAGS.layer == 'MultiHeadAttention':
+ if "dtype" not in unit_test_spec.kwargs:
+ kwargs["dtype"] = unit_test_spec.input_signature[0].dtype
+
+ inputs = keras_input_normalizer(
+ [get_input(spec.shape) for spec in unit_test_spec.input_signature])
+ if layer == "MultiHeadAttention":
# TODO(meadowlark): Remove specialization if API changes.
- outputs = layer_class(**config.kwargs)(*inputs)
+ outputs = layer_class(**kwargs)(*inputs)
else:
- outputs = layer_class(**config.kwargs)(inputs)
+ outputs = layer_class(**kwargs)(inputs)
return tf.keras.Model(inputs, outputs)
-def create_tf_function_unittest(config: Config, exported_name: str,
- model: tf.keras.Model) -> tf.function:
+def create_layer_unit_test(
+ model: tf.keras.Model,
+ unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
"""Wrap the model's __call__ function in a tf.function for testing."""
- input_shapes = config.shapes
- if FLAGS.dynamic_batch:
- input_shapes = [[None] + shape[1:] for shape in input_shapes]
+ static_signature = unit_test_spec.input_signature
- input_signature = [tf.TensorSpec(shape) for shape in input_shapes]
- if len(input_signature) > 1:
- input_signature = [input_signature]
+ dynamic_signature = static_signature
+ if FLAGS.dynamic_dims:
+ dynamic_signature = tf_utils.apply_function(dynamic_signature,
+ tf_utils.make_dims_dynamic)
+
+ if len(static_signature) > 1:
+ static_signature = [static_signature]
+ dynamic_signature = [dynamic_signature]
call = lambda *args: model(keras_arg_wrapper(*args), training=FLAGS.training)
- return tf_test_utils.tf_function_unittest(input_signature=input_signature,
- name=exported_name)(call)
+ return tf_test_utils.tf_function_unit_test(
+ input_signature=dynamic_signature,
+ static_signature=static_signature,
+ input_generator=unit_test_spec.input_generator,
+ input_args=unit_test_spec.input_args,
+ name=unit_test_spec.unit_test_name)(call)
class KerasLayersModule(tf_test_utils.TestModule):
- @classmethod
- def configure_class(cls):
- """Configure each tf_function_unittest and define it on the cls."""
- for i, (exported_name, config) in enumerate(get_configs().items()):
- model = create_wrapped_keras_layer(config)
- setattr(cls, exported_name,
- create_tf_function_unittest(config, exported_name, model))
-
def __init__(self):
super().__init__()
self.models = []
- for i, (exported_name, config) in enumerate(get_configs().items()):
- model = create_wrapped_keras_layer(config)
+ for unit_test_spec in LAYERS_TO_UNIT_TEST_SPECS[FLAGS.layer]:
+ if (FLAGS.test_default_kwargs_only and
+ get_non_default_kwargs(FLAGS.layer, unit_test_spec)):
+ # Skip all UnitTestSpecs with non-default unrequired kwargs.
+ continue
+
+ if (unsupported_by_tf(FLAGS.layer, unit_test_spec) or
+ has_mutually_exclusive_kwargs(FLAGS.layer, unit_test_spec)):
+ # Filter out UnitTestSpecs with kwargs that TensorFlow can't run on
+ # CPU or that are mutually exclusive. This allows us to take a product
+ # like that in CONV_KWARGS_TO_VALUE and filter out the configurations
+ # lacking support for particular layers.
+ continue
+
+ model = create_wrapped_keras_layer(FLAGS.layer, unit_test_spec)
+ # IREE requires that the models are stored on the module instance.
self.models.append(model)
- setattr(self, exported_name,
- create_tf_function_unittest(config, exported_name, model))
+ layer_unit_test = create_layer_unit_test(model, unit_test_spec)
+ setattr(self, unit_test_spec.unit_test_name, layer_unit_test)
class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
@@ -609,38 +563,41 @@
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
KerasLayersModule,
- exported_names=KerasLayersModule.get_exported_names())
+ exported_names=KerasLayersModule.get_tf_function_unit_tests())
def main(argv):
del argv # Unused.
- if hasattr(tf, 'enable_v2_behavior'):
+ if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
- if FLAGS.layer not in LAYER_TO_UNITTEST_CONFIGURATIONS:
- raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'.")
-
if FLAGS.list_layers_with_full_api_tests:
- for layer, configs in sorted(LAYER_TO_UNITTEST_CONFIGURATIONS.items()):
- if len(configs) > 1:
+ for layer, unit_test_specs in sorted(LAYERS_TO_UNIT_TEST_SPECS.items()):
+ if len(unit_test_specs) > 1:
print(f' "{layer}",')
return
- # Set up name for saving artifacts.
- dynamic_batch_str = 'dynamic_batch' if FLAGS.dynamic_batch else 'static_batch'
- training_str = 'training' if FLAGS.training else 'non_training'
- full_api_str = 'full_api' if FLAGS.test_full_api else 'default_api'
- settings_str = f'{full_api_str}_{dynamic_batch_str}_{training_str}'
- KerasLayersModule.__name__ = os.path.join('keras_layers', FLAGS.layer,
- settings_str)
+ if FLAGS.layer not in LAYERS_TO_UNIT_TEST_SPECS:
+ raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'")
- # Use the configurations for FLAGS.layer to add the tf.functions we wish
- # to test to the KerasLayersModule, and then generate unittests for each of
- # them.
- KerasLayersModule.configure_class()
- KerasLayersTest.generate_unittests(KerasLayersModule)
+ # Set up name for saving artifacts.
+ dynamic_str = "dynamic" if FLAGS.dynamic_dims else "static"
+ training_str = "training" if FLAGS.training else "non_training"
+ full_api_str = "default_api" if FLAGS.test_default_kwargs_only else "full_api"
+ settings_str = f"{full_api_str}_{dynamic_str}_{training_str}"
+ relative_artifacts_dir = os.path.join("tf", "keras", "layers", FLAGS.layer,
+ settings_str)
+ # The relative artifacts directory path is calculated from the module name
+ # TODO(meadowlark): provide a better way of overridding this default.
+ KerasLayersModule.__name__ = relative_artifacts_dir
+
+ unit_tests = KerasLayersModule.get_tf_function_unit_tests()
+ logging.info("Testing the following %s functions: %s", len(unit_tests),
+ unit_tests)
+
+ KerasLayersTest.generate_unit_tests(KerasLayersModule)
tf.test.main()
-if __name__ == '__main__':
+if __name__ == "__main__":
app.run(main)
diff --git a/integrations/tensorflow/e2e/logical_ops_test.py b/integrations/tensorflow/e2e/logical_ops_test.py
deleted file mode 100644
index cb25db2..0000000
--- a/integrations/tensorflow/e2e/logical_ops_test.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module that specifically handle logical ops."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class LogicalOpsModule(tf.Module):
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.bool),
- tf.TensorSpec([4], tf.bool)
- ])
- def logical_and(self, x, y):
- return tf.math.logical_and(x, y)
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.bool),
- tf.TensorSpec([4], tf.bool)
- ])
- def logical_or(self, x, y):
- return tf.math.logical_or(x, y)
-
- @tf.function(input_signature=[
- tf.TensorSpec([4], tf.bool),
- tf.TensorSpec([4], tf.bool)
- ])
- def logical_xor(self, x, y):
- return tf.math.logical_xor(x, y)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.bool)])
- def logical_not(self, x):
- return tf.math.logical_not(x)
-
-
-class LogicalOpsTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(LogicalOpsModule)
-
- # yapf: disable
- def test_logical_and(self):
- def logical_and(module):
- module.logical_and(
- np.array([1, 1, 0, 0], dtype=np.bool),
- np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_and, self._modules)
-
- def test_logical_or(self):
- def logical_or(module):
- module.logical_or(
- np.array([1, 1, 0, 0], dtype=np.bool),
- np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_or, self._modules)
-
- def test_logical_xor(self):
- def logical_xor(module):
- module.logical_xor(
- np.array([1, 1, 0, 0], dtype=np.bool),
- np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_xor, self._modules)
-
- def test_logical_not(self):
- def logical_not(module):
- module.logical_not(np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_not, self._modules)
- # yapf: enable
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/math/BUILD b/integrations/tensorflow/e2e/math/BUILD
new file mode 100644
index 0000000..1a1c46d
--- /dev/null
+++ b/integrations/tensorflow/e2e/math/BUILD
@@ -0,0 +1,992 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Test coverage across backends for e2e tests is defined directly in the BUILD
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage/tf-base-coverage
+# Updates made to test suite names should also be reflected here:
+# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
+
+load(
+ "//bindings/python:build_defs.oss.bzl",
+ "INTREE_TENSORFLOW_PY_DEPS",
+ "NUMPY_DEPS",
+ "iree_py_binary",
+ "iree_py_test",
+)
+load(
+ "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
+ "set_difference",
+)
+load(
+ "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+ "iree_e2e_cartesian_product_test_suite",
+)
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+[
+ iree_py_binary(
+ name = src.replace(".py", "_manual"),
+ srcs = [src],
+ main = src,
+ python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for src in glob(
+ ["*_test.py"],
+ exclude = ["keyword_spotting_streaming_test.py"],
+ )
+]
+
+# These functions were selected using all of the funcions in the tf.math docs:
+# https://www.tensorflow.org/api_docs/python/tf/math
+TF_MATH_FUNCTIONS = [
+ "abs",
+ "accumulate_n",
+ "acos",
+ "acosh",
+ "add",
+ "add_n",
+ "angle",
+ "argmax",
+ "argmin",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "ceil",
+ "confusion_matrix",
+ "cos",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumsum",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide",
+ "divide_no_nan",
+ "equal",
+ "erf",
+ "erfc",
+ "erfinv",
+ "exp",
+ "expm1",
+ "floor",
+ "floordiv",
+ "floormod",
+ "greater",
+ "greater_equal",
+ "igamma",
+ "igammac",
+ "imag",
+ "in_top_k",
+ "invert_permutation",
+ "is_finite",
+ "is_inf",
+ "is_nan",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "lbeta",
+ "less",
+ "less_equal",
+ "lgamma",
+ "log",
+ "log1p",
+ "log_sigmoid",
+ "log_softmax",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
+ "maximum",
+ "minimum",
+ "mod",
+ "multiply",
+ "multiply_no_nan",
+ "ndtri",
+ "negative",
+ "nextafter",
+ "not_equal",
+ "polygamma",
+ "polyval",
+ "pow",
+ "real",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_all",
+ "reduce_any",
+ "reduce_euclidean_norm",
+ "reduce_logsumexp",
+ "reduce_max",
+ "reduce_mean",
+ "reduce_min",
+ "reduce_prod",
+ "reduce_std",
+ "reduce_sum",
+ "reduce_variance",
+ "rint",
+ "round",
+ "rsqrt",
+ "scalar_mul",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "segment_sum",
+ "sigmoid",
+ "sign",
+ "sin",
+ "sinh",
+ "sobol_sample",
+ "softmax",
+ "softplus",
+ "softsign",
+ "sqrt",
+ "square",
+ "squared_difference",
+ "subtract",
+ "tan",
+ "tanh",
+ # "top_k", # TODO(meadowlark): Enable once list outputs are supported.
+ "truediv",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zero_fraction",
+ "zeta",
+]
+
+# keep sorted
+TFLITE_FAILING = [
+ "abs", # Failing for integer inputs.
+ "acos",
+ "acosh",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "confusion_matrix",
+ "conj",
+ "cosh",
+ "cumprod",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide", # Failing for integer inputs.
+ "erf",
+ "erfc",
+ "erfinv",
+ "expm1",
+ "igamma",
+ "igammac",
+ "in_top_k",
+ "invert_permutation",
+ "is_finite",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "l2_normalize",
+ "lbeta",
+ "lgamma",
+ "log1p",
+ "log_sigmoid",
+ "ndtri",
+ "nextafter",
+ "polygamma",
+ "polyval",
+ "pow", # Failing for integer inputs.
+ "reduce_all",
+ "reduce_euclidean_norm",
+ "reduce_logsumexp",
+ "reduce_mean",
+ "reduce_std",
+ "reduce_variance",
+ "rint",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "sign",
+ "sinh",
+ "sobol_sample",
+ "softmax",
+ "softplus",
+ "softsign",
+ "tan",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zeta",
+]
+
+# Note: The VMLA_FAILING_DYNAMIC specification extends this list. Newly-passing
+# functions removed from this list may need to be added to VMLA_FAILING_DYNAMIC.
+# keep sorted
+VMLA_FAILING = [
+ "acosh",
+ "argmax",
+ "argmin",
+ "asin",
+ "asinh",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "confusion_matrix",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide", # Failing for integer inputs because iree doesn't output 'f64'.
+ "erf",
+ "erfc",
+ "erfinv",
+ "expm1",
+ "igamma",
+ "igammac",
+ "in_top_k",
+ "invert_permutation",
+ "is_nan",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "ndtri",
+ "nextafter",
+ "polygamma",
+ "pow", # Failing for integer inputs.
+ "reduce_all",
+ "reduce_any",
+ "reduce_euclidean_norm",
+ "reduce_prod",
+ "rint",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "segment_sum",
+ "sign",
+ "sobol_sample",
+ "softsign",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zeta",
+]
+
+# Note: The LLVM_FAILING_DYNAMIC specification extends this list. Newly-passing
+# functions removed from this list may need to be added to LLVM_FAILING_DYNAMIC.
+# keep sorted
+LLVM_FAILING = [
+ "acos",
+ "acosh",
+ "argmax",
+ "argmin",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "confusion_matrix",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide", # Failing for integer inputs because iree doesn't output 'f64'.
+ "erf",
+ "erfc",
+ "erfinv",
+ "expm1",
+ "igamma",
+ "igammac",
+ "in_top_k",
+ "invert_permutation",
+ "is_nan",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "l2_normalize",
+ "logical_or",
+ "logical_xor",
+ "ndtri",
+ "nextafter",
+ "polygamma",
+ "pow",
+ "reduce_all",
+ "reduce_any",
+ "reduce_euclidean_norm",
+ "reduce_logsumexp",
+ "reduce_max",
+ "reduce_mean",
+ "reduce_min",
+ "reduce_prod",
+ "reduce_std",
+ "reduce_sum",
+ "reduce_variance",
+ "rint",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "segment_sum",
+ "sign",
+ "sobol_sample",
+ "softsign",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zeta",
+]
+
+# Note: The VULKAN_FAILING_DYNAMIC specification extends this list.
+# Newly-passing functions removed from this list may need to be added to
+# VULKAN_FAILING_DYNAMIC.
+# keep sorted
+VULKAN_FAILING = [
+ "acos",
+ "acosh",
+ "argmax",
+ "argmin",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bessel_i0",
+ "bessel_i0e",
+ "bessel_i1",
+ "bessel_i1e",
+ "betainc",
+ "bincount",
+ "confusion_matrix",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumsum",
+ "cumulative_logsumexp",
+ "digamma",
+ "divide", # Failing for integer inputs because iree doesn't output 'f64'.
+ "erf",
+ "erfc",
+ "erfinv",
+ "expm1",
+ "igamma",
+ "igammac",
+ "in_top_k",
+ "invert_permutation",
+ "is_nan",
+ "is_non_decreasing",
+ "is_strictly_increasing",
+ "l2_normalize",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
+ "mod", # Passes with swiftshader, but fails on Turing GPU
+ "ndtri",
+ "nextafter",
+ "polygamma",
+ "pow",
+ "reduce_all",
+ "reduce_any",
+ "reduce_euclidean_norm",
+ "reduce_logsumexp",
+ "reduce_max",
+ "reduce_mean",
+ "reduce_min",
+ "reduce_prod",
+ "reduce_std",
+ "reduce_sum",
+ "reduce_variance",
+ "rint",
+ "segment_max",
+ "segment_mean",
+ "segment_min",
+ "segment_prod",
+ "segment_sum",
+ "sign",
+ "sobol_sample",
+ "softsign",
+ "unsorted_segment_max",
+ "unsorted_segment_mean",
+ "unsorted_segment_min",
+ "unsorted_segment_prod",
+ "unsorted_segment_sqrt_n",
+ "unsorted_segment_sum",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zeta",
+]
+
+# ---- INDIVIDUAL STATIC TESTS ----------------------------------------------- #
+
+# These tests allow us to generate coverage tables and give a finer-grained view
+# of the coverage, but are very slow due to bazel overhead, so they are not
+# run on the internal or OSS CI.
+iree_e2e_cartesian_product_test_suite(
+ name = "math_tests",
+ srcs = ["math_test.py"],
+ failing_configurations = [
+ {
+ # Failing on TFLite.
+ "functions": TFLITE_FAILING,
+ "target_backends": "tflite",
+ },
+ {
+ # Failing on vmla.
+ "functions": VMLA_FAILING,
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing on llvm.
+ "functions": LLVM_FAILING,
+ "target_backends": "iree_llvmjit",
+ },
+ {
+ # Failing on vulkan.
+ "functions": VULKAN_FAILING,
+ "target_backends": "iree_vulkan",
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "functions": TF_MATH_FUNCTIONS,
+ "dynamic_dims": False,
+ "test_complex": False,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "math_test.py",
+ tags = [
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# ---- MULTIPLE STATIC TESTS ------------------------------------------------ #
+
+# These tests compile all functions in tf.math at once for testing so that
+# we can run them on the CI with 5 additional targets instead of 640. The tests
+# are run sharded such that about 5 functions run per shard. This is a
+# reasonable tradeoff between shard startup overhead and critical path test
+# latency.
+
+# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
+VMLA_FAILING_MULTIPLE = VMLA_FAILING + ["multiply"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+# @unused
+LLVM_FAILING_MULTIPLE = LLVM_FAILING + ["square"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+VULKAN_FAILING_MULTIPLE = VULKAN_FAILING + ["square"]
+
+[
+ iree_py_test(
+ name = "math_tests_multiple__{}".format(target_backend),
+ srcs = ["math_test.py"],
+ args = [
+ "--reference_backend=tf",
+ "--target_backends={}".format(target_backend),
+ "--functions={}".format(",".join(functions)),
+ "--dynamic_dims=False",
+ ],
+ main = "math_test.py",
+ python_version = "PY3",
+ shard_count = len(functions) // 5,
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for target_backend, functions in dict(
+ # TODO(#2673): enable LLVM AOT for these tests. JIT is deprecated.
+ # iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_MULTIPLE),
+ iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_MULTIPLE),
+ iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_MULTIPLE),
+ tf = TF_MATH_FUNCTIONS,
+ tflite = set_difference(TF_MATH_FUNCTIONS, TFLITE_FAILING),
+ ).items()
+]
+
+# ---- INDIVIDUAL DYNAMIC TESTS ---------------------------------------------- #
+
+# keep sorted
+VMLA_FAILING_DYNAMIC = VMLA_FAILING + [
+ "angle",
+ "cumsum",
+ "divide_no_nan",
+ "equal",
+ "floormod",
+ "imag",
+ "lbeta",
+ "lgamma",
+ "log_sigmoid",
+ "log1p",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
+ "mod",
+ "floordiv",
+ "multiply_no_nan",
+ "round",
+ "not_equal",
+ "reciprocal_no_nan",
+ "reduce_logsumexp",
+ "reduce_max",
+ "reduce_min",
+ "reduce_sum",
+ "reduce_mean",
+ "reduce_std",
+ "reduce_variance",
+ "softplus",
+ "zero_fraction",
+]
+
+# keep sorted
+LLVM_FAILING_DYNAMIC = LLVM_FAILING + [
+ "accumulate_n",
+ "add",
+ "add_n",
+ "angle",
+ "cumsum",
+ "divide",
+ "divide_no_nan",
+ "equal",
+ "floordiv",
+ "floormod",
+ "greater",
+ "greater_equal",
+ "is_finite",
+ "is_inf",
+ "lbeta",
+ "less",
+ "less_equal",
+ "lgamma",
+ "log_sigmoid",
+ "log_softmax",
+ "log1p",
+ "logical_and",
+ "logical_not",
+ "maximum",
+ "minimum",
+ "mod",
+ "multiply",
+ "multiply_no_nan",
+ "not_equal",
+ "polyval",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_mean",
+ "scalar_mul",
+ "sigmoid",
+ "sinh",
+ "softmax",
+ "softplus",
+ "square",
+ "squared_difference",
+ "subtract",
+ "round",
+ "tan",
+ "truediv",
+ "zero_fraction",
+]
+
+# keep sorted
+VULKAN_FAILING_DYNAMIC = VULKAN_FAILING + [
+ "abs",
+ "accumulate_n",
+ "add",
+ "add_n",
+ "angle",
+ "ceil",
+ "cos",
+ "divide",
+ "divide_no_nan",
+ "equal",
+ "exp",
+ "floor",
+ "floordiv",
+ "floormod",
+ "greater",
+ "greater_equal",
+ "imag",
+ "is_finite",
+ "is_inf",
+ "lbeta",
+ "less",
+ "round",
+ "less_equal",
+ "lgamma",
+ "log",
+ "log_sigmoid",
+ "log_softmax",
+ "log1p",
+ "maximum",
+ "minimum",
+ "mod",
+ "multiply",
+ "multiply_no_nan",
+ "negative",
+ "not_equal",
+ "polyval",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_max",
+ "reduce_mean",
+ "reduce_min",
+ "reduce_sum",
+ "rsqrt",
+ "scalar_mul",
+ "sigmoid",
+ "sin",
+ "sinh",
+ "softmax",
+ "softplus",
+ "sqrt",
+ "square",
+ "squared_difference",
+ "subtract",
+ "tan",
+ "tanh",
+ "truediv",
+ "zero_fraction",
+]
+
+# These tests allow us to generate coverage tables and give a finer-grained view
+# of the coverage, but are very slow due to bazel overhead, so they are not
+# run on the internal or OSS CI.
+iree_e2e_cartesian_product_test_suite(
+ name = "math_dynamic_dims_tests",
+ srcs = ["math_test.py"],
+ failing_configurations = [
+ {
+ # TFLite does not support dynamic shapes.
+ "target_backends": "tflite",
+ },
+ {
+ # Failing on vmla.
+ "functions": VMLA_FAILING_DYNAMIC,
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing on llvm.
+ "functions": LLVM_FAILING_DYNAMIC,
+ "target_backends": "iree_llvmjit",
+ },
+ {
+ # Failing on vulkan.
+ "functions": VULKAN_FAILING_DYNAMIC,
+ "target_backends": "iree_vulkan",
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "functions": TF_MATH_FUNCTIONS,
+ "dynamic_dims": True,
+ "test_complex": False,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "math_test.py",
+ tags = [
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# ---- MULTIPLE DYNAMIC TESTS ----------------------------------------------- #
+
+# These tests compile all functions in tf.math at once for testing so that
+# we can run them on the CI with 4 additional targets instead of 512. The tests
+# are run sharded such that about 5 functions run per shard. This is a
+# reasonable tradeoff between shard startup overhead and critical path test
+# latency.
+
+# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
+VMLA_FAILING_DYNAMIC_MULTIPLE = VMLA_FAILING_DYNAMIC + ["multiply"]
+
+[
+ iree_py_test(
+ name = "math_dynamic_dims_tests_multiple__{}".format(target_backend),
+ srcs = ["math_test.py"],
+ args = [
+ "--reference_backend=tf",
+ "--target_backends={}".format(target_backend),
+ "--functions={}".format(",".join(functions)),
+ "--dynamic_dims=False",
+ ],
+ main = "math_test.py",
+ python_version = "PY3",
+ shard_count = len(functions) // 5,
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for target_backend, functions in dict(
+ # TODO(#2673): enable LLVM AOT for these tests. JIT is deprecated.
+ # iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_DYNAMIC),
+ iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_DYNAMIC_MULTIPLE),
+ iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_DYNAMIC),
+ tf = TF_MATH_FUNCTIONS,
+ ).items()
+]
+
+# ---- INDIVIDUAL COMPLEX TESTS ---------------------------------------------- #
+
+# This list was generated by running:
+# bazel run integrations/tensorflow/e2e/math:math_test_manual -- --list_functions_with_complex_tests
+COMPLEX_FUNCTIONS = [
+ "abs",
+ "add",
+ "angle",
+ "asinh",
+ "atanh",
+ "conj",
+ "cos",
+ "cosh",
+ "count_nonzero",
+ "cumprod",
+ "cumsum",
+ "divide",
+ "divide_no_nan",
+ "exp",
+ "expm1",
+ "imag",
+ "l2_normalize",
+ "log",
+ "log1p",
+ "multiply",
+ "multiply_no_nan",
+ "negative",
+ "pow",
+ "real",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_euclidean_norm",
+ "reduce_std",
+ "reduce_variance",
+ "rsqrt",
+ "sigmoid",
+ "sign",
+ "sin",
+ "sinh",
+ "sqrt",
+ "square",
+ "squared_difference",
+ "subtract",
+ "tan",
+ "tanh",
+ "truediv",
+ "xdivy",
+ "xlog1py",
+ "xlogy",
+ "zero_fraction",
+]
+
+# keep sorted
+FAILING_COMPLEX = [
+ "angle",
+ "cos",
+ "cumsum",
+ "divide_no_nan",
+ "log",
+ "log1p",
+ "multiply_no_nan",
+ "negative",
+ "reciprocal",
+ "reciprocal_no_nan",
+ "reduce_std",
+ "reduce_variance",
+ "rsqrt",
+ "sigmoid",
+ "sin",
+ "sinh",
+ "sqrt",
+ "tan",
+ "tanh",
+ "zero_fraction",
+]
+
+VMLA_FAILING_COMPLEX = VMLA_FAILING + FAILING_COMPLEX
+
+LLVM_FAILING_COMPLEX = LLVM_FAILING + FAILING_COMPLEX
+
+VULKAN_FAILING_COMPLEX = VULKAN_FAILING + FAILING_COMPLEX
+
+# These tests allow us to generate coverage tables and give a finer-grained view
+# of the coverage, but are very slow due to bazel overhead, so they are not
+# run on the internal or OSS CI.
+iree_e2e_cartesian_product_test_suite(
+ name = "math_complex_tests",
+ srcs = ["math_test.py"],
+ failing_configurations = [
+ {
+ # TFLite does not support complex numbers.
+ "target_backends": "tflite",
+ },
+ {
+ # Failing on vmla.
+ "functions": VMLA_FAILING_COMPLEX,
+ "target_backends": "iree_vmla",
+ },
+ {
+ # Failing on llvm.
+ "functions": LLVM_FAILING_COMPLEX,
+ "target_backends": "iree_llvmjit",
+ },
+ {
+ # Failing on vulkan.
+ "functions": VULKAN_FAILING_COMPLEX,
+ "target_backends": "iree_vulkan",
+ },
+ ],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "functions": COMPLEX_FUNCTIONS,
+ "dynamic_dims": False,
+ "test_complex": True,
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "math_test.py",
+ tags = [
+ "manual",
+ "nokokoro",
+ "notap",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+# ---- MULTIPLE COMPLEX TESTS ----------------------------------------------- #
+
+# These tests compile all functions in tf.math at once for testing so that
+# we can run them on the CI with 4 additional targets instead of 512. The tests
+# are run sharded such that about 5 functions run per shard. This is a
+# reasonable tradeoff between shard startup overhead and critical path test
+# latency.
+
+# TODO(#3810) 'multiply' outputs all zeros when compiled with other functions.
+VMLA_FAILING_COMPLEX_MULTIPLE = VMLA_FAILING_COMPLEX + ["multiply"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+# @unused
+LLVM_FAILING_COMPLEX_MULTIPLE = LLVM_FAILING_COMPLEX + ["square"]
+
+# TODO(#3810) Including 'square' causes error: Recieved signal 11.
+VULKAN_FAILING_COMPLEX_MULTIPLE = VULKAN_FAILING_COMPLEX + ["square"]
+
+[
+ iree_py_test(
+ name = "math_complex_tests_multiple__{}".format(target_backend),
+ srcs = ["math_test.py"],
+ args = [
+ "--reference_backend=tf",
+ "--target_backends={}".format(target_backend),
+ "--functions={}".format(",".join(functions)),
+ "--dynamic_dims=False",
+ ],
+ main = "math_test.py",
+ python_version = "PY3",
+ shard_count = len(functions) // 5,
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for target_backend, functions in dict(
+ # TODO(#2673): enable LLVM AOT for these tests. JIT is deprecated.
+ # iree_llvmjit = set_difference(TF_MATH_FUNCTIONS, LLVM_FAILING_COMPLEX_MULTIPLE),
+ iree_vmla = set_difference(TF_MATH_FUNCTIONS, VMLA_FAILING_COMPLEX_MULTIPLE),
+ iree_vulkan = set_difference(TF_MATH_FUNCTIONS, VULKAN_FAILING_COMPLEX_MULTIPLE),
+ tf = TF_MATH_FUNCTIONS,
+ ).items()
+]
diff --git a/integrations/tensorflow/e2e/math/math_test.py b/integrations/tensorflow/e2e/math/math_test.py
new file mode 100644
index 0000000..00020d0
--- /dev/null
+++ b/integrations/tensorflow/e2e/math/math_test.py
@@ -0,0 +1,749 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import os
+from typing import Any, Dict, Sequence, Type, Union
+
+from absl import app
+from absl import flags
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+FLAGS = flags.FLAGS
+
+# As high as tf goes without breaking.
+RANK_7_SHAPE = [2] * 7
+UNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE]]
+BINARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 2]
+TERNARY_SIGNATURE_SHAPES = [[RANK_7_SHAPE] * 3]
+
+# Reused UnitTestSpecs.
+SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={
+ "tf_doc_example": [
+ tf.constant([
+ [1, 2, 3, 4],
+ [4, 3, 2, 1],
+ [5, 6, 7, 8],
+ ], np.float32),
+ np.array([0, 0, 1], np.int32),
+ ]
+ })
+UNSORTED_SEGMENT_UNIT_TEST_SPECS = tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={
+ "tf_doc_example": [
+ tf.constant([
+ [1, 2, 3, 4],
+ [4, 3, 2, 1],
+ [5, 6, 7, 8],
+ ], np.float32),
+ np.array([0, 0, 1], np.int32),
+ 2,
+ ]
+ })
+
+REDUCE_KWARGS_TO_VALUES = {
+ "axis": [None, 1],
+ "keepdims": [False, True],
+}
+
+# A dictionary mapping tf.math function names to lists of UnitTestSpecs.
+# Each unit_test_name will have the tf.math function name prepended to it.
+FUNCTIONS_TO_UNIT_TEST_SPECS = {
+ "abs":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "accumulate_n": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name='f32',
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]),
+ tf_test_utils.UnitTestSpec(
+ unit_test_name='i32',
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]),
+ ],
+ "acos":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "acosh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32],
+ input_generators=[tf_utils.ndarange]),
+ "add":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "add_n": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name='f32',
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.float32)] * 5]),
+ tf_test_utils.UnitTestSpec(
+ unit_test_name='i32',
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE, tf.int32)] * 5]),
+ ],
+ "angle":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "argmax":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "argmin":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "asin":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "asinh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "atan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "atan2":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "atanh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "bessel_i0":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "bessel_i0e":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "bessel_i1":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "bessel_i1e":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "betainc":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=TERNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "bincount":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.int32],
+ input_generators=[tf_utils.ndarange]),
+ "ceil":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "confusion_matrix":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "five_classes": [tf.constant([1, 2, 4]),
+ tf.constant([2, 2, 4])]
+ }),
+ "conj":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "cos":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "cosh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "count_nonzero":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64],
+ input_generators=[tf_utils.ndarange]),
+ "cumprod":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "cumsum":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "cumulative_logsumexp":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "digamma":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "divide":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "divide_no_nan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "equal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "erf":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "erfc":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "erfinv":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "exp":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "expm1":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "floor":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "floordiv":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ # Avoid integer division by 0.
+ input_generators={
+ "uniform_1_3":
+ lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
+ }),
+ "floormod":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ # Avoid integer division by 0.
+ input_generators={
+ "uniform_1_3":
+ lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
+ }),
+ "greater":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "greater_equal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "igamma":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "igammac":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "imag":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "in_top_k": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name="k_3",
+ input_signature=[
+ tf.TensorSpec([8], tf.int32),
+ tf.TensorSpec([8, 3])
+ ],
+ input_generator=tf_utils.ndarange,
+ kwargs=dict(k=3),
+ )
+ ],
+ "invert_permutation": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name="random",
+ input_signature=[tf.TensorSpec([8], tf.int32)],
+ input_generator=tf_utils.random_permutation,
+ )
+ ],
+ "is_finite":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
+ }),
+ "is_inf":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
+ }),
+ "is_nan":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "nan_and_inf": [tf.constant([[1., np.nan], [np.inf, 2.]])]
+ }),
+ "is_non_decreasing":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "is_strictly_increasing":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "l2_normalize":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "lbeta":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "less":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "less_equal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "lgamma":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "log":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "log1p":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "log_sigmoid":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "log_softmax":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "logical_and":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool]),
+ "logical_not":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool]),
+ "logical_or":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool]),
+ "logical_xor":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool]),
+ "maximum":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "minimum":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "mod":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ input_generators={
+ "positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1
+ }),
+ "multiply":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "multiply_no_nan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "ndtri":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "negative":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "nextafter":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES),
+ "not_equal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32]),
+ "polygamma":
+ tf_test_utils.unit_test_specs_from_args(names_to_input_args={
+ "nan_and_inf": [tf.ones(16), tf.linspace(0.5, 4, 16)]
+ }),
+ "polyval": [
+ tf_test_utils.UnitTestSpec(
+ unit_test_name="three_coeffs",
+ input_signature=[[tf.TensorSpec(RANK_7_SHAPE)] * 3,
+ tf.TensorSpec([])],
+ )
+ ],
+ "pow":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64],
+ input_generators={
+ "positive_ndarange": lambda *args: tf_utils.ndarange(*args) + 1
+ }),
+ "real":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "reciprocal":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "reciprocal_no_nan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "reduce_all": [
+ # Explicitly test all True inputs to be absolutely sure that some
+ # reduction axes return True.
+ *tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={
+ "all_true": [np.ones(RANK_7_SHAPE, np.bool)],
+ },
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ *tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ ],
+ "reduce_any": [
+ # Explicitly test all False inputs to be absolutely sure that some
+ # reduction axes return False.
+ *tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={
+ "all_false": [np.zeros(RANK_7_SHAPE, np.bool)],
+ },
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ *tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.bool],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ ],
+ "reduce_euclidean_norm":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_logsumexp":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_max":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_mean":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_min":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_prod":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_std":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_sum":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "reduce_variance":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64],
+ kwargs_to_values=REDUCE_KWARGS_TO_VALUES),
+ "rint":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "round":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "rsqrt":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "scalar_mul":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[[[], [8]]]),
+ "segment_max":
+ SEGMENT_UNIT_TEST_SPECS,
+ "segment_mean":
+ SEGMENT_UNIT_TEST_SPECS,
+ "segment_min":
+ SEGMENT_UNIT_TEST_SPECS,
+ "segment_prod":
+ SEGMENT_UNIT_TEST_SPECS,
+ "segment_sum":
+ SEGMENT_UNIT_TEST_SPECS,
+ "sigmoid":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "sign":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "sin":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "sinh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "sobol_sample":
+ tf_test_utils.unit_test_specs_from_args(
+ names_to_input_args={"simple": [4, 3]}),
+ "softmax":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "softplus":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "softsign":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32]),
+ "sqrt":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "square":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "squared_difference":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "subtract":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "tan":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "tanh":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "top_k":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32],
+ kwargs_to_values={"k": [1, 2]}),
+ "truediv":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "unsorted_segment_max":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_mean":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_min":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_prod":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_sqrt_n":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "unsorted_segment_sum":
+ UNSORTED_SEGMENT_UNIT_TEST_SPECS,
+ "xdivy":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "xlog1py":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "xlogy":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.complex64]),
+ "zero_fraction":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ signature_dtypes=[tf.float32, tf.int32, tf.complex64]),
+ "zeta":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ # The function is poorly behaved near zero, so we test this range
+ # to avoid outputing all nans.
+ input_generators={
+ "uniform_3_4":
+ lambda *args: tf_utils.uniform(*args, low=3.0, high=4.0)
+ },
+ )
+}
+
+for function, specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
+ # Update using 'with_name' to avoid updating shared UnitTestSpecs.
+ specs = [
+ spec.with_name(f"{function}__{spec.unit_test_name}") for spec in specs
+ ]
+ FUNCTIONS_TO_UNIT_TEST_SPECS[function] = specs
+
+ # Validate that there are not multiple UnitTestSpecs with the same name.
+ seen_unit_test_names = set()
+ for spec in specs:
+ if spec.unit_test_name in seen_unit_test_names:
+ raise ValueError(
+ f"Found multiple UnitTestSpecs with the name '{spec.unit_test_name}'")
+ seen_unit_test_names.add(spec.unit_test_name)
+
+flags.DEFINE_list(
+ "functions", None,
+ f"Any of {list(FUNCTIONS_TO_UNIT_TEST_SPECS.keys())}. If more than one "
+ "function is provided then len(--target_backends) must be one.")
+flags.DEFINE_bool(
+ "dynamic_dims", False,
+ "Whether or not to compile the layer with dynamic dimensions.")
+flags.DEFINE_bool(
+ "test_complex", False,
+ "Whether or not to test or ignore function signatures with complex types.")
+flags.DEFINE_bool(
+ 'list_functions_with_complex_tests', False,
+ 'Whether or not to print out all functions with complex inputs '
+ '(and skip running the tests).')
+
+
+def create_function_unit_test(
+ function_name: str,
+ unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
+ """Creates a tf_function_unit_test from the provided UnitTestSpec."""
+ function = getattr(tf.math, function_name)
+ signature = unit_test_spec.input_signature
+
+ if tf_utils.is_complex(signature):
+ function, signature = tf_utils.rewrite_complex_signature(
+ function, signature)
+ wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs)
+
+ if FLAGS.dynamic_dims:
+ signature = tf_utils.apply_function(signature, tf_utils.make_dims_dynamic)
+
+ return tf_test_utils.tf_function_unit_test(
+ input_signature=signature,
+ input_generator=unit_test_spec.input_generator,
+ input_args=unit_test_spec.input_args,
+ name=unit_test_spec.unit_test_name,
+ rtol=1e-5,
+ atol=1e-5)(wrapped_function)
+
+
+class TfMathModule(tf_test_utils.TestModule):
+
+ def __init__(self):
+ super().__init__()
+ for function in FLAGS.functions:
+ for unit_test_spec in FUNCTIONS_TO_UNIT_TEST_SPECS[function]:
+ if not FLAGS.test_complex and tf_utils.is_complex(
+ unit_test_spec.input_signature):
+ continue
+ function_unit_test = create_function_unit_test(function, unit_test_spec)
+ setattr(self, unit_test_spec.unit_test_name, function_unit_test)
+
+
+class TfMathTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(
+ TfMathModule, exported_names=TfMathModule.get_tf_function_unit_tests())
+
+
+def main(argv):
+ del argv # Unused.
+ if hasattr(tf, "enable_v2_behavior"):
+ tf.enable_v2_behavior()
+
+ if FLAGS.list_functions_with_complex_tests:
+ for function_name, unit_test_specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
+ for spec in unit_test_specs:
+ if tf_utils.is_complex(spec.input_signature):
+ print(f' "{function_name}",')
+ return
+
+ if FLAGS.functions is None:
+ raise flags.IllegalFlagValueError(
+ "'--functions' must be specified if "
+ "'--list_functions_with_complex_tests' isn't")
+
+ if len(FLAGS.functions) > 1:
+ # We only allow testing multiple functions with a single target backend
+ # so that we can store the artifacts under:
+ # 'artifacts_dir/multiple_functions__backend/...'
+ # We specialize the 'multiple_functions' dir by backend to avoid overwriting
+ # tf_input.mlir and iree_input.mlir. These are typically identical across
+ # backends, but are not when the functions to compile change per-backend.
+ if len(FLAGS.target_backends) != 1:
+ raise flags.IllegalFlagValueError(
+ "Expected len(target_backends) == 1 when len(functions) > 1, but got "
+ f"the following values for target_backends: {FLAGS.target_backends}.")
+ function_str = f"multiple_functions__{FLAGS.target_backends[0]}"
+ else:
+ function_str = FLAGS.functions[0]
+ dim_str = "dynamic_dims" if FLAGS.dynamic_dims else "static_dims"
+ settings_str = os.path.join(function_str, dim_str)
+ # The relative artifacts directory path is calculated from the module name
+ # TODO(meadowlark): provide a better way of overridding this default.
+ TfMathModule.__name__ = os.path.join("tf", "math", settings_str)
+
+ TfMathTest.generate_unit_tests(TfMathModule)
+ tf.test.main()
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
deleted file mode 100644
index 79d1f11..0000000
--- a/integrations/tensorflow/e2e/math_test.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Lint as: python3
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Tests for ops in the tf.math module."""
-
-from absl import app
-import numpy as np
-from pyiree.tf.support import tf_test_utils
-import tensorflow.compat.v2 as tf
-
-
-class MathModule(tf.Module):
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def abs(self, x):
- return tf.math.abs(x)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def ceil(self, x):
- return tf.math.ceil(x)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def cos(self, x):
- return tf.math.cos(x)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def log(self, x):
- return tf.math.log(x)
-
- @tf.function(input_signature=[tf.TensorSpec([4], tf.float32)])
- def mod(self, x):
- return tf.math.mod(x, 2.0)
-
-
-class MathTest(tf_test_utils.TracedModuleTestCase):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._modules = tf_test_utils.compile_tf_module(MathModule)
-
- # yapf: disable
- def test_abs(self):
- def abs(module):
- module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(abs, self._modules)
-
- def test_ceil(self):
- def ceil(module):
- module.ceil(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(ceil, self._modules)
-
- def test_cos(self):
- def cos(module):
- module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(cos, self._modules)
-
- def test_log(self):
- def log(module):
- module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(log, self._modules)
-
- def test_mod(self):
- def mod(module):
- module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(mod, self._modules)
- # yapf: enable
-
-
-def main(argv):
- del argv # Unused
- if hasattr(tf, 'enable_v2_behavior'):
- tf.enable_v2_behavior()
- tf.test.main()
-
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/integrations/tensorflow/e2e/quantization_dyn_test.py b/integrations/tensorflow/e2e/quantization_dyn_test.py
new file mode 100644
index 0000000..d1c44ef
--- /dev/null
+++ b/integrations/tensorflow/e2e/quantization_dyn_test.py
@@ -0,0 +1,58 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for ops in the tf.math module."""
+
+from absl import app
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+
+class QuantizationDynModule(tf.Module):
+
+ @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
+ def fake_quant(self, x):
+ return tf.quantization.fake_quant_with_min_max_args(x,
+ min=-6,
+ max=6,
+ num_bits=8,
+ narrow_range=False,
+ name=None)
+
+
+class QuantizationDynTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(QuantizationDynModule)
+
+ def test_fake_quant(self):
+
+ def abs(module):
+ module.fake_quant(tf_utils.uniform([32], low=-6, high=6))
+
+ self.compare_backends(abs, self._modules)
+
+
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
+ tf.enable_v2_behavior()
+ tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/quantization_test.py b/integrations/tensorflow/e2e/quantization_test.py
new file mode 100644
index 0000000..2ccf8f8
--- /dev/null
+++ b/integrations/tensorflow/e2e/quantization_test.py
@@ -0,0 +1,55 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for ops in the tf.math module."""
+
+from absl import app
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+
+class QuantizationModule(tf_test_utils.TestModule):
+
+ @tf_test_utils.tf_function_unit_test(
+ input_signature=[tf.TensorSpec([32], tf.float32)],
+ input_generator=lambda *args: tf_utils.uniform(*args, low=-6, high=6))
+ def fake_quant(self, x):
+ return tf.quantization.fake_quant_with_min_max_args(x,
+ min=-6,
+ max=6,
+ num_bits=8,
+ narrow_range=False,
+ name=None)
+
+
+class QuantizationTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(QuantizationModule)
+
+
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
+ tf.enable_v2_behavior()
+
+ QuantizationTest.generate_unit_tests(QuantizationModule)
+ tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/reduce_test.py b/integrations/tensorflow/e2e/reduce_test.py
new file mode 100644
index 0000000..61f919d
--- /dev/null
+++ b/integrations/tensorflow/e2e/reduce_test.py
@@ -0,0 +1,94 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for ops in the tf.math module."""
+
+from absl import app
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+
+class ReduceModule(tf.Module):
+
+ @tf.function(input_signature=[tf.TensorSpec([4, 4, 4], tf.float32)])
+ def max(self, x):
+ return tf.math.reduce_max(x, axis=1)
+
+ @tf.function(input_signature=[tf.TensorSpec([4, 4, 4], tf.float32)])
+ def min(self, x):
+ return tf.math.reduce_min(x, axis=1)
+
+ @tf.function(input_signature=[tf.TensorSpec([4, 4, 4], tf.float32)])
+ def sum(self, x):
+ return tf.math.reduce_sum(x, axis=1)
+
+ @tf.function(input_signature=[tf.TensorSpec([4, 2], tf.bool)])
+ def reduce_any(self, x):
+ return tf.math.reduce_any(x, axis=1)
+
+ @tf.function(input_signature=[tf.TensorSpec([4, 2], tf.bool)])
+ def reduce_all(self, x):
+ return tf.math.reduce_all(x, axis=1)
+
+
+class ReduceTest(tf_test_utils.TracedModuleTestCase):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._modules = tf_test_utils.compile_tf_module(ReduceModule)
+
+ # yapf: disable
+ def test_max(self):
+ def max(module):
+ arr = tf_utils.uniform([4, 4, 4], dtype=tf.float32)
+ module.max(arr)
+ self.compare_backends(max, self._modules)
+
+ def test_min(self):
+ def min(module):
+ arr = tf_utils.uniform([4, 4, 4], dtype=tf.float32)
+ module.min(arr)
+ self.compare_backends(min, self._modules)
+
+ def test_sum(self):
+ def sum(module):
+ arr = tf_utils.uniform([4, 4, 4], dtype=tf.float32)
+ module.sum(arr)
+ self.compare_backends(sum, self._modules)
+
+ def test_any(self):
+ def reduce_any(module):
+ arr = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.bool)
+ module.reduce_any(arr)
+ self.compare_backends(reduce_any, self._modules)
+
+ def test_all(self):
+ def reduce_all(module):
+ arr = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.bool)
+ module.reduce_all(arr)
+ self.compare_backends(reduce_all, self._modules)
+ # yapf: enable
+
+
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
+ tf.enable_v2_behavior()
+ tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
index 88f60b0..f871fe2 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/BUILD
+++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -13,8 +13,8 @@
# limitations under the License.
# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. A coverage table generated from this file can be viewed here:
-# https://google.github.io/iree/tf-e2e-coverage
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage/vision-coverage
# Updates made to test suite names should also be reflected here:
# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
@@ -76,7 +76,6 @@
"resnet_v2_152",
],
"target_backends": [
- "iree_llvmjit",
"iree_vulkan",
],
},
@@ -159,7 +158,6 @@
"tf",
"tflite",
"iree_vmla",
- "iree_llvmjit",
"iree_vulkan",
],
},
diff --git a/iree/base/BUILD b/iree/base/BUILD
index 835aa26..e6e634d 100644
--- a/iree/base/BUILD
+++ b/iree/base/BUILD
@@ -14,7 +14,7 @@
# Common types and utilities used in the IREE codebase.
-load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
+load("//build_tools/bazel:iree_flatcc.bzl", "iree_flatbuffer_c_library")
package(
default_visibility = ["//visibility:public"],
@@ -22,33 +22,55 @@
licenses = ["notice"], # Apache 2.0
)
-cc_library(
- name = "alignment",
- hdrs = ["alignment.h"],
- deps = [
- ":target_platform",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Public API
+#===------------------------------------------------------------------------===#
cc_library(
name = "api",
- srcs = [
- "api.c",
- ],
+ srcs = ["api.c"],
hdrs = ["api.h"],
visibility = ["//visibility:public"],
deps = [
- ":api_hdrs",
- ":target_platform",
+ ":core_headers",
":tracing",
],
)
+#===------------------------------------------------------------------------===#
+# Core headers (platform detection, compiler compat, etc)
+#===------------------------------------------------------------------------===#
+
cc_library(
- name = "api_hdrs",
- hdrs = ["api.h"],
+ name = "core_headers",
+ hdrs = [
+ "alignment.h",
+ "atomics.h",
+ "bitfield.h",
+ "math.h",
+ "memory.h",
+ "target_platform.h",
+ ],
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/types:span", # bitfield.h
+ ],
)
+cc_test(
+ name = "bitfield_test",
+ srcs = ["bitfield_test.cc"],
+ deps = [
+ ":core_headers",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+#===------------------------------------------------------------------------===#
+# Internal IREE C++ wrappers and utilities
+#===------------------------------------------------------------------------===#
+
cc_library(
name = "arena",
srcs = ["arena.cc"],
@@ -71,26 +93,20 @@
)
cc_library(
- name = "atomics",
- hdrs = ["atomics.h"],
+ name = "atomic_slist",
+ srcs = ["atomic_slist.c"],
+ hdrs = ["atomic_slist.h"],
deps = [
- ":target_platform",
- ],
-)
-
-cc_library(
- name = "bitfield",
- hdrs = ["bitfield.h"],
- deps = [
- "@com_google_absl//absl/types:span",
+ ":core_headers",
+ ":synchronization",
],
)
cc_test(
- name = "bitfield_test",
- srcs = ["bitfield_test.cc"],
+ name = "atomic_slist_test",
+ srcs = ["atomic_slist_test.cc"],
deps = [
- ":bitfield",
+ ":atomic_slist",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
],
@@ -103,61 +119,25 @@
"dynamic_library_win32.cc",
],
hdrs = ["dynamic_library.h"],
- linkopts = select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- "-ldl",
- ],
- }),
deps = [
+ ":core_headers",
":file_path",
":logging",
":status",
- ":target_platform",
":tracing",
+ "//build_tools:default_linkopts",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
-cc_binary(
- name = "dynamic_library_test_library.so",
- testonly = True,
- srcs = ["dynamic_library_test_library.cc"],
- linkshared = True,
-)
-
-cc_embed_data(
- name = "dynamic_library_test_library",
- testonly = True,
- srcs = [":dynamic_library_test_library.so"],
- cc_file_output = "dynamic_library_test_library_embed.cc",
- cpp_namespace = "iree",
- flatten = True,
- h_file_output = "dynamic_library_test_library_embed.h",
-)
-
-cc_test(
- name = "dynamic_library_test",
- srcs = ["dynamic_library_test.cc"],
- deps = [
- ":dynamic_library",
- ":dynamic_library_test_library",
- ":file_io",
- ":status",
- ":target_platform",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
cc_library(
name = "file_io",
hdrs = ["file_io.h"],
deps = [
+ ":core_headers",
":status",
- ":target_platform",
"//iree/base/internal:file_io_internal",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -233,39 +213,32 @@
)
cc_library(
- name = "flatbuffer_util",
- srcs = ["flatbuffer_util.cc"],
- hdrs = ["flatbuffer_util.h"],
+ name = "flags",
+ srcs = ["flags.cc"],
+ hdrs = ["flags.h"],
deps = [
- ":file_mapping",
- ":memory",
- ":ref_ptr",
- ":status",
- ":tracing",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "init",
- srcs = ["init.cc"],
- hdrs = ["init.h"],
- deps = [
- ":initializer",
+ ":api",
"@com_google_absl//absl/flags:parse",
],
)
cc_library(
- name = "initializer",
- srcs = ["initializer.cc"],
- hdrs = ["initializer.h"],
+ name = "flatcc",
+ hdrs = ["flatcc.h"],
deps = [
- ":target_platform",
+ ":flatcc_dummy",
+ "@com_github_dvidelabs_flatcc//:runtime",
+ ],
+)
+
+iree_flatbuffer_c_library(
+ name = "flatcc_dummy",
+ srcs = ["flatcc.fbs"],
+ flatcc_args = [
+ "--reader",
+ "--builder",
+ "--verifier",
+ "--json",
],
)
@@ -317,24 +290,8 @@
],
hdrs = ["main.h"],
deps = [
+ ":core_headers",
":logging",
- ":target_platform",
- ],
-)
-
-cc_library(
- name = "math",
- hdrs = ["math.h"],
- deps = [
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "memory",
- hdrs = ["memory.h"],
- deps = [
- "@com_google_absl//absl/types:span",
],
)
@@ -404,16 +361,9 @@
name = "synchronization",
srcs = ["synchronization.c"],
hdrs = ["synchronization.h"],
- linkopts = select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- "-lpthread",
- ],
- }),
deps = [
":api",
- ":atomics",
- ":target_platform",
+ ":core_headers",
":tracing",
],
)
@@ -421,9 +371,6 @@
cc_test(
name = "synchronization_benchmark",
srcs = ["synchronization_benchmark.cc"],
- tags = [
- "nowindows", # TODO(#3615) make this link on windows
- ],
deps = [
":synchronization",
"//iree/testing:benchmark_main",
@@ -434,9 +381,6 @@
cc_test(
name = "synchronization_test",
srcs = ["synchronization_test.cc"],
- tags = [
- "nowindows", # TODO(#3615) make this link on windows
- ],
deps = [
":synchronization",
"//iree/testing:gtest",
@@ -445,11 +389,6 @@
)
cc_library(
- name = "target_platform",
- hdrs = ["target_platform.h"],
-)
-
-cc_library(
name = "time",
hdrs = ["time.h"],
deps = [
@@ -477,31 +416,18 @@
"threading_win32.c",
],
hdrs = ["threading.h"],
- copts = [
- "-D_GNU_SOURCE=1",
- ],
- linkopts = select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- "-ldl",
- "-lpthread",
- ],
- }),
deps = [
":api",
- ":atomics",
+ ":core_headers",
":synchronization",
- ":target_platform",
":tracing",
+ "//build_tools:default_linkopts",
],
)
cc_test(
name = "threading_benchmark",
srcs = ["threading_benchmark.cc"],
- tags = [
- "nowindows", # TODO(#3615) make this link on windows
- ],
deps = [
":threading",
"//iree/testing:benchmark_main",
@@ -515,9 +441,6 @@
"threading_impl.h",
"threading_test.cc",
],
- tags = [
- "nowindows", # TODO(#3615) make this link on windows
- ],
deps = [
":synchronization",
":threading",
@@ -534,34 +457,32 @@
],
)
-# Dependent code has been removed and wait_handle is currently incompatible
-# with Windows, so excluding entirely.
-# See google/iree/65
-# cc_library(
-# name = "wait_handle",
-# srcs = ["wait_handle.cc"],
-# hdrs = ["wait_handle.h"],
-# deps = [
-# ":logging",
-# ":ref_ptr",
-# ":status",
-# ":time",
-# "@com_google_absl//absl/base:core_headers",
-# "@com_google_absl//absl/container:fixed_array",
-# "@com_google_absl//absl/strings",
-# "@com_google_absl//absl/time",
-# "@com_google_absl//absl/types:span",
-# ],
-# )
+cc_library(
+ name = "wait_handle",
+ srcs = [
+ "wait_handle.c",
+ "wait_handle_epoll.c",
+ "wait_handle_impl.h",
+ "wait_handle_kqueue.c",
+ "wait_handle_poll.c",
+ "wait_handle_posix.c",
+ "wait_handle_posix.h",
+ "wait_handle_win32.c",
+ ],
+ hdrs = ["wait_handle.h"],
+ deps = [
+ ":api",
+ ":core_headers",
+ ":tracing",
+ ],
+)
-# cc_test(
-# name = "wait_handle_test",
-# srcs = ["wait_handle_test.cc"],
-# deps = [
-# ":status",
-# ":wait_handle",
-# "@com_google_absl//absl/time",
-# "//iree/testing:gtest",
-# "//iree/testing:gtest_main",
-# ],
-# )
+cc_test(
+ name = "wait_handle_test",
+ srcs = ["wait_handle_test.cc"],
+ deps = [
+ ":wait_handle",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt
index 01a5b60..c45233e 100644
--- a/iree/base/CMakeLists.txt
+++ b/iree/base/CMakeLists.txt
@@ -12,17 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-iree_add_all_subdirs()
+# bazel_to_cmake: DO NOT EDIT (tracing rule move)
-iree_cc_library(
- NAME
- alignment
- HDRS
- "alignment.h"
- DEPS
- ::target_platform
- PUBLIC
-)
+iree_add_all_subdirs()
iree_cc_library(
NAME
@@ -32,20 +24,38 @@
SRCS
"api.c"
DEPS
- ::api_hdrs
- ::target_platform
+ ::core_headers
::tracing
PUBLIC
)
iree_cc_library(
NAME
- api_hdrs
+ core_headers
HDRS
- "api.h"
+ "alignment.h"
+ "atomics.h"
+ "bitfield.h"
+ "math.h"
+ "memory.h"
+ "target_platform.h"
+ DEPS
+ absl::core_headers
+ absl::span
PUBLIC
)
+iree_cc_test(
+ NAME
+ bitfield_test
+ SRCS
+ "bitfield_test.cc"
+ DEPS
+ ::core_headers
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
+
iree_cc_library(
NAME
arena
@@ -73,41 +83,28 @@
iree_cc_library(
NAME
- atomics
+ atomic_slist
HDRS
- "atomics.h"
+ "atomic_slist.h"
+ SRCS
+ "atomic_slist.c"
DEPS
- ::target_platform
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- bitfield
- HDRS
- "bitfield.h"
- DEPS
- absl::span
+ ::core_headers
+ ::synchronization
PUBLIC
)
iree_cc_test(
NAME
- bitfield_test
+ atomic_slist_test
SRCS
- "bitfield_test.cc"
+ "atomic_slist_test.cc"
DEPS
- ::bitfield
- absl::core_headers
+ ::atomic_slist
iree::testing::gtest
iree::testing::gtest_main
)
-iree_select_compiler_opts(_DYNAMIC_LIBRARY_LINKOPTS
- CLANG_OR_GCC
- "-ldl"
-)
-
iree_cc_library(
NAME
dynamic_library
@@ -116,13 +113,11 @@
SRCS
"dynamic_library_posix.cc"
"dynamic_library_win32.cc"
- LINKOPTS
- ${_DYNAMIC_LIBRARY_LINKOPTS}
DEPS
+ ::core_headers
::file_path
::logging
::status
- ::target_platform
::tracing
absl::memory
absl::span
@@ -130,59 +125,14 @@
PUBLIC
)
-# TODO(scotttodd): clean up bazel_to_cmake handling here
-# * this is a cc_binary in Bazel, but `linkshared` fits iree_cc_library better
-# * the output file name is platform-specific, get it with $<TARGET_FILE:>
-iree_cc_library(
- NAME
- dynamic_library_test_library.so
- OUT
- dynamic_library_test_library.so
- SRCS
- "dynamic_library_test_library.cc"
- TESTONLY
- SHARED
-)
-
-iree_cc_embed_data(
- NAME
- dynamic_library_test_library
- GENERATED_SRCS
- "$<TARGET_FILE:iree::base::dynamic_library_test_library.so>"
- CC_FILE_OUTPUT
- "dynamic_library_test_library_embed.cc"
- H_FILE_OUTPUT
- "dynamic_library_test_library_embed.h"
- TESTONLY
- CPP_NAMESPACE
- "iree"
- FLATTEN
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- dynamic_library_test
- SRCS
- "dynamic_library_test.cc"
- DEPS
- ::dynamic_library
- ::dynamic_library_test_library
- ::file_io
- ::status
- ::target_platform
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
iree_cc_library(
NAME
file_io
HDRS
"file_io.h"
DEPS
+ ::core_headers
::status
- ::target_platform
absl::memory
absl::span
absl::strings
@@ -268,47 +218,38 @@
iree_cc_library(
NAME
- flatbuffer_util
+ flags
HDRS
- "flatbuffer_util.h"
+ "flags.h"
SRCS
- "flatbuffer_util.cc"
+ "flags.cc"
DEPS
- ::file_mapping
- ::memory
- ::ref_ptr
- ::status
- ::tracing
- absl::memory
- absl::optional
- absl::span
- absl::strings
- flatbuffers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- init
- HDRS
- "init.h"
- SRCS
- "init.cc"
- DEPS
+ ::api
absl::flags_parse
- ::initializer
PUBLIC
)
iree_cc_library(
NAME
- initializer
+ flatcc
HDRS
- "initializer.h"
- SRCS
- "initializer.cc"
+ "flatcc.h"
DEPS
- ::target_platform
+ ::flatcc_dummy
+ flatcc::runtime
+ PUBLIC
+)
+
+flatbuffer_c_library(
+ NAME
+ flatcc_dummy
+ SRCS
+ "flatcc.fbs"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
@@ -363,28 +304,8 @@
"main_posix.cc"
"main_win32.cc"
DEPS
+ ::core_headers
::logging
- ::target_platform
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- math
- HDRS
- "math.h"
- DEPS
- absl::core_headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- memory
- HDRS
- "memory.h"
- DEPS
- absl::span
PUBLIC
)
@@ -459,16 +380,6 @@
iree::testing::gtest_main
)
-if(NOT ANDROID)
- iree_select_compiler_opts(_SYNCHRONIZATION_LINKOPTS
- CLANG_OR_GCC
- "-lpthread"
- )
-else()
- # Android provides its own pthreads support with no linking required.
- set(_SYNCHRONIZATION_LINKOPTS "")
-endif()
-
iree_cc_library(
NAME
synchronization
@@ -476,12 +387,9 @@
"synchronization.h"
SRCS
"synchronization.c"
- LINKOPTS
- ${_SYNCHRONIZATION_LINKOPTS}
DEPS
::api
- ::atomics
- ::target_platform
+ ::core_headers
::tracing
PUBLIC
)
@@ -510,14 +418,6 @@
iree_cc_library(
NAME
- target_platform
- HDRS
- "target_platform.h"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
time
HDRS
"time.h"
@@ -537,20 +437,6 @@
iree::testing::gtest_main
)
-if(NOT ANDROID)
- iree_select_compiler_opts(_THREADING_LINKOPTS
- CLANG_OR_GCC
- "-ldl"
- "-lpthread"
- )
-else()
- iree_select_compiler_opts(_THREADING_LINKOPTS
- CLANG_OR_GCC
- "-ldl"
- # Android provides its own pthreads support with no linking required.
- )
-endif()
-
iree_cc_library(
NAME
threading
@@ -562,15 +448,10 @@
"threading_impl.h"
"threading_pthreads.c"
"threading_win32.c"
- COPTS
- "-D_GNU_SOURCE=1"
- LINKOPTS
- ${_THREADING_LINKOPTS}
DEPS
::api
- ::atomics
+ ::core_headers
::synchronization
- ::target_platform
::tracing
PUBLIC
)
@@ -599,11 +480,7 @@
iree::testing::gtest_main
)
-iree_select_compiler_opts(IREE_LINKOPTS_TRACING
- GCC_OR_CLANG
- -ldl
-)
-
+# TODO(benvanik): redirect to internal/tracing/ or something.
if(${IREE_ENABLE_RUNTIME_TRACING})
iree_cc_library(
NAME
@@ -614,10 +491,8 @@
"${IREE_ROOT_DIR}/third_party/tracy/TracyC.h"
SRCS
"tracing.cc"
- LINKOPTS
- ${IREE_LINKOPTS_TRACING}
DEPS
- ::target_platform
+ ::core_headers
absl::core_headers
DEFINES
# TODO(#2114): Change the mode to 2.
@@ -636,36 +511,34 @@
)
endif()
-# TODO(benvanik): get wait_handle ported to win32.
-# iree_cc_library(
-# NAME
-# wait_handle
-# HDRS
-# "wait_handle.h"
-# SRCS
-# "wait_handle.cc"
-# DEPS
-# absl::base
-# absl::fixed_array
-# absl::span
-# absl::strings
-# absl::time
-# iree::base::logging
-# iree::base::ref_ptr
-# iree::base::status
-# iree::base::time
-# PUBLIC
-# )
-#
-# iree_cc_test(
-# NAME
-# wait_handle_test
-# SRCS
-# "wait_handle_test.cc"
-# DEPS
-# absl::time
-# iree::base::status
-# iree::base::wait_handle
-# iree::testing::gtest
-# iree::testing::gtest_main
-# )
+iree_cc_library(
+ NAME
+ wait_handle
+ HDRS
+ "wait_handle.h"
+ SRCS
+ "wait_handle.c"
+ "wait_handle_epoll.c"
+ "wait_handle_impl.h"
+ "wait_handle_kqueue.c"
+ "wait_handle_poll.c"
+ "wait_handle_posix.c"
+ "wait_handle_posix.h"
+ "wait_handle_win32.c"
+ DEPS
+ ::api
+ ::core_headers
+ ::tracing
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ wait_handle_test
+ SRCS
+ "wait_handle_test.cc"
+ DEPS
+ ::wait_handle
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
diff --git a/iree/base/alignment.h b/iree/base/alignment.h
index fe9cfaf..024a334 100644
--- a/iree/base/alignment.h
+++ b/iree/base/alignment.h
@@ -18,12 +18,22 @@
#ifndef IREE_BASE_ALIGNMENT_H_
#define IREE_BASE_ALIGNMENT_H_
+#include <stddef.h>
+
#include "iree/base/target_platform.h"
#ifdef __cplusplus
extern "C" {
#endif
+// https://en.cppreference.com/w/c/types/max_align_t
+#if defined(IREE_PLATFORM_WINDOWS)
+// NOTE: 16 is a specified Microsoft API requirement for some functions.
+#define iree_max_align_t 16
+#else
+#define iree_max_align_t sizeof(long double)
+#endif // IREE_PLATFORM_*
+
// https://en.cppreference.com/w/c/language/_Alignas
// https://en.cppreference.com/w/c/language/_Alignof
#if defined(IREE_COMPILER_MSVC)
diff --git a/iree/base/api.c b/iree/base/api.c
index cc6eee6..47c5d68 100644
--- a/iree/base/api.c
+++ b/iree/base/api.c
@@ -15,6 +15,7 @@
#include "iree/base/api.h"
#include <assert.h>
+#include <errno.h>
#include <limits.h>
#include <stdarg.h>
#include <stdio.h>
@@ -215,6 +216,182 @@
}
//===----------------------------------------------------------------------===//
+// iree_status_t canonical errors
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_code_t
+iree_status_code_from_errno(int error_number) {
+ switch (error_number) {
+ case 0:
+ return IREE_STATUS_OK;
+ case EINVAL: // Invalid argument
+ case ENAMETOOLONG: // Filename too long
+ case E2BIG: // Argument list too long
+ case EDESTADDRREQ: // Destination address required
+ case EDOM: // Mathematics argument out of domain of function
+ case EFAULT: // Bad address
+ case EILSEQ: // Illegal byte sequence
+ case ENOPROTOOPT: // Protocol not available
+ case ENOSTR: // Not a STREAM
+ case ENOTSOCK: // Not a socket
+ case ENOTTY: // Inappropriate I/O control operation
+ case EPROTOTYPE: // Protocol wrong type for socket
+ case ESPIPE: // Invalid seek
+ return IREE_STATUS_INVALID_ARGUMENT;
+ case ETIMEDOUT: // Connection timed out
+ case ETIME: // Timer expired
+ return IREE_STATUS_DEADLINE_EXCEEDED;
+ case ENODEV: // No such device
+ case ENOENT: // No such file or directory
+#ifdef ENOMEDIUM
+ case ENOMEDIUM: // No medium found
+#endif
+ case ENXIO: // No such device or address
+ case ESRCH: // No such process
+ return IREE_STATUS_NOT_FOUND;
+ case EEXIST: // File exists
+ case EADDRNOTAVAIL: // Address not available
+ case EALREADY: // Connection already in progress
+#ifdef ENOTUNIQ
+ case ENOTUNIQ: // Name not unique on network
+#endif
+ return IREE_STATUS_ALREADY_EXISTS;
+ case EPERM: // Operation not permitted
+ case EACCES: // Permission denied
+#ifdef ENOKEY
+ case ENOKEY: // Required key not available
+#endif
+ case EROFS: // Read only file system
+ return IREE_STATUS_PERMISSION_DENIED;
+ case ENOTEMPTY: // Directory not empty
+ case EISDIR: // Is a directory
+ case ENOTDIR: // Not a directory
+ case EADDRINUSE: // Address already in use
+ case EBADF: // Invalid file descriptor
+#ifdef EBADFD
+ case EBADFD: // File descriptor in bad state
+#endif
+ case EBUSY: // Device or resource busy
+ case ECHILD: // No child processes
+ case EISCONN: // Socket is connected
+#ifdef EISNAM
+ case EISNAM: // Is a named type file
+#endif
+#ifdef ENOTBLK
+ case ENOTBLK: // Block device required
+#endif
+ case ENOTCONN: // The socket is not connected
+ case EPIPE: // Broken pipe
+#ifdef ESHUTDOWN
+ case ESHUTDOWN: // Cannot send after transport endpoint shutdown
+#endif
+ case ETXTBSY: // Text file busy
+#ifdef EUNATCH
+ case EUNATCH: // Protocol driver not attached
+#endif
+ return IREE_STATUS_FAILED_PRECONDITION;
+ case ENOSPC: // No space left on device
+#ifdef EDQUOT
+ case EDQUOT: // Disk quota exceeded
+#endif
+ case EMFILE: // Too many open files
+ case EMLINK: // Too many links
+ case ENFILE: // Too many open files in system
+ case ENOBUFS: // No buffer space available
+ case ENODATA: // No message is available on the STREAM read queue
+ case ENOMEM: // Not enough space
+ case ENOSR: // No STREAM resources
+#ifdef EUSERS
+ case EUSERS: // Too many users
+#endif
+ return IREE_STATUS_RESOURCE_EXHAUSTED;
+#ifdef ECHRNG
+ case ECHRNG: // Channel number out of range
+#endif
+ case EFBIG: // File too large
+ case EOVERFLOW: // Value too large to be stored in data type
+ case ERANGE: // Result too large
+ return IREE_STATUS_OUT_OF_RANGE;
+#ifdef ENOPKG
+ case ENOPKG: // Package not installed
+#endif
+ case ENOSYS: // Function not implemented
+ case ENOTSUP: // Operation not supported
+ case EAFNOSUPPORT: // Address family not supported
+#ifdef EPFNOSUPPORT
+ case EPFNOSUPPORT: // Protocol family not supported
+#endif
+ case EPROTONOSUPPORT: // Protocol not supported
+#ifdef ESOCKTNOSUPPORT
+ case ESOCKTNOSUPPORT: // Socket type not supported
+#endif
+ case EXDEV: // Improper link
+ return IREE_STATUS_UNIMPLEMENTED;
+ case EAGAIN: // Resource temporarily unavailable
+#ifdef ECOMM
+ case ECOMM: // Communication error on send
+#endif
+ case ECONNREFUSED: // Connection refused
+ case ECONNABORTED: // Connection aborted
+ case ECONNRESET: // Connection reset
+ case EINTR: // Interrupted function call
+#ifdef EHOSTDOWN
+ case EHOSTDOWN: // Host is down
+#endif
+ case EHOSTUNREACH: // Host is unreachable
+ case ENETDOWN: // Network is down
+ case ENETRESET: // Connection aborted by network
+ case ENETUNREACH: // Network unreachable
+ case ENOLCK: // No locks available
+ case ENOLINK: // Link has been severed
+#ifdef ENONET
+ case ENONET: // Machine is not on the network
+#endif
+ return IREE_STATUS_UNAVAILABLE;
+ case EDEADLK: // Resource deadlock avoided
+#ifdef ESTALE
+ case ESTALE: // Stale file handle
+#endif
+ return IREE_STATUS_ABORTED;
+ case ECANCELED: // Operation cancelled
+ return IREE_STATUS_CANCELLED;
+ default:
+ return IREE_STATUS_UNKNOWN;
+ }
+}
+
+#if defined(IREE_PLATFORM_WINDOWS)
+IREE_API_EXPORT iree_status_code_t
+iree_status_code_from_win32_error(uint32_t error) {
+ switch (error) {
+ case ERROR_SUCCESS:
+ return IREE_STATUS_OK;
+ case ERROR_FILE_NOT_FOUND:
+ case ERROR_PATH_NOT_FOUND:
+ return IREE_STATUS_NOT_FOUND;
+ case ERROR_TOO_MANY_OPEN_FILES:
+ case ERROR_OUTOFMEMORY:
+ case ERROR_HANDLE_DISK_FULL:
+ case ERROR_HANDLE_EOF:
+ return IREE_STATUS_RESOURCE_EXHAUSTED;
+ case ERROR_ACCESS_DENIED:
+ return IREE_STATUS_PERMISSION_DENIED;
+ case ERROR_INVALID_HANDLE:
+ return IREE_STATUS_INVALID_ARGUMENT;
+ case ERROR_NOT_READY:
+ case ERROR_READ_FAULT:
+ return IREE_STATUS_UNAVAILABLE;
+ case ERROR_WRITE_FAULT:
+ return IREE_STATUS_DATA_LOSS;
+ case ERROR_NOT_SUPPORTED:
+ return IREE_STATUS_UNIMPLEMENTED;
+ default:
+ return IREE_STATUS_UNKNOWN;
+ }
+}
+#endif // IREE_PLATFORM_WINDOWS
+
+//===----------------------------------------------------------------------===//
// iree_status_t
//===----------------------------------------------------------------------===//
@@ -772,11 +949,6 @@
expected_version, actual_version);
}
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_api_init(int* argc,
- char*** argv) {
- return iree_ok_status();
-}
-
//===----------------------------------------------------------------------===//
// iree_time_t and iree_duration_t
//===----------------------------------------------------------------------===//
@@ -816,6 +988,22 @@
return iree_time_now() + timeout_ns;
}
+IREE_API_EXPORT iree_duration_t
+iree_absolute_deadline_to_timeout_ns(iree_time_t deadline_ns) {
+ if (deadline_ns == IREE_TIME_INFINITE_PAST) {
+ return IREE_DURATION_ZERO;
+ } else if (deadline_ns == IREE_TIME_INFINITE_FUTURE) {
+ return IREE_DURATION_INFINITE;
+ } else {
+ // We have either already passed the deadline (and can turn this into a
+ // poll) or want to do nanos->millis. We round up so that a deadline of 1ns
+ // results in 1ms as it should still wait, vs. if it was actually 0ns
+ // indicating the user intended a poll.
+ iree_time_t now_ns = iree_time_now();
+ return deadline_ns < now_ns ? IREE_DURATION_ZERO : deadline_ns - now_ns;
+ }
+}
+
//===----------------------------------------------------------------------===//
// iree_allocator_t
//===----------------------------------------------------------------------===//
diff --git a/iree/base/api.h b/iree/base/api.h
index 61df05c..3b02694 100644
--- a/iree/base/api.h
+++ b/iree/base/api.h
@@ -175,9 +175,9 @@
// `restrict` keyword, not supported by some older compilers.
// We define our own macro in case dependencies use `restrict` differently.
-#if defined _MSC_VER && _MSC_VER >= 1900
+#if defined(_MSC_VER) && _MSC_VER >= 1900
#define IREE_RESTRICT __restrict
-#elif defined _MSC_VER
+#elif defined(_MSC_VER)
#define IREE_RESTRICT
#else
#define IREE_RESTRICT restrict
@@ -627,6 +627,18 @@
IREE_CHECK_EQ(IREE_STATUS_OK, iree_status_consume_code(expr))
#define IREE_ASSERT_ARGUMENT(name) assert(name)
+// Returns the canonical status code for the given errno value.
+// https://en.cppreference.com/w/cpp/error/errno_macros
+IREE_API_EXPORT iree_status_code_t
+iree_status_code_from_errno(int error_number);
+
+#if defined(_WIN32) || defined(_WIN64)
+// Returns the canonical status code for the given Win32 GetLastError code.
+// https://docs.microsoft.com/en-us/windows/win32/api/errhandlingapi/nf-errhandlingapi-getlasterror
+IREE_API_EXPORT iree_status_code_t
+iree_status_code_from_win32_error(uint32_t error);
+#endif // _WIN32 || _WIN64
+
// Returns a NUL-terminated string constant for the given status code, such as
// IREE_STATUS_UNAVAILABLE = "UNAVAILABLE". Do not rely on string-matching the
// result as the exact text may change.
@@ -725,24 +737,6 @@
iree_api_version_check(iree_api_version_t expected_version,
iree_api_version_t* out_actual_version);
-// Initializes IREE for use within a binary.
-//
-// Specifically, this parses any command line flags and performs module
-// initialization (such as for tracing and dynamic driver registration). If
-// your application is certain it does not need this functionality, this call
-// may be skipped.
-//
-// |argc| and |argv| should contain any command line flags to parse.
-// If there are no flags to parse, nullptr may be passed, but this should still
-// be called so other initialization happens.
-//
-// This should typically be called early in some sort of main() function once,
-// before calling most other API functions. Certain core API functions here
-// such as iree_api_version_check, iree_allocator_malloc, and
-// iree_allocator_free are safe to call before this.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_api_init(int* argc,
- char*** argv);
-
//===----------------------------------------------------------------------===//
// iree_time_t and iree_duration_t
//===----------------------------------------------------------------------===//
@@ -776,6 +770,12 @@
IREE_API_EXPORT iree_time_t
iree_relative_timeout_to_deadline_ns(iree_duration_t timeout_ns);
+// Converts an absolute deadline time to a relative timeout duration.
+// This handles the special cases of IREE_TIME_INFINITE_PAST and
+// IREE_TIME_INFINITE_FUTURE to avoid extraneous time queries.
+IREE_API_EXPORT iree_duration_t
+iree_absolute_deadline_to_timeout_ns(iree_time_t deadline_ns);
+
//===----------------------------------------------------------------------===//
// iree_allocator_t (std::allocator-like interface)
//===----------------------------------------------------------------------===//
diff --git a/iree/base/atomic_slist.c b/iree/base/atomic_slist.c
new file mode 100644
index 0000000..c495c7e
--- /dev/null
+++ b/iree/base/atomic_slist.c
@@ -0,0 +1,114 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/atomic_slist.h"
+
+#include <assert.h>
+
+// TODO(benvanik): add TSAN annotations when switched to atomics:
+// https://github.com/gcc-mirror/gcc/blob/master/libsanitizer/include/sanitizer/tsan_interface_atomic.h
+// https://reviews.llvm.org/D18500
+
+void iree_atomic_slist_initialize(iree_atomic_slist_t* out_list) {
+ memset(out_list, 0, sizeof(*out_list));
+ iree_slim_mutex_initialize(&out_list->mutex);
+}
+
+void iree_atomic_slist_deinitialize(iree_atomic_slist_t* list) {
+ // TODO(benvanik): assert empty.
+ iree_slim_mutex_deinitialize(&list->mutex);
+ memset(list, 0, sizeof(*list));
+}
+
+void iree_atomic_slist_concat(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* head,
+ iree_atomic_slist_entry_t* tail) {
+ if (IREE_UNLIKELY(!head)) return;
+ iree_slim_mutex_lock(&list->mutex);
+ tail->next = list->head;
+ list->head = head;
+ iree_slim_mutex_unlock(&list->mutex);
+}
+
+void iree_atomic_slist_push(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* entry) {
+ iree_slim_mutex_lock(&list->mutex);
+ iree_atomic_slist_push_unsafe(list, entry);
+ iree_slim_mutex_unlock(&list->mutex);
+}
+
+void iree_atomic_slist_push_unsafe(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* entry) {
+ // NOTE: no lock is held here and no atomic operation will be used when this
+ // is actually made atomic.
+ entry->next = list->head;
+ list->head = entry;
+}
+
+iree_atomic_slist_entry_t* iree_atomic_slist_pop(iree_atomic_slist_t* list) {
+ iree_slim_mutex_lock(&list->mutex);
+ iree_atomic_slist_entry_t* entry = list->head;
+ list->head = entry ? entry->next : NULL;
+ iree_slim_mutex_unlock(&list->mutex);
+ return entry;
+}
+
+bool iree_atomic_slist_flush(iree_atomic_slist_t* list,
+ iree_atomic_slist_flush_order_t flush_order,
+ iree_atomic_slist_entry_t** out_head,
+ iree_atomic_slist_entry_t** out_tail) {
+ // Exchange list head with NULL to steal the entire list. The list will be in
+ // the native LIFO order of the slist.
+ iree_slim_mutex_lock(&list->mutex);
+ iree_atomic_slist_entry_t* head = list->head;
+ list->head = NULL;
+ iree_slim_mutex_unlock(&list->mutex);
+ if (!head) return false;
+
+ switch (flush_order) {
+ case IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_LIFO: {
+ // List is already in native LIFO order. If the user wants a tail we have
+ // to scan for it, though, which we really only want to do when required
+ // as it's a linked list pointer walk.
+ *out_head = head;
+ if (out_tail) {
+ iree_atomic_slist_entry_t* p = head;
+ while (p->next) p = p->next;
+ *out_tail = p;
+ }
+ break;
+ }
+ case IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_FIFO: {
+ // Reverse the list in a single scan. list_head is our tail, so scan
+ // forward to find our head. Since we have to walk the whole list anyway
+ // we can cheaply give both the head and tail to the caller.
+ iree_atomic_slist_entry_t* tail = head;
+ if (out_tail) *out_tail = tail;
+ iree_atomic_slist_entry_t* p = head;
+ do {
+ iree_atomic_slist_entry_t* next = p->next;
+ p->next = head;
+ head = p;
+ p = next;
+ } while (p != NULL);
+ tail->next = NULL;
+ *out_head = head;
+ break;
+ }
+ default:
+ return false;
+ }
+
+ return true;
+}
diff --git a/iree/base/atomic_slist.h b/iree/base/atomic_slist.h
new file mode 100644
index 0000000..f6fb25b
--- /dev/null
+++ b/iree/base/atomic_slist.h
@@ -0,0 +1,264 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// NOTE: the best kind of synchronization is no synchronization; always try to
+// design your algorithm so that you don't need anything from this file :)
+// See https://travisdowns.github.io/blog/2020/07/06/concurrency-costs.html
+
+#ifndef IREE_BASE_ATOMIC_SLIST_H_
+#define IREE_BASE_ATOMIC_SLIST_H_
+
+#include <stddef.h>
+
+#include "iree/base/alignment.h"
+#include "iree/base/atomics.h"
+#include "iree/base/synchronization.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// The embedded pointer to the next entry in the slist. This points to the
+// internal iree_atomic_slist_entry_t, *not* the user-provided pointer.
+typedef void* iree_atomic_slist_intrusive_ptr_t;
+
+// DO NOT USE: implementation detail.
+typedef struct iree_atomic_slist_entry_s {
+ struct iree_atomic_slist_entry_s* next;
+} iree_atomic_slist_entry_t;
+
+// Lightweight contention-avoiding singly linked list.
+// This models optimistically-ordered LIFO behavior (stack push/pop) using
+// atomic primitives.
+//
+// ***************************************************
+// ******** ONLY APPROXIMATE ORDER GUARANTEES ********
+// ***************************************************
+//
+// This makes it extremely efficient for when only eventual consistency across
+// producers and consumers is required. The most common example is free lists
+// where all that matters is that entries make it into the list and not that
+// they have any particular order between them. Work queues where all tasks
+// within the queue are able to execute in any order like with wavefront-style
+// scheduling can also benefit from this relaxed behavior.
+//
+// If a strict ordering is required this can be used as a primitive to construct
+// a flat-combining data structure where data structure change requests are
+// published to this list and a combiner is chosen to land the published data in
+// an appropriate order:
+// http://people.csail.mit.edu/shanir/publications/Flat%20Combining%20SPAA%2010.pdf
+//
+// There's often still a benefit in unordered scenarios of having LIFO behavior
+// as it promotes cache-friendly small linked lists when there is a small number
+// of producers and consumers (1:1 is the best case), though as the producer and
+// consumer count increases the LIFO behavior can pessimize performance as there
+// is more contention for the list head pointer. Prefer to shard across multiple
+// per-core/thread lists and use techniques like flat-combining for the
+// cross-core/thread aggregation/sequencing.
+//
+// This API modeled roughly on the Windows SList type:
+// https://docs.microsoft.com/en-us/windows/win32/sync/interlocked-singly-linked-lists
+// which is roughly compatible with the Apple OSAtomic queue:
+// https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man3/OSAtomicEnqueue.3.html
+// https://opensource.apple.com/source/libplatform/libplatform-125/include/libkern/OSAtomicQueue.h.auto.html
+//
+// Usage:
+// https://docs.microsoft.com/en-us/windows/win32/sync/using-singly-linked-lists
+//
+// WARNING: this is an extremely sharp pufferfish-esque API. Don't use it. 🐡
+//
+// TODO(benvanik): verify behavior (and worthwhileness) of supporting platform
+// primitives. The benefit of something like OSAtomicEnqueue/Dequeue is that it
+// may have better tooling (TSAN), special intrinsic handling in the compiler,
+// etc. That said, the Windows Interlocked* variants don't seem to. Having a
+// single heavily tested implementation seems more worthwhile than several.
+typedef iree_alignas(iree_max_align_t) struct {
+ // TODO(benvanik): spend some time golfing this. Unblocking myself for now :)
+ iree_slim_mutex_t mutex;
+ iree_atomic_slist_entry_t* head;
+} iree_atomic_slist_t;
+
+// Initializes an slist handle to an empty list.
+// Lists must be flushed to empty and deinitialized when no longer needed with
+// iree_atomic_slist_deinitialize.
+//
+// NOTE: not thread-safe; existing |out_list| contents are discarded.
+void iree_atomic_slist_initialize(iree_atomic_slist_t* out_list);
+
+// Deinitializes an slist.
+// The list must be empty; callers are expected to flush the list from the same
+// thread making this call when it is guaranteed no other thread may be trying
+// to use the list.
+//
+// NOTE: not thread-safe; |list| must not be used by any other thread.
+void iree_atomic_slist_deinitialize(iree_atomic_slist_t* list);
+
+// Concatenates a span of entries into the list in the order they are provided.
+//
+// Example:
+// existing slist: C B A
+// provided span: 1 2 3
+// resulting slist: 1 2 3 C B A
+void iree_atomic_slist_concat(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* head,
+ iree_atomic_slist_entry_t* tail);
+
+// Pushes an entry into the list.
+//
+// existing slist: C B A
+// provided entry: 1
+// resulting slist: 1 C B A
+void iree_atomic_slist_push(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* entry);
+
+// Pushes an entry into the list without using an atomic update.
+// This is useful for when |list| is known to be inaccessible to any other
+// thread, such as when populating a stack-local list prior to sharing it.
+void iree_atomic_slist_push_unsafe(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* entry);
+
+// Pops the most recently pushed entry from the list and returns it.
+// Returns NULL if the list was empty at the time it was queried.
+//
+// existing slist: C B A
+// resulting slist: B A
+// returned entry: C
+iree_atomic_slist_entry_t* iree_atomic_slist_pop(iree_atomic_slist_t* list);
+
+// Defines the approximate order in which a span of flushed entries is returned.
+typedef uint32_t iree_atomic_slist_flush_order_t;
+enum {
+ // |out_head| and |out_tail| will be set to a span of the entries roughly in
+ // the order they were pushed to the list in LIFO (stack) order.
+ //
+ // Example:
+ // slist: C B A
+ // result: C B A (or when contended possibly C A B)
+ IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_LIFO,
+ // |out_head| and |out_tail| will be set to the first and last entries
+ // pushed respectively, turning this LIFO slist into a FIFO queue.
+ //
+ // Example:
+ // slist: C B A
+ // result: A B C (or when contended possibly B A C)
+ IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_FIFO,
+};
+
+// Removes all items from the list and returns them in **APPROXIMATELY** the
+// |flush_order| requested. As there are no order guarantees there may be slight
+// transpositions of entries that were pushed from multiple processors or even
+// interleaved entries within spans of entries pushed with
+// iree_atomic_slist_concat.
+//
+// If |out_tail| is not required it can be omitted and this may avoid the
+// need for the flush to walk the list and touch each entry.
+//
+// Returns true if any items were present and false if the output list is empty.
+// Note that because atomic data structures can race it's possible for there to
+// both be something in the list prior to this call and something in the list
+// after the call and yet the return can still be false.
+bool iree_atomic_slist_flush(iree_atomic_slist_t* list,
+ iree_atomic_slist_flush_order_t flush_order,
+ iree_atomic_slist_entry_t** out_head,
+ iree_atomic_slist_entry_t** out_tail);
+
+//==============================================================================
+// Typed wrapper generator for iree_atomic_slist_t
+//==============================================================================
+
+// Typed and named wrappers for making atomic slists easier to work with.
+//
+// Usage:
+// typedef struct {
+// int some_fields;
+// iree_atomic_slist_intrusive_ptr_t slist_next;
+// int more_fields;
+// } my_type_t;
+// IREE_TYPED_ATOMIC_SLIST_WRAPPER(my_type, my_type_t,
+// offsetof(my_type_t, slist_next));
+//
+// my_type_slist_t list;
+// my_type_slist_initialize(&list);
+// my_type_t* entry = allocate_my_type(123);
+// my_type_slist_push(&list, entry);
+// entry = my_type_slist_pop(&list);
+#define IREE_TYPED_ATOMIC_SLIST_WRAPPER(name, type, next_offset) \
+ static inline iree_atomic_slist_entry_t* name##_slist_entry_from_ptr( \
+ type* entry) { \
+ return entry \
+ ? ((iree_atomic_slist_entry_t*)((uint8_t*)entry + next_offset)) \
+ : NULL; \
+ } \
+ static inline type* name##_slist_entry_to_ptr( \
+ iree_atomic_slist_entry_t* entry) { \
+ return entry ? (type*)(((uint8_t*)entry) - next_offset) : NULL; \
+ } \
+ \
+ static inline type* name##_slist_get_next(type* entry) { \
+ if (!entry) return NULL; \
+ return name##_slist_entry_to_ptr( \
+ ((iree_atomic_slist_entry_t*)((uint8_t*)entry + next_offset))->next); \
+ } \
+ static inline void name##_slist_set_next(type* entry, type* next) { \
+ name##_slist_entry_from_ptr(entry)->next = \
+ name##_slist_entry_from_ptr(next); \
+ } \
+ \
+ typedef iree_alignas(iree_max_align_t) struct { \
+ iree_atomic_slist_t impl; \
+ } name##_slist_t; \
+ \
+ static inline void name##_slist_initialize(name##_slist_t* out_list) { \
+ iree_atomic_slist_initialize(&out_list->impl); \
+ } \
+ static inline void name##_slist_deinitialize(name##_slist_t* list) { \
+ iree_atomic_slist_deinitialize(&list->impl); \
+ } \
+ \
+ static inline void name##_slist_push(name##_slist_t* list, type* entry) { \
+ iree_atomic_slist_push(&list->impl, name##_slist_entry_from_ptr(entry)); \
+ } \
+ static inline void name##_slist_push_unsafe(name##_slist_t* list, \
+ type* entry) { \
+ iree_atomic_slist_push_unsafe(&list->impl, \
+ name##_slist_entry_from_ptr(entry)); \
+ } \
+ static inline void name##_slist_concat(name##_slist_t* list, type* head, \
+ type* tail) { \
+ iree_atomic_slist_concat(&list->impl, name##_slist_entry_from_ptr(head), \
+ name##_slist_entry_from_ptr(tail)); \
+ } \
+ static inline type* name##_slist_pop(name##_slist_t* list) { \
+ return name##_slist_entry_to_ptr(iree_atomic_slist_pop(&list->impl)); \
+ } \
+ \
+ static inline bool name##_slist_flush( \
+ name##_slist_t* list, iree_atomic_slist_flush_order_t flush_order, \
+ type** out_head, type** out_tail) { \
+ iree_atomic_slist_entry_t* head = NULL; \
+ iree_atomic_slist_entry_t* tail = NULL; \
+ if (!iree_atomic_slist_flush(&list->impl, flush_order, &head, \
+ out_tail ? &tail : NULL)) { \
+ return false; /* empty list */ \
+ } \
+ *out_head = name##_slist_entry_to_ptr(head); \
+ if (out_tail) *out_tail = name##_slist_entry_to_ptr(tail); \
+ return true; \
+ }
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // IREE_BASE_ATOMIC_SLIST_H_
diff --git a/iree/base/atomic_slist_test.cc b/iree/base/atomic_slist_test.cc
new file mode 100644
index 0000000..8b32f7a
--- /dev/null
+++ b/iree/base/atomic_slist_test.cc
@@ -0,0 +1,191 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/atomic_slist.h"
+
+#include "iree/testing/gtest.h"
+
+namespace {
+
+struct dummy_entry_t {
+ // NOTE: we purposefully offset the entry pointer
+ size_t value = 0;
+ iree_atomic_slist_intrusive_ptr_t slist_next = NULL;
+};
+IREE_TYPED_ATOMIC_SLIST_WRAPPER(dummy, dummy_entry_t,
+ offsetof(dummy_entry_t, slist_next));
+
+std::vector<dummy_entry_t> MakeDummySListItems(size_t base_index,
+ size_t count) {
+ std::vector<dummy_entry_t> items(count);
+ for (size_t i = 0; i < count; ++i) {
+ items[i].value = base_index + i;
+ }
+ return items;
+}
+
+TEST(AtomicSList, Lifetime) {
+ iree_atomic_slist_t list; // NOTE: intentionally uninitialized.
+ iree_atomic_slist_initialize(&list);
+ iree_atomic_slist_deinitialize(&list);
+}
+
+TEST(AtomicSList, BasicUsage) {
+ dummy_slist_t list;
+ dummy_slist_initialize(&list);
+
+ // List starts empty.
+ EXPECT_EQ(NULL, dummy_slist_pop(&list));
+
+ // Push some items into the list (LIFO order).
+ // New contents: 5 4 3 2 1 0
+ auto item_storage = MakeDummySListItems(0, 6);
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ dummy_slist_push(&list, &item_storage[i]);
+ }
+
+ // Now pop them out - they should be in reverse order.
+ // New contents: e
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ dummy_entry_t* p = dummy_slist_pop(&list);
+ ASSERT_TRUE(p);
+ EXPECT_EQ(item_storage.size() - i - 1, p->value);
+ }
+
+ // List ends empty.
+ EXPECT_EQ(NULL, dummy_slist_pop(&list));
+
+ dummy_slist_deinitialize(&list);
+}
+
+TEST(AtomicSList, Concat) {
+ dummy_slist_t list;
+ dummy_slist_initialize(&list);
+
+ // Push some initial items into the list (LIFO order).
+ // New contents: 1 0
+ auto initial_item_storage = MakeDummySListItems(0, 2);
+ for (size_t i = 0; i < initial_item_storage.size(); ++i) {
+ dummy_slist_push(&list, &initial_item_storage[i]);
+ }
+
+ // Stitch items together modeling what a user may do when building the list
+ // themselves.
+ // Items: 2 3 4
+ auto span_item_storage = MakeDummySListItems(2, 3);
+ for (size_t i = 0; i < span_item_storage.size() - 1; ++i) {
+ dummy_slist_set_next(&span_item_storage[i], &span_item_storage[i + 1]);
+ }
+
+ // Push all of the items to the list at once.
+ // New contents: 2 3 4 1 0
+ dummy_slist_concat(&list, &span_item_storage.front(),
+ &span_item_storage.back());
+
+ // Pop the span items and verify they are in the correct order: we effectively
+ // pushed them such that popping is FIFO (2->4).
+ // New contents: 1 0
+ for (size_t i = 0; i < span_item_storage.size(); ++i) {
+ dummy_entry_t* p = dummy_slist_pop(&list);
+ ASSERT_TRUE(p);
+ EXPECT_EQ(/*base_index=*/2 + i, p->value);
+ }
+
+ // Pop the initial items and ensure they survived.
+ // New contents: e
+ for (size_t i = 0; i < initial_item_storage.size(); ++i) {
+ dummy_entry_t* p = dummy_slist_pop(&list);
+ ASSERT_TRUE(p);
+ EXPECT_EQ(initial_item_storage.size() - i - 1, p->value);
+ }
+
+ dummy_slist_deinitialize(&list);
+}
+
+TEST(AtomicSList, FlushLIFO) {
+ dummy_slist_t list;
+ dummy_slist_initialize(&list);
+
+ // Flushing when empty is ok.
+ dummy_entry_t* head = NULL;
+ dummy_entry_t* tail = NULL;
+ EXPECT_FALSE(dummy_slist_flush(
+ &list, IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_LIFO, &head, &tail));
+
+ // Push items into the list (LIFO order).
+ // New contents: 3 2 1 0
+ auto item_storage = MakeDummySListItems(0, 4);
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ dummy_slist_push(&list, &item_storage[i]);
+ }
+
+ // Flush in LIFO order and verify empty.
+ // New contents: e
+ EXPECT_TRUE(dummy_slist_flush(
+ &list, IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_LIFO, &head, &tail));
+ EXPECT_EQ(NULL, dummy_slist_pop(&list));
+
+ // Verify LIFO order and list pointer walking.
+ // Note that head and tail are reverse of item storage!
+ EXPECT_EQ(&item_storage.back(), head);
+ EXPECT_EQ(&item_storage.front(), tail);
+ dummy_entry_t* p = head;
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ ASSERT_TRUE(p);
+ EXPECT_EQ(item_storage.size() - i - 1, p->value);
+ p = dummy_slist_get_next(p);
+ }
+ EXPECT_EQ(NULL, p);
+
+ dummy_slist_deinitialize(&list);
+}
+
+TEST(AtomicSList, FlushFIFO) {
+ dummy_slist_t list;
+ dummy_slist_initialize(&list);
+
+ // Flushing when empty is ok.
+ dummy_entry_t* head = NULL;
+ dummy_entry_t* tail = NULL;
+ EXPECT_FALSE(dummy_slist_flush(
+ &list, IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_FIFO, &head, &tail));
+
+ // Push items into the list (LIFO order).
+ // New contents: 3 2 1 0
+ auto item_storage = MakeDummySListItems(0, 4);
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ dummy_slist_push(&list, &item_storage[i]);
+ }
+
+ // Flush in FIFO order and verify empty.
+ // New contents: e
+ EXPECT_TRUE(dummy_slist_flush(
+ &list, IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_FIFO, &head, &tail));
+ EXPECT_EQ(NULL, dummy_slist_pop(&list));
+
+ // Verify FIFO order and list pointer walking.
+ EXPECT_EQ(&item_storage.front(), head);
+ EXPECT_EQ(&item_storage.back(), tail);
+ dummy_entry_t* p = head;
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ ASSERT_TRUE(p);
+ EXPECT_EQ(i, p->value);
+ p = dummy_slist_get_next(p);
+ }
+ EXPECT_EQ(NULL, p);
+
+ dummy_slist_deinitialize(&list);
+}
+
+} // namespace
diff --git a/iree/base/atomics.h b/iree/base/atomics.h
index b97d803..38b85a3 100644
--- a/iree/base/atomics.h
+++ b/iree/base/atomics.h
@@ -266,6 +266,34 @@
#endif // iree_atomic_load_auto
//==============================================================================
+// Pointer-width atomics
+//==============================================================================
+
+#if defined(IREE_PTR_SIZE_32)
+typedef iree_atomic_int32_t iree_atomic_ptr_t;
+#define iree_atomic_load_ptr iree_atomic_load_int32
+#define iree_atomic_store_ptr iree_atomic_store_int32
+#define iree_atomic_fetch_add_ptr iree_atomic_fetch_add_int32
+#define iree_atomic_fetch_sub_ptr iree_atomic_fetch_sub_int32
+#define iree_atomic_exchange_ptr iree_atomic_exchange_int32
+#define iree_atomic_compare_exchange_strong_ptr \
+ iree_atomic_compare_exchange_strong_int32
+#define iree_atomic_compare_exchange_weak_ptr \
+ iree_atomic_compare_exchange_weak_int32
+#else
+typedef iree_atomic_int64_t iree_atomic_ptr_t;
+#define iree_atomic_load_ptr iree_atomic_load_int64
+#define iree_atomic_store_ptr iree_atomic_store_int64
+#define iree_atomic_fetch_add_ptr iree_atomic_fetch_add_int64
+#define iree_atomic_fetch_sub_ptr iree_atomic_fetch_sub_int64
+#define iree_atomic_exchange_ptr iree_atomic_exchange_int64
+#define iree_atomic_compare_exchange_strong_ptr \
+ iree_atomic_compare_exchange_strong_int64
+#define iree_atomic_compare_exchange_weak_ptr \
+ iree_atomic_compare_exchange_weak_int64
+#endif // IREE_PTR_SIZE_32
+
+//==============================================================================
// Reference count atomics
//==============================================================================
// These are just aliases that allow use to have nicely readable ref counting
diff --git a/iree/base/flags.cc b/iree/base/flags.cc
new file mode 100644
index 0000000..4edd463
--- /dev/null
+++ b/iree/base/flags.cc
@@ -0,0 +1,54 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/flags.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+// TODO(#3814): replace abseil with pretty much anything else.
+#include "absl/flags/parse.h"
+
+iree_status_t iree_flags_parse(int* argc, char*** argv) {
+ if (argc == nullptr || argv == nullptr || *argc == 0) {
+ // No flags; that's fine - in some environments flags aren't supported.
+ return iree_ok_status();
+ }
+
+ auto positional_args = absl::ParseCommandLine(*argc, *argv);
+ if (positional_args.size() < *argc) {
+ // Edit the passed argument refs to only include positional args.
+ *argc = static_cast<int>(positional_args.size());
+ for (int i = 0; i < *argc; ++i) {
+ (*argv)[i] = positional_args[i];
+ }
+ (*argv)[*argc + 1] = nullptr;
+ }
+
+ return iree_ok_status();
+}
+
+void iree_flags_parse_checked(int* argc, char*** argv) {
+ iree_status_t status = iree_flags_parse(argc, argv);
+ if (iree_status_is_cancelled(status)) {
+ exit(EXIT_SUCCESS);
+ return;
+ }
+ if (!iree_status_is_ok(status)) {
+ // TODO(#2843): replace C++ logging.
+ iree_status_ignore(status);
+ exit(EXIT_FAILURE);
+ return;
+ }
+}
diff --git a/iree/base/flags.h b/iree/base/flags.h
new file mode 100644
index 0000000..9368f68
--- /dev/null
+++ b/iree/base/flags.h
@@ -0,0 +1,59 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_FLAGS_H_
+#define IREE_BASE_FLAGS_H_
+
+#include "iree/base/api.h"
+
+//===----------------------------------------------------------------------===//
+// Flag parsing
+//===----------------------------------------------------------------------===//
+
+// Parses flags from the given command line arguments.
+// All flag-style arguments ('--foo', '-f', etc) will be consumed and argc/argv
+// will be updated to contain only the program name (index 0) and any remaining
+// positional arguments.
+//
+// Returns success if all flags were parsed and execution should continue.
+// May return IREE_STATUS_CANCELLED if execution should be cancelled gracefully
+// such as when --help is used.
+//
+// Usage:
+// extern "C" int main(int argc, char** argv) {
+// iree_status_t status = iree_flags_parse(&argc, &argv);
+// if (iree_status_is_cancelled(status)) return 0;
+// if (!iree_status_is_ok(status)) {
+// // TODO(#2843): replace C++ logging.
+// LOG(ERROR) << status;
+// iree_status_ignore(status);
+// return 1;
+// }
+// consume_positional_args(argc, argv);
+// return 0;
+// }
+//
+// Example:
+// argc = 4, argv = ['program', 'abc', '--flag=2', '-p']
+// Results:
+// argc = 2, argv = ['program', 'abc']
+iree_status_t iree_flags_parse(int* argc, char*** argv);
+
+// Parses flags as with iree_flags_parse but will use exit() or abort().
+// WARNING: this almost always what you want in a command line tool and *never*
+// what you want when embedded in a host process. You don't want to have a flag
+// typo and shut down your entire server/sandbox/Android app/etc.
+void iree_flags_parse_checked(int* argc, char*** argv);
+
+#endif // IREE_BASE_FLAGS_H_
diff --git a/iree/base/flatbuffer_util.cc b/iree/base/flatbuffer_util.cc
deleted file mode 100644
index 117aea2..0000000
--- a/iree/base/flatbuffer_util.cc
+++ /dev/null
@@ -1,143 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/flatbuffer_util.h"
-
-#include <cerrno>
-#include <cstring>
-
-#include "absl/memory/memory.h"
-#include "iree/base/file_mapping.h"
-#include "iree/base/memory.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-
-FlatBufferFileBase::~FlatBufferFileBase() {
- if (deleter_) {
- deleter_();
- deleter_ = []() {};
- }
-}
-
-Status FlatBufferFileBase::Create(const void* root_ptr,
- std::function<void()> deleter) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::Create");
-
- root_ptr_ = root_ptr;
- deleter_ = std::move(deleter);
-
- return OkStatus();
-}
-
-Status FlatBufferFileBase::CreateWithBackingBuffer(
- const void* root_ptr, ::flatbuffers::DetachedBuffer backing_buffer) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::Create");
-
- root_ptr_ = root_ptr;
-
- // Pass along the buffer provided so we keep it alive until the
- // FlatBufferFileBase is destructed.
- auto backing_buffer_baton = IreeMoveToLambda(backing_buffer);
- deleter_ = [backing_buffer_baton]() { (void)backing_buffer_baton.value; };
-
- return OkStatus();
-}
-
-Status FlatBufferFileBase::Wrap(const void* root_ptr) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::Wrap");
- return Create(root_ptr, []() {});
-}
-
-Status FlatBufferFileBase::FromBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- IREE_TRACE_SCOPE();
-
- // Sanity check buffer for the minimum size as FlatBuffers doesn't.
- if (buffer_data.size() < 16) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Provided serialized flatbuffer buffer is too small to be legit "
- "at size="
- << buffer_data.size();
- }
-
- // Ensure the buffer has the BIPE magic bytes.
- if (identifier.has_value() && !::flatbuffers::BufferHasIdentifier(
- buffer_data.data(), identifier.value())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Provided serialized buffer does not contain the expected type; "
- "magic bytes mismatch (expected "
- << identifier.value() << ")";
- }
-
- // Verify the FlatBuffer contains valid offsets and won't try to read out of
- // bounds of the buffer. We inline a bit of VerifyBufferFromStart so this code
- // can stay generic.
- {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::FromBufferVerification");
- ::flatbuffers::Verifier verifier{buffer_data.data(), buffer_data.size()};
- if (!verifier_fn(identifier.value_or(nullptr), &verifier)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "FlatBuffer failed to verify as expected type; possibly "
- "corrupt input";
- }
- }
-
- // Resolve the root pointer in the buffer.
- // This is GetMutableRoot such that we don't need to know T.
- root_ptr_ = buffer_data.data() +
- ::flatbuffers::EndianScalar(
- *reinterpret_cast<const ::flatbuffers::uoffset_t*>(
- buffer_data.data()));
- if (!root_ptr_) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Unable to resolve root table";
- }
- deleter_ = std::move(deleter);
-
- return OkStatus();
-}
-
-Status FlatBufferFileBase::WrapBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::WrapBuffer");
- return FromBuffer(
- identifier, buffer_data, []() {}, root_type_size, verifier_fn);
-}
-
-Status FlatBufferFileBase::LoadFile(Identifier identifier, std::string path,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::LoadFile");
-
- IREE_ASSIGN_OR_RETURN(auto file_mapping, FileMapping::OpenRead(path));
- auto buffer_data = file_mapping->data();
-
- auto handle_baton = IreeMoveToLambda(file_mapping);
- return FromBuffer(
- identifier, buffer_data,
- [handle_baton]() {
- // Keeping the mmap handle alive.
- (void)handle_baton.value;
- },
- root_type_size, verifier_fn);
-}
-
-} // namespace iree
diff --git a/iree/base/flatbuffer_util.h b/iree/base/flatbuffer_util.h
deleted file mode 100644
index 0e8c81f..0000000
--- a/iree/base/flatbuffer_util.h
+++ /dev/null
@@ -1,328 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_FLATBUFFER_UTIL_H_
-#define IREE_BASE_FLATBUFFER_UTIL_H_
-
-#include <cstddef>
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "absl/types/span.h"
-#include "flatbuffers/flatbuffers.h"
-#include "iree/base/memory.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-
-namespace iree {
-
-// A helper wrapper that moves the wrapped object on copy.
-// This is particularly handy for capturing unique_ptrs in lambdas.
-// Usage example:
-//
-// std::unique_ptr<Foo> foo_ptr(new Foo());
-// move_on_copy<std::unique_ptr<Foo>> foo(std::move(foo_ptr));
-// auto some_lambda = [bar]() { ... }
-//
-template <typename T>
-struct move_on_copy {
- explicit move_on_copy(T&& t) : value(std::move(t)) {}
-
- move_on_copy(move_on_copy const& other) : value(std::move(other.value)) {}
-
- move_on_copy(move_on_copy&& other) : value(std::move(other.value)) {}
-
- move_on_copy& operator=(move_on_copy const& other) {
- value = std::move(other.value);
- return *this;
- }
-
- move_on_copy& operator=(move_on_copy&& other) {
- value = std::move(other.value);
- return *this;
- }
-
- mutable T value;
-};
-
-// Utility to aid in moving ref_ptr's into closures.
-//
-// Usage:
-// auto baton = MoveToLambda(my_ref);
-// DoSomething([baton] () { baton.value; });
-#define IreeMoveToLambda(p) ::iree::move_on_copy<decltype(p)>(std::move(p))
-
-// Wraps a FlatBuffer String in an absl::string_view.
-// Returns empty-string ("") for nullptr values.
-inline absl::string_view WrapString(const ::flatbuffers::String* value) {
- return value ? absl::string_view{value->data(), value->size()} : "";
-}
-
-// Base type for FlatBufferFile<T>. See below.
-class FlatBufferFileBase : public RefObject<FlatBufferFileBase> {
- public:
- using Identifier = absl::optional<const char*>;
-
- virtual ~FlatBufferFileBase();
-
- protected:
- template <typename T>
- friend class FlatBufferFile;
-
- using VerifierFn = bool (*)(const char* identifier,
- ::flatbuffers::Verifier* verifier);
-
- FlatBufferFileBase() = default;
-
- const void* root_ptr() const { return root_ptr_; }
-
- // Redirections of template static methods on FlatBufferFile so we can put the
- // implementations in a shared compilation unit.
- // See FlatBufferFile<T> for doc comments.
- Status Create(const void* root_ptr, std::function<void()> deleter);
- Status CreateWithBackingBuffer(const void* root_ptr,
- ::flatbuffers::DetachedBuffer backing_buffer);
- Status Wrap(const void* root);
- Status FromBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter, size_t root_type_size,
- VerifierFn verifier_fn);
- // Initializes from an STL byte based container (string and vector of
- // char/byte should be compatible).
- template <typename Container>
- Status FromContainer(Identifier identifier, Container container,
- size_t root_type_size, VerifierFn verifier_fn);
- Status WrapBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- size_t root_type_size, VerifierFn verifier_fn);
- Status LoadFile(Identifier identifier, std::string path,
- size_t root_type_size, VerifierFn verifier_fn);
-
- private:
- const void* root_ptr_ = nullptr;
- std::function<void()> deleter_;
-};
-
-// Immutable root FlatBuffer type wrapper with support for loading and backing
-// buffer management.
-//
-// Immutable and thread-safe.
-template <typename T>
-class FlatBufferFile final : public FlatBufferFileBase {
- public:
- // Creates a FlatBufferFile from an in-memory root pointer.
- // The provided |deleter| will be called when the FlatBufferFile is destructed
- // and can be used to deallocate/clean up resources.
- //
- // This assumes that the root pointer has already been verified as valid.
- // If verification is required instead use FromBuffer on the original buffer.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> Create(
- const T* root, std::function<void()> deleter);
-
- // Creates a FlatBufferFile from an in-memory root pointer and the detached
- // backing buffer storing it.
- //
- // Example:
- // FlatBufferBuilder fbb;
- // MyTypeBuilder mtb(fbb);
- // fbb.Finish(mtb.Finish());
- // auto my_type = FlatBufferFile<MyType>::CreateWithBackingBuffer(
- // fbb.Release());
- // my_type->foo();
- static StatusOr<ref_ptr<FlatBufferFile<T>>> CreateWithBackingBuffer(
- ::flatbuffers::DetachedBuffer backing_buffer);
-
- // Wraps a caller-owned in-memory root pointer.
- // The provided |root| must remain valid for the lifetime of the returned
- // FlatBufferFile.
- //
- // This assumes that the root pointer has already been verified as valid.
- // If verification is required instead use FromBuffer on the original buffer.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> Wrap(const T* root);
-
- // Creates a FlatBufferFile wrapping an external data buffer with a deleter
- // function that will be called when the FlatBufferFile is destructed.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter);
-
- // Creates a FlatBufferFile from a serialized data buffer.
- // The FlatBufferFile takes ownership of the vector.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromBuffer(
- Identifier identifier, std::vector<uint8_t> buffer_data);
-
- // Loads a FlatBufferFile from an external buffer owned by the caller.
- // The buffer must remain valid until the Pipeline is destroyed.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> WrapBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data);
-
- // Loads the FlatBufferFile from a serialized byte-based STL container.
- template <typename Container>
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromContainer(
- Identifier identifier, Container buffer_data);
-
- // Loads a FlatBufferFile from a serialized string.
- // The FlatBufferFile takes ownership of the string.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromString(
- Identifier identifier, std::string buffer_data) {
- return FromContainer(identifier, std::move(buffer_data));
- }
-
- // Loads a FlatBufferFile from a serialized byte vector.
- // The FlatBufferFile takes ownership of the vector.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromVector(
- Identifier identifier, std::vector<uint8_t> buffer_data) {
- return FromContainer(identifier, std::move(buffer_data));
- }
-
- // Loads a FlatBufferFile from a serialized file on the file system.
- // This will attempt to mmap the file and is the preferred way of loading as
- // only those pages that contain requested tables will be read.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> LoadFile(Identifier identifier,
- std::string path);
-
- ~FlatBufferFile() override = default;
-
- // Typed root pointer of the file.
- const T* root() const { return reinterpret_cast<const T*>(root_ptr()); }
-
- private:
- FlatBufferFile() = default;
-
- // Conforms to VerifierFn.
- static bool VerifierFnT(const char* identifier,
- ::flatbuffers::Verifier* verifier) {
- return verifier->VerifyBuffer<T>(identifier);
- }
-};
-
-template <typename Container>
-Status FlatBufferFileBase::FromContainer(Identifier identifier,
- Container container,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- static_assert(sizeof(*container.data()) == 1,
- "Expected container of byte sized elements");
- auto buffer_data = absl::MakeConstSpan(
- // Double static_cast through void is safer than reinterpret_cast.
- static_cast<const uint8_t*>(static_cast<const void*>(container.data())),
- container.size());
- // Use a baton to keep the container alive until the FlatBufferFileBase is
- // destroyed.
- auto buffer_data_baton = IreeMoveToLambda(container);
- return FromBuffer(
- identifier, buffer_data,
- [buffer_data_baton]() {
- // Keeping the container alive.
- (void)buffer_data_baton.value;
- },
- root_type_size, verifier_fn);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Create(
- const T* root, std::function<void()> deleter) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(base_file->Create(root, std::move(deleter)));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::CreateWithBackingBuffer(
- ::flatbuffers::DetachedBuffer backing_buffer) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- auto* root_ptr = ::flatbuffers::GetRoot<T>(backing_buffer.data());
- IREE_RETURN_IF_ERROR(
- base_file->CreateWithBackingBuffer(root_ptr, std::move(backing_buffer)));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Wrap(const T* root) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(base_file->Wrap(root));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(base_file->FromBuffer(
- identifier, buffer_data, std::move(deleter), sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer(
- Identifier identifier, std::vector<uint8_t> buffer_data) {
- auto* buffer_data_ptr = new decltype(buffer_data);
- (*buffer_data_ptr) = std::move(buffer_data);
- return FromBuffer(identifier, absl::MakeConstSpan(*buffer_data_ptr),
- [buffer_data_ptr]() { delete buffer_data_ptr; });
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::WrapBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(
- base_file->WrapBuffer(identifier, buffer_data, sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-template <typename Container>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromContainer(
- Identifier identifier, Container buffer_data) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(base_file->FromContainer(
- identifier, std::move(buffer_data), sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::LoadFile(
- Identifier identifier, std::string path) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(
- base_file->LoadFile(identifier, std::move(path), sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-} // namespace iree
-
-#endif // IREE_BASE_FLATBUFFER_UTIL_H_
diff --git a/iree/base/flatcc.fbs b/iree/base/flatcc.fbs
new file mode 100644
index 0000000..590afb4
--- /dev/null
+++ b/iree/base/flatcc.fbs
@@ -0,0 +1,30 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace iree;
+
+// HACK: flatcc public API headers are incomplete and some things only exist
+// when pulled in via generated headers. So here we give ourselves something to
+// include that's always available and cheap.
+//
+// Instead of directly including this file use iree/base/flatcc.h.
+//
+// Normally including any generated file will include the appropriate headers in
+// the required order (as they are non-hermetic), but that requires that we have
+// a generated file. Though most of the API is exposed through the main includes
+// there are various types that only get generated and included by way of the
+// common headers that are not easily included.
+struct __IncludeWorkaround {
+ reserved:int;
+}
diff --git a/iree/base/flatcc.h b/iree/base/flatcc.h
new file mode 100644
index 0000000..0940517
--- /dev/null
+++ b/iree/base/flatcc.h
@@ -0,0 +1,50 @@
+#// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_FLATCC_H_
+#define IREE_BASE_FLATCC_H_
+
+//===----------------------------------------------------------------------===//
+// flatcc include order fixes
+//===----------------------------------------------------------------------===//
+//
+// This header merely wraps the flatcc headers that are generally useful to
+// include in various places that may not know the specific messages they are
+// working with.
+//
+// If using flatcc prefer to include this file over any hard-to-handle flatcc
+// file such as flatbuffers_common_reader.h or flatbuffers_common_builder.h.
+//
+// NOTE: order matters for these includes so stop clang from messing with it:
+// clang-format off
+
+#include "flatcc/reflection/flatbuffers_common_reader.h"
+#include "iree/base/flatcc_reader.h"
+
+#include "flatcc/flatcc_verifier.h"
+#include "iree/base/flatcc_verifier.h"
+
+#include "flatcc/flatcc_builder.h"
+#include "flatcc/reflection/flatbuffers_common_builder.h"
+#include "iree/base/flatcc_builder.h"
+
+#include "flatcc/flatcc_json_parser.h"
+#include "iree/base/flatcc_json_parser.h"
+
+#include "flatcc/flatcc_json_printer.h"
+#include "iree/base/flatcc_json_printer.h"
+
+// clang-format on
+
+#endif // IREE_BASE_FLATCC_H_
diff --git a/iree/base/init.cc b/iree/base/init.cc
deleted file mode 100644
index 33729c0..0000000
--- a/iree/base/init.cc
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/init.h"
-
-#include <string.h>
-
-#include <set>
-
-#include "absl/flags/parse.h"
-#include "iree/base/initializer.h"
-
-namespace iree {
-
-void InitializeEnvironment(int* argc, char*** argv) {
- if (argc != nullptr && argv != nullptr && *argc != 0) {
- auto positional_args = absl::ParseCommandLine(*argc, *argv);
- if (positional_args.size() < *argc) {
- // Edit the passed argument refs to only include positional args.
- *argc = positional_args.size();
- for (int i = 0; i < *argc; ++i) {
- (*argv)[i] = positional_args[i];
- }
- (*argv)[*argc + 1] = nullptr;
- }
- }
-
- IREE_RUN_MODULE_INITIALIZERS();
-}
-
-} // namespace iree
diff --git a/iree/base/init.h b/iree/base/init.h
deleted file mode 100644
index 9c93a93..0000000
--- a/iree/base/init.h
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INIT_H_
-#define IREE_BASE_INIT_H_
-
-// Initialization happens automatically during InitializeEnvironment(), which
-// should be called early in main(), before other code runs.
-
-namespace iree {
-
-// Initializes the system environment in a binary.
-//
-// This first parses command line flags, then resolves module initializers
-// by calling IREE_RUN_MODULE_INITIALIZERS().
-//
-// 'argc' and 'argv' are the command line flags to parse.
-//
-// This should typically be called early in main(), before other code runs.
-void InitializeEnvironment(int* argc, char*** argv);
-
-} // namespace iree
-
-#endif // IREE_BASE_INIT_H_
diff --git a/iree/base/initializer.cc b/iree/base/initializer.cc
deleted file mode 100644
index 5e1e21e..0000000
--- a/iree/base/initializer.cc
+++ /dev/null
@@ -1,100 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/initializer.h"
-
-#include <string.h>
-
-#include <mutex> // NOLINT
-#include <set>
-
-namespace iree {
-
-static Initializer::NameMap* static_name_map = nullptr;
-
-struct Initializer::InitializerData {
- Initializer* initializer_obj;
- std::set<std::string> dependency_names;
-
- InitializerData() : initializer_obj(nullptr) {}
- explicit InitializerData(Initializer* i) : initializer_obj(i) {}
-};
-
-Initializer::DependencyRegisterer::DependencyRegisterer(
- const char* name, Initializer* initializer, const Dependency& dependency) {
- NameMap* name_map = InitializerNameMap();
-
- // Insert 'dependency' into the 'dependency_names' set for 'initializer'.
- InitializerData* initializer_data = &(*name_map)[name];
- initializer_data->dependency_names.insert(dependency.name);
-
- // Ensure that 'dependency' exists in the map.
- InitializerData* dependency_data = &(*name_map)[dependency.name];
- dependency_data->initializer_obj = dependency.initializer;
-}
-
-Initializer::Initializer(const char* name, InitializerFunc function)
- : name_(name), function_(function), done_(false) {
- // Register this Initializer instance (wrapped by an InitializerData) within
- // the static name map.
- NameMap* name_map = InitializerNameMap();
- InitializerData* initializer_data = &(*name_map)[name];
- initializer_data->initializer_obj = this;
-}
-
-void Initializer::RunInitializers() {
- static std::once_flag init_once;
- std::call_once(init_once, &Initializer::InitializeOnceCallback);
-}
-
-void Initializer::InitializeOnceCallback() {
- // Run each registered Initializer, in lexicographic order of their names.
- // Initializer dependencies will be run first as needed.
- NameMap* name_map = InitializerNameMap();
- for (auto& p : *name_map) {
- RunInitializer(&p.second);
- }
-}
-
-void Initializer::Require() {
- NameMap* name_map = InitializerNameMap();
- InitializerData* initializer_data = &(name_map->find(name_)->second);
- RunInitializer(initializer_data);
-}
-
-Initializer::NameMap* Initializer::InitializerNameMap() {
- if (static_name_map == nullptr) {
- static_name_map = new Initializer::NameMap;
- }
- return static_name_map;
-}
-
-void Initializer::RunInitializer(InitializerData* initializer_data) {
- if (initializer_data->initializer_obj->done_) {
- return;
- }
-
- // Run Initializer dependencies first.
- NameMap* name_map = InitializerNameMap();
- for (const auto& dependency_name : initializer_data->dependency_names) {
- auto dep_init = name_map->find(dependency_name);
- RunInitializer(&dep_init->second);
- }
-
- // Finally run the Initializer itself.
- initializer_data->initializer_obj->function_();
- initializer_data->initializer_obj->done_ = true;
-}
-
-} // namespace iree
diff --git a/iree/base/initializer.h b/iree/base/initializer.h
deleted file mode 100644
index b31bc73..0000000
--- a/iree/base/initializer.h
+++ /dev/null
@@ -1,119 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INITIALIZER_H_
-#define IREE_BASE_INITIALIZER_H_
-
-#include <map>
-#include <string>
-
-#include "iree/base/target_platform.h"
-
-namespace iree {
-
-// Initializer macros are defined in this files:
-// IREE_DECLARE_MODULE_INITIALIZER(name)
-// IREE_REGISTER_MODULE_INITIALIZER(name, body)
-// IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2)
-// IREE_REQUIRE_MODULE_INITIALIZED(name)
-// IREE_RUN_MODULE_INITIALIZERS()
-// IREE_REQUIRE_MODULE_LINKED(name)
-//
-// These macros allow for arranging pieces of initialization code to be
-// executed at a well-defined time and in a well-defined order.
-
-// A static instance of this class is declared for each piece of initialization
-// code using the initializer macros.
-class Initializer {
- public:
- typedef void (*InitializerFunc)();
-
- Initializer(const char* name, InitializerFunc function);
-
- // Runs all registered initializers that have not yet run.
- // The initializers are invoked in lexicographically increasing order by name,
- // except as necessary to satisfy dependencies.
- //
- // This is normally called by InitializeEnvironment(), so application code
- // typically should not call it directly.
- static void RunInitializers();
-
- // Runs this initializer if it has not yet run, including any dependencies.
- void Require();
-
- struct Dependency {
- Dependency(const char* n, Initializer* i) : name(n), initializer(i) {}
- const char* const name;
- Initializer* const initializer;
- };
-
- // A static instance of this class is declared for each piece of
- // initializer ordering definition.
- struct DependencyRegisterer {
- DependencyRegisterer(const char* name, Initializer* initializer,
- const Dependency& dependency);
- };
-
- struct InitializerData;
- typedef std::map<std::string, InitializerData> NameMap;
-
- private:
- static NameMap* InitializerNameMap();
- static void RunInitializer(InitializerData* initializer_data);
- static void InitializeOnceCallback();
-
- const std::string name_;
- InitializerFunc function_;
- bool done_;
-};
-
-} // namespace iree
-
-#define IREE_DECLARE_MODULE_INITIALIZER(name) \
- extern ::iree::Initializer iree_initializer_##name
-
-#define IREE_REGISTER_MODULE_INITIALIZER(name, body) \
- static void iree_init_##name() { body; } \
- ::iree::Initializer iree_initializer_##name(#name, iree_init_##name)
-
-#define IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) \
- namespace { \
- static ::iree::Initializer::DependencyRegisterer \
- iree_initializer_dependency_##name1##_##name2( \
- #name2, &iree_initializer_##name2, \
- ::iree::Initializer::Dependency(#name1, &iree_initializer_##name1)); \
- }
-
-#define IREE_REQUIRE_MODULE_INITIALIZED(name) \
- do { \
- IREE_DECLARE_MODULE_INITIALIZER(name); \
- iree_initializer_##name.Require(); \
- } while (0)
-
-#define IREE_RUN_MODULE_INITIALIZERS() \
- do { \
- ::iree::Initializer::RunInitializers(); \
- } while (0)
-
-#if !defined(IREE_COMPILER_MSVC)
-#define IREE_ATTRIBUTE_USED __attribute__((used))
-#else
-#define IREE_ATTRIBUTE_USED
-#endif // IREE_COMPILER_MSVC
-
-#define IREE_REQUIRE_MODULE_LINKED(name) \
- IREE_ATTRIBUTE_USED static ::iree::Initializer* iree_module_ref_##name = \
- &iree_initializer_##name
-
-#endif // IREE_BASE_INITIALIZER_H_
diff --git a/iree/base/internal/BUILD b/iree/base/internal/BUILD
index ae3f2fa..2fe99ee 100644
--- a/iree/base/internal/BUILD
+++ b/iree/base/internal/BUILD
@@ -25,8 +25,8 @@
srcs = ["file_handle_win32.cc"],
hdrs = ["file_handle_win32.h"],
deps = [
+ "//iree/base:core_headers",
"//iree/base:status",
- "//iree/base:target_platform",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -40,10 +40,10 @@
],
deps = [
":file_handle_win32",
+ "//iree/base:core_headers",
"//iree/base:file_io_hdrs",
"//iree/base:file_path",
"//iree/base:status",
- "//iree/base:target_platform",
"//iree/base:tracing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -58,8 +58,8 @@
],
deps = [
":file_handle_win32",
+ "//iree/base:core_headers",
"//iree/base:file_mapping_hdrs",
- "//iree/base:target_platform",
"//iree/base:tracing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -88,9 +88,10 @@
deps = [
":ostringstream",
"//iree/base:api",
+ "//iree/base:core_headers",
"//iree/base:logging",
- "//iree/base:target_platform",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/utility",
],
)
diff --git a/iree/base/internal/CMakeLists.txt b/iree/base/internal/CMakeLists.txt
index 973d297..f9975d7 100644
--- a/iree/base/internal/CMakeLists.txt
+++ b/iree/base/internal/CMakeLists.txt
@@ -24,8 +24,8 @@
DEPS
absl::memory
absl::strings
+ iree::base::core_headers
iree::base::status
- iree::base::target_platform
PUBLIC
)
@@ -39,10 +39,10 @@
::file_handle_win32
absl::memory
absl::strings
+ iree::base::core_headers
iree::base::file_io_hdrs
iree::base::file_path
iree::base::status
- iree::base::target_platform
iree::base::tracing
PUBLIC
)
@@ -57,8 +57,8 @@
::file_handle_win32
absl::memory
absl::strings
+ iree::base::core_headers
iree::base::file_mapping_hdrs
- iree::base::target_platform
iree::base::tracing
PUBLIC
)
@@ -88,8 +88,9 @@
::ostringstream
absl::core_headers
absl::strings
+ absl::utility
iree::base::api
+ iree::base::core_headers
iree::base::logging
- iree::base::target_platform
PUBLIC
)
diff --git a/iree/base/internal/file_io_win32.cc b/iree/base/internal/file_io_win32.cc
index 004f0f4..3bf4b0e 100644
--- a/iree/base/internal/file_io_win32.cc
+++ b/iree/base/internal/file_io_win32.cc
@@ -48,7 +48,8 @@
result.resize(file->size());
DWORD bytes_read = 0;
if (::ReadFile(file->handle(), const_cast<char*>(result.data()),
- result.size(), &bytes_read, nullptr) == FALSE) {
+ static_cast<DWORD>(result.size()), &bytes_read,
+ nullptr) == FALSE) {
return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
<< "Unable to read file span of " << result.size() << " bytes from '"
<< path << "'";
@@ -63,8 +64,8 @@
Status SetFileContents(const std::string& path, absl::string_view content) {
IREE_TRACE_SCOPE0("file_io::SetFileContents");
IREE_ASSIGN_OR_RETURN(auto file, FileHandle::OpenWrite(std::move(path), 0));
- if (::WriteFile(file->handle(), content.data(), content.size(), NULL, NULL) ==
- FALSE) {
+ if (::WriteFile(file->handle(), content.data(),
+ static_cast<DWORD>(content.size()), NULL, NULL) == FALSE) {
return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
<< "Unable to write file span of " << content.size() << " bytes to '"
<< path << "'";
@@ -103,7 +104,8 @@
std::string temp_path(64, '\0');
for (bool retry_query = true; retry_query;) {
- DWORD required_length = ::GetTempPathA(temp_path.size(), &temp_path[0]);
+ DWORD required_length =
+ ::GetTempPathA(static_cast<DWORD>(temp_path.size()), &temp_path[0]);
retry_query = required_length > temp_path.size();
temp_path.resize(required_length);
}
diff --git a/iree/base/internal/status_builder.cc b/iree/base/internal/status_builder.cc
index 48200a2..24f6da6 100644
--- a/iree/base/internal/status_builder.cc
+++ b/iree/base/internal/status_builder.cc
@@ -17,6 +17,7 @@
#include <cerrno>
#include <cstdio>
+#include "iree/base/api.h"
#include "iree/base/target_platform.h"
namespace iree {
@@ -133,188 +134,17 @@
return StatusBuilder(StatusCode::kUnknown, location);
}
-// Returns the code for |error_number|, which should be an |errno| value.
-// See https://en.cppreference.com/w/cpp/error/errno_macros and similar refs.
-static StatusCode ErrnoToCanonicalCode(int error_number) {
- switch (error_number) {
- case 0:
- return StatusCode::kOk;
- case EINVAL: // Invalid argument
- case ENAMETOOLONG: // Filename too long
- case E2BIG: // Argument list too long
- case EDESTADDRREQ: // Destination address required
- case EDOM: // Mathematics argument out of domain of function
- case EFAULT: // Bad address
- case EILSEQ: // Illegal byte sequence
- case ENOPROTOOPT: // Protocol not available
- case ENOSTR: // Not a STREAM
- case ENOTSOCK: // Not a socket
- case ENOTTY: // Inappropriate I/O control operation
- case EPROTOTYPE: // Protocol wrong type for socket
- case ESPIPE: // Invalid seek
- return StatusCode::kInvalidArgument;
- case ETIMEDOUT: // Connection timed out
- case ETIME: // Timer expired
- return StatusCode::kDeadlineExceeded;
- case ENODEV: // No such device
- case ENOENT: // No such file or directory
-#ifdef ENOMEDIUM
- case ENOMEDIUM: // No medium found
-#endif
- case ENXIO: // No such device or address
- case ESRCH: // No such process
- return StatusCode::kNotFound;
- case EEXIST: // File exists
- case EADDRNOTAVAIL: // Address not available
- case EALREADY: // Connection already in progress
-#ifdef ENOTUNIQ
- case ENOTUNIQ: // Name not unique on network
-#endif
- return StatusCode::kAlreadyExists;
- case EPERM: // Operation not permitted
- case EACCES: // Permission denied
-#ifdef ENOKEY
- case ENOKEY: // Required key not available
-#endif
- case EROFS: // Read only file system
- return StatusCode::kPermissionDenied;
- case ENOTEMPTY: // Directory not empty
- case EISDIR: // Is a directory
- case ENOTDIR: // Not a directory
- case EADDRINUSE: // Address already in use
- case EBADF: // Invalid file descriptor
-#ifdef EBADFD
- case EBADFD: // File descriptor in bad state
-#endif
- case EBUSY: // Device or resource busy
- case ECHILD: // No child processes
- case EISCONN: // Socket is connected
-#ifdef EISNAM
- case EISNAM: // Is a named type file
-#endif
-#ifdef ENOTBLK
- case ENOTBLK: // Block device required
-#endif
- case ENOTCONN: // The socket is not connected
- case EPIPE: // Broken pipe
-#ifdef ESHUTDOWN
- case ESHUTDOWN: // Cannot send after transport endpoint shutdown
-#endif
- case ETXTBSY: // Text file busy
-#ifdef EUNATCH
- case EUNATCH: // Protocol driver not attached
-#endif
- return StatusCode::kFailedPrecondition;
- case ENOSPC: // No space left on device
-#ifdef EDQUOT
- case EDQUOT: // Disk quota exceeded
-#endif
- case EMFILE: // Too many open files
- case EMLINK: // Too many links
- case ENFILE: // Too many open files in system
- case ENOBUFS: // No buffer space available
- case ENODATA: // No message is available on the STREAM read queue
- case ENOMEM: // Not enough space
- case ENOSR: // No STREAM resources
-#ifdef EUSERS
- case EUSERS: // Too many users
-#endif
- return StatusCode::kResourceExhausted;
-#ifdef ECHRNG
- case ECHRNG: // Channel number out of range
-#endif
- case EFBIG: // File too large
- case EOVERFLOW: // Value too large to be stored in data type
- case ERANGE: // Result too large
- return StatusCode::kOutOfRange;
-#ifdef ENOPKG
- case ENOPKG: // Package not installed
-#endif
- case ENOSYS: // Function not implemented
- case ENOTSUP: // Operation not supported
- case EAFNOSUPPORT: // Address family not supported
-#ifdef EPFNOSUPPORT
- case EPFNOSUPPORT: // Protocol family not supported
-#endif
- case EPROTONOSUPPORT: // Protocol not supported
-#ifdef ESOCKTNOSUPPORT
- case ESOCKTNOSUPPORT: // Socket type not supported
-#endif
- case EXDEV: // Improper link
- return StatusCode::kUnimplemented;
- case EAGAIN: // Resource temporarily unavailable
-#ifdef ECOMM
- case ECOMM: // Communication error on send
-#endif
- case ECONNREFUSED: // Connection refused
- case ECONNABORTED: // Connection aborted
- case ECONNRESET: // Connection reset
- case EINTR: // Interrupted function call
-#ifdef EHOSTDOWN
- case EHOSTDOWN: // Host is down
-#endif
- case EHOSTUNREACH: // Host is unreachable
- case ENETDOWN: // Network is down
- case ENETRESET: // Connection aborted by network
- case ENETUNREACH: // Network unreachable
- case ENOLCK: // No locks available
- case ENOLINK: // Link has been severed
-#ifdef ENONET
- case ENONET: // Machine is not on the network
-#endif
- return StatusCode::kUnavailable;
- case EDEADLK: // Resource deadlock avoided
-#ifdef ESTALE
- case ESTALE: // Stale file handle
-#endif
- return StatusCode::kAborted;
- case ECANCELED: // Operation cancelled
- return StatusCode::kCancelled;
- default:
- return StatusCode::kUnknown;
- }
-}
-
StatusBuilder ErrnoToCanonicalStatusBuilder(int error_number,
SourceLocation location) {
- return StatusBuilder(ErrnoToCanonicalCode(error_number), location);
+ return StatusBuilder(iree_status_code_from_errno(error_number), location);
}
#if defined(IREE_PLATFORM_WINDOWS)
-// Returns the code for |error| which should be a Win32 error dword.
-static StatusCode Win32ErrorToCanonicalCode(uint32_t error) {
- switch (error) {
- case ERROR_SUCCESS:
- return StatusCode::kOk;
- case ERROR_FILE_NOT_FOUND:
- case ERROR_PATH_NOT_FOUND:
- return StatusCode::kNotFound;
- case ERROR_TOO_MANY_OPEN_FILES:
- case ERROR_OUTOFMEMORY:
- case ERROR_HANDLE_DISK_FULL:
- case ERROR_HANDLE_EOF:
- return StatusCode::kResourceExhausted;
- case ERROR_ACCESS_DENIED:
- return StatusCode::kPermissionDenied;
- case ERROR_INVALID_HANDLE:
- return StatusCode::kInvalidArgument;
- case ERROR_NOT_READY:
- case ERROR_READ_FAULT:
- return StatusCode::kUnavailable;
- case ERROR_WRITE_FAULT:
- return StatusCode::kDataLoss;
- case ERROR_NOT_SUPPORTED:
- return StatusCode::kUnimplemented;
- default:
- return StatusCode::kUnknown;
- }
-}
-
StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error,
SourceLocation location) {
// TODO(benvanik): use FormatMessage; or defer until required?
- return StatusBuilder(Win32ErrorToCanonicalCode(error), location)
+ return StatusBuilder(iree_status_code_from_win32_error(error), location)
<< "<TBD>: " << error;
}
diff --git a/iree/base/internal/statusor.h b/iree/base/internal/statusor.h
index 93338a9..d2ecac7 100644
--- a/iree/base/internal/statusor.h
+++ b/iree/base/internal/statusor.h
@@ -15,6 +15,9 @@
#ifndef IREE_BASE_INTERNAL_STATUSOR_H_
#define IREE_BASE_INTERNAL_STATUSOR_H_
+#include <type_traits>
+#include <utility>
+
#include "absl/base/attributes.h"
#include "absl/utility/utility.h"
#include "iree/base/internal/status.h"
@@ -49,12 +52,11 @@
template <typename T, typename U>
struct IsAmbiguousStatusOrForInitialization
: // Strip const-value refs from type and check again, else false_type.
- public absl::conditional_t<
- std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>,
- U>::value,
+ public std::conditional_t<
+ std::is_same<std::remove_cv_t<std::remove_reference_t<U>>, U>::value,
std::false_type,
IsAmbiguousStatusOrForInitialization<
- T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
+ T, std::remove_cv_t<std::remove_reference_t<U>>>> {};
template <typename T, typename U>
struct IsAmbiguousStatusOrForInitialization<T, StatusOr<U>>
@@ -62,17 +64,17 @@
template <typename T, typename U>
using IsStatusOrDirectInitializationAmbiguous = absl::disjunction<
- std::is_same<StatusOr<T>, absl::remove_cv_t<absl::remove_reference_t<U>>>,
- std::is_same<Status, absl::remove_cv_t<absl::remove_reference_t<U>>>,
- std::is_same<StatusBuilder, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<StatusOr<T>, std::remove_cv_t<std::remove_reference_t<U>>>,
+ std::is_same<Status, std::remove_cv_t<std::remove_reference_t<U>>>,
+ std::is_same<StatusBuilder, std::remove_cv_t<std::remove_reference_t<U>>>,
std::is_same<absl::in_place_t,
- absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::remove_cv_t<std::remove_reference_t<U>>>,
IsAmbiguousStatusOrForInitialization<T, U>>;
template <typename T, typename U>
using IsStatusOrDirectInitializationValid = absl::disjunction<
// The is_same allows nested status ors to ignore this check iff same type.
- std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<T, std::remove_cv_t<std::remove_reference_t<U>>>,
absl::negation<IsStatusOrDirectInitializationAmbiguous<T, U>>>;
class Helper {
@@ -327,7 +329,7 @@
// explicit.
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>,
std::is_constructible<T, const U&>,
@@ -339,7 +341,7 @@
: Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>,
std::is_constructible<T, const U&>,
@@ -352,7 +354,7 @@
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
std::is_convertible<U&&, T>,
@@ -363,7 +365,7 @@
: Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
absl::negation<std::is_convertible<U&&, T>>,
@@ -378,7 +380,7 @@
// StatusOr<U>.
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>,
std::is_constructible<T, const U&>,
@@ -393,7 +395,7 @@
}
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
std::is_assignable<T, U&&>,
@@ -445,7 +447,7 @@
// a StatusOr<J>, where J is convertible to T.
template <
typename U = T,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>,
std::is_constructible<T, U&&>,
@@ -456,7 +458,7 @@
template <
typename U = T,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>,
std::is_constructible<T, U&&>,
diff --git a/iree/base/platform_headers.h b/iree/base/platform_headers.h
deleted file mode 100644
index 16ad04b..0000000
--- a/iree/base/platform_headers.h
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_PLATFORM_HEADERS_H_
-#define IREE_BASE_PLATFORM_HEADERS_H_
-
-#include "iree/base/target_platform.h"
-
-#if defined(IREE_PLATFORM_WINDOWS)
-
-#if defined(_MSC_VER)
-// Abseil compatibility: don't include incompatible winsock versions.
-#ifndef WIN32_LEAN_AND_MEAN
-#define WIN32_LEAN_AND_MEAN
-#endif // WIN32_LEAN_AND_MEAN
-// Abseil compatibility: don't define min and max macros.
-#ifndef NOMINMAX
-#define NOMINMAX
-#endif // NOMINMAX
-#endif // _MSC_VER
-
-#include <windows.h>
-
-// WinGDI.h defines `ERROR`, undef to avoid conflict naming.
-#undef ERROR
-
-#endif // IREE_PLATFORM_WINDOWS
-
-#endif // IREE_BASE_PLATFORM_HEADERS_H_
diff --git a/iree/base/synchronization.c b/iree/base/synchronization.c
index ed185955..8bc186e 100644
--- a/iree/base/synchronization.c
+++ b/iree/base/synchronization.c
@@ -95,7 +95,7 @@
#elif defined(IREE_PLATFORM_WINDOWS)
-#pragma comment(lib, "synchronization")
+#pragma comment(lib, "Synchronization.lib")
static inline iree_status_t iree_futex_wait(void* address,
uint32_t expected_value,
diff --git a/iree/base/synchronization.h b/iree/base/synchronization.h
index 6cf7dfa..9da41bb 100644
--- a/iree/base/synchronization.h
+++ b/iree/base/synchronization.h
@@ -43,13 +43,6 @@
#define IREE_PTR_GUARDED_BY(x) \
IREE_THREAD_ANNOTATION_ATTRIBUTE(pt_guarded_by(x))
-#define IREE_EXCLUSIVE_LOCK_FUNCTION(...) \
- IREE_THREAD_ANNOTATION_ATTRIBUTE(exclusive_lock_function(__VA_ARGS__))
-#define IREE_EXCLUSIVE_TRYLOCK_FUNCTION(...) \
- IREE_THREAD_ANNOTATION_ATTRIBUTE(exclusive_trylock_function(__VA_ARGS__))
-#define IREE_UNLOCK_FUNCTION(...) \
- IREE_THREAD_ANNOTATION_ATTRIBUTE(unlock_function(__VA_ARGS__))
-
#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_EMSCRIPTEN) || \
defined(IREE_PLATFORM_LINUX) || defined(IREE_PLATFORM_WINDOWS)
#define IREE_PLATFORM_HAS_FUTEX 1
diff --git a/iree/base/target_platform.h b/iree/base/target_platform.h
index 7c9aecb..d89bbb4 100644
--- a/iree/base/target_platform.h
+++ b/iree/base/target_platform.h
@@ -15,7 +15,9 @@
#ifndef IREE_BASE_TARGET_PLATFORM_H_
#define IREE_BASE_TARGET_PLATFORM_H_
-// The Bazel rule defines one of the following top-level platforms and then
+#include <stdint.h>
+
+// The build system defines one of the following top-level platforms and then
// one platform+architecture pair for that platform.
//
// IREE_ARCH_ARM_32
@@ -25,6 +27,10 @@
// IREE_ARCH_X86_32
// IREE_ARCH_X86_64
//
+// IREE_PTR_SIZE
+// IREE_PTR_SIZE_32
+// IREE_PTR_SIZE_64
+//
// IREE_ENDIANNESS_LITTLE
// IREE_ENDIANNESS_BIG
//
@@ -81,6 +87,18 @@
#endif // all archs
//==============================================================================
+// IREE_PTR_SIZE_*
+//==============================================================================
+
+#if UINTPTR_MAX > UINT_MAX
+#define IREE_PTR_SIZE_64
+#define IREE_PTR_SIZE 8
+#else
+#define IREE_PTR_SIZE_32
+#define IREE_PTR_SIZE 4
+#endif
+
+//==============================================================================
// IREE_ENDIANNESS_*
//==============================================================================
// https://en.wikipedia.org/wiki/Endianness
diff --git a/iree/base/testing/BUILD b/iree/base/testing/BUILD
new file mode 100644
index 0000000..43b4afe
--- /dev/null
+++ b/iree/base/testing/BUILD
@@ -0,0 +1,52 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_binary(
+ name = "dynamic_library_test_library.so",
+ testonly = True,
+ srcs = ["dynamic_library_test_library.cc"],
+ linkshared = True,
+)
+
+cc_embed_data(
+ name = "dynamic_library_test_library",
+ testonly = True,
+ srcs = [":dynamic_library_test_library.so"],
+ cc_file_output = "dynamic_library_test_library_embed.cc",
+ cpp_namespace = "iree",
+ flatten = True,
+ h_file_output = "dynamic_library_test_library_embed.h",
+)
+
+cc_test(
+ name = "dynamic_library_test",
+ srcs = ["dynamic_library_test.cc"],
+ deps = [
+ ":dynamic_library_test_library",
+ "//iree/base:core_headers",
+ "//iree/base:dynamic_library",
+ "//iree/base:file_io",
+ "//iree/base:status",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
diff --git a/iree/base/testing/CMakeLists.txt b/iree/base/testing/CMakeLists.txt
new file mode 100644
index 0000000..ba88fe8
--- /dev/null
+++ b/iree/base/testing/CMakeLists.txt
@@ -0,0 +1,60 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# bazel_to_cmake: DO NOT EDIT (some scotttodd todos remain)
+
+# TODO(scotttodd): clean up bazel_to_cmake handling here
+# * this is a cc_binary in Bazel, but `linkshared` fits iree_cc_library better
+# * the output file name is platform-specific, get it with $<TARGET_FILE:>
+iree_cc_library(
+ NAME
+ dynamic_library_test_library.so
+ OUT
+ dynamic_library_test_library.so
+ SRCS
+ "dynamic_library_test_library.cc"
+ TESTONLY
+ SHARED
+)
+
+iree_cc_embed_data(
+ NAME
+ dynamic_library_test_library
+ GENERATED_SRCS
+ "$<TARGET_FILE:iree::base::testing::dynamic_library_test_library.so>"
+ CC_FILE_OUTPUT
+ "dynamic_library_test_library_embed.cc"
+ H_FILE_OUTPUT
+ "dynamic_library_test_library_embed.h"
+ TESTONLY
+ CPP_NAMESPACE
+ "iree"
+ FLATTEN
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ dynamic_library_test
+ SRCS
+ "dynamic_library_test.cc"
+ DEPS
+ ::dynamic_library_test_library
+ iree::base::core_headers
+ iree::base::dynamic_library
+ iree::base::file_io
+ iree::base::status
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
diff --git a/iree/base/dynamic_library_test.cc b/iree/base/testing/dynamic_library_test.cc
similarity index 96%
rename from iree/base/dynamic_library_test.cc
rename to iree/base/testing/dynamic_library_test.cc
index f607c22..406a3c5 100644
--- a/iree/base/dynamic_library_test.cc
+++ b/iree/base/testing/dynamic_library_test.cc
@@ -16,10 +16,10 @@
#include <string>
-#include "iree/base/dynamic_library_test_library_embed.h"
#include "iree/base/file_io.h"
#include "iree/base/status.h"
#include "iree/base/target_platform.h"
+#include "iree/base/testing/dynamic_library_test_library_embed.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
@@ -64,7 +64,6 @@
TEST_F(DynamicLibraryTest, LoadLibrarySuccess) {
IREE_ASSERT_OK_AND_ASSIGN(auto library,
DynamicLibrary::Load(library_temp_path_.c_str()));
- EXPECT_EQ(library_temp_path_, library->file_name());
}
TEST_F(DynamicLibraryTest, LoadLibraryFailure) {
diff --git a/iree/base/dynamic_library_test_library.cc b/iree/base/testing/dynamic_library_test_library.cc
similarity index 100%
rename from iree/base/dynamic_library_test_library.cc
rename to iree/base/testing/dynamic_library_test_library.cc
diff --git a/iree/base/threading.c b/iree/base/threading.c
index 0b3a49f..833194f 100644
--- a/iree/base/threading.c
+++ b/iree/base/threading.c
@@ -43,6 +43,16 @@
}
//==============================================================================
+// iree_thread_affinity_t
+//==============================================================================
+
+// TODO(benvanik): add more helpers and possibly move cpuinfo usage into here.
+
+void iree_thread_affinity_set_any(iree_thread_affinity_t* out_thread_affinity) {
+ memset(out_thread_affinity, 0x00, sizeof(*out_thread_affinity));
+}
+
+//==============================================================================
// iree_thread_override_list_t
//==============================================================================
// This is shared by multiple platform implementations and gets stripped in LTO
diff --git a/iree/base/threading.h b/iree/base/threading.h
index 56eeb7b..f6cb7fd 100644
--- a/iree/base/threading.h
+++ b/iree/base/threading.h
@@ -55,6 +55,37 @@
};
typedef int32_t iree_thread_priority_class_t;
+// Specifies the processor affinity for a particular thread.
+// Each platform handles this differently (if at all).
+//
+// macOS/iOS:
+// Only affinity tags are supported; the ID will be used by the kernel to
+// group threads that having matching values together and (hopefully) schedule
+// them on cores that may share some level of the cache hierarchy. The API is
+// effectively just asking nicely and hoping the kernel is on the same
+// wavelength.
+//
+// Linux/Android:
+// sched_setaffinity is used to pin the thread to the core with the given ID.
+// There are, naturally, issues on Android where if the governer has turned
+// off some cores (such as powering down big cores in an ARM big.LITTLE
+// configuration) the affinity request will be dropped on the floor even if
+// the cores are later enabled. This is one of the reasons why we note in
+// iree_thread_request_affinity that requests may need to be made at
+// ¯\_(ツ)_/¯ intervals. In the future we can try to hook into power
+// management infra to see if we can tell when we need to do this.
+//
+// Windows:
+// Stuff just works. Love it.
+typedef struct {
+ uint32_t specified : 1;
+ uint32_t group : 7;
+ uint32_t id : 24;
+} iree_thread_affinity_t;
+
+// Sets |thread_affinity| to match with any processor in the system.
+void iree_thread_affinity_set_any(iree_thread_affinity_t* out_thread_affinity);
+
// Thread creation parameters.
// All are optional and the entire struct can safely be zero-initialized.
typedef struct {
@@ -73,9 +104,14 @@
bool create_suspended;
// Initial priority class.
- // This may be changed later via iree_thread_set_priority_class; see that for
- // more information.
+ // This may be changed later via iree_thread_priority_class_override_begin;
+ // see that for more information.
iree_thread_priority_class_t priority_class;
+
+ // Initial thread affinity.
+ // This may be changed later via iree_thread_request_affinity; see that for
+ // more information.
+ iree_thread_affinity_t initial_affinity;
} iree_thread_create_params_t;
typedef int (*iree_thread_entry_t)(void* entry_arg);
@@ -119,6 +155,26 @@
// iree_thread_priority_class_override_begin.
void iree_thread_override_end(iree_thread_override_t* override_token);
+// Updates the thread affinity of the given |thread|.
+// Affinities are not sticky and may need to be refreshed over time as CPUs are
+// enabled/disabled by the OS (such as power mode changes, governer adjustments,
+// etc). Users wanting to ensure threads have specific affinities may want to
+// request updates whenever new large amounts of work are about to be performed.
+//
+// NOTE: thread affinities are just a hint. The OS scheduler is free to do
+// whatever it wants up to and including entirely ignoring the specified
+// affinity. In many cases where cores are oversubscribed setting an affinity
+// mask can pessimize battery/thermals/performance as the OS will sometimes try
+// to shuffle around threads to disable physical cores/etc.
+//
+// Compatibility warning: Apple/darwin only support affinity groups, with each
+// unique affinity sharing time with all others of the same value. This means
+// that trying to get clever with several thread sets with overlapping
+// affinities will likely not work as expected. Try to stick with threads that
+// run only on a single processor.
+void iree_thread_request_affinity(iree_thread_t* thread,
+ iree_thread_affinity_t affinity);
+
// Resumes |thread| if it was created suspended.
// This has no effect if the thread is not suspended.
void iree_thread_resume(iree_thread_t* thread);
diff --git a/iree/base/threading_darwin.c b/iree/base/threading_darwin.c
index 6f72a2f..c0ba729 100644
--- a/iree/base/threading_darwin.c
+++ b/iree/base/threading_darwin.c
@@ -12,10 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/base/atomics.h"
-#include "iree/base/threading.h"
+// NOTE: must be first to ensure that we can define settings for all includes.
#include "iree/base/threading_impl.h"
-#include "iree/base/tracing.h"
#if defined(IREE_PLATFORM_APPLE)
@@ -25,6 +23,10 @@
#include <pthread.h>
#include <string.h>
+#include "iree/base/atomics.h"
+#include "iree/base/threading.h"
+#include "iree/base/tracing.h"
+
// Useful to see how pthreads is implemented on (old) darwin:
// https://opensource.apple.com/source/Libc/Libc-825.40.1/pthreads/pthread.c.auto.html
@@ -123,6 +125,9 @@
}
thread->mach_port = pthread_mach_thread_np(thread->handle);
+ if (params.initial_affinity.specified) {
+ iree_thread_request_affinity(thread, params.initial_affinity);
+ }
// Retain the thread for the thread itself; this way if the caller immediately
// releases the iree_thread_t handle the thread won't explode.
@@ -195,6 +200,23 @@
IREE_TRACE_ZONE_END(z0);
}
+void iree_thread_request_affinity(iree_thread_t* thread,
+ iree_thread_affinity_t affinity) {
+ if (!affinity.specified) return;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // See:
+ // https://gist.github.com/Coneko/4234842
+ // https://fergofrog.com/code/cbowser/xnu/osfmk/mach/thread_policy.h.html
+ // http://www.hybridkernel.com/2015/01/18/binding_threads_to_cores_osx.html
+ thread_affinity_policy_data_t policy_data = {affinity.id};
+ thread_policy_set(thread->mach_port, THREAD_AFFINITY_POLICY,
+ (thread_policy_t)(&policy_data),
+ THREAD_AFFINITY_POLICY_COUNT);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
void iree_thread_resume(iree_thread_t* thread) {
IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/iree/base/threading_impl.h b/iree/base/threading_impl.h
index b6794f2..40de48f 100644
--- a/iree/base/threading_impl.h
+++ b/iree/base/threading_impl.h
@@ -15,6 +15,18 @@
#ifndef IREE_BASE_THREADING_IMPL_H_
#define IREE_BASE_THREADING_IMPL_H_
+// Ensure that any posix header we include exposes GNU stuff. Ignored on
+// platforms where we either don't have the GNU stuff or don't have posix
+// headers at all.
+//
+// Note that this does not need to be the same for all compilation units, only
+// those we want to access the non-portable features in. It *must* be defined
+// prior to including any of the files, though, as otherwise header-guards will
+// cause the setting at the time of first inclusion to win.
+//
+// https://stackoverflow.com/a/5583764
+#define _GNU_SOURCE 1
+
#include <assert.h>
#include <errno.h>
#include <stddef.h>
@@ -22,6 +34,7 @@
#include "iree/base/api.h"
#include "iree/base/synchronization.h"
+#include "iree/base/target_platform.h"
#include "iree/base/threading.h"
#ifdef __cplusplus
diff --git a/iree/base/threading_pthreads.c b/iree/base/threading_pthreads.c
index 5482414..8d2f14c 100644
--- a/iree/base/threading_pthreads.c
+++ b/iree/base/threading_pthreads.c
@@ -12,11 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/base/atomics.h"
-#include "iree/base/synchronization.h"
-#include "iree/base/threading.h"
+// NOTE: must be first to ensure that we can define settings for all includes.
#include "iree/base/threading_impl.h"
-#include "iree/base/tracing.h"
#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_EMSCRIPTEN) || \
defined(IREE_PLATFORM_LINUX)
@@ -31,6 +28,17 @@
#include <time.h>
#include <unistd.h>
+#include "iree/base/atomics.h"
+#include "iree/base/synchronization.h"
+#include "iree/base/threading.h"
+#include "iree/base/tracing.h"
+
+// Older glibc doesn't have a gettid wrapper:
+// https://stackoverflow.com/a/63494768
+#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30
+#define gettid() syscall(SYS_gettid)
+#endif
+
struct iree_thread_s {
iree_atomic_ref_count_t ref_count;
iree_allocator_t allocator;
@@ -156,6 +164,9 @@
if (params.priority_class != IREE_THREAD_PRIORITY_CLASS_NORMAL) {
iree_thread_set_priority_class(thread, params.priority_class);
}
+ if (params.initial_affinity.specified) {
+ iree_thread_request_affinity(thread, params.initial_affinity);
+ }
IREE_TRACE_ZONE_END(z0);
*out_thread = thread;
@@ -270,6 +281,23 @@
IREE_TRACE_ZONE_END(z0);
}
+void iree_thread_request_affinity(iree_thread_t* thread,
+ iree_thread_affinity_t affinity) {
+ if (!affinity.specified) return;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // NOTE: Android uses Linux lightweight processes (LWP) for threads, so the
+ // pid is really the tid. This is *not* anything related to pthreads_self.
+ pid_t tid = gettid();
+
+ cpu_set_t cpu_set;
+ CPU_ZERO(&cpu_set);
+ CPU_SET(affinity.id, &cpu_set);
+ sched_setaffinity(tid, sizeof(cpu_set), &cpu_set);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
void iree_thread_resume(iree_thread_t* thread) {
IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/iree/base/threading_win32.c b/iree/base/threading_win32.c
index 56ee1aa..ff7356b 100644
--- a/iree/base/threading_win32.c
+++ b/iree/base/threading_win32.c
@@ -12,12 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/threading_impl.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+
#include "iree/base/atomics.h"
#include "iree/base/threading.h"
#include "iree/base/threading_impl.h"
#include "iree/base/tracing.h"
-#if defined(IREE_PLATFORM_WINDOWS)
+// Great documentation:
+// https://www.microsoftpressstore.com/articles/article.aspx?p=2233328
struct iree_thread_s {
iree_atomic_ref_count_t ref_count;
@@ -167,6 +173,9 @@
if (params.priority_class != IREE_THREAD_PRIORITY_CLASS_NORMAL) {
iree_thread_set_priority_class(thread, params.priority_class);
}
+ if (params.initial_affinity.specified) {
+ iree_thread_request_affinity(thread, params.initial_affinity);
+ }
// Retain the thread for the thread itself; this way if the caller immediately
// releases the iree_thread_t handle the thread won't explode.
@@ -255,6 +264,20 @@
IREE_TRACE_ZONE_END(z0);
}
+void iree_thread_request_affinity(iree_thread_t* thread,
+ iree_thread_affinity_t affinity) {
+ if (!affinity.specified) return;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ PROCESSOR_NUMBER ideal_processor;
+ memset(&ideal_processor, 0, sizeof(ideal_processor));
+ ideal_processor.Group = affinity.group;
+ ideal_processor.Number = affinity.id;
+ SetThreadIdealProcessorEx(thread->handle, &ideal_processor, NULL);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
void iree_thread_resume(iree_thread_t* thread) {
IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/iree/base/tracing.h b/iree/base/tracing.h
index 5bbf393..a5c3b4f 100644
--- a/iree/base/tracing.h
+++ b/iree/base/tracing.h
@@ -266,6 +266,15 @@
// The C-string |name| will be copied and does not need to be a literal.
#define IREE_TRACE_SET_THREAD_NAME(name) iree_tracing_set_thread_name_impl(name)
+// Evalutes the expression code only if tracing is enabled.
+//
+// Example:
+// struct {
+// IREE_TRACE(uint32_t trace_only_value);
+// } my_object;
+// IREE_TRACE(my_object.trace_only_value = 5);
+#define IREE_TRACE(expr) expr
+
// Begins a new zone with the parent function name.
#define IREE_TRACE_ZONE_BEGIN(zone_id) \
IREE_TRACE_ZONE_BEGIN_NAMED(zone_id, NULL)
@@ -377,6 +386,7 @@
#else
#define IREE_TRACE_SET_APP_INFO(value, value_length)
#define IREE_TRACE_SET_THREAD_NAME(name)
+#define IREE_TRACE(expr)
#define IREE_TRACE_ZONE_BEGIN(zone_id)
#define IREE_TRACE_ZONE_BEGIN_NAMED(zone_id, name_literal)
#define IREE_TRACE_ZONE_BEGIN_NAMED_DYNAMIC(zone_id, name, name_length)
diff --git a/iree/base/wait_handle.c b/iree/base/wait_handle.c
new file mode 100644
index 0000000..20d4847
--- /dev/null
+++ b/iree/base/wait_handle.c
@@ -0,0 +1,33 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/wait_handle.h"
+
+//===----------------------------------------------------------------------===//
+// iree_wait_handle_t
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_wait_handle_wrap_primitive(
+ iree_wait_primitive_type_t primitive_type,
+ iree_wait_primitive_value_t primitive_value,
+ iree_wait_handle_t* out_handle) {
+ memset(out_handle, 0, sizeof(*out_handle));
+ out_handle->type = primitive_type;
+ out_handle->value = primitive_value;
+ return iree_ok_status();
+}
+
+void iree_wait_handle_deinitialize(iree_wait_handle_t* handle) {
+ memset(handle, 0, sizeof(*handle));
+}
diff --git a/iree/base/wait_handle.cc b/iree/base/wait_handle.cc
deleted file mode 100644
index ddaec34..0000000
--- a/iree/base/wait_handle.cc
+++ /dev/null
@@ -1,536 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/wait_handle.h"
-
-#include <errno.h>
-#include <fcntl.h>
-#include <poll.h>
-#include <time.h>
-#include <unistd.h>
-
-#include <type_traits>
-#include <utility>
-
-#include "absl/container/fixed_array.h"
-#include "absl/strings/str_cat.h"
-#include "absl/time/clock.h"
-#include "absl/time/time.h"
-#include "iree/base/status.h"
-
-// TODO(benvanik): organize these macros - they are terrible.
-
-#if !defined(__ANDROID__) && !defined(OS_IOS) && !defined(__EMSCRIPTEN__)
-#define IREE_HAS_PPOLL 1
-#endif // !__ANDROID__ && !__EMSCRIPTEN__
-#define IREE_HAS_POLL 1
-
-#if !defined(OS_IOS) && !defined(OS_MACOSX) && !defined(__EMSCRIPTEN__)
-#define IREE_HAS_EVENTFD 1
-#endif
-#define IREE_HAS_PIPE 1
-// #define IREE_HAS_SYNC_FILE 1
-
-#if defined(IREE_HAS_EVENTFD)
-#include <sys/eventfd.h>
-#endif // IREE_HAS_EVENTFD
-
-namespace iree {
-
-namespace {
-
-constexpr int kInvalidFd = WaitableObject::kInvalidFd;
-constexpr int kSignaledFd = WaitableObject::kSignaledFd;
-
-// Retries a syscall until it succeeds or fails for a real reason.
-template <typename SyscallT, typename... ParamsT>
-StatusOr<typename std::result_of<SyscallT(ParamsT...)>::type> Syscall(
- SyscallT syscall, ParamsT&&... params) {
- while (true) {
- const auto rv = syscall(std::forward<ParamsT>(params)...);
- if (rv >= 0) return rv;
- if (errno == EINTR) {
- // Retry on EINTR.
- continue;
- } else {
- return ErrnoToCanonicalStatusBuilder(errno, IREE_LOC);
- }
- }
-}
-
-#if defined(IREE_HAS_PPOLL)
-
-// ppoll(), present on Linux.
-// ppoll is preferred as it has a much better timing mechanism; poll can have a
-// large slop on the deadline.
-// Documentation: https://linux.die.net/man/2/poll
-StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, Time deadline_ns) {
- // Convert the deadline into a tmo_p struct for ppoll that controls whether
- // the call is blocking or non-blocking. Note that we must do this every
- // iteration of the loop as a previous ppoll may have taken some of the
- // time.
- //
- // See the ppoll docs for more information as to what the expected value is:
- // http://man7.org/linux/man-pages/man2/poll.2.html
- timespec timeout_spec;
- timespec* tmo_p;
- if (deadline == InfinitePast()) {
- // 0 for non-blocking.
- timeout_spec = {0};
- tmo_p = &timeout_spec;
- } else if (deadline == InfiniteFuture()) {
- // nullptr to ppoll() to block forever.
- tmo_p = nullptr;
- } else {
- // Wait only for as much time as we have before the deadline is exceeded.
- absl::Duration remaining_time = deadline - Now();
- if (remaining_time < absl::ZeroDuration()) {
- // Note: we likely have already bailed before getting here with a negative
- // duration.
- return DeadlineExceededErrorBuilder(IREE_LOC);
- }
- timeout_spec = absl::ToTimespec(remaining_time);
- tmo_p = &timeout_spec;
- }
- return Syscall(::ppoll, poll_fds.data(), poll_fds.size(), tmo_p, nullptr);
-}
-
-#elif defined(IREE_HAS_POLL)
-
-// poll(), present pretty much everywhere.
-// Documentation: https://linux.die.net/man/2/poll
-StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, Time deadline_ns) {
- int timeout;
- if (deadline == InfinitePast()) {
- // Don't block.
- timeout = 0;
- } else if (deadline == InfiniteFuture()) {
- // Block forever.
- timeout = -1;
- } else {
- absl::Duration remaining_time = deadline - Now();
- if (remaining_time < absl::ZeroDuration()) {
- return DeadlineExceededErrorBuilder(IREE_LOC);
- }
- timeout = static_cast<int>(absl::ToInt64Milliseconds(remaining_time));
- }
- return Syscall(::poll, poll_fds.data(), poll_fds.size(), timeout);
-}
-
-#else
-#error "No SystemPoll implementation"
-#endif // IREE_HAS_PPOLL / IREE_HAS_POLL / etc
-
-// Builds the list of pollfds to for ppoll wait on and will perform any
-// required wait handle callbacks.
-//
-// The provided deadline will be observed if any of the wait handles needs to
-// block for acquiring an fd.
-StatusOr<absl::FixedArray<pollfd>> AcquireWaitHandles(
- WaitHandle::WaitHandleSpan wait_handles, Time deadline_ns) {
- absl::FixedArray<pollfd> poll_fds{wait_handles.size()};
- for (int i = 0; i < wait_handles.size(); ++i) {
- poll_fds[i].events = POLLIN | POLLPRI | POLLERR | POLLHUP | POLLNVAL;
- poll_fds[i].revents = 0;
- // NOTE: poll will ignore any negative fds and our kInvalidFd == -1 so we
- // can still put them in the list and it'll just skip them.
- if (!wait_handles[i] || !wait_handles[i]->object()) {
- poll_fds[i].fd = kInvalidFd;
- continue;
- }
-
- // Acquire the file descriptor for waiting.
- // This may block (if |deadline| allows it) if the fd is not yet available.
- // This is like a pre-wait for the actual poll operation. It can be bad with
- // WaitAny, though we could handle that better here.
- IREE_ASSIGN_OR_RETURN(
- auto fd_info, wait_handles[i]->object()->AcquireFdForWait(deadline));
- poll_fds[i].fd = fd_info.second;
-
- // Abort if deadline exceeded.
- if (deadline != InfinitePast() && deadline < Now()) {
- return DeadlineExceededErrorBuilder(IREE_LOC)
- << "Deadline exceeded acquiring for fds";
- }
- }
- return poll_fds;
-}
-
-Status ClearFd(WaitableObject::FdType fd_type, int fd) {
- // Read in a loop until the read would block.
- // Depending on how the users setup the fd the act of reading may reset the
- // entire handle (such as with the default eventfd mode) or multiple reads
- // may be required (such as with semaphores).
- while (true) {
-#if defined(IREE_HAS_EVENTFD)
- eventfd_t val = 0;
- int rv = ::eventfd_read(fd, &val);
-#elif defined(IREE_HAS_PIPE)
- char buf;
- int rv = ::read(fd, &buf, 1);
-#else
- return UnimplementedErrorBuilder(IREE_LOC) << "fd_type cannot be cleared";
-#endif // IREE_HAS_EVENTFD
- if (rv != -1) {
- // Success! Keep going.
- continue;
- } else {
- if (errno == EWOULDBLOCK) {
- // The read would have blocked meaning that we've hit the end and
- // successfully cleared the fd.
- return OkStatus();
- } else if (errno == EINTR) {
- // Retry.
- continue;
- } else {
- return ErrnoToCanonicalStatusBuilder(errno, IREE_LOC)
- << "ClearFd failed";
- }
- }
- }
-}
-
-// Performs a single poll on multiple fds and returns information about the
-// signaled fds, if any.
-Status MultiPoll(WaitHandle::WaitHandleSpan wait_handles,
- absl::Span<pollfd> poll_fds, Time deadline_ns,
- int* out_any_signaled_index, int* out_unsignaled_count) {
- *out_any_signaled_index = -1;
- *out_unsignaled_count = 0;
-
- // poll has a nasty behavior where it allows -1 for fds... except for at [0].
- // To keep the rest of the code sane we correct for that here as epoll doesn't
- // have that behavior and we may want to special case this later.
- bool any_valid_fds = true;
- int swapped_zero_index = -1;
- if (poll_fds[0].fd < 0) {
- // Find a valid handle.
- for (int i = 1; i < poll_fds.size(); ++i) {
- if (poll_fds[i].fd > 0) {
- swapped_zero_index = i;
- std::swap(poll_fds[0], poll_fds[i]);
- break;
- }
- }
- if (swapped_zero_index == -1) {
- // No valid handles found, meaning that all handles are invalid.
- // We'll skip the wait below so we can share the processing code for any
- // fds that may be kSignaledFd.
- any_valid_fds = false;
- }
- }
-
- // Pass handles to ppoll.
- // http://man7.org/linux/man-pages/man2/poll.2.html
- if (any_valid_fds) {
- IREE_ASSIGN_OR_RETURN(int rv, SystemPoll(poll_fds, deadline));
- if (rv == 0) {
- // Call timed out and no descriptors were ready.
- // If this was just a poll then that's fine.
- return DeadlineExceededErrorBuilder(IREE_LOC);
- }
- }
-
- // If we had swapped fds[0] above we need to correct for that now.
- if (swapped_zero_index != -1) {
- std::swap(poll_fds[0], poll_fds[swapped_zero_index]);
- }
-
- // |rv| denotes the number of fds that were ready. Run through the list and
- // find the ones that were ready and mark them as completed.
- for (int i = 0; i < poll_fds.size(); ++i) {
- if (poll_fds[i].fd == kSignaledFd || poll_fds[i].revents == POLLIN) {
- // First attempt any resolve actions. If these fail we can't consider the
- // fd as having been signaled.
- IREE_ASSIGN_OR_RETURN(
- bool resolved,
- wait_handles[i]->object()->TryResolveWakeOnFd(poll_fds[i].fd));
- if (!resolved) {
- ++(*out_unsignaled_count);
- continue;
- }
-
- // Successful wait. Kill the fd so it is ignored on the next poll.
- poll_fds[i].fd = kInvalidFd;
- *out_any_signaled_index = i;
- } else if (poll_fds[i].revents) {
- if (poll_fds[i].revents & POLLERR) {
- return InternalErrorBuilder(IREE_LOC);
- } else if (poll_fds[i].revents & POLLHUP) {
- return CancelledErrorBuilder(IREE_LOC);
- } else if (poll_fds[i].revents & POLLNVAL) {
- return InvalidArgumentErrorBuilder(IREE_LOC);
- } else {
- return UnknownErrorBuilder(IREE_LOC);
- }
- } else if (poll_fds[i].fd != kInvalidFd) {
- ++(*out_unsignaled_count);
- }
- }
-
- return OkStatus();
-}
-
-} // namespace
-
-// static
-std::atomic<uint64_t> WaitHandle::next_unique_id_{1};
-
-// static
-WaitHandle WaitHandle::AlwaysSignaling() {
- class AlwaysSignalingObject : public WaitableObject {
- public:
- std::string DebugString() const override { return "signal"; }
- StatusOr<std::pair<FdType, int>> AcquireFdForWait(
- Time deadline_ns) override {
- return std::make_pair(FdType::kPermanent, kSignaledFd);
- }
- StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; }
- };
- static auto* obj = new AlwaysSignalingObject();
- return WaitHandle(add_ref(obj));
-}
-
-// static
-WaitHandle WaitHandle::AlwaysFailing() {
- class AlwaysFailingObject : public WaitableObject {
- public:
- std::string DebugString() const override { return "fail"; }
- StatusOr<std::pair<FdType, int>> AcquireFdForWait(
- Time deadline_ns) override {
- return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject";
- }
- StatusOr<bool> TryResolveWakeOnFd(int fd) override {
- return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject";
- }
- };
- static auto* obj = new AlwaysFailingObject();
- return WaitHandle(add_ref(obj));
-}
-
-// static
-Status WaitHandle::WaitAll(WaitHandleSpan wait_handles, Time deadline_ns) {
- if (wait_handles.empty()) return OkStatus();
-
- // Build the list of pollfds to wait on.
- IREE_ASSIGN_OR_RETURN(auto poll_fds,
- AcquireWaitHandles(wait_handles, deadline));
-
- // Loop until all handles have been signaled or the deadline is exceeded.
- int unsignaled_count = 0;
- do {
- int any_signaled_index = 0;
- IREE_RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds),
- deadline, &any_signaled_index,
- &unsignaled_count));
- } while (unsignaled_count > 0 && Now() < deadline);
-
- if (unsignaled_count == 0) {
- // All waits resolved.
- return OkStatus();
- } else {
- // One or more were unsignaled.
- return DeadlineExceededErrorBuilder(IREE_LOC);
- }
-}
-
-// static
-StatusOr<bool> WaitHandle::TryWaitAll(WaitHandleSpan wait_handles) {
- auto status = WaitAll(wait_handles, InfinitePast());
- if (status.ok()) {
- return true;
- } else if (IsDeadlineExceeded(status)) {
- return false;
- }
- return status;
-}
-
-// static
-StatusOr<int> WaitHandle::WaitAny(WaitHandleSpan wait_handles,
- Time deadline_ns) {
- if (wait_handles.empty()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "At least one wait handle is required for WaitAny";
- }
-
- // Build the list of pollfds to wait on.
- IREE_ASSIGN_OR_RETURN(auto poll_fds,
- AcquireWaitHandles(wait_handles, deadline));
-
- // Poll once; this makes a WaitAny just a WaitMulti that doesn't loop.
- int any_signaled_index = -1;
- int unsignaled_count = 0;
- IREE_RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds),
- deadline, &any_signaled_index,
- &unsignaled_count));
- if (any_signaled_index == -1) {
- // No wait handles were valid. Pretend 0 was signaled.
- return 0;
- }
- return any_signaled_index;
-}
-
-// static
-StatusOr<int> WaitHandle::TryWaitAny(WaitHandleSpan wait_handles) {
- auto status_or = WaitAny(wait_handles, InfinitePast());
- return IsDeadlineExceeded(status_or.status()) ? -1 : status_or;
-}
-
-// Storage for static class variables; these won't be needed when we can use
-// c++17 everywhere.
-constexpr int WaitableObject::kInvalidFd;
-constexpr int WaitableObject::kSignaledFd;
-
-WaitHandle::WaitHandle(ref_ptr<WaitableObject> object)
- : unique_id_(++next_unique_id_), object_(std::move(object)) {}
-
-WaitHandle::~WaitHandle() { Dispose(); }
-
-void WaitHandle::Dispose() { object_.reset(); }
-
-WaitHandle::WaitHandle(WaitHandle&& other)
- : unique_id_(other.unique_id_), object_(std::move(other.object_)) {
- other.unique_id_ = 0;
-}
-
-WaitHandle& WaitHandle::operator=(WaitHandle&& other) {
- if (this != std::addressof(other)) {
- // Close current handle.
- Dispose();
-
- // Take ownership of handle and resources.
- object_ = std::move(other.object_);
-
- other.unique_id_ = ++next_unique_id_;
- }
- return *this;
-}
-
-std::string WaitHandle::DebugString() const {
- return object_ ? object_->DebugString() : absl::StrCat("wh_", unique_id_);
-}
-
-StatusOr<bool> WaitHandle::TryWait() {
- auto status = WaitAll({this}, InfinitePast());
- if (status.ok()) {
- return true;
- } else if (IsDeadlineExceeded(status)) {
- return false;
- }
- return status;
-}
-
-ManualResetEvent::ManualResetEvent(const char* debug_name)
- : debug_name_(debug_name) {
- Initialize();
-}
-
-ManualResetEvent::~ManualResetEvent() { Dispose(); }
-
-void ManualResetEvent::Initialize() {
-#if defined(IREE_HAS_EVENTFD)
- // Create with an eventfd by default when we support it.
- // eventfd has lower overhead than pipes (the syscalls are cheap).
- // This usually will only fail if the system is completely out of handles.
- //
- // Docs: http://man7.org/linux/man-pages/man2/eventfd.2.html
- fd_type_ = FdType::kEventFd;
- fd_ = Syscall(::eventfd, 0, EFD_CLOEXEC | EFD_NONBLOCK).value();
-#elif defined(IREE_HAS_PIPE)
- // Android/Linux/iOS-compatible POSIX pipe handle.
- // Two handles are generated: one for transmitting and one for receiving.
- //
- // Docs: http://man7.org/linux/man-pages/man2/pipe.2.html
- fd_type_ = FdType::kPipe;
- int pipefd[2];
- Syscall(::pipe, pipefd).value();
- Syscall(::fcntl, pipefd[0], F_SETFL, O_NONBLOCK).value();
- fd_ = pipefd[0];
- write_fd_ = pipefd[1];
-#else
-// NOTE: sync_file does not use Notifier as they come from the kernel.
-#error "No fd-based sync primitive on this platform"
-#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE / etc
-}
-
-void ManualResetEvent::Dispose() {
- if (fd_ != kInvalidFd) {
- // Always signal, as we need to ensure waiters are woken.
- IREE_CHECK_OK(Set());
- Syscall(::close, fd_).value();
- fd_ = kInvalidFd;
- }
- if (write_fd_ != kInvalidFd) {
- Syscall(::close, write_fd_).value();
- write_fd_ = kInvalidFd;
- }
-}
-
-ManualResetEvent::ManualResetEvent(ManualResetEvent&& other)
- : fd_type_(other.fd_type_),
- fd_(other.fd_),
- write_fd_(other.write_fd_),
- debug_name_(other.debug_name_) {
- other.fd_type_ = FdType::kPermanent;
- other.fd_ = kInvalidFd;
- other.write_fd_ = kInvalidFd;
- other.debug_name_ = nullptr;
-}
-
-ManualResetEvent& ManualResetEvent::operator=(ManualResetEvent&& other) {
- if (this != std::addressof(other)) {
- Dispose();
- fd_type_ = other.fd_type_;
- fd_ = other.fd_;
- write_fd_ = other.write_fd_;
- debug_name_ = other.debug_name_;
- other.fd_type_ = FdType::kPermanent;
- other.fd_ = kInvalidFd;
- other.write_fd_ = kInvalidFd;
- other.debug_name_ = nullptr;
- other.Initialize();
- }
- return *this;
-}
-
-std::string ManualResetEvent::DebugString() const {
- if (debug_name_) {
- return debug_name_;
- }
-#if defined(IREE_HAS_EVENTFD)
- return absl::StrCat("eventfd_", fd_);
-#elif defined(IREE_HAS_PIPE)
- return absl::StrCat("pipe_", fd_, "_", write_fd_);
-#else
- return absl::StrCat("unknown_", fd_, "_", write_fd_);
-#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE
-}
-
-Status ManualResetEvent::Set() {
-#if defined(IREE_HAS_EVENTFD)
- return Syscall(::eventfd_write, fd_, 1ull).status();
-#elif defined(IREE_HAS_PIPE)
- char buf = '\n';
- return Syscall(::write, write_fd_, &buf, 1).status();
-#else
- return UnimplementedErrorBuilder(IREE_LOC)
- << "No fd-based sync primitive on this platform";
-#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE
-}
-
-Status ManualResetEvent::Reset() { return ClearFd(fd_type_, fd_); }
-
-WaitHandle ManualResetEvent::OnSet() { return WaitHandle(add_ref(this)); }
-
-} // namespace iree
diff --git a/iree/base/wait_handle.h b/iree/base/wait_handle.h
index 7d9a4af..b95bf7e 100644
--- a/iree/base/wait_handle.h
+++ b/iree/base/wait_handle.h
@@ -1,4 +1,4 @@
-// Copyright 2019 Google LLC
+// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,303 +15,297 @@
#ifndef IREE_BASE_WAIT_HANDLE_H_
#define IREE_BASE_WAIT_HANDLE_H_
-#include <atomic>
-#include <cstdint>
-#include <string>
-#include <utility>
+#include "iree/base/api.h"
+#include "iree/base/target_platform.h"
-#include "absl/types/span.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-#include "iree/base/time.h"
+#if defined(IREE_PLATFORM_WINDOWS)
+// Though Windows can support pipes no one uses them so for simplicity we only
+// exposes HANDLEs.
+#define IREE_HAVE_WAIT_TYPE_WIN32_HANDLE 1
+#elif defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_LINUX)
+// Treat Android and modern linux as (mostly) the same.
+#define IREE_HAVE_WAIT_TYPE_EVENTFD 1
+#define IREE_HAVE_WAIT_TYPE_PIPE 1
+#else
+// BSD/Darwin/etc all have pipe.
+#define IREE_HAVE_WAIT_TYPE_PIPE 1
+#endif // IREE_PLATFORM_*
-namespace iree {
+// TODO(benvanik): see if we can get sync file on linux too:
+#if defined(IREE_PLATFORM_ANDROID)
+#define IREE_HAVE_WAIT_TYPE_SYNC_FILE 1
+#endif // IREE_PLATFORM_ANDROID
-// Interfaces for waitable objects that can produce WaitHandles.
-// WaitableObjects are much like ::thread::Selectable, only they support both
-// the classic locking style as well as file descriptors for use with select().
-//
-// Usage:
-// class MyWaitableObject : public WaitableObject {
-// public:
-// std::string DebugString() const override { return "something useful"; }
-// WaitHandle OnAsyncTask() {
-// return WaitHandle(retain_ref(this));
-// }
-// private:
-// StatusOr<std::pair<FdType, int>> AcquireFdForWait(
-// Time deadline_ns) override {
-// // If blocking traditionally do so now and then return this:
-// return std::make_pair(FdType::kPermanent, kSignaledFd);
-// // Otherwise, see ManualResetEvent for an example using fds.
-// }
-// StatusOr<bool> TryResolveWakeOnFd(int fd) override {
-// // Return true iff the object is really acquired, such as the semaphore
-// // being decremented.
-// return true;
-// }
-// };
-class WaitableObject : public RefObject<WaitableObject> {
- public:
- // Indicates that a file descriptor is invalid. It will not block when waited
- // upon.
- constexpr static int kInvalidFd = -1;
- // Indicates that a file descriptor should be treated as signaled.
- // Waiting on this fd should return as if it has already been signaled.
- constexpr static int kSignaledFd = -2;
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
- // Defines the type of the native handle used for synchronization.
- enum class FdType : uint16_t {
- // Event has no handle and should be treated as permanently signaled.
- kPermanent,
+//===----------------------------------------------------------------------===//
+// iree_wait_primitive_*
+//===----------------------------------------------------------------------===//
- // Android/Linux/iOS-compatible POSIX pipe handle.
- // Two handles are generated: one for transmitting and one for receiving.
- //
- // More information:
- // http://man7.org/linux/man-pages/man2/pipe.2.html
- kPipe,
+// TODO(benvanik): conditionally compile out enum values unavailable (to avoid
+// runtime surprises).
- // Android/Linux eventfd handle.
- // These are akin to pipe() but require only a single handle and have
- // significantly lower overhead (equivalent if not slightly better than
- // pthreads condvars).
- //
- // eventfds support acting as both semaphores and auto reset events.
- //
- // More information:
- // http://man7.org/linux/man-pages/man2/eventfd.2.html
- kEventFd,
-
- // Android/Linux sync_file handle (aka 'sync fence').
- // The handle is allocated indirectly by the device driver via the
- // <linux/sync_file.h> API. It may be waited upon with poll(), select(), or
- // epoll() and must be closed with close() when no longer required. If
- // waiting on multiple sync_files the caller should first merge them
- // together.
- //
- // A sync_file must only be used as fences (one-shot manual reset events).
- //
- // More information:
- // https://www.kernel.org/doc/Documentation/sync_file.txt
- // https://lwn.net/Articles/702339/
- // https://source.android.com/devices/graphics/implement-vsync#explicit_synchronization
- kSyncFile,
- };
-
- virtual ~WaitableObject() = default;
-
- // Returns a string representing the object, either specified as a debug_name
- // or a unique ID.
- virtual std::string DebugString() const = 0;
-
- // Attempts to acquire a file descriptor for the waitable objects by the given
- // |deadline|. In many cases this will return immediately with a valid fd.
+// Specifies the type of a wait handle.
+enum iree_wait_primitive_type_e {
+ // Android/Linux eventfd handle.
+ // These are akin to pipe() but require only a single handle and have
+ // significantly lower overhead (equivalent if not slightly better than
+ // pthreads condvars).
//
- // In cases where the file descriptor may not be available the call may block
- // until either it is available or the |deadline| has elapsed. Use
- // InfinitePast() to prevent blocking.
+ // eventfds support acting as both semaphores and auto reset events.
//
- // Returns a valid file descriptor or kInvalidFd as an indication that the
- // object should not be waited on (already signaled, etc). Can return
- // kSignaledFd to indicate that it's already known that the handle has been
- // signaled and the caller should resolve as if it caused a wake normally.
- virtual StatusOr<std::pair<FdType, int>> AcquireFdForWait(
- Time deadline_ns) = 0;
+ // More information:
+ // http://man7.org/linux/man-pages/man2/eventfd.2.html
+ IREE_WAIT_PRIMITIVE_TYPE_EVENT_FD = 1u,
- // Tries to resolve the object with the given |fd|.
- // In many cases this will no-op, however some types may require additional
- // checks to ensure that the wait operation succeeded (such as semaphores
- // that may need to query a count). If resolution fails the waitable object
- // must not be considered signaled. This call will never block.
- virtual StatusOr<bool> TryResolveWakeOnFd(int fd) = 0;
+ // Android/Linux sync_file handle (aka 'sync fence').
+ // The handle is allocated indirectly by the device driver via the
+ // <linux/sync_file.h> API. It may be waited upon with poll(), select(), or
+ // epoll() and must be closed with close() when no longer required. If
+ // waiting on multiple sync_files the caller should first merge them
+ // together.
+ //
+ // A sync_file must only be used as fences (one-shot manual reset events).
+ //
+ // More information:
+ // https://www.kernel.org/doc/Documentation/sync_file.txt
+ // https://lwn.net/Articles/702339/
+ // https://source.android.com/devices/graphics/implement-vsync#explicit_synchronization
+ // https://developer.android.com/ndk/reference/group/sync
+ IREE_WAIT_PRIMITIVE_TYPE_SYNC_FILE = 2u,
+
+ // Android/Linux/iOS-compatible POSIX pipe handle.
+ // Two handles are generated: one for transmitting and one for receiving.
+ //
+ // More information:
+ // http://man7.org/linux/man-pages/man2/pipe.2.html
+ IREE_WAIT_PRIMITIVE_TYPE_PIPE = 3u,
+
+ // Windows HANDLE type.
+ // The HANDLE may represent a thread, event, semaphore, timer, etc.
+ //
+ // More information:
+ // https://docs.microsoft.com/en-us/windows/win32/sysinfo/object-categories
+ // https://docs.microsoft.com/en-us/windows/win32/sync/using-event-objects
+ IREE_WAIT_PRIMITIVE_TYPE_WIN32_HANDLE = 4u,
};
+typedef uint8_t iree_wait_primitive_type_t;
-// Handle to waitable objects.
-// WaitHandles are created by a particular synchronization primitive, such as
-// Fence, as a way for one or more observers to poll or wait for notification.
+// A handle value whose behavior is defined by the iree_wait_primitive_type_t.
+typedef union {
+#if defined(IREE_HAVE_WAIT_TYPE_EVENTFD)
+ // IREE_WAIT_PRIMITIVE_TYPE_EVENT_FD
+ struct {
+ int fd;
+ } event;
+#endif // IREE_HAVE_WAIT_TYPE_EVENTFD
+#if defined(IREE_HAVE_WAIT_TYPE_SYNC_FILE)
+ // IREE_WAIT_PRIMITIVE_TYPE_SYNC_FILE
+ struct {
+ int fd;
+ } sync_file;
+#endif // IREE_HAVE_WAIT_TYPE_SYNC_FILE
+#if defined(IREE_HAVE_WAIT_TYPE_PIPE)
+ // IREE_WAIT_PRIMITIVE_TYPE_PIPE
+ union {
+ struct {
+ int read_fd;
+ int write_fd;
+ };
+ int fds[2];
+ } pipe;
+#endif // IREE_HAVE_WAIT_TYPE_PIPE
+#if defined(IREE_HAVE_WAIT_TYPE_WIN32_HANDLE)
+ // IREE_WAIT_PRIMITIVE_TYPE_WIN32_HANDLE
+ struct {
+ uintptr_t handle;
+ } win32;
+#endif // IREE_HAVE_WAIT_TYPE_WIN32_HANDLE
+} iree_wait_primitive_value_t;
+
+//===----------------------------------------------------------------------===//
+// iree_wait_handle_t
+//===----------------------------------------------------------------------===//
+
+// Non-owning handle reference to a waitable object.
+// TODO(benvanik): packing to ensure we are getting the expected alignments.
+typedef struct {
+ iree_wait_primitive_type_t type; // uint8_t
+ union {
+ // Used by iree_wait_set_t storage to track the number of duplicate
+ // instances of a particular handle within the set to avoid needing to store
+ // them all separately. A dupe_count of 0 means there is one unique handle.
+ uint32_t dupe_count : 16;
+ // Used by iree_wait_any and iree_wait_set_erase to optimize the
+ // wait-wake-erase pattern by avoiding the need to scan the internal storage
+ // list to erase a handle.
+ uint32_t index : 16;
+ // (3 bytes total available)
+ uint8_t storage[3];
+ } set_internal;
+ iree_wait_primitive_value_t value;
+} iree_wait_handle_t;
+
+// Initializes a wait handle with the given primitive type and value.
+// Wait handles do not retain the provided primitives and they must be kept
+// valid (allocated and open) for the duration any wait handle references them.
+iree_status_t iree_wait_handle_wrap_primitive(
+ iree_wait_primitive_type_t primitive_type,
+ iree_wait_primitive_value_t primitive_value,
+ iree_wait_handle_t* out_handle);
+
+// Deinitializes a wait handle.
+// Note that wait handles do not retain the underlying wait primitive and
+// deinitializing a handle will not close the resource.
+void iree_wait_handle_deinitialize(iree_wait_handle_t* handle);
+
+//===----------------------------------------------------------------------===//
+// iree_wait_set_t
+//===----------------------------------------------------------------------===//
+
+// A platform-specific cache of wait handles that can be multi-waited.
+// By caching callers don't need to build the list each wait and implementations
+// can store acceleration information or kernel API data structures and either
+// optimize or make compliant sets such as by deduplicating or sorting by
+// primitive type to perform a multi-api muli-wait.
//
-// External synchronization primitives can be wrapped in WaitHandles to enable
-// other libraries or languages to be waited on alongside WaitHandles created
-// by the IREE primitives like Fence. See the notes on WaitHandleType for a list
-// of handle types that are supported.
+// Certain handle types may also gain benefits: when syncfile is used we can use
+// sync_merge to coalesce wait handles when performing a wait-all on multiple
+// handles.
//
-// Wait handles are thread-safe in that multiple threads may be waiting on them
-// concurrently.
-class WaitHandle {
- public:
- // Returns a WaitHandle that when waited on will never block.
- static WaitHandle AlwaysSignaling();
+// This cache shines when handles are persistent (such as sockets/eventfds/etc)
+// and the set will rarely be changing relative to how many times it will be
+// waited on. It's not as optimal in the cases of one-shot waits on small
+// numbers of handles but those are also the cases where the set overhead is
+// small (2 set insertions all touching hot cache lines is fine) and we gain
+// the benefits of a unified code path and nice error handling/validation.
+//
+// Thread-compatible; only one thread may be manipulating or waiting on a
+// particular set at any time.
+typedef struct iree_wait_set_s iree_wait_set_t;
- // Returns a WaitHandle that when waited on will always fail.
- static WaitHandle AlwaysFailing();
+// Allocates a wait set with the maximum |capacity| of unique handles.
+iree_status_t iree_wait_set_allocate(iree_host_size_t capacity,
+ iree_allocator_t allocator,
+ iree_wait_set_t** out_set);
- using WaitHandleSpan = absl::Span<WaitHandle* const>;
+// Frees a wait set. The wait set must not be being waited on.
+void iree_wait_set_free(iree_wait_set_t* set);
- // Blocks the caller until all passed |wait_handles| are signaled or the
- // |deadline| elapses.
- //
- // Returns success if the wait is successful and all events have been
- // signaled.
- //
- // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all handles
- // having been signaled. Note that a subset of the |wait_handles| may have
- // been signaled and each can be queried to see which one.
- static Status WaitAll(WaitHandleSpan wait_handles, Time deadline_ns);
- static Status WaitAll(WaitHandleSpan wait_handles, Duration timeout_ns) {
- return WaitAll(wait_handles, RelativeTimeoutToDeadlineNanos(timeout_ns));
- }
- static Status WaitAll(WaitHandleSpan wait_handles) {
- return WaitAll(wait_handles, InfiniteFuture());
- }
+// Inserts a wait handle into the set.
+// If the handle is already in the set it will be reference counted such that a
+// matching number of iree_wait_set_erase calls are required.
+iree_status_t iree_wait_set_insert(iree_wait_set_t* set,
+ iree_wait_handle_t handle);
- // Tries waiting on the handles and returns immediately if it would have
- // blocked. The caller will not be blocked even if a handle has not yet been
- // signaled.
- //
- // Returns true if all handles have been signaled.
- static StatusOr<bool> TryWaitAll(WaitHandleSpan wait_handles);
+// Erases a wait handle from the set.
+// Decrements the reference count; if the same handle was inserted multiple
+// times then the it may still remain in the set after an erase!
+void iree_wait_set_erase(iree_wait_set_t* set, iree_wait_handle_t handle);
- // Blocks the caller until at least one of the |wait_handles| is signaled or
- // the |deadline| elapses.
- //
- // Returns the index into |wait_handles| of a handle that was signaled. Note
- // that more than one handle may have been signaled and all of the other
- // |wait_handles| should be queried or waited on again until waits for them
- // succeed.
- //
- // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any handles
- // having been signaled.
- static StatusOr<int> WaitAny(WaitHandleSpan wait_handles, Time deadline_ns);
- static StatusOr<int> WaitAny(WaitHandleSpan wait_handles,
- Duration timeout_ns) {
- return WaitAny(wait_handles, RelativeTimeoutToDeadlineNanos(timeout_ns));
- }
- static StatusOr<int> WaitAny(WaitHandleSpan wait_handles) {
- return WaitAny(wait_handles, InfiniteFuture());
- }
+// Clears all handles from the wait set.
+void iree_wait_set_clear(iree_wait_set_t* set);
- // Tries waiting for at least one handle to complete and returns immediately
- // if none have been. The caller will not be blocked even if a handle has not
- // yet been signaled.
- //
- // Returns the index into |wait_handles| of a handle that was signaled. Note
- // that more than one handle may have been signaled and all of the other
- // |wait_handles| should be queried or waited on again until waits for them
- // succeed.
- //
- // Returns -1 if no handles were signaled.
- static StatusOr<int> TryWaitAny(WaitHandleSpan wait_handles);
+// TODO(benvanik): signal/interrupt API to make a wait set wake up.
+// Can be implemented with signals/QueueUserAPC/etc. The workaround is that the
+// caller will need to create their own events to add to the set where for
+// transient wakes we could avoid that extra overhead.
- // Default constructor creates a permanently signaled handle.
- // Waiting on this handle will never block.
- WaitHandle() = default;
+// Blocks the caller until all of the passed wait handles are signaled or the
+// |deadline_ns| elapses.
+//
+// A deadline of IREE_DURATION_ZERO will act as a poll and not block the caller.
+// IREE_DURATION_INFINITE can be used to block until signaled.
+//
+// Returns success if all handles were signaled either prior to the call or
+// during the wait.
+//
+// Returns IREE_STATUS_DEADLINE_EXCEEDED if the deadline elapses without all
+// handles having been signaled. Note that zero or more handles may have
+// actually signaled even if the deadline is exceeded (such as if they signal
+// while the waiting thread is resuming from the failed wait).
+//
+// iree_wait_set_t is thread-compatible; only one thread may be manipulating or
+// waiting on a set at any time.
+iree_status_t iree_wait_all(iree_wait_set_t* set, iree_time_t deadline_ns);
- // Wraps an existing sync file descriptor.
- // Ownership of the file descriptor is transferred to the WaitHandle and must
- // be duplicated by the caller if they want to continue using it.
- explicit WaitHandle(ref_ptr<WaitableObject> object);
+// Blocks the caller until at least one of the handles is signaled or the
+// |deadline_ns| elapses.
+//
+// A deadline of IREE_TIME_INFINITE_PAST will act as a poll and not block the
+// caller. IREE_TIME_INFINITE_FUTURE can be used to block until signaled.
+//
+// Returns success if all handles were signaled either prior to the call or
+// during the wait. A handle of one of the signaled handles will be returned in
+// the optional |out_wake_handle| argument; note however that one or more
+// handles may have signaled and which handle is returned is unspecified.
+// Callers are expected to use the handle to short-circuit scanning the handles
+// list but if a full scan is going to happen regardless it can be ignored.
+//
+// |out_wake_handle| contains an optimization for wait-wake-erase set
+// operations; it is cheap to pass the woken handle to iree_wait_set_erase if
+// there are no interleaving operations that change the set layout.
+//
+// Returns IREE_STATUS_DEADLINE_EXCEEDED if the deadline elapses without any
+// handle having been signaled.
+//
+// iree_wait_set_t is thread-compatible; only one thread may be manipulating or
+// waiting on a set at any time.
+iree_status_t iree_wait_any(iree_wait_set_t* set, iree_time_t deadline_ns,
+ iree_wait_handle_t* out_wake_handle);
- ~WaitHandle();
+// Blocks the caller until the given wait handle is signaled or |deadline_ns|
+// elapses. This is functionally equivalent to iree_wait_any/iree_wait_all used
+// on a set with a single handle in it but depending on the implementation may
+// not require additional allocations/state tracking.
+//
+// A deadline of IREE_TIME_INFINITE_PAST will act as a poll and not block the
+// caller. IREE_TIME_INFINITE_FUTURE can be used to block until signaled.
+//
+// Returns success if the handle was signaled either prior to the call or
+// during the wait.
+//
+// Returns IREE_STATUS_DEADLINE_EXCEEDED if the deadline elapses without the
+// handle having been signaled.
+iree_status_t iree_wait_one(iree_wait_handle_t* handle,
+ iree_time_t deadline_ns);
- // Copying not supported. Create a new WaitHandle from the source.
- WaitHandle(const WaitHandle&) = delete;
- WaitHandle& operator=(const WaitHandle&) = delete;
+//===----------------------------------------------------------------------===//
+// iree_event_t
+//===----------------------------------------------------------------------===//
- // Moving supported; sync primitive ownership is transferred.
- WaitHandle(WaitHandle&& other);
- WaitHandle& operator=(WaitHandle&& other);
+// A manual reset event (aka binary semaphore).
+// https://docs.microsoft.com/en-us/windows/win32/sync/event-objects
+//
+// Events are much heavier than iree_notification_t but are waitable objects
+// that can be passed to iree_wait_all/iree_wait_any. Prefer iree_notification_t
+// when multiwaiting is not required.
+//
+// Which primitive is used will depend on the current platform.
+typedef iree_wait_handle_t iree_event_t;
- // Unique ID for the WaitHandle instance.
- // Two wait handles, even if waiting on the same underlying primitive, will
- // have differing unique_ids. This can be used for deduping the handles or
- // storing handles in a map.
- uint64_t unique_id() const { return unique_id_; }
+// Initializes an event in either the signaled or unsignaled state.
+// The event must be closed with iree_event_deinitialize.
+iree_status_t iree_event_initialize(bool initial_state,
+ iree_event_t* out_event);
- // Returns a unique string representing the handle.
- std::string DebugString() const;
+// Deinitializes an event.
+void iree_event_deinitialize(iree_event_t* event);
- // Blocks the caller until the handle is signaled or the |deadline| elapses.
- //
- // If waiting on multiple wait handles use WaitAll or WaitAny instead of
- // multiple calls to Wait as they can significantly reduce overhead.
- //
- // Returns success if the wait is successful and the |wait_handle| was
- // signaled. Returns DEADLINE_EXCEEDED if the timeout elapses without the
- // handle having been signaled.
- Status Wait(Time deadline_ns) { return WaitAll({this}, deadline); }
- Status Wait(Duration timeout_ns) {
- return WaitAll({this}, RelativeTimeoutToDeadlineNanos(timeout_ns));
- }
- Status Wait() { return WaitAll({this}, InfiniteFuture()); }
+// Sets the event object to the signaled state.
+// The event stays signaled until iree_event_reset is called. Multiple waiters
+// will be woken and attempted waits while the event is set will succeed
+// immediately.
+void iree_event_set(iree_event_t* event);
- // Tries waiting on the handle and returns immediately if it would have
- // waited. The caller will not be blocked even if the handle has not yet been
- // signaled.
- //
- // Returns true if the handle has been signaled.
- StatusOr<bool> TryWait();
+// Resets the event object to the unsignaled state.
+// Resetting an event that is already reset has no effect.
+void iree_event_reset(iree_event_t* event);
- // These accessors should generally be considered opaque but may be useful to
- // code trying to interop with other runtimes.
- const ref_ptr<WaitableObject>& object() const { return object_; }
-
- private:
- // Disposes the handle by closing the fd and issuing callbacks.
- void Dispose();
-
- static std::atomic<uint64_t> next_unique_id_;
-
- uint64_t unique_id_ = 0;
- ref_ptr<WaitableObject> object_;
-};
-
-// A manually-resettable event primitive.
-// Effectively a binary semaphore with a maximum_count of 1 when running in
-// auto-reset mode but also provides a sticky manual reset mode.
-class ManualResetEvent : public WaitableObject {
- public:
- explicit ManualResetEvent(const char* debug_name = nullptr);
-
- ~ManualResetEvent() override;
-
- // Copying not supported.
- ManualResetEvent(const ManualResetEvent&) = delete;
- ManualResetEvent& operator=(const ManualResetEvent&) = delete;
-
- // Moving supported; sync primitive ownership is transferred.
- ManualResetEvent(ManualResetEvent&& other);
- ManualResetEvent& operator=(ManualResetEvent&& other);
-
- std::string DebugString() const override;
-
- // Sets the specified event object to the signaled state.
- // The event stays signaled until Reset is called. Multiple waiters will be
- // woken.
- Status Set();
-
- // Resets the specified event object to the nonsignaled state.
- // Resetting an event that is already reset has no effect.
- Status Reset();
-
- // Returns a WaitHandle that will be signaled when the event is set.
- WaitHandle OnSet();
-
- protected:
- void Initialize();
- void Dispose();
-
- StatusOr<std::pair<FdType, int>> AcquireFdForWait(Time deadline_ns) override {
- return std::make_pair(fd_type_, fd_);
- }
- StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; }
-
- FdType fd_type_ = FdType::kPermanent;
- int fd_ = kInvalidFd;
- int write_fd_ = kInvalidFd; // Used only for fd_type_ == kPipe.
- const char* debug_name_ = nullptr;
-};
-
-} // namespace iree
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
#endif // IREE_BASE_WAIT_HANDLE_H_
diff --git a/iree/base/wait_handle_epoll.c b/iree/base/wait_handle_epoll.c
new file mode 100644
index 0000000..1249d99
--- /dev/null
+++ b/iree/base/wait_handle_epoll.c
@@ -0,0 +1,74 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
+
+#if IREE_WAIT_API == IREE_WAIT_API_EPOLL
+
+#include "iree/base/tracing.h"
+#include "iree/base/wait_handle_posix.h"
+
+//===----------------------------------------------------------------------===//
+// iree_wait_set_t
+//===----------------------------------------------------------------------===//
+
+// TODO(benvanik): iree_wait_set_s using an epoll fd.
+// epoll lets us route the wait set operations right to kernel and not need our
+// own duplicate data structure. epoll is great, just not available on mac/ios
+// so we still need poll for that. linux/android/bsd all have epoll, though.
+struct iree_wait_set_s {
+ // NOTE: we could in theory use the epoll handle directly (iree_wait_set_s
+ // then is just a pointer). Then allocate/free just go straight to the system.
+ int reserved;
+};
+
+iree_status_t iree_wait_set_allocate(iree_host_size_t capacity,
+ iree_allocator_t allocator,
+ iree_wait_set_t** out_set) {
+ // TODO(benvanik): epoll_create()
+}
+
+void iree_wait_set_free(iree_wait_set_t* set) {
+ // TODO(benvanik): close()
+}
+
+iree_status_t iree_wait_set_insert(iree_wait_set_t* set,
+ iree_wait_handle_t handle) {
+ // TODO(benvanik): epoll_ctl(EPOLL_CTL_ADD)
+}
+
+void iree_wait_set_erase(iree_wait_set_t* set, iree_wait_handle_t handle) {
+ // TODO(benvanik): epoll_ctl(EPOLL_CTL_DEL)
+}
+
+void iree_wait_set_clear(iree_wait_set_t* set) {
+ // TODO(benvanik): close and reopen?
+}
+
+iree_status_t iree_wait_all(iree_wait_set_t* set, iree_time_t deadline_ns) {
+ // TODO(benvanik): epoll_wait
+}
+
+iree_status_t iree_wait_any(iree_wait_set_t* set, iree_time_t deadline_ns,
+ iree_wait_handle_t* out_wake_handle) {
+ // TODO(benvanik): epoll_wait
+}
+
+iree_status_t iree_wait_one(iree_wait_handle_t* handle,
+ iree_time_t deadline_ns) {
+ // TODO(benvanik): just use poll?
+}
+
+#endif // IREE_WAIT_API == IREE_WAIT_API_EPOLL
diff --git a/iree/base/wait_handle_impl.h b/iree/base/wait_handle_impl.h
new file mode 100644
index 0000000..38f82e5
--- /dev/null
+++ b/iree/base/wait_handle_impl.h
@@ -0,0 +1,69 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_WAIT_HANDLE_IMPL_H_
+#define IREE_BASE_WAIT_HANDLE_IMPL_H_
+
+//===----------------------------------------------------------------------===//
+// Platform overrides
+//===----------------------------------------------------------------------===//
+// NOTE: this must come first prior to any local/system includes!
+
+// Ensure that any posix header we include exposes GNU stuff. Ignored on
+// platforms where we either don't have the GNU stuff or don't have posix
+// headers at all.
+//
+// Note that this does not need to be the same for all compilation units, only
+// those we want to access the non-portable features in. It *must* be defined
+// prior to including any of the files, though, as otherwise header-guards will
+// cause the setting at the time of first inclusion to win.
+//
+// https://stackoverflow.com/a/5583764
+#define _GNU_SOURCE 1
+
+//===----------------------------------------------------------------------===//
+// Active wait API implementation selection (wait_handle_*.c)
+//===----------------------------------------------------------------------===//
+
+#include "iree/base/target_platform.h"
+
+// Priorities are (kqueue|epoll) > ppoll > poll
+#define IREE_WAIT_API_POLL 1
+#define IREE_WAIT_API_PPOLL 2
+#define IREE_WAIT_API_EPOLL 3
+#define IREE_WAIT_API_KQUEUE 4
+
+// NOTE: we could be tighter here, but we today only have win32 or not-win32.
+#if defined(IREE_PLATFORM_WINDOWS)
+#define IREE_WAIT_API 0 // WFMO used in wait_handle_win32.c
+#else
+
+// TODO(benvanik): EPOLL on android/linux/bsd/etc.
+// TODO(benvanik): KQUEUE on mac/ios.
+// KQUEUE is not implemented yet. Use POLL for mac/ios
+#if !defined(IREE_PLATFORM_APPLE) && !defined(__EMSCRIPTEN__)
+#define IREE_WAIT_API IREE_WAIT_API_PPOLL
+#else
+#define IREE_WAIT_API IREE_WAIT_API_POLL
+#endif // insanity
+
+#endif // IREE_PLATFORM_WINDOWS
+
+//===----------------------------------------------------------------------===//
+// Wait handle included with options set
+//===----------------------------------------------------------------------===//
+
+#include "iree/base/wait_handle.h"
+
+#endif // IREE_BASE_WAIT_HANDLE_IMPL_H_
diff --git a/iree/base/wait_handle_kqueue.c b/iree/base/wait_handle_kqueue.c
new file mode 100644
index 0000000..5e03f2c
--- /dev/null
+++ b/iree/base/wait_handle_kqueue.c
@@ -0,0 +1,71 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
+
+#if IREE_WAIT_API == IREE_WAIT_API_KQUEUE
+
+#include "iree/base/tracing.h"
+#include "iree/base/wait_handle_posix.h"
+
+//===----------------------------------------------------------------------===//
+// iree_wait_set_t
+//===----------------------------------------------------------------------===//
+
+// TODO(benvanik): iree_wait_set_s using a kqueue.
+// Could just cast the kqueue() fd to iree_wait_set_s* to avoid allocs.
+// https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man2/kqueue.2.html
+struct iree_wait_set_s {
+ int reserved;
+};
+
+iree_status_t iree_wait_set_allocate(iree_host_size_t capacity,
+ iree_allocator_t allocator,
+ iree_wait_set_t** out_set) {
+ // TODO(benvanik): kqueue support
+}
+
+void iree_wait_set_free(iree_wait_set_t* set) {
+ // TODO(benvanik): close()
+}
+
+iree_status_t iree_wait_set_insert(iree_wait_set_t* set,
+ iree_wait_handle_t handle) {
+ // TODO(benvanik): kqueue support
+}
+
+void iree_wait_set_erase(iree_wait_set_t* set, iree_wait_handle_t handle) {
+ // TODO(benvanik): kqueue support
+}
+
+void iree_wait_set_clear(iree_wait_set_t* set) {
+ // TODO(benvanik): kqueue support
+}
+
+iree_status_t iree_wait_all(iree_wait_set_t* set, iree_time_t deadline_ns) {
+ // TODO(benvanik): kqueue support
+}
+
+iree_status_t iree_wait_any(iree_wait_set_t* set, iree_time_t deadline_ns,
+ iree_wait_handle_t* out_wake_handle) {
+ // TODO(benvanik): kqueue support
+}
+
+iree_status_t iree_wait_one(iree_wait_handle_t* handle,
+ iree_time_t deadline_ns) {
+ // TODO(benvanik): kqueue support
+}
+
+#endif // IREE_WAIT_API == IREE_WAIT_API_KQUEUE
diff --git a/iree/base/wait_handle_poll.c b/iree/base/wait_handle_poll.c
new file mode 100644
index 0000000..2ba2eb1
--- /dev/null
+++ b/iree/base/wait_handle_poll.c
@@ -0,0 +1,406 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
+
+#if IREE_WAIT_API == IREE_WAIT_API_POLL || IREE_WAIT_API == IREE_WAIT_API_PPOLL
+
+#include <errno.h>
+#include <poll.h>
+#include <time.h>
+
+#include "iree/base/tracing.h"
+#include "iree/base/wait_handle_posix.h"
+
+//===----------------------------------------------------------------------===//
+// Platform utilities
+//===----------------------------------------------------------------------===//
+
+// ppoll is preferred as it has a much better timing mechanism; poll can have a
+// large slop on the deadline as not only is it at ms timeout granularity but
+// in general tends to round more.
+//
+// poll/ppoll may spuriously wake with an EINTR. We don't do anything with that
+// opportunity (no fancy signal stuff), but we do need to retry the poll and
+// ensure that we do so with an updated timeout based on the deadline.
+//
+// Documentation: https://linux.die.net/man/2/poll
+
+#if IREE_WAIT_API == IREE_WAIT_API_POLL
+static iree_status_t iree_syscall_poll(struct pollfd* fds, nfds_t nfds,
+ iree_time_t deadline_ns,
+ int* out_signaled_count) {
+ *out_signaled_count = 0;
+ int rv = -1;
+ do {
+ iree_duration_t timeout_ns =
+ iree_absolute_deadline_to_timeout_ns(deadline_ns);
+ int timeout_ms = timeout_ns != IREE_TIME_INFINITE_FUTURE
+ ? (int)(timeout_ns / 1000000ull)
+ : (int)timeout_ns;
+ rv = poll(fds, nfds, timeout_ms);
+ } while (rv < 0 && errno == EINTR);
+ if (rv > 0) {
+ // One or more events set.
+ *out_signaled_count = rv;
+ return iree_ok_status();
+ } else if (rv == 0) {
+ // Timeout; no events set.
+ return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ } else if (IREE_UNLIKELY(rv < 0)) {
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "poll failure %d", errno);
+ }
+}
+#elif IREE_WAIT_API == IREE_WAIT_API_PPOLL
+static iree_status_t iree_syscall_poll(struct pollfd* fds, nfds_t nfds,
+ iree_time_t deadline_ns,
+ int* out_signaled_count) {
+ *out_signaled_count = 0;
+ int rv = -1;
+ do {
+ // Convert the deadline into a tmo_p struct for ppoll that controls whether
+ // the call is blocking or non-blocking. Note that we must do this every
+ // iteration of the loop as a previous ppoll may have taken some of the
+ // time.
+ //
+ // See the ppoll docs for more information as to what the expected value is:
+ // http://man7.org/linux/man-pages/man2/poll.2.html
+ struct timespec timeout_ts;
+ struct timespec* tmo_p = &timeout_ts;
+ if (deadline_ns == IREE_TIME_INFINITE_PAST) {
+ // Block never.
+ memset(&timeout_ts, 0, sizeof(timeout_ts));
+ } else if (deadline_ns == IREE_TIME_INFINITE_FUTURE) {
+ // Block forever (NULL timeout to ppoll).
+ tmo_p = NULL;
+ } else {
+ // Wait only for as much time as we have before the deadline is exceeded.
+ iree_duration_t timeout_ns = deadline_ns - iree_time_now();
+ if (timeout_ns < 0) {
+ // We've reached the deadline; we'll still perform the poll though as
+ // the caller is likely expecting that behavior (intentional context
+ // switch/thread yield/etc).
+ memset(&timeout_ts, 0, sizeof(timeout_ts));
+ } else {
+ timeout_ts.tv_sec = (time_t)(timeout_ns / 1000000000ull);
+ timeout_ts.tv_nsec = (long)(timeout_ns % 1000000000ull);
+ }
+ }
+ rv = ppoll(fds, nfds, tmo_p, NULL);
+ } while (rv < 0 && errno == EINTR);
+ if (rv > 0) {
+ // One or more events set.
+ *out_signaled_count = rv;
+ return iree_ok_status();
+ } else if (rv == 0) {
+ // Timeout; no events set.
+ return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ } else {
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "ppoll failure %d", errno);
+ }
+}
+#else
+#error "unsupported IREE_WAIT_API value"
+#endif // IREE_WAIT_API
+
+//===----------------------------------------------------------------------===//
+// iree_wait_set_t
+//===----------------------------------------------------------------------===//
+
+struct iree_wait_set_s {
+ iree_allocator_t allocator;
+
+ // Total capacity of each handle list.
+ iree_host_size_t handle_capacity;
+
+ // Total number of valid user_handles/poll_fds.
+ iree_host_size_t handle_count;
+
+ // User-provided handles.
+ // We only really need to track these so that we can preserve the handle
+ // types; we could either just do that (a few bytes) or keep them here as-is
+ // where they are a bit easier to debug.
+ iree_wait_handle_t* user_handles;
+
+ // Native list of fds+req we can pass to poll/ppoll/etc and that will receive
+ // the output information like which events were triggered during the wait.
+ //
+ // pollfd::events is specified when the fds are added to the set and then each
+ // wait pollfd::revents is modified during the poll syscall.
+ struct pollfd* poll_fds;
+};
+
+iree_status_t iree_wait_set_allocate(iree_host_size_t capacity,
+ iree_allocator_t allocator,
+ iree_wait_set_t** out_set) {
+ // Be reasonable; 64K objects is too high (even if poll supports it, which is
+ // hard to tell if it does).
+ if (capacity >= UINT16_MAX) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "wait set capacity of %zu is unreasonably large",
+ capacity);
+ }
+
+ iree_host_size_t user_handle_list_size =
+ capacity * sizeof(iree_wait_handle_t);
+ iree_host_size_t poll_fd_list_size = capacity * sizeof(struct pollfd);
+ iree_host_size_t total_size =
+ sizeof(iree_wait_set_t) + user_handle_list_size + poll_fd_list_size;
+
+ iree_wait_set_t* set = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(allocator, total_size, (void**)&set));
+ set->allocator = allocator;
+ set->handle_capacity = capacity;
+ iree_wait_set_clear(set);
+
+ set->user_handles =
+ (iree_wait_handle_t*)((uint8_t*)set + sizeof(iree_wait_set_t));
+ set->poll_fds =
+ (struct pollfd*)((uint8_t*)set->user_handles + user_handle_list_size);
+
+ *out_set = set;
+ return iree_ok_status();
+}
+
+void iree_wait_set_free(iree_wait_set_t* set) {
+ iree_allocator_free(set->allocator, set);
+}
+
+iree_status_t iree_wait_set_insert(iree_wait_set_t* set,
+ iree_wait_handle_t handle) {
+ if (set->handle_count + 1 > set->handle_capacity) {
+ return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+ "wait set capacity reached");
+ }
+
+ iree_host_size_t index = set->handle_count++;
+
+ iree_wait_handle_t* user_handle = &set->user_handles[index];
+ IREE_IGNORE_ERROR(
+ iree_wait_handle_wrap_primitive(handle.type, handle.value, user_handle));
+
+ // NOTE: poll will ignore any negative fds.
+ struct pollfd* poll_fd = &set->poll_fds[index];
+ poll_fd->fd = iree_wait_primitive_get_read_fd(&handle);
+ poll_fd->events = POLLIN | POLLPRI; // implicit POLLERR | POLLHUP | POLLNVAL
+ poll_fd->revents = 0;
+
+ return iree_ok_status();
+}
+
+void iree_wait_set_erase(iree_wait_set_t* set, iree_wait_handle_t handle) {
+ // Find the user handle in the set. This either requires a linear scan to
+ // find the matching user handle or - if valid - we can use the native index
+ // set after an iree_wait_any wake to do a quick lookup.
+ iree_host_size_t index = handle.set_internal.index;
+ if (IREE_UNLIKELY(index >= set->handle_count) ||
+ IREE_UNLIKELY(!iree_wait_primitive_compare_identical(
+ &set->user_handles[index], &handle))) {
+ // Fallback to a linear scan of (hopefully) a small list.
+ for (iree_host_size_t i = 0; i < set->handle_count; ++i) {
+ if (iree_wait_primitive_compare_identical(&set->user_handles[i],
+ &handle)) {
+ index = i;
+ break;
+ }
+ }
+ }
+
+ // Remove from both handle lists.
+ // Since we make no guarantees about the order of the lists we can just swap
+ // with the last value.
+ int tail_index = (int)set->handle_count - 1;
+ if (tail_index > index) {
+ memcpy(&set->poll_fds[index], &set->poll_fds[tail_index],
+ sizeof(*set->poll_fds));
+ memcpy(&set->user_handles[index], &set->user_handles[tail_index],
+ sizeof(*set->user_handles));
+ }
+ --set->handle_count;
+}
+
+void iree_wait_set_clear(iree_wait_set_t* set) { set->handle_count = 0; }
+
+// Maps a poll revent bitfield result to a status (on failure) and an indicator
+// of whether the event was signaled.
+static iree_status_t iree_wait_set_resolve_poll_events(short revents,
+ bool* out_signaled) {
+ if (revents & POLLERR) {
+ return iree_make_status(IREE_STATUS_INTERNAL, "POLLERR on fd");
+ } else if (revents & POLLHUP) {
+ return iree_make_status(IREE_STATUS_CANCELLED, "POLLHUP on fd");
+ } else if (revents & POLLNVAL) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "POLLNVAL on fd");
+ }
+ *out_signaled = (revents & POLLIN) != 0;
+ return iree_ok_status();
+}
+
+iree_status_t iree_wait_all(iree_wait_set_t* set, iree_time_t deadline_ns) {
+ // Make the syscall only when we have at least one valid fd.
+ // Don't use this as a sleep.
+ if (set->handle_count <= 0) {
+ return iree_ok_status();
+ }
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // TODO(benvanik): see if we can use tracy's mutex tracking to make waits
+ // nicer (at least showing signal->wait relations).
+
+ // Certain poll implementations have a nasty behavior where they allow
+ // negative fds to ignore entries... except for at [0]. To avoid any
+ // additional tracking here we manage a local pollfd list that we keep offset
+ // to the first non-negative fd.
+ //
+ // Gotcha is buried in here (and various spooky bug reports on the web):
+ // https://manpages.debian.org/buster/manpages-dev/poll.2.en.html
+ // This provides an easy way of ignoring a file descriptor for a single
+ // poll() call: simply negate the fd field. Note, however, that this
+ // technique can't be used to ignore file descriptor 0.
+ //
+ // Thanks guys 🙄
+ struct pollfd* poll_fd_base = set->poll_fds;
+ nfds_t poll_fd_count = set->handle_count;
+
+ // Wait-all requires that we repeatedly poll until all handles have been
+ // signaled. To reduce overhead (and not miss events) we mark any handle we
+ // have successfully polled as invalid (fd<0) so that the kernel ignores it.
+ // Only when all handles are invalid does it mean that we've actually waited
+ // for all of them.
+ iree_status_t status = iree_ok_status();
+ int unsignaled_count = poll_fd_count;
+ do {
+ // Eat any negative handles at the start to avoid the mentioned fd[0] bug.
+ while (poll_fd_base[0].fd < 0) {
+ ++poll_fd_base;
+ --poll_fd_count;
+ }
+
+ int signaled_count = 0;
+ status = iree_syscall_poll(poll_fd_base, poll_fd_count, deadline_ns,
+ &signaled_count);
+ if (!iree_status_is_ok(status)) {
+ // Failed during the poll itself. Ensure that we fall-through and refresh
+ // the poll_fds handle list.
+ break;
+ }
+ unsignaled_count -= signaled_count;
+
+ // Neuter any that have successfully resolved.
+ for (nfds_t i = 0; i < poll_fd_count; ++i) {
+ if (poll_fd_base[i].fd < 0) continue;
+ bool signaled = false;
+ status =
+ iree_wait_set_resolve_poll_events(poll_fd_base[i].revents, &signaled);
+ if (!iree_status_is_ok(status)) {
+ // One (or more) fds had an issue. Ensure that we fall-through and
+ // refresh the poll_fds handle list.
+ break;
+ }
+ if (signaled) {
+ // Negate fd so that we ignore it in the next poll.
+ poll_fd_base[i].fd = -poll_fd_base[i].fd;
+ }
+ }
+ } while (unsignaled_count > 0);
+
+ // Since we destroyed the list of handles during the operation we need to
+ // refresh them with their fds so that the next wait can happen. This is the
+ // kind of thing kqueue/epoll solves (mutable in-place updates on polls) and
+ // an unfortunate reality of using an ancient API. Thankfully most waits are
+ // wait-any so a little loop isn't the worst thing in the wait-all case.
+ for (nfds_t i = 0; i < set->handle_count; ++i) {
+ set->poll_fds[i].fd = -set->poll_fds[i].fd;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+iree_status_t iree_wait_any(iree_wait_set_t* set, iree_time_t deadline_ns,
+ iree_wait_handle_t* out_wake_handle) {
+ // Make the syscall only when we have at least one valid fd.
+ // Don't use this as a sleep.
+ if (set->handle_count <= 0) {
+ memset(out_wake_handle, 0, sizeof(*out_wake_handle));
+ return iree_ok_status();
+ }
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // TODO(benvanik): see if we can use tracy's mutex tracking to make waits
+ // nicer (at least showing signal->wait relations).
+
+ // Wait-any lets us just poll all the handles we have without needing to worry
+ // about whether all of them were signaled.
+ int signaled_count = 0;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_syscall_poll(set->poll_fds, set->handle_count, deadline_ns,
+ &signaled_count));
+
+ // Find at least one signaled handle.
+ memset(out_wake_handle, 0, sizeof(*out_wake_handle));
+ if (signaled_count > 0) {
+ for (iree_host_size_t i = 0; i < set->handle_count; ++i) {
+ bool signaled = false;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_wait_set_resolve_poll_events(set->poll_fds[i].revents,
+ &signaled));
+ if (signaled) {
+ memcpy(out_wake_handle, &set->user_handles[i],
+ sizeof(*out_wake_handle));
+ out_wake_handle->set_internal.index = i;
+ break;
+ }
+ }
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+iree_status_t iree_wait_one(iree_wait_handle_t* handle,
+ iree_time_t deadline_ns) {
+ struct pollfd poll_fds;
+ poll_fds.fd = iree_wait_primitive_get_read_fd(handle);
+ if (poll_fds.fd == -1) return false;
+ poll_fds.events = POLLIN;
+ poll_fds.revents = 0;
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // TODO(benvanik): see if we can use tracy's mutex tracking to make waits
+ // nicer (at least showing signal->wait relations).
+
+ // Just check for our single handle/event.
+ // The benefit of this is that we didn't need to heap alloc the pollfds and
+ // the cache should all stay hot. Reusing the same iree_syscall_pool as the
+ // multi-wait variants ensures consistent handling (and the same syscall
+ // showing in strace/tracy/etc).
+ int signaled_count = 0;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_syscall_poll(&poll_fds, 1, deadline_ns, &signaled_count));
+
+ IREE_TRACE_ZONE_END(z0);
+ return signaled_count ? iree_ok_status()
+ : iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+}
+
+#endif // IREE_WAIT_API == IREE_WAIT_API_POLL ||
+ // IREE_WAIT_API == IREE_WAIT_API_PPOLL
diff --git a/iree/base/wait_handle_posix.c b/iree/base/wait_handle_posix.c
new file mode 100644
index 0000000..088f6cc
--- /dev/null
+++ b/iree/base/wait_handle_posix.c
@@ -0,0 +1,289 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/wait_handle_posix.h"
+
+#include "iree/base/tracing.h"
+
+// NOTE: we could be tighter here, but we today only have win32 or not-win32.
+#if !defined(IREE_PLATFORM_WINDOWS)
+
+#include <errno.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#if defined(IREE_HAVE_WAIT_TYPE_EVENTFD)
+#include <sys/eventfd.h>
+#endif // IREE_HAVE_WAIT_TYPE_EVENTFD
+#if defined(IREE_HAVE_WAIT_TYPE_SYNC_FILE)
+#include <android/sync.h>
+#endif // IREE_HAVE_WAIT_TYPE_SYNC_FILE
+
+//===----------------------------------------------------------------------===//
+// iree_wait_primitive_* raw calls
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_HAVE_WAIT_TYPE_EVENTFD)
+static iree_status_t iree_wait_primitive_create_eventfd(
+ bool initial_state, iree_wait_handle_t* out_handle) {
+ memset(out_handle, 0, sizeof(*out_handle));
+ out_handle->type = IREE_WAIT_PRIMITIVE_TYPE_EVENT_FD;
+
+ // https://man7.org/linux/man-pages/man2/eventfd.2.html
+ out_handle->value.event.fd =
+ eventfd(initial_state ? 1 : 0, EFD_CLOEXEC | EFD_NONBLOCK);
+ if (IREE_UNLIKELY(out_handle->value.event.fd == -1)) {
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to create eventfd (%d)", errno);
+ }
+
+ return iree_ok_status();
+}
+#endif // IREE_HAVE_WAIT_TYPE_EVENTFD
+
+#if defined(IREE_HAVE_WAIT_TYPE_PIPE)
+static iree_status_t iree_wait_primitive_create_pipe(
+ bool initial_state, iree_wait_handle_t* out_handle) {
+ memset(out_handle, 0, sizeof(*out_handle));
+ out_handle->type = IREE_WAIT_PRIMITIVE_TYPE_PIPE;
+
+ // Create read (fds[0]) and write (fds[1]) handles.
+ // https://man7.org/linux/man-pages/man2/pipe.2.html
+ if (IREE_UNLIKELY(pipe(out_handle->value.pipe.fds) < 0)) {
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to create pipe (%d)", errno);
+ }
+
+ // Set both fds to non-blocking.
+ // NOTE: we could use pipe2 when available on linux to avoid the need for the
+ // fcntl, but BSD/darwin/etc don't have it so we'd still need a fallback. This
+ // is effectively the same as passing O_NONBLOCK to pipe2.
+ for (int i = 0; i < 2; ++i) {
+ if (IREE_UNLIKELY(
+ fcntl(out_handle->value.pipe.fds[i], F_SETFL, O_NONBLOCK) < 0)) {
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to set pipe fd %d to non-blocking (%d)",
+ i, errno);
+ }
+ }
+
+ // Initially triggered means we just write once to the pipe.
+ // This write must not fail as if the caller requested the state they would
+ // likely deadlock if the first read would block.
+ if (initial_state) {
+ iree_status_t status = iree_wait_primitive_write(out_handle);
+ if (!iree_status_is_ok(status)) {
+ iree_wait_primitive_close(out_handle);
+ return status;
+ }
+ }
+
+ return iree_ok_status();
+}
+#endif // IREE_HAVE_WAIT_TYPE_PIPE
+
+iree_status_t iree_wait_primitive_create_native(
+ bool initial_state, iree_wait_handle_t* out_handle) {
+#if defined(IREE_HAVE_WAIT_TYPE_EVENTFD)
+ // Always prefer eventfd when present; they rock.
+ return iree_wait_primitive_create_eventfd(initial_state, out_handle);
+#elif defined(IREE_HAVE_WAIT_TYPE_PIPE)
+ // Pipes are fine but much heavier than eventfds.
+ return iree_wait_primitive_create_pipe(initial_state, out_handle);
+#else
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "no native wait handle type supported");
+#endif // IREE_HAVE_WAIT_TYPE_*
+}
+
+void iree_wait_primitive_close_fd(int fd) {
+ int rv;
+ IREE_SYSCALL(rv, close(fd));
+ // NOTE: we could fail to close if the handle is invalid/already closed/etc.
+ // As Windows has undefined behavior when handles are closed while there are
+ // active waits we don't use fd closes as load-bearing operations and it's
+ // fine to ignore the error.
+}
+
+void iree_wait_primitive_close(iree_wait_handle_t* handle) {
+ switch (handle->type) {
+#if defined(IREE_HAVE_WAIT_TYPE_EVENTFD)
+ case IREE_WAIT_PRIMITIVE_TYPE_EVENT_FD: {
+ iree_wait_primitive_close_fd(handle->value.event.fd);
+ break;
+ }
+#endif // IREE_HAVE_WAIT_TYPE_EVENTFD
+#if defined(IREE_HAVE_WAIT_TYPE_SYNC_FILE)
+ case IREE_WAIT_PRIMITIVE_TYPE_SYNC_FILE:
+ iree_wait_primitive_close_fd(handle->value.sync_file.fd);
+ break;
+#endif // IREE_HAVE_WAIT_TYPE_SYNC_FILE
+#if defined(IREE_HAVE_WAIT_TYPE_PIPE)
+ case IREE_WAIT_PRIMITIVE_TYPE_PIPE: {
+ iree_wait_primitive_close_fd(handle->value.pipe.read_fd);
+ iree_wait_primitive_close_fd(handle->value.pipe.write_fd);
+ break;
+ }
+#endif // IREE_HAVE_WAIT_TYPE_PIPE
+ default:
+ break;
+ }
+ iree_wait_handle_deinitialize(handle);
+}
+
+bool iree_wait_primitive_compare_identical(const iree_wait_handle_t* lhs,
+ const iree_wait_handle_t* rhs) {
+ return lhs->type == rhs->type &&
+ memcmp(&lhs->value, &rhs->value, sizeof(lhs->value)) == 0;
+}
+
+int iree_wait_primitive_get_read_fd(const iree_wait_handle_t* handle) {
+ switch (handle->type) {
+#if defined(IREE_HAVE_WAIT_TYPE_EVENTFD)
+ case IREE_WAIT_PRIMITIVE_TYPE_EVENT_FD:
+ return handle->value.event.fd;
+#endif // IREE_HAVE_WAIT_TYPE_EVENTFD
+#if defined(IREE_HAVE_WAIT_TYPE_SYNC_FILE)
+ case IREE_WAIT_PRIMITIVE_TYPE_SYNC_FILE:
+ return handle->value.sync_file.fd;
+#endif // IREE_HAVE_WAIT_TYPE_SYNC_FILE
+#if defined(IREE_HAVE_WAIT_TYPE_PIPE)
+ case IREE_WAIT_PRIMITIVE_TYPE_PIPE:
+ return handle->value.pipe.read_fd;
+#endif // IREE_HAVE_WAIT_TYPE_PIPE
+ default:
+ return -1;
+ }
+}
+
+iree_status_t iree_wait_primitive_read(iree_wait_handle_t* handle,
+ iree_time_t deadline_ns) {
+ // Until we need it this does not support anything but polling.
+ // If we want to support auto reset events we'd want to implement blocking.
+ if (deadline_ns != IREE_TIME_INFINITE_PAST) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "reads are just polls today");
+ }
+
+ int rv = -1;
+ switch (handle->type) {
+#if defined(IREE_HAVE_WAIT_TYPE_EVENTFD)
+ case IREE_WAIT_PRIMITIVE_TYPE_EVENT_FD: {
+ eventfd_t val = 0;
+ IREE_SYSCALL(rv, eventfd_read(handle->value.event.fd, &val));
+ break;
+ }
+#endif // IREE_HAVE_WAIT_TYPE_EVENTFD
+#if defined(IREE_HAVE_WAIT_TYPE_SYNC_FILE)
+ case IREE_WAIT_PRIMITIVE_TYPE_SYNC_FILE:
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "sync files not yet implemented");
+#endif // IREE_HAVE_WAIT_TYPE_SYNC_FILE
+#if defined(IREE_HAVE_WAIT_TYPE_PIPE)
+ case IREE_WAIT_PRIMITIVE_TYPE_PIPE: {
+ char buf;
+ IREE_SYSCALL(rv, read(handle->value.pipe.read_fd, &buf, 1));
+ break;
+ }
+#endif // IREE_HAVE_WAIT_TYPE_PIPE
+ default:
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "unhandled wait type %d", (int)handle->type);
+ }
+ if (rv >= 0) {
+ // Read completed successfully.
+ return iree_ok_status();
+ } else if (errno == EWOULDBLOCK) {
+ // Would have blocked meaning that there's no data waiting.
+ // NOTE: we purposefully avoid a full status result here as this is a
+ // non-exceptional result.
+ return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ } else {
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "fd read failure %d", errno);
+ }
+}
+
+iree_status_t iree_wait_primitive_write(iree_wait_handle_t* handle) {
+ int rv = -1;
+ switch (handle->type) {
+#if defined(IREE_HAVE_WAIT_TYPE_EVENTFD)
+ case IREE_WAIT_PRIMITIVE_TYPE_EVENT_FD: {
+ IREE_SYSCALL(rv, eventfd_write(handle->value.event.fd, 1ull));
+ break;
+ }
+#endif // IREE_HAVE_WAIT_TYPE_EVENTFD
+#if defined(IREE_HAVE_WAIT_TYPE_SYNC_FILE)
+ case IREE_WAIT_PRIMITIVE_TYPE_SYNC_FILE:
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "sync files not yet implemented");
+#endif // IREE_HAVE_WAIT_TYPE_SYNC_FILE
+#if defined(IREE_HAVE_WAIT_TYPE_PIPE)
+ case IREE_WAIT_PRIMITIVE_TYPE_PIPE: {
+ char buf = '\n';
+ IREE_SYSCALL(rv, write(handle->value.pipe.write_fd, &buf, 1));
+ break;
+ }
+#endif // IREE_HAVE_WAIT_TYPE_PIPE
+ default:
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unhandled wait type");
+ }
+ if (rv >= 0) {
+ // Write completed successfully.
+ return iree_ok_status();
+ } else {
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "fd write failure %d", errno);
+ }
+}
+
+iree_status_t iree_wait_primitive_clear(iree_wait_handle_t* handle) {
+ // Read in a loop until the read would block.
+ // Depending on how the user setup the fd the act of reading may reset the
+ // entire handle (such as with the default eventfd mode) or multiple reads may
+ // be required (such as with semaphores).
+ while (true) {
+ iree_status_t status =
+ iree_wait_primitive_read(handle, IREE_TIME_INFINITE_PAST);
+ if (iree_status_is_deadline_exceeded(status)) {
+ // Would have blocked reading which means we've cleared the fd.
+ return iree_ok_status();
+ } else if (!iree_status_is_ok(status)) {
+ return status;
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// iree_event_t
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_event_initialize(bool initial_state,
+ iree_event_t* out_event) {
+ return iree_wait_primitive_create_native(initial_state, out_event);
+}
+
+void iree_event_deinitialize(iree_event_t* event) {
+ iree_wait_primitive_close(event);
+}
+
+void iree_event_set(iree_event_t* event) {
+ IREE_IGNORE_ERROR(iree_wait_primitive_write(event));
+}
+
+void iree_event_reset(iree_event_t* event) {
+ IREE_IGNORE_ERROR(iree_wait_primitive_clear(event));
+}
+
+#endif // !IREE_PLATFORM_WINDOWS
diff --git a/iree/base/wait_handle_posix.h b/iree/base/wait_handle_posix.h
new file mode 100644
index 0000000..63e1ded
--- /dev/null
+++ b/iree/base/wait_handle_posix.h
@@ -0,0 +1,86 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
+
+#ifndef IREE_BASE_WAIT_HANDLE_POSIX_H_
+#define IREE_BASE_WAIT_HANDLE_POSIX_H_
+
+// NOTE: we could be tighter here, but we today only have win32 or not-win32.
+#if !defined(IREE_PLATFORM_WINDOWS)
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Perform a syscall with a retry on EINTR (spurious wake/signal/etc).
+//
+// Usage:
+// int rv;
+// IREE_SYSCALL(rv, fcntl(...));
+// if (rv < 0) { /* failure */ }
+#define IREE_SYSCALL(result_value, expr) \
+ do { \
+ result_value = expr; \
+ } while (result_value < 0 && errno == EINTR);
+
+// NOTE: these are intended for low-level signaling and may expose various
+// platform quirks to the caller. Always prefer using a higher level type such
+// as iree_event_t when possible.
+
+// Creates a wait primitive of the type native to the current platform.
+// May fail if resources are exhausted or wait handles are not supported.
+// The handle must be closed with iree_wait_primitive_close to release its
+// resources.
+iree_status_t iree_wait_primitive_create_native(bool initial_state,
+ iree_wait_handle_t* out_handle);
+
+// Closes an existing handle from iree_wait_primitive_create_native or
+// iree_wait_primitive_clone. Must not be called while there are any waiters on
+// the handle.
+void iree_wait_primitive_close(iree_wait_handle_t* handle);
+
+// Returns true if the two handles are identical in representation.
+// Note that two unique handles may point to the same underlying primitive
+// object (such as when they have been cloned).
+bool iree_wait_primitive_compare_identical(const iree_wait_handle_t* lhs,
+ const iree_wait_handle_t* rhs);
+
+// Returns an fd that can be used to read/wait on the handle.
+// Returns -1 if the handle is invalid.
+int iree_wait_primitive_get_read_fd(const iree_wait_handle_t* handle);
+
+// Reads a nonce from the given handle and blocks the caller if none are
+// available. IREE_TIME_INFINITE_PAST can be used to poll (the call will never
+// block) and IREE_TIME_INFINITE_FUTURE can be used to block until the primitive
+// is written.
+iree_status_t iree_wait_primitive_read(iree_wait_handle_t* handle,
+ iree_time_t deadline_ns);
+
+// Writes a nonce to the given handle causing it to signal any waiters.
+// The exact value written is platform/primitive specific.
+iree_status_t iree_wait_primitive_write(iree_wait_handle_t* handle);
+
+// Clears the wait primitive by repeatedly reading values until no more remain.
+// Never blocks the caller.
+iree_status_t iree_wait_primitive_clear(iree_wait_handle_t* handle);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // !IREE_PLATFORM_WINDOWS
+
+#endif // IREE_BASE_WAIT_HANDLE_POSIX_H_
diff --git a/iree/base/wait_handle_test.cc b/iree/base/wait_handle_test.cc
index 5887ace..983bbc2 100644
--- a/iree/base/wait_handle_test.cc
+++ b/iree/base/wait_handle_test.cc
@@ -1,4 +1,4 @@
-// Copyright 2019 Google LLC
+// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -14,538 +14,803 @@
#include "iree/base/wait_handle.h"
-#include <unistd.h>
+#include <thread>
-#include <string>
-#include <thread> // NOLINT
-#include <type_traits>
-
-#include "absl/time/time.h"
-#include "iree/base/status.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
-// StatusOr<bool> will be true if the status is ok, which is bad.
-#define ASSERT_STATUSOR_TRUE(x) ASSERT_TRUE(x.value())
-#define ASSERT_STATUSOR_FALSE(x) ASSERT_FALSE(x.value())
-
namespace iree {
namespace {
-using ::testing::_;
-using ::testing::Return;
+// We don't want to wait too long in here but when we are testing that timeouts
+// work as expected we do have to sometimes wait. These are set to hopefully
+// reduce flakes and not hang a build bot forever if something is broken :)
+constexpr iree_duration_t kShortTimeoutNS = 1000000ull; // 1ms
+constexpr iree_duration_t kLongTimeoutNS = 60000000000ull; // 1min
-// Tests the AlwaysSignaling helper.
-TEST(WaitHandleTest, AlwaysSignaling) {
- IREE_ASSERT_OK(WaitHandle::AlwaysSignaling().Wait());
- EXPECT_FALSE(WaitHandle::AlwaysSignaling().DebugString().empty());
+//===----------------------------------------------------------------------===//
+// IREE_WAIT_PRIMITIVE_TYPE_EVENT_FD
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_HAVE_WAIT_TYPE_EVENTFD)
+
+// TODO(benvanik): tests wrapping external eventfds.
+
+#endif // IREE_HAVE_WAIT_TYPE_EVENTFD
+
+//===----------------------------------------------------------------------===//
+// IREE_WAIT_PRIMITIVE_TYPE_SYNC_FILE
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_HAVE_WAIT_TYPE_SYNC_FILE)
+
+// TODO(benvanik): tests wrapping external sync files.
+
+#endif // IREE_HAVE_WAIT_TYPE_SYNC_FILE
+
+//===----------------------------------------------------------------------===//
+// IREE_WAIT_PRIMITIVE_TYPE_PIPE
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_HAVE_WAIT_TYPE_PIPE)
+
+// TODO(benvanik): tests wrapping external pipes.
+
+#endif // IREE_HAVE_WAIT_TYPE_PIPE
+
+//===----------------------------------------------------------------------===//
+// IREE_WAIT_PRIMITIVE_TYPE_WIN32_HANDLE
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_HAVE_WAIT_TYPE_WIN32_HANDLE)
+
+// TODO(benvanik): tests wrapping external win32 handles.
+
+#endif // IREE_HAVE_WAIT_TYPE_WIN32_HANDLE
+
+//===----------------------------------------------------------------------===//
+// iree_event_t
+//===----------------------------------------------------------------------===//
+// NOTE: this is testing the user-visible behavior of iree_event_t and the use
+// of functions like iree_wait_one is not exhaustive as that is tested
+// elsewhere.
+
+// Tests that we don't leak.
+TEST(Event, Lifetime) {
+ iree_event_t event;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &event));
+ iree_event_deinitialize(&event);
}
-// Tests the AlwaysFailing helper.
-TEST(WaitHandleTest, AlwaysFailing) {
- ASSERT_FALSE(WaitHandle::AlwaysFailing().Wait().ok());
- EXPECT_FALSE(WaitHandle::AlwaysFailing().DebugString().empty());
+TEST(Event, WaitOneInitialFalse) {
+ iree_event_t event;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &event));
+ IREE_EXPECT_STATUS_IS(IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+ iree_event_deinitialize(&event);
}
-// Tests the basic lifecycle of a permanently signaled wait handle.
-TEST(WaitHandleTest, LifecyclePermanentSignaled) {
- // Just to be sure it's ok to safely no-op a WaitHandle value.
- WaitHandle wh_never_used;
- (void)wh_never_used;
-
- // Try waiting; should return immediately.
- WaitHandle wh0;
- IREE_ASSERT_OK(wh0.Wait());
-
- // Waits on multiple permanent handles should be ok.
- WaitHandle wh1;
- IREE_ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}));
+TEST(Event, WaitOneInitialTrue) {
+ iree_event_t event;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &event));
+ IREE_EXPECT_OK(iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+ iree_event_deinitialize(&event);
}
-// Tests moving permanent WaitHandles around.
-TEST(WaitHandleTest, MovePermanent) {
- WaitHandle wh0;
- WaitHandle wh1{std::move(wh0)};
- WaitHandle wh2 = std::move(wh1);
- wh1 = std::move(wh2);
+TEST(Event, SetWait) {
+ iree_event_t event;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &event));
+
+ // Initially unset.
+ IREE_EXPECT_STATUS_IS(IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+
+ // Set and wait.
+ iree_event_set(&event);
+ IREE_EXPECT_OK(iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+
+ // Set should be sticky until reset manually.
+ IREE_EXPECT_OK(iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+
+ // Resetting should unsignal the event.
+ iree_event_reset(&event);
+ IREE_EXPECT_STATUS_IS(IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+
+ iree_event_deinitialize(&event);
}
-// Tests moving around real handles (that may require closing).
-TEST(WaitHandleTest, MoveRealHandle) {
- ManualResetEvent fence0;
- WaitHandle wh0 = fence0.OnSet();
- WaitHandle wh1{std::move(wh0)};
- WaitHandle wh2 = std::move(wh1);
- wh1 = std::move(wh2);
+// Tests that we can use set/reset and that certain behavior (such as sets
+// without intervening resets) is allowed. Note that this does not wait and is
+// just testing the client behavior; it's possible to implement these such that
+// a set while another set is pending fails and we want to verify that here.
+TEST(Event, SetReset) {
+ iree_event_t event;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &event));
- // Now overwrite the handle value to force a close.
- ManualResetEvent fence1;
- WaitHandle wh3 = fence1.OnSet();
- wh1 = std::move(wh3);
- wh1 = WaitHandle(); // Ensure handle dies first.
+ IREE_EXPECT_STATUS_IS(IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+
+ iree_event_set(&event);
+ IREE_EXPECT_OK(iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+ iree_event_set(&event);
+ IREE_EXPECT_OK(iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+
+ iree_event_reset(&event);
+ IREE_EXPECT_STATUS_IS(IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+ iree_event_reset(&event);
+ IREE_EXPECT_STATUS_IS(IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+
+ iree_event_set(&event);
+ IREE_EXPECT_OK(iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+ iree_event_set(&event);
+ IREE_EXPECT_OK(iree_wait_one(&event, IREE_TIME_INFINITE_PAST));
+
+ iree_event_deinitialize(&event);
}
-// Tests the various forms of waiting on a single WaitHandle.
-// Since these just call WaitAll we leave the involved testing to those.
-TEST(WaitHandleTest, SingleWait) {
- WaitHandle wh;
- IREE_ASSERT_OK(wh.Wait());
- IREE_ASSERT_OK(wh.Wait(Now() + absl::Seconds(1)));
- IREE_ASSERT_OK(wh.Wait(absl::Seconds(1)));
- ASSERT_STATUSOR_TRUE(wh.TryWait());
-}
+TEST(Event, BlockingBehavior) {
+ iree_event_t main_to_thread;
+ IREE_ASSERT_OK(
+ iree_event_initialize(/*initial_state=*/false, &main_to_thread));
+ iree_event_t thread_to_main;
+ IREE_ASSERT_OK(
+ iree_event_initialize(/*initial_state=*/false, &thread_to_main));
-// Tests using WaitAll with no valid handles. This should no-op.
-TEST(WaitHandleTest, WaitAllNop) {
- IREE_ASSERT_OK(WaitHandle::WaitAll({}));
- IREE_ASSERT_OK(WaitHandle::WaitAll({nullptr}));
- IREE_ASSERT_OK(WaitHandle::WaitAll({nullptr, nullptr}));
-}
+ // Spinup a thread to signal the event.
+ // Note that it waits on the main_to_thread event until we get further along.
+ bool did_run_thread = false;
+ std::thread thread([&]() {
+ // Wait for main thread to signal (below).
+ IREE_ASSERT_OK(iree_wait_one(&main_to_thread, IREE_TIME_INFINITE_FUTURE));
-// Tests polling with WaitAll with multiple wait handles.
-TEST(WaitHandleTest, WaitAllPoll) {
- ManualResetEvent fence0;
- WaitHandle wh0 = fence0.OnSet();
- ManualResetEvent fence1;
- WaitHandle wh1 = fence1.OnSet();
+ // Set something so we know this ran at all.
+ did_run_thread = true;
- // Poll; should return immediately with timeout.
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh0, &wh1}, InfinitePast())));
-
- // Notify fence1.
- IREE_ASSERT_OK(fence1.Set());
-
- // Poll; should return immediately with timeout as fence1 is not signaled.
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh0, &wh1}, InfinitePast())));
-
- // Notify fence0.
- IREE_ASSERT_OK(fence0.Set());
-
- // Poll again; should return immediately with success.
- IREE_ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}, InfinitePast()));
-}
-
-// Tests waiting when the first file handle is invalid. This is to verify a
-// workaround for bad poll() behavior with fds[0] == -1.
-TEST(WaitHandleTest, WaitAllWithInvalid0) {
- ManualResetEvent fence;
- WaitHandle wh = fence.OnSet();
-
- // Poll; should return immediately with timeout as fence is not signaled.
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({nullptr, &wh}, InfinitePast())));
-
- // Notify fence.
- IREE_ASSERT_OK(fence.Set());
-
- // Poll again; should return immediately with success.
- IREE_ASSERT_OK(WaitHandle::WaitAll({nullptr, &wh}, InfinitePast()));
-}
-
-// Tests exceeding the timeout deadline with WaitAll.
-TEST(WaitHandleTest, WaitAllTimeout) {
- ManualResetEvent fence;
- WaitHandle wh = fence.OnSet();
-
- // Wait with timeout on the unsignaled fence:
- // Via polling (should never block):
- ASSERT_TRUE(IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, InfinitePast())));
- ASSERT_STATUSOR_FALSE(WaitHandle::TryWaitAll({&wh}));
- // Via time in the near future (should block):
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, Milliseconds(250))));
- // Via time in the past, should exceed deadline.
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, Milliseconds(-250))));
-
- // Notify and ensure no more timeouts.
- IREE_ASSERT_OK(fence.Set());
- IREE_ASSERT_OK(WaitHandle::WaitAll({&wh}, InfinitePast()));
- ASSERT_STATUSOR_TRUE(WaitHandle::TryWaitAll({&wh}));
- IREE_ASSERT_OK(WaitHandle::WaitAll({&wh}, Milliseconds(250)));
-
- // Via time in the past, should exceed deadline even if signaled.
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, Milliseconds(-250))));
-}
-
-// Tests using WaitAll to wait on other threads.
-TEST(WaitHandleTest, WaitAllThreaded) {
- // Spin up two threads.
- ManualResetEvent fence0;
- std::thread t0{[&]() {
- ::usleep(absl::ToInt64Microseconds(Milliseconds(250)));
- IREE_ASSERT_OK(fence0.Set());
- }};
- ManualResetEvent fence1;
- std::thread t1{[&]() {
- ::usleep(absl::ToInt64Microseconds(Milliseconds(250)));
- IREE_ASSERT_OK(fence1.Set());
- }};
-
- // Wait on both threads to complete.
- WaitHandle wh0 = fence0.OnSet();
- WaitHandle wh1 = fence1.OnSet();
- IREE_ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}));
-
- t0.join();
- t1.join();
-}
-
-// Tests using WaitAll with multiple wait handles from the same fence.
-TEST(WaitHandleTest, WaitAllSameSource) {
- ManualResetEvent fence;
- WaitHandle wh0 = fence.OnSet();
- WaitHandle wh1 = fence.OnSet();
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh0, &wh1}, InfinitePast())));
- IREE_ASSERT_OK(fence.Set());
- IREE_ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}));
-}
-
-// Tests using WaitAll with literally the same wait handles.
-TEST(WaitHandleTest, WaitAllSameHandle) {
- ManualResetEvent fence;
- WaitHandle wh = fence.OnSet();
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAll({&wh, &wh}, InfinitePast())));
- IREE_ASSERT_OK(fence.Set());
- IREE_ASSERT_OK(WaitHandle::WaitAll({&wh, &wh}));
-}
-
-// Tests WaitAll when a wait handle fails.
-TEST(WaitHandleTest, WaitAllFailure) {
- WaitHandle good_wh;
- // Create a purposefully bad handle to induce an error.
- WaitHandle bad_wh = WaitHandle::AlwaysFailing();
- // Should fail with some posixy error.
- ASSERT_FALSE(WaitHandle::WaitAll({&good_wh, &bad_wh}).ok());
-}
-
-// Tests using WaitAny with no valid handles. This should no-op.
-TEST(WaitHandleTest, WaitAnyNop) {
- ASSERT_TRUE(IsInvalidArgument(WaitHandle::WaitAny({}).status()));
- IREE_ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({nullptr}));
- ASSERT_EQ(0, index);
- IREE_ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({nullptr, nullptr}));
- ASSERT_EQ(0, index);
-}
-
-// Tests polling with WaitAny with multiple wait handles.
-TEST(WaitHandleTest, WaitAnyPoll) {
- ManualResetEvent fence0;
- WaitHandle wh0 = fence0.OnSet();
- ManualResetEvent fence1;
- WaitHandle wh1 = fence1.OnSet();
-
- // Poll; should return immediately with timeout.
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0, &wh1}, InfinitePast()).status()));
-
- // Notify fence1.
- IREE_ASSERT_OK(fence1.Set());
-
- // Poll; should return immediately with fence1 signaled.
- IREE_ASSERT_OK_AND_ASSIGN(int index,
- WaitHandle::WaitAny({&wh0, &wh1}, InfinitePast()));
- EXPECT_EQ(1, index);
-
- // Notify fence0.
- IREE_ASSERT_OK(fence0.Set());
-
- // Poll again; should return immediately; which one is signaled is undefined.
- IREE_ASSERT_OK_AND_ASSIGN(index,
- WaitHandle::WaitAny({&wh0, &wh1}, InfinitePast()));
- ASSERT_TRUE(index == 0 || index == 1);
-}
-
-// Tests exceeding the timeout deadline with WaitAny.
-TEST(WaitHandleTest, WaitAnyTimeout) {
- ManualResetEvent fence0;
- WaitHandle wh0 = fence0.OnSet();
- ManualResetEvent fence1;
- WaitHandle wh1 = fence1.OnSet();
-
- // Wait with timeout on the unsignaled fences:
- // Via polling (should never block):
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0, &wh1}, InfinitePast()).status()));
- IREE_ASSERT_OK_AND_ASSIGN(int index, WaitHandle::TryWaitAny({&wh0, &wh1}));
- ASSERT_EQ(-1, index);
- // Via time in the near future (should block):
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0, &wh1}, Milliseconds(250)).status()));
-
- // Notify one of the fences. Should return immediately.
- IREE_ASSERT_OK(fence1.Set());
- IREE_ASSERT_OK_AND_ASSIGN(index,
- WaitHandle::WaitAny({&wh0, &wh1}, InfinitePast()));
- ASSERT_EQ(1, index);
- IREE_ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0, &wh1}));
- ASSERT_EQ(1, index);
- IREE_ASSERT_OK_AND_ASSIGN(
- index, WaitHandle::WaitAny({&wh0, &wh1}, Milliseconds(250)));
- ASSERT_EQ(1, index);
-
- // The unnotified fence should still timeout.
- ASSERT_TRUE(
- IsDeadlineExceeded(WaitHandle::WaitAny({&wh0}, InfinitePast()).status()));
- IREE_ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0}));
- ASSERT_EQ(-1, index);
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0}, Milliseconds(250)).status()));
-
- // Notify last fence and ensure complete.
- IREE_ASSERT_OK(fence0.Set());
- IREE_ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({&wh0}, InfinitePast()));
- ASSERT_EQ(0, index);
- IREE_ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0}));
- ASSERT_EQ(0, index);
- IREE_ASSERT_OK_AND_ASSIGN(index,
- WaitHandle::WaitAny({&wh0}, Milliseconds(250)));
- ASSERT_EQ(0, index);
-}
-
-// Tests using WaitAny to wait on other threads.
-TEST(WaitHandleTest, WaitAnyThreaded) {
- // Spin up two threads.
- // t1 will wait on t0 such that they will act in sequence.
- ManualResetEvent fence0;
- std::thread t0{[&]() {
- ::usleep(absl::ToInt64Microseconds(Milliseconds(250)));
- IREE_ASSERT_OK(fence0.Set());
- }};
- ManualResetEvent fence1;
- std::thread t1{[&]() {
- IREE_ASSERT_OK(fence0.OnSet().Wait());
- ::usleep(absl::ToInt64Microseconds(Milliseconds(250)));
- IREE_ASSERT_OK(fence1.Set());
- }};
-
- // Wait on both threads. We expect 0 to complete first.
- WaitHandle wh0 = fence0.OnSet();
- WaitHandle wh1 = fence1.OnSet();
- IREE_ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1}));
- ASSERT_EQ(0, index);
-
- // Now wait for thread 1.
- IREE_ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({&wh1}));
- ASSERT_EQ(0, index);
-
- t0.join();
- t1.join();
-}
-
-// Tests using WaitAny with multiple wait handles from the same fence.
-TEST(WaitHandleTest, WaitAnySameSource) {
- ManualResetEvent fence;
- WaitHandle wh0 = fence.OnSet();
- WaitHandle wh1 = fence.OnSet();
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh0, &wh1}, InfinitePast()).status()));
- IREE_ASSERT_OK(fence.Set());
- IREE_ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1}));
- ASSERT_TRUE(index == 0 || index == 1);
-}
-
-// Tests using WaitAny with literally the same wait handles.
-TEST(WaitHandleTest, WaitAnySameHandle) {
- ManualResetEvent fence;
- WaitHandle wh = fence.OnSet();
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAny({&wh, &wh}, InfinitePast()).status()));
- IREE_ASSERT_OK(fence.Set());
- IREE_ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh, &wh}));
- ASSERT_TRUE(index == 0 || index == 1);
-}
-
-// Tests WaitAny when a wait handle fails.
-TEST(WaitHandleTest, WaitAnyFailure) {
- WaitHandle good_wh;
- // Create a purposefully bad handle to induce an error.
- WaitHandle bad_wh = WaitHandle::AlwaysFailing();
- // Should fail with some posixy error.
- ASSERT_FALSE(WaitHandle::WaitAny({&good_wh, &bad_wh}).ok());
-}
-
-// ManualResetEvent with innards exposed. Meh.
-class ExposedManualResetEvent : public ManualResetEvent {
- public:
- using ManualResetEvent::AcquireFdForWait;
- using ManualResetEvent::TryResolveWakeOnFd;
-};
-
-// Mock type for the WaitableObject methods.
-class MockWaitableObject : public ::testing::StrictMock<WaitableObject> {
- public:
- MockWaitableObject() : ::testing::StrictMock<WaitableObject>() {}
-
- MOCK_METHOD(std::string, DebugString, (), (const, override));
- MOCK_METHOD((StatusOr<std::pair<FdType, int>>), AcquireFdForWait,
- (Time deadline_ns), (override));
- MOCK_METHOD(StatusOr<bool>, TryResolveWakeOnFd, (int fd), (override));
-
- WaitHandle OnSomething() { return WaitHandle(add_ref(this)); }
-};
-
-// Tests normal AcquireFdForWait + TryResolveWakeOnFd use.
-TEST(WaitableObjectTest, AcquireAndResolve) {
- MockWaitableObject mwo;
- WaitHandle wh = mwo.OnSomething();
-
- // Use a MRE for testing, as we can just use its fd.
- ExposedManualResetEvent mre;
-
- // Try waiting; we should see the AcquireFdForWait and then return because
- // the fd has not been resolved.
- EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](Time deadline_ns) {
- // Return the valid FD from the MRE.
- return mre.AcquireFdForWait(deadline);
+ // Notify the caller thread.
+ iree_event_set(&thread_to_main);
});
- ASSERT_STATUSOR_FALSE(wh.TryWait());
- // Signal the MRE.
- IREE_ASSERT_OK(mre.Set());
+ // The thread may take some time to spin up; it must wait for us to allow it
+ // to run its body though so we should be fine here.
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ ASSERT_FALSE(did_run_thread);
- // Try waiting again; we should get the AcquireFdForWait and then also get
- // the TryResolveWakeOnFd.
- EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](Time deadline_ns) {
- // Return the valid (and now signaled) FD from the MRE.
- return mre.AcquireFdForWait(deadline);
+ // Allow the thread to continue and wait for it to exit.
+ iree_event_set(&main_to_thread);
+ IREE_ASSERT_OK(iree_wait_one(&thread_to_main, IREE_TIME_INFINITE_FUTURE));
+ ASSERT_TRUE(did_run_thread);
+
+ thread.join();
+ iree_event_deinitialize(&main_to_thread);
+ iree_event_deinitialize(&thread_to_main);
+}
+
+//===----------------------------------------------------------------------===//
+// iree_wait_set_t
+//===----------------------------------------------------------------------===//
+
+// Tests basic usage of the wait set API without waiting.
+TEST(WaitSet, Lifetime) {
+ iree_event_t event;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &event));
+
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, event));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, event));
+ iree_wait_set_erase(wait_set, event);
+ iree_wait_set_clear(wait_set);
+ iree_wait_set_free(wait_set);
+
+ iree_event_deinitialize(&event);
+}
+
+TEST(WaitSet, UnreasonableCapacity) {
+ iree_wait_set_t* wait_set = NULL;
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_INVALID_ARGUMENT,
+ iree_wait_set_allocate(1 * 1024 * 1024, iree_allocator_system(),
+ &wait_set));
+}
+
+// Tests that inserting the same handles multiple times is tracked correctly.
+TEST(WaitSet, Deduplication) {
+ iree_event_t ev_unset, ev_dupe;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_dupe));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ // We want to test for duplication on ev_dupe here so ensure it's added.
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_dupe));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_dupe));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_dupe));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset));
+
+ // Wait should succeed immediately because ev_dupe is set (and our wake handle
+ // should be ev_dupe).
+ iree_wait_handle_t wake_handle;
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0,
+ memcmp(&ev_dupe.value, &wake_handle.value, sizeof(ev_dupe.value)));
+
+ // Erase the events one at a time and ensure we still get the expected number
+ // of waits on ev_dupe.
+ iree_wait_set_erase(wait_set, wake_handle);
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0,
+ memcmp(&ev_dupe.value, &wake_handle.value, sizeof(ev_dupe.value)));
+ iree_wait_set_erase(wait_set, wake_handle);
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0,
+ memcmp(&ev_dupe.value, &wake_handle.value, sizeof(ev_dupe.value)));
+ iree_wait_set_erase(wait_set, wake_handle);
+
+ // Now there should just be ev_unset present in the set and a poll will fail.
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset);
+ iree_event_deinitialize(&ev_dupe);
+}
+
+// Tests that clear handles things right in the face of dupes.
+TEST(WaitSet, Clear) {
+ iree_event_t ev_unset, ev_dupe;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_dupe));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ // We want to test for duplication o n ev_dupe here.
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_dupe));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_dupe));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_dupe));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset));
+
+ // Wait should succeed immediately because ev_dupe is set (and our wake handle
+ // should be ev_dupe).
+ iree_wait_handle_t wake_handle;
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0,
+ memcmp(&ev_dupe.value, &wake_handle.value, sizeof(ev_dupe.value)));
+
+ // Erase all events from the set.
+ iree_wait_set_clear(wait_set);
+
+ // No more events remaining; should pass immediately.
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset);
+ iree_event_deinitialize(&ev_dupe);
+}
+
+// Tests iree_wait_all when polling (deadline_ns = IREE_TIME_INFINITE_PAST).
+TEST(WaitSet, WaitAllPolling) {
+ iree_event_t ev_unset_0, ev_unset_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_1));
+ iree_event_t ev_set_0, ev_set_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_1));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ // Polls when empty should never block.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_all(wait_set, IREE_TIME_INFINITE_PAST));
+
+ // Polls with only unset handles should never block.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+ IREE_EXPECT_STATUS_IS(IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_all(wait_set, IREE_TIME_INFINITE_PAST));
+
+ // Polls with only set handles should return immediately.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_1));
+ IREE_ASSERT_OK(iree_wait_all(wait_set, IREE_TIME_INFINITE_PAST));
+
+ // Polls with mixed set/unset should never succeed.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_1));
+ IREE_EXPECT_STATUS_IS(IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_all(wait_set, IREE_TIME_INFINITE_PAST));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset_0);
+ iree_event_deinitialize(&ev_unset_1);
+ iree_event_deinitialize(&ev_set_0);
+ iree_event_deinitialize(&ev_set_1);
+}
+
+// Tests iree_wait_all with timeouts (deadline_ns = non-zero).
+TEST(WaitSet, WaitAllTimeout) {
+ iree_event_t ev_unset_0, ev_unset_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_1));
+ iree_event_t ev_set_0, ev_set_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_1));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ // Timeouts when empty should never block.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_all(wait_set, iree_time_now() + kShortTimeoutNS));
+
+ // Timeouts with only unset handles should block (and then expire).
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+ constexpr iree_duration_t kShortTimeoutNS = 1000000ull;
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_all(wait_set, iree_time_now() + kShortTimeoutNS));
+
+ // Timeouts with only set handles should return immediately.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_1));
+ IREE_ASSERT_OK(iree_wait_all(wait_set, iree_time_now() + kShortTimeoutNS));
+
+ // Timeouts with mixed set/unset should never succeed.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_1));
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_all(wait_set, iree_time_now() + kShortTimeoutNS));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset_0);
+ iree_event_deinitialize(&ev_unset_1);
+ iree_event_deinitialize(&ev_set_0);
+ iree_event_deinitialize(&ev_set_1);
+}
+
+// Tests iree_wait_all when blocking (deadline_ns = IREE_TIME_INFINITE_FUTURE).
+TEST(WaitSet, WaitAllBlocking) {
+ iree_event_t thread_to_main;
+ IREE_ASSERT_OK(
+ iree_event_initialize(/*initial_state=*/false, &thread_to_main));
+ iree_event_t ev_set_0, ev_set_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_1));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ // Throw in some other set handles so that we are multi-waiting for just the
+ // thread_to_main event to be set.
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_1));
+
+ // Wait forever (no timeout).
+ // We approximate that by forking off a thread to signal our local event. We
+ // can assume that a moderate wait is enough to verify the forever behavior as
+ // otherwise we are probably just messing up the math and will timeout.
+ std::thread thread([&]() {
+ // Notify the caller thread after sleeping (to ensure it's not polling).
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ iree_event_set(&thread_to_main);
});
- EXPECT_CALL(mwo, TryResolveWakeOnFd(_)).WillOnce(Return(true));
- ASSERT_STATUSOR_TRUE(wh.TryWait());
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, thread_to_main));
+ IREE_ASSERT_OK(iree_wait_all(wait_set, IREE_TIME_INFINITE_FUTURE));
+
+ thread.join();
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&thread_to_main);
+ iree_event_deinitialize(&ev_set_0);
+ iree_event_deinitialize(&ev_set_1);
}
-// Tests timing out in AcquireFdForWait.
-TEST(WaitableObjectTest, AcquireFdForWaitTimeout) {
- ManualResetEvent mre;
- WaitHandle always_wait = mre.OnSet();
- WaitHandle always_signal = WaitHandle::AlwaysSignaling();
- MockWaitableObject mwo;
- WaitHandle wh = mwo.OnSomething();
+// Tests iree_wait_all when one or more handles are duplicated.
+TEST(WaitSet, WaitAllDuplicates) {
+ iree_event_t ev_set;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
- // Make the AcquireFdForWait take longer than the timeout. We should hit
- // deadline exceeded even though always_wait hasn't be signaled.
- EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([](Time deadline_ns) {
- ::usleep(absl::ToInt64Microseconds(Milliseconds(10)));
- return std::make_pair(WaitableObject::FdType::kPermanent,
- WaitableObject::kInvalidFd);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
+
+ // Wait should succeed immediately because ev_set is set.
+ IREE_ASSERT_OK(iree_wait_all(wait_set, IREE_TIME_INFINITE_PAST));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_set);
+}
+
+// Tests iree_wait_any; note that this is only focused on testing the wait.
+TEST(WaitSet, WaitAny) {
+ iree_event_t ev_unset, ev_set;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
+
+ // Wait should succeed immediately because ev_set is set (and our wake handle
+ // should be ev_set).
+ iree_wait_handle_t wake_handle;
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0, memcmp(&ev_set.value, &wake_handle.value, sizeof(ev_set.value)));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset);
+ iree_event_deinitialize(&ev_set);
+}
+
+// Tests iree_wait_any when polling (deadline_ns = IREE_TIME_INFINITE_PAST).
+TEST(WaitSet, WaitAnyPolling) {
+ iree_event_t ev_unset_0, ev_unset_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_1));
+ iree_event_t ev_set_0, ev_set_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_1));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ iree_wait_handle_t empty_handle;
+ memset(&empty_handle, 0, sizeof(empty_handle));
+
+ // Polls when empty should never block and return an empty wake handle.
+ // This is so that if the caller touches the wake_handle they at least have
+ // initialized memory.
+ iree_wait_set_clear(wait_set);
+ iree_wait_handle_t wake_handle;
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0, memcmp(&empty_handle, &wake_handle, sizeof(empty_handle)));
+
+ // Polls with only unset handles should never block.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0, memcmp(&empty_handle, &wake_handle, sizeof(empty_handle)));
+
+ // Polls with only set handles should return immediately.
+ // Note that which handle is returned is not specified.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_1));
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_TRUE(
+ 0 ==
+ memcmp(&ev_set_0.value, &wake_handle.value, sizeof(ev_set_0.value)) ||
+ 0 == memcmp(&ev_set_1.value, &wake_handle.value, sizeof(ev_set_1.value)));
+
+ // Polls with mixed set/unset should return immediately.
+ // Note that which handle is returned is not specified but we know it should
+ // at least be one of the signaled ones.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_1));
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_TRUE(
+ 0 ==
+ memcmp(&ev_set_0.value, &wake_handle.value, sizeof(ev_set_0.value)) ||
+ 0 == memcmp(&ev_set_1.value, &wake_handle.value, sizeof(ev_set_1.value)));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset_0);
+ iree_event_deinitialize(&ev_unset_1);
+ iree_event_deinitialize(&ev_set_0);
+ iree_event_deinitialize(&ev_set_1);
+}
+
+// Tests iree_wait_any with timeouts (deadline_ns = non-zero).
+TEST(WaitSet, WaitAnyTimeout) {
+ iree_event_t ev_unset_0, ev_unset_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_1));
+ iree_event_t ev_set_0, ev_set_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set_1));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ iree_wait_handle_t empty_handle;
+ memset(&empty_handle, 0, sizeof(empty_handle));
+
+ // Timeouts when empty should never block.
+ iree_wait_set_clear(wait_set);
+ iree_wait_handle_t wake_handle;
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, iree_time_now() + kShortTimeoutNS, &wake_handle));
+ EXPECT_EQ(0, memcmp(&empty_handle, &wake_handle, sizeof(empty_handle)));
+
+ // Timeouts with only unset handles should block (and then expire).
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+ constexpr iree_duration_t kShortTimeoutNS = 1000000ull;
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_any(wait_set, iree_time_now() + kShortTimeoutNS, &wake_handle));
+ EXPECT_EQ(0, memcmp(&empty_handle, &wake_handle, sizeof(empty_handle)));
+
+ // Timeouts with only set handles should return immediately and have one of
+ // the set handles as the wake handle.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_1));
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, iree_time_now() + kShortTimeoutNS, &wake_handle));
+ EXPECT_TRUE(
+ 0 ==
+ memcmp(&ev_set_0.value, &wake_handle.value, sizeof(ev_set_0.value)) ||
+ 0 == memcmp(&ev_set_1.value, &wake_handle.value, sizeof(ev_set_1.value)));
+
+ // Timeouts with mixed set/unset should return immediately and have one of the
+ // set handles as the wake handle.
+ iree_wait_set_clear(wait_set);
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set_1));
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, iree_time_now() + kShortTimeoutNS, &wake_handle));
+ EXPECT_TRUE(
+ 0 ==
+ memcmp(&ev_set_0.value, &wake_handle.value, sizeof(ev_set_0.value)) ||
+ 0 == memcmp(&ev_set_1.value, &wake_handle.value, sizeof(ev_set_1.value)));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset_0);
+ iree_event_deinitialize(&ev_unset_1);
+ iree_event_deinitialize(&ev_set_0);
+ iree_event_deinitialize(&ev_set_1);
+}
+
+// Tests iree_wait_any when blocking (deadline_ns = IREE_TIME_INFINITE_FUTURE).
+TEST(WaitSet, WaitAnyBlocking) {
+ iree_event_t thread_to_main;
+ IREE_ASSERT_OK(
+ iree_event_initialize(/*initial_state=*/false, &thread_to_main));
+ iree_event_t ev_unset_0, ev_unset_1;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_1));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ // Throw in some unset handles so that we are multi-waiting for just the
+ // thread_to_main event to be set.
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+
+ // Wait forever (no timeout).
+ // We approximate that by forking off a thread to signal our local event. We
+ // can assume that a moderate wait is enough to verify the forever behavior as
+ // otherwise we are probably just messing up the math and will timeout.
+ std::thread thread([&]() {
+ // Notify the caller thread after sleeping (to ensure it's not polling).
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ iree_event_set(&thread_to_main);
});
- ASSERT_TRUE(IsDeadlineExceeded(
- WaitHandle::WaitAll({&wh, &always_signal}, Now() - Milliseconds(250))));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, thread_to_main));
+ iree_wait_handle_t wake_handle;
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_FUTURE, &wake_handle));
+ EXPECT_EQ(0, memcmp(&thread_to_main.value, &wake_handle.value,
+ sizeof(thread_to_main.value)));
+
+ thread.join();
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&thread_to_main);
+ iree_event_deinitialize(&ev_unset_0);
+ iree_event_deinitialize(&ev_unset_1);
}
-// Tests TryResolveWakeOnFd when a handle is a permanent kSignaledFd.
-TEST(WaitableObjectTest, SignaledFd) {
- MockWaitableObject mwo;
- WaitHandle wh = mwo.OnSomething();
+// Tests that an iree_wait_any followed by an iree_wait_set_erase properly
+// chooses the right handle to erase.
+TEST(WaitSet, WaitAnyErase) {
+ iree_event_t ev_unset_0, ev_unset_1;
+ iree_event_t ev_set;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_0));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset_1));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
- // Return the kSignaledFd handle and expect that we still get our notify call.
- // We can do this multiple times.
- for (int i = 0; i < 4; ++i) {
- EXPECT_CALL(mwo, AcquireFdForWait(_))
- .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent,
- WaitableObject::kSignaledFd)));
- EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd))
- .WillOnce(Return(true));
- ASSERT_STATUSOR_TRUE(wh.TryWait());
- }
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_0));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset_1));
+
+ // Wait should succeed immediately because ev_set is set (and our wake handle
+ // should be ev_set).
+ iree_wait_handle_t wake_handle;
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0, memcmp(&ev_set.value, &wake_handle.value, sizeof(ev_set.value)));
+
+ // Erase the woken handle.
+ // NOTE: to get the behavior we want to test we must pass wake_handle here and
+ // not the ev_set value.
+ iree_wait_set_erase(wait_set, wake_handle);
+
+ // Try to wait again; this time we should timeout because only ev_unset_*
+ // remains in the set.
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset_0);
+ iree_event_deinitialize(&ev_unset_1);
+ iree_event_deinitialize(&ev_set);
}
-// Tests that waiting will not resolve if TryResolveWakeOnFd returns false.
-TEST(WaitableObjectTest, UnresolvedWake) {
- MockWaitableObject mwo;
- WaitHandle wh = mwo.OnSomething();
+// Tests that an iree_wait_any followed by an iree_wait_set_erase properly
+// chooses the right handle to erase (the tail one).
+TEST(WaitSet, WaitAnyEraseTail) {
+ iree_event_t ev_unset, ev_set;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
- // Fail to resolve the first time.
- // Since we are only trying to wait it should bail.
- EXPECT_CALL(mwo, AcquireFdForWait(_))
- .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent,
- WaitableObject::kSignaledFd)));
- EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd))
- .WillOnce(Return(false));
- ASSERT_STATUSOR_FALSE(wh.TryWait());
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
- // Resolve on the next try.
- EXPECT_CALL(mwo, AcquireFdForWait(_))
- .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent,
- WaitableObject::kSignaledFd)));
- EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd))
- .WillOnce(Return(true));
- ASSERT_STATUSOR_TRUE(wh.TryWait());
+ // Wait should succeed immediately because ev_set is set (and our wake handle
+ // should be ev_set).
+ iree_wait_handle_t wake_handle;
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0, memcmp(&ev_set.value, &wake_handle.value, sizeof(ev_set.value)));
+
+ // Erase the woken handle.
+ // NOTE: to get the behavior we want to test we must pass wake_handle here and
+ // not the ev_set value.
+ iree_wait_set_erase(wait_set, wake_handle);
+
+ // Try to wait again; this time we should timeout because only ev_unset
+ // remains in the set.
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset);
+ iree_event_deinitialize(&ev_set);
}
-// Tests the normal lifecycle of a ManualResetEvent.
-TEST(ManualResetEventTest, Lifecycle) {
- ManualResetEvent ev;
- EXPECT_FALSE(ev.DebugString().empty());
- WaitHandle wh0 = ev.OnSet();
- EXPECT_EQ(ev.DebugString(), wh0.DebugString());
- WaitHandle wh1 = ev.OnSet();
- EXPECT_EQ(ev.DebugString(), wh1.DebugString());
- // Should not be set.
- ASSERT_STATUSOR_FALSE(wh0.TryWait());
- ASSERT_STATUSOR_FALSE(wh1.TryWait());
- // Set should be sticky.
- IREE_ASSERT_OK(ev.Set());
- ASSERT_STATUSOR_TRUE(wh0.TryWait());
- ASSERT_STATUSOR_TRUE(wh1.TryWait());
- // Reset should clear.
- IREE_ASSERT_OK(ev.Reset());
- ASSERT_STATUSOR_FALSE(wh0.TryWait());
- ASSERT_STATUSOR_FALSE(wh1.TryWait());
- // Setting again should enable the previous WaitHandles to be signaled.
- IREE_ASSERT_OK(ev.Set());
- ASSERT_STATUSOR_TRUE(wh0.TryWait());
- ASSERT_STATUSOR_TRUE(wh1.TryWait());
+// Tests that an iree_wait_any followed by an iree_wait_set_erase without using
+// the wake_handle still erases the correct handle.
+TEST(WaitSet, WaitAnyEraseSplit) {
+ iree_event_t ev_unset, ev_set;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set));
+ iree_wait_set_t* wait_set = NULL;
+ IREE_ASSERT_OK(
+ iree_wait_set_allocate(128, iree_allocator_system(), &wait_set));
+
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_unset));
+ IREE_ASSERT_OK(iree_wait_set_insert(wait_set, ev_set));
+
+ // Wait should succeed immediately because ev_set is set (and our wake handle
+ // should be ev_set).
+ iree_wait_handle_t wake_handle;
+ IREE_ASSERT_OK(
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+ EXPECT_EQ(0, memcmp(&ev_set.value, &wake_handle.value, sizeof(ev_set.value)));
+
+ // Erase the woken handle *WITHOUT* using the wake_handle.
+ iree_wait_set_erase(wait_set, ev_set);
+
+ // Try to wait again; this time we should timeout because only ev_unset
+ // remains in the set.
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_any(wait_set, IREE_TIME_INFINITE_PAST, &wake_handle));
+
+ iree_wait_set_free(wait_set);
+ iree_event_deinitialize(&ev_unset);
+ iree_event_deinitialize(&ev_set);
}
-// Tests moving ManualResetEvents around.
-TEST(ManualResetEventTest, Move) {
- ManualResetEvent ev0;
- WaitHandle wh = ev0.OnSet();
- ManualResetEvent ev1{std::move(ev0)};
- ManualResetEvent ev2 = std::move(ev1);
- ev1 = std::move(ev2);
- IREE_ASSERT_OK(ev1.Set());
- ASSERT_STATUSOR_TRUE(wh.TryWait());
+// Tests iree_wait_one when polling (deadline_ns = IREE_TIME_INFINITE_PAST).
+TEST(WaitSet, WaitOnePolling) {
+ iree_event_t ev_unset, ev_set;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set));
+
+ // Polling (don't block even if unset).
+ IREE_EXPECT_STATUS_IS(IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_one(&ev_unset, IREE_TIME_INFINITE_PAST));
+ IREE_ASSERT_OK(iree_wait_one(&ev_set, IREE_TIME_INFINITE_PAST));
+
+ iree_event_deinitialize(&ev_unset);
+ iree_event_deinitialize(&ev_set);
}
-// Tests redundantly setting and resetting ManualResetEvents.
-TEST(ManualResetEventTest, RedundantUse) {
- ManualResetEvent ev;
- IREE_ASSERT_OK(ev.Reset());
- IREE_ASSERT_OK(ev.Reset());
- ASSERT_FALSE(ev.OnSet().TryWait().value());
- IREE_ASSERT_OK(ev.Set());
- IREE_ASSERT_OK(ev.Set());
- ASSERT_TRUE(ev.OnSet().TryWait().value());
- IREE_ASSERT_OK(ev.Reset());
- ASSERT_FALSE(ev.OnSet().TryWait().value());
+// Tests iree_wait_one with timeouts (deadline_ns = non-zero).
+TEST(WaitSet, WaitOneTimeout) {
+ iree_event_t ev_unset, ev_set;
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/false, &ev_unset));
+ IREE_ASSERT_OK(iree_event_initialize(/*initial_state=*/true, &ev_set));
+
+ // Force a timeout by waiting on an event that'll never get set.
+ IREE_EXPECT_STATUS_IS(
+ IREE_STATUS_DEADLINE_EXCEEDED,
+ iree_wait_one(&ev_unset, iree_time_now() + kShortTimeoutNS));
+
+ // Ensure we return immediately when waiting on a set value (and not wait
+ // 100 years because we messed up our math).
+ IREE_ASSERT_OK(iree_wait_one(&ev_set, iree_time_now() + kLongTimeoutNS));
+
+ iree_event_deinitialize(&ev_unset);
+ iree_event_deinitialize(&ev_set);
}
-// Tests waiting on an initially-set ManualResetEvent;
-TEST(ManualResetEventTest, SetThenWait) {
- ManualResetEvent ev;
- IREE_ASSERT_OK(ev.Set());
- ASSERT_TRUE(ev.OnSet().TryWait().value());
-}
+// Tests iree_wait_one when blocking (deadline_ns = IREE_TIME_INFINITE_FUTURE).
+TEST(WaitSet, WaitOneBlocking) {
+ iree_event_t thread_to_main;
+ IREE_ASSERT_OK(
+ iree_event_initialize(/*initial_state=*/false, &thread_to_main));
-// Tests that dangling an event will not wake waiters.
-// This is intentional (for now); we could with a bit of wrangling make it so
-// that WaitableObjects tracked their waiters and ensured they were all cleaned
-// up, but that seems hard. Don't drop your objects.
-TEST(ManualResetEventTest, NeverSet) {
- ManualResetEvent ev;
- WaitHandle wh = ev.OnSet();
- ASSERT_STATUSOR_FALSE(wh.TryWait());
- // Kill event to unblock waiters.
- ev = ManualResetEvent();
- // Waiter should not have woken.
- ASSERT_STATUSOR_FALSE(wh.TryWait());
+ // Wait forever (no timeout).
+ // We approximate that by forking off a thread to signal our local event. We
+ // can assume that a moderate wait is enough to verify the forever behavior as
+ // otherwise we are probably just messing up the math and will timeout.
+ std::thread thread([&]() {
+ // Notify the caller thread after sleeping (to ensure it's not polling).
+ std::this_thread::sleep_for(std::chrono::milliseconds(50));
+ iree_event_set(&thread_to_main);
+ });
+ IREE_ASSERT_OK(iree_wait_one(&thread_to_main, IREE_TIME_INFINITE_FUTURE));
+
+ thread.join();
+ iree_event_deinitialize(&thread_to_main);
}
} // namespace
diff --git a/iree/base/wait_handle_win32.c b/iree/base/wait_handle_win32.c
new file mode 100644
index 0000000..cf91401
--- /dev/null
+++ b/iree/base/wait_handle_win32.c
@@ -0,0 +1,449 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+
+#include "iree/base/tracing.h"
+
+//===----------------------------------------------------------------------===//
+// Platform utilities
+//===----------------------------------------------------------------------===//
+
+static_assert(
+ sizeof(iree_wait_primitive_value_t) == sizeof(HANDLE),
+ "win32 HANDLE type must match uintptr size in wait primitive struct");
+
+//===----------------------------------------------------------------------===//
+// iree_wait_primitive_* raw calls
+//===----------------------------------------------------------------------===//
+
+// Clones a wait handle such that both the |source_handle| and new
+// |out_target_handle| both reference the same wait primitive. The handle must
+// be closed with iree_wait_primitive_close as if it had been created.
+static iree_status_t iree_wait_primitive_clone(
+ iree_wait_handle_t* source_handle, iree_wait_handle_t* out_target_handle) {
+ if (source_handle->type != IREE_WAIT_PRIMITIVE_TYPE_WIN32_HANDLE) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "source wait handle must be a win32 HANDLE");
+ }
+
+ iree_wait_primitive_value_t value;
+ memset(&value, 0, sizeof(value));
+ HANDLE process = GetCurrentProcess();
+ if (!DuplicateHandle(process, (HANDLE)source_handle->value.win32.handle,
+ process, (LPHANDLE)&value.win32.handle, 0, FALSE,
+ DUPLICATE_SAME_ACCESS)) {
+ return iree_make_status(
+ iree_status_code_from_win32_error(GetLastError()),
+ "unable to duplicate HANDLE; possibly out of process handles");
+ }
+ return iree_wait_handle_wrap_primitive(IREE_WAIT_PRIMITIVE_TYPE_WIN32_HANDLE,
+ value, out_target_handle);
+}
+
+// Closes an existing handle that was either created manually or via
+// iree_wait_primitive_clone. Must not be called while there are any waiters on
+// the handle.
+static void iree_wait_primitive_close(iree_wait_handle_t* handle) {
+ if (IREE_LIKELY(handle->value.win32.handle != 0)) {
+ CloseHandle((HANDLE)handle->value.win32.handle);
+ }
+ iree_wait_handle_deinitialize(handle);
+}
+
+// Returns true if the two handles share the same underlying primitive object.
+static bool iree_wait_primitive_compare(const iree_wait_handle_t* lhs,
+ const iree_wait_handle_t* rhs) {
+ if (lhs->type != rhs->type) return false;
+ bool handles_match =
+ memcmp(&lhs->value, &rhs->value, sizeof(lhs->value)) == 0;
+ switch (lhs->type) {
+ case IREE_WAIT_PRIMITIVE_TYPE_WIN32_HANDLE:
+ // Note that multiple HANDLEs may point at the same underlying object
+ // (such as if they have been cloned).
+ return handles_match ||
+ CompareObjectHandles((HANDLE)lhs->value.win32.handle,
+ (HANDLE)rhs->value.win32.handle)
+ ? true
+ : false;
+ default:
+ return handles_match;
+ }
+}
+
+// Returns true if the two handles are identical in representation.
+// Note that two unique handles may point to the same underlying primitive
+// object (such as when they have been cloned); if testing for duplicate
+// primitives prefer iree_wait_primitive_compare.
+static bool iree_wait_primitive_compare_identical(
+ const iree_wait_handle_t* lhs, const iree_wait_handle_t* rhs) {
+ return lhs->type == rhs->type &&
+ memcmp(&lhs->value, &rhs->value, sizeof(lhs->value)) == 0;
+}
+
+//===----------------------------------------------------------------------===//
+// iree_wait_set_t
+//===----------------------------------------------------------------------===//
+
+struct iree_wait_set_s {
+ iree_allocator_t allocator;
+
+ // Total capacity of handles in the set (including duplicates).
+ // This defines the capacity of user_handles and native_handles and to ensure
+ // that we don't get insanely hard to debug behavioral differences when some
+ // handles happen to be duplicates we track the total count against this total
+ // capacity including duplicates.
+ //
+ // If you added 1000 duplicate handles to the set you'd need a handle_capacity
+ // of 1000 even though handle_count (expluding duplicates) would be 1.
+ iree_host_size_t handle_capacity;
+
+ // Total number of handles in the set (including duplicates).
+ // We use this to ensure that we provide consistent capacity errors;
+ iree_host_size_t total_handle_count;
+
+ // Number of handles in the set (excluding duplicates), defining the valid
+ // size of both user_handles and native_handles.
+ iree_host_size_t handle_count;
+
+ // De-duped user-provided handles. iree_wait_handle_t::set_internal.dupe_count
+ // is used to indicate how many additional duplicates there are of a
+ // particular handle. For example, dupe_count=0 means that there are no
+ // duplicates.
+ iree_wait_handle_t* user_handles;
+
+ // Native list of win32 HANDLE we will pass directly to WFMO.
+ // This list may be smaller than the total_handle_count if handles have been
+ // deduplicated.
+ HANDLE* native_handles;
+};
+
+iree_status_t iree_wait_set_allocate(iree_host_size_t capacity,
+ iree_allocator_t allocator,
+ iree_wait_set_t** out_set) {
+ // Be reasonable; 64 MAXIMUM_WAIT_OBJECTS is low, but 64K objects is too high.
+ if (capacity >= UINT16_MAX) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "wait set capacity of %zu is unreasonably large",
+ capacity);
+ }
+
+ iree_host_size_t user_handle_list_size =
+ capacity * sizeof(iree_wait_handle_t);
+ iree_host_size_t native_handle_list_size = capacity * sizeof(HANDLE);
+ iree_host_size_t total_size =
+ sizeof(iree_wait_set_t) + user_handle_list_size + native_handle_list_size;
+
+ iree_wait_set_t* set = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_allocator_malloc(allocator, total_size, (void**)&set));
+ set->allocator = allocator;
+ set->handle_capacity = capacity;
+ iree_wait_set_clear(set);
+
+ set->user_handles =
+ (iree_wait_handle_t*)((uint8_t*)set + sizeof(iree_wait_set_t));
+ set->native_handles =
+ (HANDLE*)((uint8_t*)set->user_handles + user_handle_list_size);
+
+ *out_set = set;
+ return iree_ok_status();
+}
+
+void iree_wait_set_free(iree_wait_set_t* set) {
+ iree_allocator_free(set->allocator, set);
+}
+
+iree_status_t iree_wait_set_insert(iree_wait_set_t* set,
+ iree_wait_handle_t handle) {
+ if (set->total_handle_count + 1 > set->handle_capacity) {
+ return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+ "wait set capacity reached");
+ }
+
+ // First check to see if we already have the handle in the set; since APIs
+ // like WFMO don't allow duplicate handles in their arguments this is our
+ // workaround (with the benefit of also reducing the native handle count).
+ for (iree_host_size_t i = 0; i < set->handle_count; ++i) {
+ iree_wait_handle_t* existing_handle = &set->user_handles[i];
+ if (iree_wait_primitive_compare_identical(existing_handle, &handle)) {
+ // Handle already exists in the set; just increment the reference count.
+ ++existing_handle->set_internal.dupe_count;
+ ++set->total_handle_count;
+ return iree_ok_status();
+ }
+ }
+
+ HANDLE native_handle = NULL;
+ if (IREE_LIKELY(handle.type == IREE_WAIT_PRIMITIVE_TYPE_WIN32_HANDLE)) {
+ // Our normal handle type; pass-through below.
+ native_handle = (HANDLE)handle.value.win32.handle;
+ } else {
+ return iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED,
+ "unimplemented primitive type %d (expected PERMANENT/WIN32_HANDLE)",
+ (int)handle.type);
+ }
+
+ // There's a max of 64 waitable handles. If we want to support more than that
+ // we can spawn threads to wait on 64 objects and then wait on all those
+ // threads. For example:
+ // iree_wait_multi(...180 handles...):
+ // -> spawn th0 and wait on handles 0-63 (64 handles)
+ // -> spawn th1 and wait on handles 64-127 (64 handles)
+ // wait on [th0, th1, handles 128-179] (threads + 52 remaining handles)
+ //
+ // At the point you're multiwaiting on that many things, though, it indicates
+ // that there may be higher level coalescing that can be done by the
+ // application itself (by, say, multiplexing sockets onto a single fd instead
+ // of trying to wait on every unique socket handle via this API).
+ if (native_handle &&
+ IREE_UNLIKELY(set->handle_count + 1 > MAXIMUM_WAIT_OBJECTS)) {
+ return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+ "max wait objects exceeded; only up to %d native "
+ "wait handles are supported in WFMO",
+ (int)MAXIMUM_WAIT_OBJECTS);
+ }
+
+ ++set->total_handle_count;
+ iree_host_size_t index = set->handle_count++;
+ iree_wait_handle_t* user_handle = &set->user_handles[index];
+ IREE_IGNORE_ERROR(
+ iree_wait_handle_wrap_primitive(handle.type, handle.value, user_handle));
+ user_handle->set_internal.dupe_count = 0; // just us so far
+ set->native_handles[index] = native_handle;
+
+ return iree_ok_status();
+}
+
+void iree_wait_set_erase(iree_wait_set_t* set, iree_wait_handle_t handle) {
+ // Find the user handle in the set. This either requires a linear scan to
+ // find the matching user handle or - if valid - we can use the native index
+ // set after an iree_wait_any wake to do a quick lookup.
+ iree_host_size_t index = handle.set_internal.index;
+ if (IREE_UNLIKELY(index >= set->handle_count) ||
+ IREE_UNLIKELY(!iree_wait_primitive_compare_identical(
+ &set->user_handles[index], &handle))) {
+ // Fallback to a linear scan of (hopefully) a small list.
+ for (iree_host_size_t i = 0; i < set->handle_count; ++i) {
+ if (iree_wait_primitive_compare_identical(&set->user_handles[i],
+ &handle)) {
+ index = i;
+ break;
+ }
+ }
+ }
+
+ // Decrement reference count.
+ iree_wait_handle_t* existing_handle = &set->user_handles[index];
+ if (existing_handle->set_internal.dupe_count-- > 0) {
+ // Still one or more remaining in the set; leave it in the handle list.
+ --set->total_handle_count;
+ return;
+ }
+
+ // No more references remaining; remove from both handle lists.
+ // Since we make no guarantees about the order of the lists we can just swap
+ // with the last value.
+ int tail_index = (int)set->handle_count - 1;
+ if (tail_index > index) {
+ memcpy(&set->native_handles[index], &set->native_handles[tail_index],
+ sizeof(*set->native_handles));
+ memcpy(&set->user_handles[index], &set->user_handles[tail_index],
+ sizeof(*set->user_handles));
+ }
+ --set->total_handle_count;
+ --set->handle_count;
+}
+
+void iree_wait_set_clear(iree_wait_set_t* set) {
+ set->total_handle_count = 0;
+ set->handle_count = 0;
+}
+
+static iree_status_t iree_wait_multi(iree_wait_set_t* set, bool require_all,
+ iree_time_t deadline_ns,
+ iree_wait_handle_t* out_wake_handle) {
+ // TODO(benvanik): see if we can use tracy's mutex tracking to make waits
+ // nicer (at least showing signal->wait relations).
+
+ // Early-exit when there's nothing to wait on.
+ if (set->handle_count == 0) {
+ if (out_wake_handle) memset(out_wake_handle, 0, sizeof(*out_wake_handle));
+ return iree_ok_status();
+ }
+
+ // Remap absolute timeout to relative timeout, handling special values as
+ // needed.
+ DWORD timeout_ms =
+ (DWORD)(iree_absolute_deadline_to_timeout_ns(deadline_ns) / 1000000ull);
+
+ // Perform the wait; this is allowed to yield the calling thread even if the
+ // timeout_ms is 0 to indicate a poll.
+ DWORD result =
+ WaitForMultipleObjectsEx(set->handle_count, set->native_handles,
+ /*bWaitAll=*/(require_all ? TRUE : FALSE),
+ timeout_ms, /*bAlertable=*/FALSE);
+
+ if (result == WAIT_TIMEOUT) {
+ // Timeout elapsed while waiting; note that the timeout may have been 0 to
+ // force a poll and be an expected result. We avoid a full status object
+ // here as we don't want to track all that in non-exceptional cases.
+ return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ } else if (result >= WAIT_OBJECT_0 &&
+ result < WAIT_OBJECT_0 + set->handle_count) {
+ // One (or more) handles were signaled sucessfully.
+ if (out_wake_handle) {
+ DWORD wake_index = result - WAIT_OBJECT_0;
+ iree_wait_primitive_value_t wake_value;
+ memset(&wake_value, 0, sizeof(wake_value));
+ wake_value.win32.handle = (uintptr_t)set->native_handles[wake_index];
+ iree_wait_handle_wrap_primitive(IREE_WAIT_PRIMITIVE_TYPE_WIN32_HANDLE,
+ wake_value, out_wake_handle);
+
+ // Optimization for wait-wake-erase; this lets us avoid scanning the
+ // native handle list (the kernel already did that for us!).
+ out_wake_handle->set_internal.index = wake_index;
+ }
+ return iree_ok_status();
+ } else if (result >= WAIT_ABANDONED_0 &&
+ result < WAIT_ABANDONED_0 + set->handle_count) {
+ // One (or more) mutex handles were abandonded during the wait.
+ // This happens when a thread holding the mutex dies without releasing it.
+ // This is less common in-process and more for the cross-process situations
+ // where we have duped/opened a remote handle and the remote process dies.
+ // That's a pretty situation but not quite unheard of in sandboxing impls
+ // where death is a feature.
+ //
+ // NOTE: we shouldn't get abandoned handles in regular cases - both because
+ // we don't really use mutex handles (though users may provide them) and
+ // that mutex abandonment is exceptional. If you see this you are probably
+ // going to want to look for thread exit messages or zombie processes.
+ DWORD wake_index = result - WAIT_ABANDONED_0;
+ return iree_make_status(
+ IREE_STATUS_DATA_LOSS,
+ "mutex native handle %lu abanonded; shared state is "
+ "(likely) inconsistent",
+ wake_index);
+ } else if (result == WAIT_FAILED) {
+ return iree_make_status(iree_status_code_from_win32_error(GetLastError()),
+ "WFMO failed");
+ } else {
+ return iree_make_status(IREE_STATUS_INTERNAL,
+ "WFMO internal error (unimplemented APC?)");
+ }
+}
+
+iree_status_t iree_wait_all(iree_wait_set_t* set, iree_time_t deadline_ns) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_status_t status =
+ iree_wait_multi(set, /*require_all=*/true, deadline_ns, NULL);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+iree_status_t iree_wait_any(iree_wait_set_t* set, iree_time_t deadline_ns,
+ iree_wait_handle_t* out_wake_handle) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_status_t status =
+ iree_wait_multi(set, /*require_all=*/false, deadline_ns, out_wake_handle);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+iree_status_t iree_wait_one(iree_wait_handle_t* handle,
+ iree_time_t deadline_ns) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Remap absolute timeout to relative timeout, handling special values as
+ // needed.
+ DWORD timeout_ms =
+ (DWORD)(iree_absolute_deadline_to_timeout_ns(deadline_ns) / 1000000ull);
+
+ // Perform the wait; this is allowed to yield the calling thread even if the
+ // timeout_ms is 0 to indicate a poll.
+ DWORD result =
+ WaitForSingleObjectEx((HANDLE)handle->value.win32.handle, timeout_ms,
+ /*bAlertable=*/FALSE);
+
+ iree_status_t status;
+ if (result == WAIT_TIMEOUT) {
+ // Timeout elapsed while waiting; note that the timeout may have been 0 to
+ // force a poll and be an expected result. We avoid a full status object
+ // here as we don't want to track all that in non-exceptional cases.
+ status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED);
+ } else if (result == WAIT_OBJECT_0) {
+ // Handle was signaled sucessfully.
+ status = iree_ok_status();
+ } else if (result == WAIT_ABANDONED_0) {
+ // The mutex handle was abandonded during the wait.
+ // This happens when a thread holding the mutex dies without releasing it.
+ // This is less common in-process and more for the cross-process situations
+ // where we have duped/opened a remote handle and the remote process dies.
+ // That's a pretty situation but not quite unheard of in sandboxing impls
+ // where death is a feature.
+ //
+ // NOTE: we shouldn't get abandoned handles in regular cases - both because
+ // we don't really use mutex handles (though users may provide them) and
+ // that mutex abandonment is exceptional. If you see this you are probably
+ // going to want to look for thread exit messages or zombie processes.
+ status = iree_make_status(IREE_STATUS_DATA_LOSS,
+ "mutex native handle abanonded; shared state is "
+ "(likely) inconsistent");
+ } else if (result == WAIT_FAILED) {
+ status = iree_make_status(iree_status_code_from_win32_error(GetLastError()),
+ "WFSO failed");
+ } else {
+ status = iree_make_status(IREE_STATUS_INTERNAL,
+ "WFSO internal error (unimplemented APC?)");
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+//===----------------------------------------------------------------------===//
+// iree_event_t
+//===----------------------------------------------------------------------===//
+
+iree_status_t iree_event_initialize(bool initial_state,
+ iree_event_t* out_event) {
+ iree_wait_primitive_value_t value;
+ memset(&value, 0, sizeof(value));
+ value.win32.handle =
+ (uintptr_t)CreateEvent(NULL, TRUE, initial_state ? TRUE : FALSE, NULL);
+ if (!value.win32.handle) {
+ return iree_make_status(iree_status_code_from_win32_error(GetLastError()),
+ "unable to create event");
+ }
+ return iree_wait_handle_wrap_primitive(IREE_WAIT_PRIMITIVE_TYPE_WIN32_HANDLE,
+ value, out_event);
+}
+
+void iree_event_deinitialize(iree_event_t* event) {
+ iree_wait_primitive_close(event);
+}
+
+void iree_event_set(iree_event_t* event) {
+ SetEvent((HANDLE)event->value.win32.handle);
+}
+
+void iree_event_reset(iree_event_t* event) {
+ ResetEvent((HANDLE)event->value.win32.handle);
+}
+
+#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/build_defs.oss.bzl b/iree/build_defs.oss.bzl
index 4d359e7..33fa984 100644
--- a/iree/build_defs.oss.bzl
+++ b/iree/build_defs.oss.bzl
@@ -14,9 +14,6 @@
"""Common Bazel definitions for IREE."""
-load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
-load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library")
-
# Target to the FileCheck binary.
INTREE_FILECHECK_TARGET = "@llvm-project//llvm:FileCheck"
@@ -40,27 +37,6 @@
"//iree/%s/internal:%s_internal" % (path, basename),
]
-# A platform-sensitive list of dependencies for non-test targets using Vulkan.
-PLATFORM_VULKAN_DEPS = select({
- "//iree/hal/vulkan:native_vk": [],
- "//iree/hal/vulkan:swiftshader_vk": [],
- "//conditions:default": [],
-})
-
-# A platform-sensitive list of dependencies for tests using Vulkan.
-PLATFORM_VULKAN_TEST_DEPS = []
-
-# Driver modules that register themselves at link time.
-IREE_DRIVER_MODULES = [
- "//iree/hal/dylib:dylib_driver_module",
- "//iree/hal/vmla:vmla_driver_module",
- "//iree/hal/vulkan:vulkan_driver_module",
- "//iree/hal/llvmjit:llvmjit_driver_module",
-]
-
-# Aliases to the Starlark cc rules.
-cc_library = _cc_library
-
def iree_build_test(name, targets):
"""Dummy rule to ensure that targets build.
@@ -68,34 +44,6 @@
"""
pass
-# The OSS build currently has issues with generating flatbuffer reflections.
-# It is hard-coded to disabled here (and in iree_flatbuffer_cc_library) until triaged/fixed.
-FLATBUFFER_SUPPORTS_REFLECTIONS = False
-
-def iree_flatbuffer_cc_library(**kwargs):
- """Wrapper for the flatbuffer_cc_library."""
-
- # TODO(laurenzo): The bazel rule for reflections seems broken in OSS
- # builds. Fix it and enable by default.
- flatbuffer_cc_library(
- gen_reflections = False,
- **kwargs
- )
-
-def cc_binary(linkopts = [], **kwargs):
- """Wrapper around low-level cc_binary that adds flags."""
- _cc_binary(
- linkopts = linkopts + select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- # Just include libraries that should be presumed in 2020.
- "-ldl",
- "-lpthread",
- ],
- }),
- **kwargs
- )
-
def iree_cmake_extra_content(content = "", inline = False):
"""Tool for inserting arbitrary content during Bazel->CMake conversion.
diff --git a/iree/compiler/BUILD b/iree/compiler/BUILD
index 01d988d..3d97d17 100644
--- a/iree/compiler/BUILD
+++ b/iree/compiler/BUILD
@@ -17,3 +17,9 @@
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)
+
+#===------------------------------------------------------------------------===#
+# Public API
+#===------------------------------------------------------------------------===#
+# TODO(#3817): expose :compiler as the public C API.
+# TODO(#3817): expose :cc as the public C++ wrapper API.
diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
index 448d86e..7e888ed 100644
--- a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
@@ -11,12 +11,14 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
+
+#include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
+
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
index ddda183..0e0f638 100644
--- a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
+++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
@@ -17,8 +17,9 @@
#include <memory>
+#include "mlir/Pass/Pass.h"
+
namespace mlir {
-class FunctionPass;
namespace iree_compiler {
/// An ad-hoc pass to canonicalize selected loop carried dependencies on
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
index 64b0f10..452b662 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
@@ -12,18 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
-#include "mlir/IR/PatternMatch.h"
#define DEBUG_TYPE "workgroup-calculation"
@@ -69,8 +66,10 @@
// Clone the linalg operation just to compute the loop bounds.
linalg::LinalgOp clonedLinalgOp =
rewriter.clone(*linalgOp.getOperation(), mapper);
- Optional<SmallVector<Value, 4>> bounds =
- getLoopRanges(rewriter, clonedLinalgOp);
+ SmallVector<Range, 4> ranges = clonedLinalgOp.createLoopRanges(rewriter, loc);
+ SmallVector<Value, 4> bounds;
+ bounds.reserve(ranges.size());
+ for (Range r : ranges) bounds.push_back(r.size);
unsigned numParallelLoops = linalgOp.iterator_types()
.getValue()
.take_while([](Attribute attr) -> bool {
@@ -78,8 +77,8 @@
getParallelIteratorTypeName();
})
.size();
- SmallVector<Value, 2> returnVals(
- bounds->begin(), std::next(bounds->begin(), numParallelLoops));
+ SmallVector<Value, 2> returnVals(bounds.begin(),
+ std::next(bounds.begin(), numParallelLoops));
rewriter.eraseOp(clonedLinalgOp);
return returnVals;
}
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
index 0ce710f..4e28779 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
@@ -12,38 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_WORKGROUPCALULCATION_H_
-#define MLIR_EDGE_BENCHMARKS_STRATEGIES_WORKGROUPCALULCATION_H_
+#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_GETNUMWORKGROUPS_H_
+#define IREE_COMPILER_CONVERSION_CODEGENUTILS_GETNUMWORKGROUPS_H_
+#include <array>
#include <cstdint>
-namespace llvm {
-class StringRef;
-template <typename T>
-class ArrayRef;
-template <typename T>
-class Optional;
-} // namespace llvm
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
namespace mlir {
-class Location;
-class FuncOp;
-class LogicalResult;
-class PatternRewriter;
-class ConversionPatternRewriter;
-class Value;
-namespace linalg {
-class LinalgOp;
-} // namespace linalg
-
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-class InterfaceOp;
-class TensorRewriteAdaptor;
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
namespace iree_compiler {
/// Generates a function that computes the number of workgroups as
@@ -94,4 +79,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_WORKGROUPCALULCATION_H_
+#endif // IREE_COMPILER_CONVERSION_CODEGENUTILS_GETNUMWORKGROUPS_H_
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
index b185ae7..5072293 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
@@ -23,11 +23,10 @@
#define IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
#include "llvm/ADT/ArrayRef.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
-
-class Operation;
namespace iree_compiler {
/// Marker to denote that a linalg operation has been partitioned to
diff --git a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
index 654e421..43bbf3e 100644
--- a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.cpp
@@ -13,10 +13,8 @@
// limitations under the License.
// -----------------------------------------------------------------------------
-// This is a copy of the matmul strategy infrastructure existing in mlir_edge.
-// This version will be removed once this gets upstreamed to common mlir.
-// Please try to limit changes in this code only minor changes or make sure the
-// changes are applied in mlir_edge as well.
+// This code will be removed once this gets upstreamed to common mlir.
+// Please try to limit changes in this code only minor changes.
#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
diff --git a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h
index 6048bea..7c8536c 100644
--- a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h
+++ b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h
@@ -24,12 +24,11 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/Function.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
-class FuncOp;
-
/// Abstract Transformation class applied in a sequence that also handles state
/// through markers.
struct Transformation {
diff --git a/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
index 8a143f5..a8f64c1 100644
--- a/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
+++ b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 9f33bd1..d178038 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -32,6 +32,8 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -44,8 +46,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
namespace mlir {
namespace iree_compiler {
@@ -665,44 +665,55 @@
namespace {
/// Converts mhlo.slice operation to linalg.subview + linalg.copy
-struct SliceOpConversion
- : public ConvertToLinalgBufferOp<SliceOpConversion, mhlo::SliceOp> {
- using ConvertToLinalgBufferOp<SliceOpConversion,
- mhlo::SliceOp>::ConvertToLinalgBufferOp;
+struct SliceOpConversion : public OpConversionPattern<mhlo::SliceOp> {
+ SliceOpConversion(MLIRContext *context,
+ TensorToBufferMap const &resultTensorToBufferMap,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<mhlo::SliceOp>(context, benefit),
+ resultTensorToBufferMap(resultTensorToBufferMap) {}
- LogicalResult apply(mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
- ArrayRef<Value> resultBuffers,
- ConversionPatternRewriter &rewriter) const;
+ LogicalResult matchAndRewrite(
+ mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
+ if (!argType || !argType.hasStaticShape()) {
+ return op.emitError("expected static shape");
+ }
+
+ auto resultShape = op.getResult().getType().cast<ShapedType>().getShape();
+ SmallVector<Value, 3> offsets, sizes, strides;
+ for (int i = 0, e = argType.getRank(); i < e; ++i) {
+ Value startIndex = rewriter.create<ConstantIndexOp>(
+ loc, op.start_indices().getValue<int64_t>(i));
+ offsets.push_back(startIndex);
+ Value size = rewriter.create<ConstantIndexOp>(loc, resultShape[i]);
+ sizes.push_back(size);
+ Value stride = rewriter.create<ConstantIndexOp>(
+ loc, op.strides().getValue<int64_t>(i));
+ strides.push_back(stride);
+ }
+ auto subViewOp = rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets,
+ sizes, strides);
+
+ // If the result of the subview is already mapped to a buffer, a copy is
+ // required from the buffer above into the mapped buffer.
+ if (Value bufferForResult =
+ resultTensorToBufferMap.lookup(op.getResult())) {
+ rewriter.create<linalg::CopyOp>(loc, subViewOp, bufferForResult);
+ rewriter.replaceOp(op, bufferForResult);
+ } else {
+ rewriter.replaceOp(op, subViewOp.getResult());
+ }
+
+ return success();
+ }
+
+ private:
+ TensorToBufferMap const &resultTensorToBufferMap;
};
} // namespace
-LogicalResult SliceOpConversion::apply(
- mhlo::SliceOp op, ArrayRef<Value> inputBuffers,
- ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
- auto loc = op.getLoc();
- auto argType = inputBuffers[0].getType().template dyn_cast<ShapedType>();
- if (!argType || !argType.hasRank()) {
- return op.emitError("expected known-rank args");
- }
-
- SmallVector<Value, 3> offsets, sizes, strides;
- for (int i = 0, e = argType.getRank(); i < e; ++i) {
- Value startIndex = rewriter.create<ConstantIndexOp>(
- loc, op.start_indices().getValue<int64_t>(i));
- offsets.push_back(startIndex);
- Value size = rewriter.create<DimOp>(loc, resultBuffers[0], i);
- sizes.push_back(size);
- Value stride = rewriter.create<ConstantIndexOp>(
- loc, op.strides().getValue<int64_t>(i));
- strides.push_back(stride);
- }
- auto subViewOp =
- rewriter.create<SubViewOp>(loc, inputBuffers[0], offsets, sizes, strides);
- rewriter.create<linalg::CopyOp>(loc, subViewOp, resultBuffers[0]);
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// mhlo.reduce_window conversion patterns and utility functions.
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index 9ae86c7..48707d4 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -22,6 +22,8 @@
#include <memory>
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
@@ -35,8 +37,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
index 406a62b..0de43b3 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-pipeline -canonicalize %s | IreeFileCheck %s
module {
// CHECK_LABEL: @slice_whole_buffer
@@ -25,19 +25,10 @@
// -----
module {
- // CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+ // CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 4)>
// CHECK: @slice_whole_stride
- // CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x4xi32>
// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
- // CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
- // CHECK-DAG: %[[ONE:.+]] = constant 1 : index
- // CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x4xi32>
- // CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x4xi32>
- // CHECK: subview %[[IN]]
- // CHECK-SAME: [%[[ONE]], %[[ZERO]]]
- // CHECK-SAME: [%[[DIM0]], %[[DIM1]]]
- // CHECK-SAME: [%[[ONE]], %[[ONE]]]
- // CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #[[MAP]]>
+ // CHECK: subview %[[IN]][1, 0] [1, 4] [1, 1] : memref<3x4xi32> to memref<1x4xi32, #[[MAP]]>
// CHECK: linalg.copy
func @slice_whole_stride()
attributes {signature = (tensor<3x4xi32>) -> (tensor<1x4xi32>)} {
@@ -60,19 +51,10 @@
// -----
module {
- // CHECK: #[[MAP:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+ // CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
// CHECK: @slice_stride_part
- // CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x2xi32>
// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<3x4xi32>
- // CHECK-DAG: %[[ZERO:.+]] = constant 0 : index
- // CHECK-DAG: %[[ONE:.+]] = constant 1 : index
- // CHECK-DAG: %[[DIM0:.+]] = dim %[[OUT]], %[[ZERO]] : memref<1x2xi32>
- // CHECK-DAG: %[[DIM1:.+]] = dim %[[OUT]], %[[ONE]] : memref<1x2xi32>
- // CHECK: subview %[[IN]]
- // CHECK-SAME: [%[[ONE]], %[[ONE]]]
- // CHECK-SAME: [%[[DIM0]], %[[DIM1]]]
- // CHECK-SAME: [%[[ONE]], %[[ONE]]]
- // CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #map>
+ // CHECK: subview %[[IN]][1, 1] [1, 2] [1, 1] : memref<3x4xi32> to memref<1x2xi32, #[[MAP]]>
// CHECK: linalg.copy
func @slice_stride_part()
attributes {signature = (tensor<3x4xi32>) -> (tensor<1x2xi32>)} {
@@ -91,3 +73,32 @@
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
}
}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: func @slice_stride_part
+// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<1x2xi32>
+// CHECK: %[[IN0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4xi32>
+// CHECK: %[[SUBVIEW:.+]] = subview %[[IN0]][1, 1] [1, 2] [1, 1] : memref<3x4xi32> to memref<1x2xi32, #[[MAP0]]>
+// CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<1x2xi32>
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[SUBVIEW]], %[[IN1]]
+// CHECK-SAME: outs(%[[OUT]]
+module {
+ func @slice_stride_part() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 {operand_result_index = 0 : i32} : tensor<3x4xi32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 {operand_result_index = 1 : i32} : tensor<1x2xi32>
+ %2 = "mhlo.slice"(%0) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ %3 = mhlo.add %2, %1 : tensor<1x2xi32>
+ hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 {operand_result_index = 2 : i32} : tensor<1x2xi32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
diff --git a/iree/compiler/Conversion/LLVMToLLVM/BUILD b/iree/compiler/Conversion/LLVMToLLVM/BUILD
new file mode 100644
index 0000000..1656647
--- /dev/null
+++ b/iree/compiler/Conversion/LLVMToLLVM/BUILD
@@ -0,0 +1,36 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "LLVMToLLVM",
+ srcs = [
+ "FastExpConversion.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ ],
+ deps = [
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
diff --git a/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt
new file mode 100644
index 0000000..4b02edc
--- /dev/null
+++ b/iree/compiler/Conversion/LLVMToLLVM/CMakeLists.txt
@@ -0,0 +1,31 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ LLVMToLLVM
+ HDRS
+ "Passes.h"
+ SRCS
+ "FastExpConversion.cpp"
+ DEPS
+ LLVMSupport
+ MLIRIR
+ MLIRLLVMIR
+ MLIRPass
+ MLIRTransforms
+ PUBLIC
+)
diff --git a/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp b/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp
new file mode 100644
index 0000000..f4f41d9
--- /dev/null
+++ b/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp
@@ -0,0 +1,125 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+// Fast polynomial approximation of exp(x) using its reduced range exp(y)
+// where y is in the range [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2)
+// = x - k * ln(2), exp(x) = exp(y) * 2^k. exp(y) is computed with 4th degree
+// polyomial: exp(y) = c0 + c1 * y + c2 * y^2 + c3 * y^3 + c4 * y^4
+struct FastExpConversionPattern : public OpRewritePattern<LLVM::ExpOp> {
+ using OpRewritePattern<LLVM::ExpOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(LLVM::ExpOp op,
+ PatternRewriter &rewriter) const override {
+ constexpr float ln2Const = 0.693147181f;
+ constexpr float ln2InvConst = 1.44269504f;
+
+ // Least squares polynomial fit computed :
+ // cValues = np.polyfit(np.linspace(0, math.log(2), 10000), np.exp(x), 4)
+ constexpr float cValues[5] = {0.05924867f, 0.15514645f, 0.50308552f,
+ 0.99968939f, 1.00000721531f};
+ auto loc = op.getLoc();
+ Value x = op.getOperand();
+
+ auto floatType = LLVM::LLVMType::getFloatTy(rewriter.getContext());
+ auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
+
+ Value ln2 = rewriter.create<LLVM::ConstantOp>(
+ loc, floatType, rewriter.getF32FloatAttr(ln2Const));
+ Value ln2Inv = rewriter.create<LLVM::ConstantOp>(
+ loc, floatType, rewriter.getF32FloatAttr(ln2InvConst));
+
+ // Compute reduced range input y = x - floor(x / ln(2)) * ln(2)
+ Value xL2Inv = rewriter.create<LLVM::FMulOp>(loc, floatType, x, ln2Inv);
+ Value kF32 = rewriter.create<LLVM::FFloorOp>(loc, floatType, xL2Inv);
+ Value kLn2 = rewriter.create<LLVM::FMulOp>(loc, floatType, kF32, ln2);
+ Value y = rewriter.create<LLVM::FSubOp>(loc, floatType, x, kLn2);
+
+ SmallVector<Value, 4> PConst(5);
+ for (int i = 0; i < 5; ++i) {
+ PConst[i] = rewriter.create<LLVM::ConstantOp>(
+ loc, floatType, rewriter.getF32FloatAttr(cValues[i]));
+ }
+ // Evaluate exp(y) = sum(c[i] * y**i, i)
+ Value expY = rewriter.create<LLVM::FMulOp>(loc, floatType, y, PConst[0]);
+ expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[1]);
+ expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, y);
+ expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[2]);
+ expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, y);
+ expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[3]);
+ expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, y);
+ expY = rewriter.create<LLVM::FAddOp>(loc, floatType, expY, PConst[4]);
+
+ // Compute exp2(k) with integer bitshift:
+ // exp2(k) = f32_bitcast((127 + k) << 23)
+ Value fPBias = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Type, rewriter.getI32IntegerAttr(127));
+ Value k = rewriter.create<LLVM::FPToSIOp>(loc, i32Type, kF32);
+ Value kPlusfPBias = rewriter.create<LLVM::AddOp>(loc, i32Type, k, fPBias);
+ Value shiftConst = rewriter.create<LLVM::ConstantOp>(
+ loc, i32Type, rewriter.getI32IntegerAttr(23));
+ Value twoPowkI =
+ rewriter.create<LLVM::ShlOp>(loc, i32Type, kPlusfPBias, shiftConst);
+ Value twoPowk = rewriter.create<LLVM::BitcastOp>(loc, floatType, twoPowkI);
+ expY = rewriter.create<LLVM::FMulOp>(loc, floatType, expY, twoPowk);
+ rewriter.replaceOp(op, {expY});
+ // TODO(ataei): Handle overflow and underflow cases (e.g |k| > 128).
+ return success();
+ }
+};
+
+struct FastExpConversionPass
+ : public PassWrapper<FastExpConversionPass, OperationPass<ModuleOp>> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<LLVM::LLVMDialect>();
+ }
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void populateFastExpConversionPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context) {
+ patterns.insert<FastExpConversionPattern>(context);
+}
+
+void FastExpConversionPass::runOnOperation() {
+ auto moduleOp = getOperation();
+ auto context = moduleOp.getContext();
+ OwningRewritePatternList patterns;
+ populateFastExpConversionPatterns(patterns, context);
+ applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createFastExpApproximationConversionPass() {
+ return std::make_unique<FastExpConversionPass>();
+}
+
+static PassRegistration<OperationPass<ModuleOp>> pass(
+ "iree-codegen-linalg-to-llvm-fast-exp-conversion-pass",
+ "Convert llvm.intr.exp into its fast polynomial approximation version",
+ [] { return std::make_unique<FastExpConversionPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LLVMToLLVM/Passes.h b/iree/compiler/Conversion/LLVMToLLVM/Passes.h
new file mode 100644
index 0000000..e40fe90
--- /dev/null
+++ b/iree/compiler/Conversion/LLVMToLLVM/Passes.h
@@ -0,0 +1,30 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_
+#define IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Creates a pass to rewrite llvm.intr.exp using its reduced range polynomial
+// approximation.
+std::unique_ptr<OperationPass<ModuleOp>>
+createFastExpApproximationConversionPass();
+
+} // namespace iree_compiler
+} // namespace mlir
+#endif // IREE_COMPILER_CONVERSION_LLVMTOLLVM_PASSES_H_
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index c4396c5..5dbbadd 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -37,6 +37,7 @@
"//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Conversion/HLOToLinalg",
+ "//iree/compiler/Conversion/LLVMToLLVM",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/IREE/IR",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index 9564f1e..4018a29 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -49,6 +49,7 @@
iree::compiler::Conversion::Common
iree::compiler::Conversion::HLOToHLO
iree::compiler::Conversion::HLOToLinalg
+ iree::compiler::Conversion::LLVMToLLVM
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
index 793c6f0..f2d7711 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
@@ -15,13 +15,11 @@
#include <cstdint>
#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
namespace mlir {
-class Operation;
-class Value;
-class OpBuilder;
-class Operation;
-
namespace iree_compiler {
enum class TilingLevel {
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 6f2f30b..348b425 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/Common/Passes.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
+#include "iree/compiler/Conversion/LLVMToLLVM/Passes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
@@ -34,6 +35,12 @@
"linag.matmul"),
llvm::cl::init(false));
+static llvm::cl::opt<bool> fastExpConversion(
+ "iree-codegen-linalg-to-llvm-fast-exp",
+ llvm::cl::desc("If true convert llvm.intr.exp into its range reduced "
+ "polynomial approximation."),
+ llvm::cl::init(false));
+
void addLinalgToLLVMPasses(OpPassManager &passManager) {
// Distribute linalg op among a 3d grid of parallel threads. Tile each
// workgroup thread memory then vectorize the linalg op.
@@ -66,6 +73,11 @@
passManager.addPass(createConvertToLLVMPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
+
+ // Approximate llvm.intr.exp with a 4-th order ploynmial in range[0, ln2].
+ if (fastExpConversion) {
+ passManager.addPass(createFastExpApproximationConversionPass());
+ }
}
void buildLLVMTransformPassPipeline(OpPassManager &passManager) {
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 8dcc089..33c74d4 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -30,6 +30,9 @@
/// Vectorize linalg ops executed in the same iree.workgroup.
std::unique_ptr<FunctionPass> createLinalgTileAndVectorizeWorkgroupsPass();
+std::unique_ptr<OperationPass<ModuleOp>>
+createFastExpApproximationConversionPass();
+
/// Populates patterns to rewrite linalg::ConvOp into packed img2col operation
/// followed by linalg::MatmulOp.
void populateConvImg2ColMatmulConversionPatterns(
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
index b6ef3e3..fcae26a 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
@@ -75,13 +75,9 @@
// CHECK: %[[STRIDE_DIM0:.+]] = llvm.mul %[[STRIDE_DIM1:.+]], %[[DIM1_0:.+]]: !llvm.i64
// CHECK: %[[INSERT_DIM0:.+]] = llvm.insertvalue %[[STRIDE_DIM0:.+]], %[[MEMREF3:.+]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[EXTRACT1:.+]] = llvm.extractvalue %[[INSERT_DIM0:.+]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[CONST0_3:.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[EXTRACT2:.+]] = llvm.extractvalue %[[INSERT_DIM0:.+]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MUL1:.+]] = llvm.mul %[[CONST0_2:.+]], %[[EXTRACT2:.+]] : !llvm.i64
-// CHECK: %[[ADD1:.+]] = llvm.add %[[CONST0_3:.+]], %[[MUL1:.+]] : !llvm.i64
-// CHECK: %[[CONST1_2:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK: %[[MUL2:.+]] = llvm.mul %[[CONST0_2:.+]], %[[CONST1_2:.+]] : !llvm.i64
-// CHECK: %[[ADD2:.+]] = llvm.add %[[ADD1:.+]], %[[MUL2:.+]] : !llvm.i64
+// CHECK: %[[MUL1:.+]] = llvm.mul %[[CONST0_2]], %[[EXTRACT2]] : !llvm.i64
+// CHECK: %[[ADD1:.+]] = llvm.add %[[MUL1]], %[[CONST0_2]] : !llvm.i64
// CHECK: %[[GET_PTR:.+]] = llvm.getelementptr %[[EXTRACT1:.+]][%[[ADD2:.+]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK: %[[LOAD:.+]] = llvm.load %[[GET_PTR:.+]] : !llvm.ptr<float>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index d61e74b..0156b47 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -37,6 +37,7 @@
"ConvertToGPUPass.cpp",
"ConvertToSPIRVPass.cpp",
"CooperativeMatrixAnalysis.cpp",
+ "FoldGPUProcessorIDUses.cpp",
"KernelDispatchUtils.cpp",
"LinalgTileAndFusePass.cpp",
"MatMulVectorizationTest.cpp",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index b10d22b..9368a9d 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -39,6 +39,7 @@
"ConvertToGPUPass.cpp"
"ConvertToSPIRVPass.cpp"
"CooperativeMatrixAnalysis.cpp"
+ "FoldGPUProcessorIDUses.cpp"
"KernelDispatchUtils.cpp"
"LinalgTileAndFusePass.cpp"
"MatMulVectorizationTest.cpp"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h b/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h
index f75133e..8ee737a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h
@@ -23,13 +23,14 @@
// or not.
//
//===----------------------------------------------------------------------===//
+
#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_COOPERATIVEMATRIXANALYSIS_H_
#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_COOPERATIVEMATRIXANALYSIS_H_
+
#include "llvm/ADT/DenseSet.h"
+#include "mlir/IR/Operation.h"
namespace mlir {
-class Operation;
-
namespace iree_compiler {
class CooperativeMatrixAnalysis {
@@ -45,6 +46,8 @@
private:
llvm::DenseSet<mlir::Operation *> usesCooperativeMatrix;
};
+
} // namespace iree_compiler
} // namespace mlir
+
#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_COOPERATIVEMATRIXANALYSIS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp
new file mode 100644
index 0000000..1dda99d
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/FoldGPUProcessorIDUses.cpp
@@ -0,0 +1,266 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- FoldGPUProcessorIDUses.cpp -----------------------------------------===//
+//
+// This file implements patterns and passes for folding GPU processor ID uses.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+#include "iree/compiler/Conversion/Common/Attributes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-fold-gpu-procid-uses"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Returns true if the given `expr` is a linear expression over one
+/// symbol/dimension.
+///
+/// Note that this function is not meant to check for all linear expression
+/// cases. It only checks that:
+/// 1) No mod/div operations,
+/// 2) For mul operations, one of the operand is a constant.
+/// Also this function assumes `expr` only contains one symbol/dimension.
+bool isLinearExpr(AffineExpr expr) {
+ switch (expr.getKind()) {
+ case mlir::AffineExprKind::Add: {
+ auto binExpr = expr.cast<AffineBinaryOpExpr>();
+ return isLinearExpr(binExpr.getLHS()) && isLinearExpr(binExpr.getRHS());
+ }
+ case mlir::AffineExprKind::Mul: {
+ auto binExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineExpr lhs = binExpr.getLHS();
+ AffineExpr rhs = binExpr.getRHS();
+ return (lhs.isa<AffineConstantExpr>() && isLinearExpr(rhs)) ||
+ (rhs.isa<AffineConstantExpr>() && isLinearExpr(lhs));
+ };
+ case mlir::AffineExprKind::Mod:
+ case mlir::AffineExprKind::FloorDiv:
+ case mlir::AffineExprKind::CeilDiv:
+ return false;
+ case mlir::AffineExprKind::Constant:
+ case mlir::AffineExprKind::DimId:
+ case mlir::AffineExprKind::SymbolId:
+ return true;
+ }
+}
+
+/// Replaces the given `dim` in `expr` with a constant `value`.
+AffineExpr replaceSymbolWithValue(AffineExpr expr, AffineSymbolExpr dim,
+ int64_t value) {
+ auto cstExpr = getAffineConstantExpr(value, expr.getContext());
+ return expr.replace(dim, cstExpr);
+}
+
+/// Converts a dimension string to its corresponding index.
+int dimensionToIndex(StringRef dimension) {
+ return StringSwitch<int>(dimension).Case("x", 0).Case("y", 1).Case("z", 2);
+}
+
+/// Gets the block processor ID's upper bound. This queries the workgroup count
+/// function.
+Optional<int64_t> getProcessorIDUpperBound(gpu::BlockIdOp blockIDOp) {
+ auto numWorkgroupsFn = getNumWorkgroupsFn(blockIDOp.getParentOfType<FuncOp>(),
+ getNumWorkgroupsFnAttrName());
+ if (!numWorkgroupsFn) return llvm::None;
+
+ Operation *terminator = numWorkgroupsFn.getBlocks().back().getTerminator();
+ auto retOp = dyn_cast<ReturnOp>(terminator);
+ if (!retOp || retOp.getNumOperands() != 3) return llvm::None;
+ LLVM_DEBUG(llvm::dbgs() << "workgroup count function return op: " << retOp
+ << "\n");
+
+ int index = dimensionToIndex(blockIDOp.dimension());
+ IntegerAttr attr;
+ if (!matchPattern(retOp.getOperand(index), m_Constant(&attr)))
+ return llvm::None;
+
+ return attr.getInt();
+}
+
+/// Gets the thread processor ID's upper bound. This queries the SPIR-V entry
+/// point ABI.
+Optional<int64_t> getProcessorIDUpperBound(gpu::ThreadIdOp threadIDOp) {
+ FuncOp funcOp = threadIDOp.getParentOfType<FuncOp>();
+ auto abiAttr = funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
+ spirv::getEntryPointABIAttrName());
+ if (!abiAttr) return llvm::None;
+
+ int index = dimensionToIndex(threadIDOp.dimension());
+ auto valueIt = abiAttr.local_size().getIntValues().begin() + index;
+ return (*valueIt).getZExtValue();
+}
+
+/// Folds `affine.min` ops which has only one symbol operand, which is a
+/// processor ID. For such cases we can use the processor ID's upper bound to
+/// simplify the `affine.min`.
+///
+/// For example, this pattern can simplify the following IR:
+/// ```
+/// %id = "gpu.thread_id"() {dimension = "x"} : () -> index
+/// ...
+/// affine.min affine_map<()[s0] -> (3, s0 * -2 + 225)>()[%id]
+/// ```
+/// if the upper bound for thread ID along the x dimension is 112.
+struct FoldAffineMinOverProcessorID : OpRewritePattern<AffineMinOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AffineMinOp minOp,
+ PatternRewriter &rewriter) const override {
+ LLVM_DEBUG(llvm::dbgs() << "inspecting " << minOp << "\n");
+ MLIRContext *context = minOp.getContext();
+ auto dimensions = minOp.getDimOperands();
+ auto symbols = minOp.getSymbolOperands();
+
+ // We expect the affine.min op to only have one symbol operand.
+ if (!llvm::hasSingleElement(symbols) || !dimensions.empty()) {
+ return rewriter.notifyMatchFailure(
+ minOp, "expected to only have one symbol operand");
+ }
+
+ // And the symbol operand should come from a GPU processor ID.
+ Operation *symbolOp = symbols.front().getDefiningOp();
+ auto symbol0 = getAffineSymbolExpr(0, context).cast<AffineSymbolExpr>();
+
+ Optional<int64_t> ub;
+ if (auto blockIDOp = dyn_cast<gpu::BlockIdOp>(symbolOp)) {
+ ub = getProcessorIDUpperBound(blockIDOp);
+ } else if (auto threadIDOp = dyn_cast<gpu::ThreadIdOp>(symbolOp)) {
+ ub = getProcessorIDUpperBound(threadIDOp);
+ }
+ if (!ub) {
+ return rewriter.notifyMatchFailure(
+ minOp, "failed to query processor ID upper bound");
+ }
+ LLVM_DEBUG(llvm::dbgs() << "processor ID '" << *symbolOp
+ << "' upper bound: " << *ub << "\n");
+
+ // Look at each result expression. For expressions that are functions of
+ // the input symbol, try to simplify it. We do this by replacing the
+ // symbol with its lower and upper bound. This requires the result
+ // expression to be a linear function of the input symbol.
+ SmallVector<AffineExpr, 4> results;
+ // The indices into `results` where the corresponding AffineExpr is a
+ // constant from the original map. We need to keep track of this so later we
+ // can probe whether the constant is the min.
+ SmallVector<unsigned, 4> cstIndices;
+ for (auto result : minOp.getAffineMap().getResults()) {
+ if (auto cstResult = result.dyn_cast<AffineConstantExpr>()) {
+ results.push_back(cstResult);
+ cstIndices.push_back(results.size() - 1);
+ } else if (isLinearExpr(result)) {
+ results.push_back(simplifyAffineExpr(
+ replaceSymbolWithValue(result, symbol0, 0), 0, 1));
+ results.push_back(simplifyAffineExpr(
+ replaceSymbolWithValue(result, symbol0, *ub - 1), 0, 1));
+ LLVM_DEBUG({
+ auto map = AffineMap::get(0, 1, results, context);
+ llvm::dbgs() << "map after substituting with processor ID bounds: "
+ << map << "\n";
+ });
+ } else {
+ // We cannot handle such cases. Just bail out on matching the pattern.
+ return rewriter.notifyMatchFailure(
+ minOp, "expected to have a linear function of the symbol");
+ }
+ }
+
+ // Returns true if the given affine expression is a non-negative constant.
+ auto isNonNegativeCstExpr = [](AffineExpr e) {
+ if (auto cst = e.dyn_cast<AffineConstantExpr>())
+ return cst.getValue() >= 0;
+ return false;
+ };
+
+ // Check whether any of the original constant expressions, when subtracted
+ // from all other expressions, produces only >= 0 constants. If so, it is
+ // the min.
+ for (auto cstIndex : cstIndices) {
+ auto candidate = results[cstIndex].cast<AffineConstantExpr>();
+
+ SmallVector<AffineExpr, 4> subExprs;
+ subExprs.reserve(results.size());
+ for (auto r : results) subExprs.push_back(r - candidate);
+
+ AffineMap subMap =
+ simplifyAffineMap(AffineMap::get(0, 1, subExprs, context));
+ LLVM_DEBUG(llvm::dbgs() << "map by subtracting expr '" << candidate
+ << "': " << subMap << "\n");
+ if (llvm::all_of(subMap.getResults(), isNonNegativeCstExpr)) {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp,
+ candidate.getValue());
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+
+/// Tests processor ID use folding patterns.
+struct FoldGPUProcessIDUsesPass
+ : public PassWrapper<FoldGPUProcessIDUsesPass, FunctionPass> {
+ FoldGPUProcessIDUsesPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, gpu::GPUDialect>();
+ }
+
+ void runOnFunction() override {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList patterns;
+ populateFoldGPUProcessorIDUsesPatterns(context, patterns);
+ applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+ }
+};
+
+}; // namespace
+
+void populateFoldGPUProcessorIDUsesPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<FoldAffineMinOverProcessorID>(context);
+ AffineMinOp::getCanonicalizationPatterns(patterns, context);
+}
+
+std::unique_ptr<OperationPass<FuncOp>> createFoldProcessorIDUsesPass() {
+ return std::make_unique<FoldGPUProcessIDUsesPass>();
+}
+
+static PassRegistration<FoldGPUProcessIDUsesPass> pass(
+ "iree-codegen-fold-gpu-procid-uses",
+ "Fold GPU processor ID uses where possible",
+ [] { return std::make_unique<FoldGPUProcessIDUsesPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index a233a8b..7f0198a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -61,6 +61,14 @@
return std::make_tuple(nprocs_x, nprocs / nprocs_x);
}
+namespace {
+struct LaunchConfigInfo {
+ std::array<int64_t, 3> workgroupSize = {1, 1, 1};
+ std::array<int64_t, 3> numSubgroups = {1, 1, 1};
+ bool vectorize = false;
+};
+} // namespace
+
/// For a given operation `op`, compute the following configurations according
/// to SPIR-V `targetEnv` and `options`:
/// 1) number of tiling levels and tile sizes to use (updates `tileSizes`),
@@ -71,8 +79,7 @@
static LogicalResult getOpLaunchConfig(T op, const spirv::TargetEnv &targetEnv,
const SPIRVCodegenOptions &options,
TileSizesListType &tileSizes,
- std::array<int64_t, 3> &workgroupSize,
- std::array<int64_t, 3> &numSubgroups) {
+ LaunchConfigInfo &config) {
return op.emitError("undefined launch config for tiled operation");
}
@@ -82,14 +89,13 @@
const spirv::TargetEnv &targetEnv,
const SPIRVCodegenOptions &options,
TileSizesListType &tileSizes,
- std::array<int64_t, 3> &workgroupSize,
- std::array<int64_t, 3> &numSubgroups) {
+ LaunchConfigInfo &config) {
unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
.max_compute_workgroup_invocations()
.getInt();
- std::tie(workgroupSize[0], workgroupSize[1]) =
+ std::tie(config.workgroupSize[0], config.workgroupSize[1]) =
distributeProcs2D(maxWorkgroupSize);
- workgroupSize[2] = 1;
+ config.workgroupSize[2] = 1;
// This is just being hard-wired for now to be minimal viable, but this can be
// decided better when we have better estimates of device charecteristics.
const int64_t nRowsPerWorkitem = 1;
@@ -102,9 +108,9 @@
tileSizeK = 32;
}
assert(tileSizes.empty());
- SmallVector<int64_t, 4> ts = {nBatchesPerWorkitem,
- nRowsPerWorkitem * workgroupSize[1],
- nColsPerWorkitem * workgroupSize[0], tileSizeK};
+ SmallVector<int64_t, 4> ts = {
+ nBatchesPerWorkitem, nRowsPerWorkitem * config.workgroupSize[1],
+ nColsPerWorkitem * config.workgroupSize[0], tileSizeK};
tileSizes.emplace_back(std::move(ts));
return success();
}
@@ -209,16 +215,36 @@
std::array<int64_t, 3> &numSubgroups) {
if (targetEnv.getVendorID() != spirv::Vendor::ARM) return failure();
+ auto lhsType = op.inputs()[0].getType().cast<MemRefType>();
+ auto rhsType = op.inputs()[1].getType().cast<MemRefType>();
+ assert(lhsType.getElementType() == rhsType.getElementType());
+ // Pick ideal tile size based on the type.
+ SmallVector<int64_t, 4> workgroupLevelTs;
+ if (lhsType.getElementType().isF16()) {
+ workgroupLevelTs.append({16, 64, 8});
+ } else {
+ workgroupLevelTs.append({8, 64, 4});
+ }
+
+ // Fall back to the none vectorize path for cases we don't handle.
+ if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape() ||
+ lhsType.getDimSize(0) % workgroupLevelTs[0] != 0 ||
+ rhsType.getDimSize(0) % workgroupLevelTs[1] != 0 ||
+ lhsType.getDimSize(1) % workgroupLevelTs[2] != 0) {
+ return failure();
+ }
+
workgroupSize[0] = targetEnv.getResourceLimits().subgroup_size().getInt();
workgroupSize[1] = 1;
workgroupSize[2] = 1;
- SmallVector<int64_t, 4> ts = {8, 64, 4};
- tileSizes.emplace_back(ts);
+ tileSizes.emplace_back(workgroupLevelTs);
// No tiling at the subgroup level since this target doesn't use subgroup op
// or shared memory.
tileSizes.emplace_back();
- SmallVector<int64_t, 4> threadTs = {ts[0], ts[1] / workgroupSize[0], ts[2]};
- tileSizes.emplace_back(threadTs);
+ SmallVector<int64_t, 4> invocationLevelTs = {
+ workgroupLevelTs[0], workgroupLevelTs[1] / workgroupSize[0],
+ workgroupLevelTs[2]};
+ tileSizes.emplace_back(invocationLevelTs);
return success();
}
@@ -227,25 +253,27 @@
const spirv::TargetEnv &targetEnv,
const SPIRVCodegenOptions &options,
TileSizesListType &tileSizes,
- std::array<int64_t, 3> &workgroupSize,
- std::array<int64_t, 3> &numSubgroups) {
+ LaunchConfigInfo &config) {
if (options.enableVectorization &&
succeeded(getConfigForCooperativeMatmul(op, targetEnv, options, tileSizes,
- workgroupSize, numSubgroups))) {
+ config.workgroupSize,
+ config.numSubgroups))) {
+ config.vectorize = true;
return success();
} else if (options.enableVectorization &&
succeeded(getTargetSpecificConfig(op, targetEnv, options,
- tileSizes, workgroupSize,
- numSubgroups))) {
+ tileSizes, config.workgroupSize,
+ config.numSubgroups))) {
+ config.vectorize = true;
return success();
}
unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
.max_compute_workgroup_invocations()
.getInt();
- std::tie(workgroupSize[0], workgroupSize[1]) =
+ std::tie(config.workgroupSize[0], config.workgroupSize[1]) =
distributeProcs2D(maxWorkgroupSize);
- workgroupSize[2] = 1;
+ config.workgroupSize[2] = 1;
const int nRowsPerWorkitem = 1;
const int nColsPerWorkitem = 1;
int64_t tileSizeK = 0;
@@ -255,8 +283,9 @@
tileSizeK = 32;
}
assert(tileSizes.empty());
- SmallVector<int64_t, 4> ts = {nRowsPerWorkitem * workgroupSize[1],
- nColsPerWorkitem * workgroupSize[0], tileSizeK};
+ SmallVector<int64_t, 4> ts = {nRowsPerWorkitem * config.workgroupSize[1],
+ nColsPerWorkitem * config.workgroupSize[0],
+ tileSizeK};
tileSizes.emplace_back(std::move(ts));
return success();
}
@@ -266,8 +295,7 @@
const spirv::TargetEnv &targetEnv,
const SPIRVCodegenOptions &options,
TileSizesListType &tileSizes,
- std::array<int64_t, 3> &workgroupSize,
- std::array<int64_t, 3> &numSubgroups) {
+ LaunchConfigInfo &config) {
unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
.max_compute_workgroup_invocations()
.getInt();
@@ -275,7 +303,7 @@
int64_t tileSizeY = maxWorkgroupSize / tileSizeX;
SmallVector<int64_t, 4> ts = {1, tileSizeY, tileSizeX};
tileSizes.emplace_back(std::move(ts));
- workgroupSize = {tileSizeX, tileSizeY, 1};
+ config.workgroupSize = {tileSizeX, tileSizeY, 1};
return success();
}
@@ -283,8 +311,7 @@
static LogicalResult getPoolingOpLaunchConfig(
PoolingOpTy op, const spirv::TargetEnv &targetEnv,
const SPIRVCodegenOptions &options, TileSizesListType &tileSizes,
- std::array<int64_t, 3> &workgroupSize,
- std::array<int64_t, 3> &numSubgroups) {
+ LaunchConfigInfo &config) {
unsigned maxWorkgroupSize = targetEnv.getResourceLimits()
.max_compute_workgroup_invocations()
.getInt();
@@ -299,7 +326,7 @@
ts[ts.size() - 2] = tileSizeY;
ts[ts.size() - 1] = tileSizeX;
tileSizes.emplace_back(std::move(ts));
- workgroupSize = {tileSizeX, tileSizeY, 1};
+ config.workgroupSize = {tileSizeX, tileSizeY, 1};
return success();
}
@@ -308,10 +335,9 @@
LogicalResult getOpLaunchConfig( \
opName op, const spirv::TargetEnv &targetEnv, \
const SPIRVCodegenOptions &options, TileSizesListType &tileSizes, \
- std::array<int64_t, 3> &workgroupSize, \
- std::array<int64_t, 3> &numSubgroups) { \
+ LaunchConfigInfo &config) { \
return getPoolingOpLaunchConfig(op, targetEnv, options, tileSizes, \
- workgroupSize, numSubgroups); \
+ config); \
}
DEFINE_POOLING_OP_CONFIG(linalg::PoolingMaxOp)
@@ -328,7 +354,7 @@
LogicalResult LaunchConfig::init(
MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
- const SPIRVCodegenOptions &options, ArrayRef<Operation *> linalgOps) {
+ const SPIRVCodegenOptions &options, ArrayRef<linalg::LinalgOp> linalgOps) {
unsigned numTiledOps = 0;
auto setKey = [&](Operation *op) -> std::string {
std::string key = llvm::formatv("__op_num_{0}__", numTiledOps++).str();
@@ -338,7 +364,7 @@
};
if (!options.workgroupSize.empty()) {
- for (Operation *linalgOp : linalgOps)
+ for (linalg::LinalgOp linalgOp : linalgOps)
tileSizes[setKey(linalgOp)].emplace_back(options.tileSizes.begin(),
options.tileSizes.end());
workgroupSize = {1, 1, 1};
@@ -354,21 +380,21 @@
spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(*linalgOps.begin()));
Optional<linalg::LinalgOp> rootOperation = {};
-
- for (Operation *op : linalgOps) {
-#define DISPATCH(opName) \
- if (auto linalgOp = dyn_cast<opName>(op)) { \
- if (rootOperation) { \
- return linalgOp.emitError( \
- "unhandled multiple root operations in dispatch region"); \
- } \
- rootOperation = cast<linalg::LinalgOp>(linalgOp.getOperation()); \
- TileSizesListType &tileSizesInfo = tileSizes[setKey(*rootOperation)]; \
- if (failed(getOpLaunchConfig(linalgOp, targetEnv, options, tileSizesInfo, \
- workgroupSize, numSubgroups))) { \
- return failure(); \
- } \
- continue; \
+ LaunchConfigInfo config;
+ for (linalg::LinalgOp linalgOp : linalgOps) {
+#define DISPATCH(opName) \
+ if (auto op = dyn_cast<opName>(linalgOp.getOperation())) { \
+ if (rootOperation) { \
+ return op.emitError( \
+ "unhandled multiple root operations in dispatch region"); \
+ } \
+ rootOperation = linalgOp; \
+ TileSizesListType &tileSizesInfo = tileSizes[setKey(*rootOperation)]; \
+ if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo, \
+ config))) { \
+ return failure(); \
+ } \
+ continue; \
}
DISPATCH(linalg::BatchMatmulOp)
@@ -380,7 +406,9 @@
#undef DISPATCH
}
-
+ workgroupSize = config.workgroupSize;
+ numSubgroups = config.numSubgroups;
+ vectorize = config.vectorize;
if (!rootOperation) {
// No root operations found. Dont need to do anything.
return success();
@@ -407,14 +435,12 @@
// producer and consumer must match for the parallel loops.
for (auto dependence :
dependenceGraph.getDependentOperations(rootOperation.getValue())) {
- Optional<unsigned> viewIndex =
- rootOperation->getIndexOfShapedOperand(dependence.indexingView);
- AffineMap indexingMap = rootOperation->getIndexingMap(*viewIndex);
+ unsigned viewIndex = dependence.indexingOpView.operandIndex;
+ AffineMap indexingMap = rootOperation->getIndexingMap(viewIndex);
linalg::LinalgOp fusedOp =
cast<linalg::LinalgOp>(dependence.dependentOpView.op);
- Optional<unsigned> fusedViewIndex =
- fusedOp.getIndexOfShapedOperand(dependence.dependentOpView.view);
- AffineMap fusedIndexingMap = fusedOp.getIndexingMap(*fusedViewIndex);
+ unsigned fusedViewIndex = dependence.dependentOpView.operandIndex;
+ AffineMap fusedIndexingMap = fusedOp.getIndexingMap(fusedViewIndex);
if (indexingMap.getNumResults() < numOuterParallel ||
fusedIndexingMap.getNumResults() < numOuterParallel ||
!llvm::all_of(
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index fe4c453..72ee6d2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -24,28 +24,21 @@
#include <array>
+#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
namespace mlir {
-class FuncOp;
-class LogicalResult;
-class Operation;
-class PatternRewriter;
-class ShapedType;
-class Value;
-
-namespace linalg {
-class LinalgDependenceGraph;
-class LinalgOp;
-} // namespace linalg
-namespace iree_compiler {
-struct SPIRVCodegenOptions;
-}
-
namespace iree_compiler {
/// Store the tile sizes to use at different levels of tiling as a vector of
@@ -75,7 +68,7 @@
LogicalResult init(MLIRContext *context,
const linalg::LinalgDependenceGraph &dependenceGraph,
const SPIRVCodegenOptions &options,
- ArrayRef<Operation *> linalgOps);
+ ArrayRef<linalg::LinalgOp> linalgOps);
/// Remove attributed added to operations for retrieving tile size
/// information.
@@ -110,6 +103,9 @@
return !getTileSizes(op, level).empty();
}
+ /// Use vectorize transformations.
+ bool useVectorize() const { return vectorize; }
+
protected:
/// Current tile size configuration per operation. They key used here to
/// retrieve the tile size information per operation is the value of a StrAttr
@@ -125,6 +121,9 @@
/// Number of subgroups that are logically distributed along x, y & z.
std::array<int64_t, 3> numSubgroups;
+ /// Use vectorization.
+ bool vectorize = false;
+
private:
/// Retrieves the key to use to get the `tileSizes` for a given
/// `operation`. Returns llvm::None on failure.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 44b0841..34c6632 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -529,8 +529,10 @@
if (linalgOps.empty()) continue;
LaunchConfig launchConfig;
- SmallVector<Operation *, 4> linalgOpsVec(linalgOps.begin(),
- linalgOps.end());
+ SmallVector<linalg::LinalgOp, 4> linalgOpsVec =
+ llvm::to_vector<4>(llvm::map_range(linalgOps, [](Operation *op) {
+ return cast<linalg::LinalgOp>(op);
+ }));
linalg::Aliases aliases;
linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOpsVec);
if (failed(launchConfig.init(context, dependenceGraph, options,
@@ -594,7 +596,7 @@
});
}
- if (options.enableVectorization) {
+ if (launchConfig.useVectorize()) {
{
OwningRewritePatternList secondLevelTilingPatterns;
populateTilingToSubgroupPatterns(context, launchConfig,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 7c650ca..482d4ae 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -24,6 +24,10 @@
namespace mlir {
namespace iree_compiler {
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
/// Pass to tile and fuse linalg operations on buffers. The pass takes as
/// argument the `workgroupSize` that the tiling should use. Note that the
/// tile-sizes are the reverse of the workgroup size. So workgroup size along
@@ -59,11 +63,18 @@
/// Pass to apply tiling and vectorization transformations on linagl::MatMulOp.
std::unique_ptr<FunctionPass> createMatMulTileAndVectorizeGPUPass();
-/// Convert memref of scalar to memref of vector of efficent size. This will
+/// Converts memref of scalar to memref of vector of efficent size. This will
/// allow to convert memory accesses to vector load/store in SPIR-V without
/// having pointer bitcast.
std::unique_ptr<OperationPass<ModuleOp>> createVectorizeMemref();
+/// Creates a pass to fold processor ID uses where possible.
+std::unique_ptr<OperationPass<FuncOp>> createFoldProcessorIDUsesPass();
+
+//===----------------------------------------------------------------------===//
+// Pipelines
+//===----------------------------------------------------------------------===//
+
/// Populates passes needed to lower a XLA HLO op to SPIR-V dialect via the
/// structured ops path. The pass manager `pm` in here operate on the module
/// within the IREE::HAL::ExecutableOp. The `workGroupSize` can be used to
@@ -73,9 +84,19 @@
void buildSPIRVTransformPassPipeline(OpPassManager &pm,
const SPIRVCodegenOptions &options);
-/// Populate patterns to tile and distribute linalg operations.
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+/// Populates patterns to tile and distribute linalg operations.
void populateLinalgTileAndDistributePatterns(
MLIRContext *context, OwningRewritePatternList &patterns);
+
+/// Populates patterns to fold processor ID uses by using processor counts
+/// information where possible.
+void populateFoldGPUProcessorIDUsesPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
index 972e116..ea1316c 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -21,19 +21,19 @@
#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/FoldUtils.h"
namespace mlir {
-class FuncOp;
-class Value;
-class SubViewOp;
-class OperationFolder;
-class OpBuilder;
-class LogicalResult;
-
namespace iree_compiler {
static constexpr int kNumGPUDims = 3;
+
/// Allocation callback for allocation workgroup local memory.
Optional<Value> allocateWorkgroupMemory(OpBuilder &b, SubViewOp subview,
ArrayRef<Value> boundingSubViewSize,
@@ -63,6 +63,7 @@
/// Updates the workgroup size used for the dispatch region.
LogicalResult updateWorkGroupSize(FuncOp funcOp,
ArrayRef<int64_t> workGroupSize);
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
index 9391676..5a610d3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
@@ -24,8 +24,8 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/StandardTypes.h"
-constexpr int kVectorizationSizeInBits = 128;
-constexpr int kVecSize = kVectorizationSizeInBits / (sizeof(float) * 8);
+constexpr int kMaxVectorizationSizeInBits = 128;
+constexpr int kMaxVectorNumElements = 4;
namespace mlir {
namespace iree_compiler {
@@ -47,22 +47,6 @@
return true;
}
-/// Returns true of the type is a memref that can be vectorized to
-/// vector<4xi32>. If it returns true also return the uses of memref.
-static bool isMemRefAndVectorizable(Value v,
- SmallVectorImpl<Operation *> &uses) {
- auto memrefType = v.getType().dyn_cast<MemRefType>();
- // To be able to vectorize the memref it needs to be a scalar memref with a
- // static most inner dimension aligned on the vectorization size.
- return memrefType && !memrefType.getElementType().isa<VectorType>() &&
- (kVectorizationSizeInBits % memrefType.getElementTypeBitWidth() ==
- 0) &&
- !ShapedType::isDynamic(memrefType.getShape().back()) &&
- ((memrefType.getElementTypeBitWidth() * memrefType.getShape().back()) %
- kVectorizationSizeInBits ==
- 0) &&
- getUsesIfAllTransferOp(v, uses);
-}
/// Returns the bitwidth of a scalar or vector type.
static Optional<unsigned> getBitWidth(Type type) {
@@ -76,6 +60,37 @@
return {};
}
+// Calculate the vector size we want to use based on the memref uses.
+static unsigned calculateMemrefVecSize(SmallVectorImpl<Operation *> &uses) {
+ unsigned minSize = kMaxVectorizationSizeInBits;
+ for (Operation *op : uses) {
+ auto transferOp = dyn_cast<VectorTransferOpInterface>(op);
+ if (!transferOp) return 0;
+ Optional<unsigned> transferSize = getBitWidth(transferOp.getVectorType());
+ if (!transferSize) return 0;
+ minSize = std::min(minSize, *transferSize);
+ }
+ return minSize;
+}
+
+/// If the memref is vectorizable return the vector size we want to use,
+/// otherwise return 0. If it returns a value greater than 0 it also returns the
+/// memref uses.
+static unsigned isMemRefAndVectorizable(Value v,
+ SmallVectorImpl<Operation *> &uses) {
+ auto memrefType = v.getType().dyn_cast<MemRefType>();
+ // To be able to vectorize the memref it needs to be a scalar memref with a
+ // static most inner dimension aligned on the vectorization size.
+ if (memrefType && !memrefType.getElementType().isa<VectorType>() &&
+ (kMaxVectorizationSizeInBits % memrefType.getElementTypeBitWidth() ==
+ 0) &&
+ !ShapedType::isDynamic(memrefType.getShape().back()) &&
+ getUsesIfAllTransferOp(v, uses)) {
+ return calculateMemrefVecSize(uses);
+ }
+ return 0;
+}
+
namespace {
/// Analyze memref usages to decide if it should be vectorized. Right now the
/// logic is to vectorize memref only if it is used by
@@ -85,7 +100,12 @@
explicit MemRefUsageAnalysis(mlir::Operation *);
// Returns true if the memref should be converted to a vector of memref.
- bool vectorizeMemRef(Value v) const { return vectorize.count(v); }
+ bool vectorizeMemRef(Value v) const { return vectorization_size.count(v); }
+
+ // Return the size of the vector we want to use for memref vectorization.
+ unsigned getMemRefVectorSizeInBits(Value v) const {
+ return vectorization_size.find(v)->second;
+ }
// Returns true if the transfer operation needs to be updated during memref
// vectorization.
bool transferConvert(Operation *op) const { return transferOps.count(op); }
@@ -94,7 +114,7 @@
void analyzeFunc(FuncOp funcOp);
void analyzeAlloc(AllocOp allocOp);
void analyzePlaceholder(IREE::PlaceholderOp placeholderOp);
- llvm::DenseSet<Value> vectorize;
+ llvm::DenseMap<Value, unsigned> vectorization_size;
llvm::DenseSet<Operation *> transferOps;
};
@@ -110,8 +130,8 @@
void MemRefUsageAnalysis::analyzeFunc(FuncOp funcOp) {
for (Value arg : funcOp.getArguments()) {
SmallVector<Operation *, 4> vectorUses;
- if (isMemRefAndVectorizable(arg, vectorUses)) {
- vectorize.insert(arg);
+ if (unsigned vectorSize = isMemRefAndVectorizable(arg, vectorUses)) {
+ vectorization_size.insert(std::make_pair(arg, vectorSize));
transferOps.insert(vectorUses.begin(), vectorUses.end());
}
}
@@ -120,16 +140,17 @@
void MemRefUsageAnalysis::analyzePlaceholder(
IREE::PlaceholderOp placeholderOp) {
SmallVector<Operation *, 4> vectorUses;
- if (isMemRefAndVectorizable(placeholderOp, vectorUses)) {
- vectorize.insert(placeholderOp);
+ if (unsigned vectorSize =
+ isMemRefAndVectorizable(placeholderOp, vectorUses)) {
+ vectorization_size.insert(std::make_pair(placeholderOp, vectorSize));
transferOps.insert(vectorUses.begin(), vectorUses.end());
}
}
void MemRefUsageAnalysis::analyzeAlloc(AllocOp allocOp) {
SmallVector<Operation *, 4> vectorUses;
- if (isMemRefAndVectorizable(allocOp, vectorUses)) {
- vectorize.insert(allocOp);
+ if (unsigned vectorSize = isMemRefAndVectorizable(allocOp, vectorUses)) {
+ vectorization_size.insert(std::make_pair(allocOp, vectorSize));
transferOps.insert(vectorUses.begin(), vectorUses.end());
}
}
@@ -143,6 +164,8 @@
memrefUsageAnalysis(memrefUsageAnalysis) {}
protected:
+ Optional<MemRefType> getVectorizedMemRefType(
+ ConversionPatternRewriter &rewriter, Value memRefValue) const;
const MemRefUsageAnalysis &memrefUsageAnalysis;
};
@@ -241,19 +264,37 @@
}
};
-static Optional<MemRefType> getVectorizedMemRefType(
- ConversionPatternRewriter &rewriter, MemRefType type) {
+/// Decide the new memref of vector type we want to use after vectorization
+/// based on the original type and the vectorization size we want. Since Vulkan
+/// only supports vector up to 4 elements we may re-interpret the memref using a
+/// larger type. For example:
+/// * memref<1024xf16> vectorized with a size of 64bits will return
+/// memref<256xvec<4xf16>>
+/// * memref<1024xf16> vectorized with a size of 128bits will return
+/// memref<128xvec<4xf32>>
+template <typename OpTy>
+Optional<MemRefType> MemRefConversionPattern<OpTy>::getVectorizedMemRefType(
+ ConversionPatternRewriter &rewriter, Value memRefValue) const {
+ unsigned vecSizeInBits =
+ memrefUsageAnalysis.getMemRefVectorSizeInBits(memRefValue);
+ MemRefType type = memRefValue.getType().cast<MemRefType>();
unsigned elemSize = type.getElementTypeBitWidth();
- unsigned vecSize = kVectorizationSizeInBits / elemSize;
- // Pick a new type of element size 32bits.
- Type newElemType = type.getElementType().isa<IntegerType>()
- ? rewriter.getI32Type().cast<Type>()
- : rewriter.getF32Type().cast<Type>();
- Type vecType = VectorType::get(kVecSize, newElemType);
+ unsigned numElements = vecSizeInBits / elemSize;
+ Type elemType = type.getElementType();
+ // If the vector we need to generate is bigger than the the max vector size
+ // allowed for loads use a larger element type.
+ if (numElements > kMaxVectorNumElements) {
+ elemType = elemType.isa<IntegerType>() ? rewriter.getI32Type().cast<Type>()
+ : rewriter.getF32Type().cast<Type>();
+ elemSize = elemType.getIntOrFloatBitWidth();
+ numElements = vecSizeInBits / elemSize;
+ }
+ Type vecType = VectorType::get(numElements, elemType);
SmallVector<int64_t, 2> newShape(type.getShape().begin(),
type.getShape().end());
- if (newShape.back() % vecSize != 0) return {};
- newShape.back() = newShape.back() / vecSize;
+ unsigned ratio = vecSizeInBits / type.getElementTypeBitWidth();
+ if (newShape.back() % ratio != 0) return {};
+ newShape.back() = newShape.back() / ratio;
return MemRefType::get(newShape, vecType, {}, type.getMemorySpace());
}
@@ -263,10 +304,10 @@
LogicalResult matchAndRewrite(
AllocOp alloc, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto memrefType = getVectorizedMemRefType(rewriter, alloc.getType());
+ auto memrefType = getVectorizedMemRefType(rewriter, alloc.getResult());
if (!memrefType) return failure();
- Value newAlloc =
- rewriter.create<AllocOp>(alloc.getLoc(), *memrefType, alloc.value());
+ Value newAlloc = rewriter.create<AllocOp>(alloc.getLoc(), *memrefType,
+ alloc.dynamicSizes());
rewriter.replaceOp(alloc, newAlloc);
return success();
}
@@ -281,7 +322,7 @@
ConversionPatternRewriter &rewriter) const override {
auto memrefType = placeholder.getType().dyn_cast<MemRefType>();
if (!memrefType) return failure();
- auto vecMemRef = getVectorizedMemRefType(rewriter, memrefType);
+ auto vecMemRef = getVectorizedMemRefType(rewriter, placeholder.getResult());
if (!vecMemRef) return failure();
ValueRange dummyOperands;
Value newPlaceholder = rewriter.create<IREE::PlaceholderOp>(
@@ -309,8 +350,7 @@
TypeConverter typeConverter;
for (const auto &arg : llvm::enumerate(funcOp.getArguments())) {
if (memrefUsageAnalysis.vectorizeMemRef(arg.value())) {
- if (auto memrefType = getVectorizedMemRefType(
- rewriter, arg.value().getType().cast<MemRefType>())) {
+ if (auto memrefType = getVectorizedMemRefType(rewriter, arg.value())) {
signatureConverter.addInputs(arg.index(), *memrefType);
continue;
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
index 0be6dc4..ea292d3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
@@ -6,7 +6,7 @@
func @push_constant() {
// CHECK: %[[INDEX_0:.+]] = spv.constant 0 : i32
// CHECK: %[[INDEX_1:.+]] = spv.constant 2 : i32
- // CHECK: %[[ADDR:.+]] = spv._address_of @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
+ // CHECK: %[[ADDR:.+]] = spv.mlir.addressof @__push_constant_var__ : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[INDEX_0]], %[[INDEX_1]]] : !spv.ptr<!spv.struct<(!spv.array<5 x i32, stride=4> [0])>, PushConstant>
// CHECK: spv.Load "PushConstant" %[[AC]] : i32
%0 = hal.interface.load.constant offset = 2 : index
@@ -49,8 +49,8 @@
// CHECK: spv.func @resource_bindings_in_entry_func1()
func @resource_bindings_in_entry_func1() {
- // CHECK: spv._address_of @[[FUNC1_ARG:.+]]
- // CHECK: spv._address_of @[[FUNC1_RET:.+]]
+ // CHECK: spv.mlir.addressof @[[FUNC1_ARG:.+]]
+ // CHECK: spv.mlir.addressof @[[FUNC1_RET:.+]]
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4xvector<4xf32>>
return
@@ -58,8 +58,8 @@
// CHECK: spv.func @resource_bindings_in_entry_func2()
func @resource_bindings_in_entry_func2() {
- // CHECK: spv._address_of @[[FUNC2_ARG]]
- // CHECK: spv._address_of @[[FUNC2_RET]]
+ // CHECK: spv.mlir.addressof @[[FUNC2_ARG]]
+ // CHECK: spv.mlir.addressof @[[FUNC2_RET]]
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4x4xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4x4xf32>
return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir
new file mode 100644
index 0000000..1d45673
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir
@@ -0,0 +1,101 @@
+// RUN: iree-opt -split-input-file -iree-codegen-fold-gpu-procid-uses %s | IreeFileCheck %s
+
+module {
+ // CHECK-LABEL: func @fold_block_id_x()
+ func @fold_block_id_x() -> index attributes {hal.num_workgroups_fn = @num_workgroups} {
+ // CHECK: %[[cst:.+]] = constant 3
+ // CHECK: return %[[cst]]
+ %0 = "gpu.block_id"() {dimension = "x"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (3, s0 * -2 + 225)>()[%0]
+ return %1: index
+ }
+
+ // CHECK-LABEL: func @fold_block_id_y()
+ func @fold_block_id_y() -> index attributes {hal.num_workgroups_fn = @num_workgroups} {
+ // CHECK: %[[cst:.+]] = constant 8
+ // CHECK: return %[[cst]]
+ %0 = "gpu.block_id"() {dimension = "y"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (8, s0 * -1 + s0 * -1 + s0 * -1 + 131)>()[%0]
+ return %1: index
+ }
+
+ // CHECK-LABEL: func @fold_block_id_z()
+ func @fold_block_id_z() -> index attributes {hal.num_workgroups_fn = @num_workgroups} {
+ // CHECK: %[[cst:.+]] = constant 11
+ // CHECK: return %[[cst]]
+ %0 = "gpu.block_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (11, s0 + 15)>()[%0]
+ return %1: index
+ }
+
+ func @num_workgroups() -> (index, index, index) {
+ %x = constant 112: index
+ %y = constant 42: index
+ %z = constant 1: index
+ return %x, %y, %z: index, index, index
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_thread_id_x()
+func @fold_thread_id_x() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: %[[cst:.+]] = constant 7
+ // CHECK: return %[[cst]]
+ %0 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (7, s0 * -1 + s0 * -1 + 21)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @fold_thread_id_y()
+func @fold_thread_id_y() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: %[[cst:.+]] = constant 11
+ // CHECK: return %[[cst]]
+ %0 = "gpu.thread_id"() {dimension = "y"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (11, s0 * -3 + 14)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @fold_thread_id_z()
+func @fold_thread_id_z() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: %[[cst:.+]] = constant 21
+ // CHECK: return %[[cst]]
+ %0 = "gpu.thread_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (21, s0 + (s0 + 21))>()[%0]
+ return %1: index
+}
+
+// -----
+
+// CHECK-LABEL: func @does_not_fold_mod()
+func @does_not_fold_mod() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: affine.min
+ %0 = "gpu.thread_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (21, s0 mod 5)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @does_not_fold_div()
+func @does_not_fold_div() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: affine.min
+ %0 = "gpu.thread_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (21, s0 ceildiv 5)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @does_not_fold_symbol_mul_symbol()
+func @does_not_fold_symbol_mul_symbol() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: affine.min
+ %0 = "gpu.thread_id"() {dimension = "z"} : () -> index
+ %1 = affine.min affine_map<()[s0] -> (21, s0 * s0)>()[%0]
+ return %1: index
+}
+
+// CHECK-LABEL: func @does_not_fold_if_cst_not_lower_bound()
+func @does_not_fold_if_cst_not_lower_bound() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} {
+ // CHECK: affine.min
+ %0 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ // 5 is in %0's range of [0,7] so we cannot fold the following into 5 or 0.
+ %1 = affine.min affine_map<()[s0] -> (5, s0)>()[%0]
+ return %1: index
+}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index a115363..593a693 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -353,7 +353,7 @@
// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
-// CHECK: %[[VIEW3:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
+// CHECK: %[[VIEW3:.+]] = subview %[[RET0]][%[[LBY]], %[[LBX]]]
// CHECK: linalg.fill(%[[VIEW3]], %{{.+}})
// CHECK-SAME: "workgroup"
// CHECK: linalg.matmul
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
index 2bcaf93..33e1131 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
@@ -50,3 +50,56 @@
// CHECK: scf.yield
// CHECK-COUNT-8: vector.transfer_write %[[FOR_RES]]
// CHECK: return
+
+// -----
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+ StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+ GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+ GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+ VariablePointersStorageBuffer],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
+ ARM:IntegratedGPU,
+ {max_compute_shared_memory_size = 32768 : i32,
+ max_compute_workgroup_invocations = 512 : i32,
+ max_compute_workgroup_size = dense<512> : vector<3xi32>,
+ subgroup_size = 16 : i32}>} {
+ func @matmul_static_shape_f16()
+ attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
+ %arg0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
+ %arg1 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
+ %ret0 = iree.placeholder for "interface buffer"
+ {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
+ %cst = constant 0.000000e+00 : f16
+ linalg.fill(%ret0, %cst) : memref<4096x4096xf16>, f16
+ linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>)
+ outs(%ret0 : memref<4096x4096xf16>)
+ return
+ }
+ func @matmul_static_shape__num_workgroups__
+ (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
+ !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+
+// CHECK-LABEL: func @matmul_static_shape_f16
+// CHECK-COUNT-16: vector.transfer_write
+// CHECK-COUNT-16: vector.transfer_read
+// CHECK: %[[FOR_RES:.+]]:16 = scf.for
+// CHECK-COUNT-40: vector.transfer_read
+// CHECK-COUNT-64: vector.contract
+// CHECK: scf.yield
+// CHECK-COUNT-16: vector.transfer_write %[[FOR_RES]]
+// CHECK: return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir
index 9a87e8c..251daaa 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir
@@ -59,3 +59,54 @@
hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
}
+// -----
+
+// CHECK-LABEL: func @resource_copy_f16
+// CHECK: %[[A:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4096x1024xvector<4xf16>>
+// CHECK: %[[B:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4096x1024xvector<4xf16>>
+// CHECK: %[[V:.+]] = load %[[A]][%{{.*}}, %{{.*}}] : memref<4096x1024xvector<4xf16>>
+// CHECK: store %[[V]], %[[B]][%{{.*}}, %{{.*}}] : memref<4096x1024xvector<4xf16>>
+// CHECK: %[[MAT:.+]] = vector.transfer_read %[[A]][%{{.*}}, %{{.*}}], %{{.*}} : memref<4096x1024xvector<4xf16>>, vector<32x8xf16>
+// CHECK: vector.transfer_write %[[MAT]], %[[B]][%{{.*}}, %{{.*}}] {{.*}} : vector<32x8xf16>, memref<4096x1024xvector<4xf16>>
+func @resource_copy_f16() {
+ %cst = constant 0.000000e+00 : f16
+ %c0 = constant 0 : index
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4096x4096xf16>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4096x4096xf16>
+ %v = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf16>, vector<1x4xf16>
+ vector.transfer_write %v, %1[%c0, %c0] : vector<1x4xf16>, memref<4096x4096xf16>
+ %mat = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf16>, vector<32x8xf16>
+ vector.transfer_write %mat, %1[%c0, %c0] : vector<32x8xf16>, memref<4096x4096xf16>
+ return
+}
+
+hal.interface @legacy_io attributes {push_constants = 5 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
+}
+
+// -----
+
+// CHECK-LABEL: func @resource_copy_8xf16
+// CHECK: %[[A:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4096x512xvector<4xf32>>
+// CHECK: %[[B:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4096x512xvector<4xf32>>
+// CHECK: %[[V:.+]] = load %[[A]][%{{.*}}, %{{.*}}] : memref<4096x512xvector<4xf32>>
+// CHECK: store %[[V]], %[[B]][%{{.*}}, %{{.*}}] : memref<4096x512xvector<4xf32>>
+// CHECK: %[[MAT:.+]] = vector.transfer_read %[[A]][%{{.*}}, %{{.*}}], %{{.*}} : memref<4096x512xvector<4xf32>>, vector<32x8xf16>
+// CHECK: vector.transfer_write %[[MAT]], %[[B]][%{{.*}}, %{{.*}}] {{.*}} : vector<32x8xf16>, memref<4096x512xvector<4xf32>>
+func @resource_copy_8xf16() {
+ %cst = constant 0.000000e+00 : f16
+ %c0 = constant 0 : index
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4096x4096xf16>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4096x4096xf16>
+ %v = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf16>, vector<1x8xf16>
+ vector.transfer_write %v, %1[%c0, %c0] : vector<1x8xf16>, memref<4096x4096xf16>
+ %mat = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf16>, vector<32x8xf16>
+ vector.transfer_write %mat, %1[%c0, %c0] : vector<32x8xf16>, memref<4096x4096xf16>
+ return
+}
+
+hal.interface @legacy_io attributes {push_constants = 5 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
+}
diff --git a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
index e5358c5..0d7b7aa 100644
--- a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
+++ b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
@@ -205,8 +205,7 @@
/*initTensors*/ ValueRange{}, genericOp.indexing_mapsAttr(),
genericOp.iterator_types(),
/*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*symbol_source=*/nullptr);
+ /*library_call=*/nullptr, genericOp.sparseAttr());
Region &newRegion = newOp.region();
rewriter.inlineRegionBefore(genericOp.getRegion(), newRegion,
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 3b8e8bb..6d72146 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -47,6 +47,7 @@
static bool init_once = []() {
// LinalgToSPIRV
createConvertToGPUPass();
+ createFoldProcessorIDUsesPass();
createLinalgTileAndFusePass(SPIRVCodegenOptions());
createSplitDispatchFunctionPass();
createVectorToGPUPass();
diff --git a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
index 16a05ef..c3dd2f4 100644
--- a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
+++ b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
@@ -18,10 +18,10 @@
#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h"
#include "llvm/ADT/SetVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
index 0eb58e5..650f19e 100644
--- a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
@@ -18,12 +18,12 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/IR/BUILD b/iree/compiler/Dialect/Flow/IR/BUILD
index 7b54e4b..ae3b48d 100644
--- a/iree/compiler/Dialect/Flow/IR/BUILD
+++ b/iree/compiler/Dialect/Flow/IR/BUILD
@@ -104,7 +104,7 @@
"//iree/compiler/Dialect/IREE/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
],
)
@@ -120,6 +120,6 @@
"//iree/compiler/Dialect/IREE/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
],
)
diff --git a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp
index cbfdebe..25b8fdc 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/CreateBenchmarkFuncs.cpp
@@ -127,8 +127,7 @@
auto variableOp =
moduleBuilder.create<VariableOp>(loc, name,
/*isMutable=*/false, inputType, attr);
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Private);
+ variableOp.setPrivate();
variableOp.setAttr("noinline", UnitAttr::get(moduleBuilder.getContext()));
auto lookupOp = blockBuilder.create<IREE::Flow::VariableLoadOp>(
loc, inputType, variableOp.getName());
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 176f09d..049b5c1 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -18,8 +18,8 @@
#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/Support/Debug.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#define DEBUG_TYPE "iree-detail"
@@ -31,12 +31,6 @@
namespace {
// TODO(laurenzo): Every one of these should have better support and removed
// from this exclusion list eventually.
-bool isUnsupportedFusionOp(Operation *op) {
- return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
- mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp, mhlo::SliceOp,
- mhlo::TorchIndexSelectOp>(op);
-}
-
// Allowlist of ops that materialize to a an index-permuted copy of some kind
// if they exist standalone. Generally we try to avoid anchoring on these,
// letting them fuse into more meaningful ops as possible.
@@ -182,6 +176,18 @@
return FusionType::DISABLED;
}
+// TODO(b/144530470): replace with tablegen attributes/interfaces.
+bool OpDispatchPolicy::isUnsupportedFusionOp(Operation *op) {
+ return isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp, mhlo::DotOp,
+ mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
+ mhlo::TorchIndexSelectOp>(op) ||
+ isRootOnlyOp(op);
+}
+
+bool OpDispatchPolicy::isRootOnlyOp(Operation *op) {
+ return isa<mhlo::SliceOp>(op);
+}
+
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
index 22666b3..ee9299f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h
@@ -41,6 +41,12 @@
OpDispatchPolicy(Dispatchability &dispatchability)
: dispatchability(dispatchability) {}
+ // Returns true if |op| is not able to fuse with either producer or consumer.
+ static bool isUnsupportedFusionOp(Operation *op);
+
+ // Returns true if |op| can only be a root op.
+ static bool isRootOnlyOp(Operation *op);
+
// Returns true if the given |op| can be dispatched in all cases.
// Other passes may handle special cases of these ops but this initial
// identification is conservative.
diff --git a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
index d6bfd26..86f886a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -22,7 +23,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index b3c48a0..0982a12 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -13,11 +13,13 @@
// limitations under the License.
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -29,7 +31,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#define DEBUG_TYPE "iree-dispatch"
@@ -164,13 +165,17 @@
value.getDefiningOp()->isBeforeInBlock(op)) {
// Can't depend on |op| as it is defined prior to it.
return false;
+ } else if (value.getDefiningOp()->getBlock() == op->getBlock() &&
+ !value.getDefiningOp()->isBeforeInBlock(op)) {
+ // |op| is defined before one of |value| operands.
+ return true;
}
for (auto operand : value.getDefiningOp()->getOperands()) {
if (doesValueDependOnOperation(operand, op)) {
return true;
}
}
- return true;
+ return false;
}
// Returns true if |rhs| transitively depends on any out of |lhs|.
@@ -196,10 +201,9 @@
// that substituting library calls is easier.
for (auto &block : regionOp.body().getBlocks()) {
for (auto &op : block) {
- // TODO(b/144530470): replace with tablegen attributes/interfaces.
- if (isa<mhlo::ConcatenateOp, mhlo::ConvOp, mhlo::DotGeneralOp,
- mhlo::DotOp, mhlo::PadOp, mhlo::ReduceOp, mhlo::ReduceWindowOp,
- mhlo::SliceOp, mhlo::TorchIndexSelectOp>(op)) {
+ // A root only op is mergable.
+ if (OpDispatchPolicy::isUnsupportedFusionOp(&op) &&
+ !OpDispatchPolicy::isRootOnlyOp(&op)) {
return false;
}
}
@@ -207,6 +211,24 @@
return regionOp.body().getBlocks().size() == 1;
}
+// Returns true if rhs has ops that can only be root op and will lose the
+// characteristic if merge two dispatch regions.
+bool rhsHasRootOnlyOp(DispatchRegionOp &lhs, DispatchRegionOp &rhs) {
+ auto &rhsBlock = rhs.body().front();
+ auto lhsArgs = llvm::to_vector<8>(lhs.args());
+ auto rhsArgs = llvm::to_vector<8>(rhs.args());
+ for (int rhsOpIdx = 0; rhsOpIdx < rhsArgs.size(); ++rhsOpIdx) {
+ for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults();
+ ++lhsResultIdx) {
+ if (rhsArgs[rhsOpIdx] != lhs.getResult(lhsResultIdx)) continue;
+ for (auto *user : rhsBlock.getArgument(rhsOpIdx).getUsers()) {
+ if (OpDispatchPolicy::isRootOnlyOp(user)) return true;
+ }
+ }
+ }
+ return false;
+}
+
// Merges |rhs| into |lhs| and returns the new |lhs| op.
// Precondition: !areDispatchRegionsTransitivelyDependent
DispatchRegionOp mergeDispatchRegions(DispatchRegionOp &lhs,
@@ -341,6 +363,10 @@
LLVM_DEBUG(llvm::dbgs()
<< " -REGION CONTAINS NON-TRIVIAL CONTROL FLOW-\n");
}
+ if (rhsHasRootOnlyOp(lhs, rhs)) {
+ LLVM_DEBUG(llvm::dbgs() << " -RHS REGION HAS ROOT OP-\n");
+ continue;
+ }
mergableRegions[i] = mergeDispatchRegions(lhs, rhs);
if (!mergableRegions[i]) {
return failure();
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index 039c829..db6856c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -16,6 +16,9 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
+#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
@@ -24,9 +27,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
index d636867..50593ff 100644
--- a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
@@ -28,6 +28,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -40,7 +41,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#define DEBUG_TYPE "iree-dispatch-detail"
diff --git a/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp b/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp
index 0b5d964..3dd2e56 100644
--- a/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -26,7 +27,6 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index 7d271a8..ec2c0cd 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -183,8 +183,7 @@
regionOp.getLoc(), namePrefix, {dispatchFuncOp},
parentFuncOp.getParentOfType<ModuleOp>(), dispatchableFuncOps);
executableOp.getOperation()->moveBefore(parentFuncOp);
- SymbolTable::setSymbolVisibility(executableOp,
- SymbolTable::Visibility::Private);
+ executableOp.setPrivate();
// Add dispatch export pointing at the function.
OpBuilder builder(executableOp.body());
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
index 98bb36c..84c7b11 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
@@ -94,8 +94,7 @@
auto variableOp = moduleBuilder.create<IREE::Flow::VariableOp>(
largeConstantOp.getLoc(), name, /*isMutable=*/false,
largeConstantOp.getType(), largeConstantOp.getValue());
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Private);
+ variableOp.setPrivate();
replacements.emplace_back(largeConstantOp, variableOp);
// Prevent the variable from being re-inlined if the canonicalizer runs.
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index cf093c5..d012ba5 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -18,10 +18,10 @@
#include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
index f25f67e..d62568e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
@@ -19,14 +19,14 @@
#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Shape/Transforms/Patterns.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
index fc9d792..431f0e3 100644
--- a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
@@ -16,6 +16,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -28,7 +29,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/StripAndSplatConstantVariables.cpp b/iree/compiler/Dialect/Flow/Transforms/StripAndSplatConstantVariables.cpp
index ec708b7..e19e802 100644
--- a/iree/compiler/Dialect/Flow/Transforms/StripAndSplatConstantVariables.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/StripAndSplatConstantVariables.cpp
@@ -72,8 +72,7 @@
builder.setInsertionPointAfter(op);
auto newOp = builder.create<VariableOp>(
op.getLoc(), op.sym_name(), op.is_mutable(), op.type(), newValue);
- SymbolTable::setSymbolVisibility(newOp,
- SymbolTable::getSymbolVisibility(op));
+ newOp.setVisibility(op.getVisibility());
newOp.setAttr("noinline", UnitAttr::get(builder.getContext()));
op.erase();
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
index e08d671..c324ef7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/fold_compatible_dispatch_regions.mlir
@@ -84,3 +84,81 @@
// CHECK-NEXT: flow.return %3 : tensor<4x4xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %[[R2]] : tensor<4x4xf32>
+
+// -----
+
+module {
+ flow.variable @var1 dense<1.000000e+00> : tensor<4xf32>
+ flow.variable @var2 dense<2.000000e+00> : tensor<4xf32>
+ func @notDominate() -> tensor<4xf32> {
+ %c4 = constant 4 : index
+ %0 = flow.variable.load @var1 : tensor<4xf32>
+ %1 = flow.dispatch.region[%c4 : index](%arg0 = %0 : tensor<4xf32>) -> tensor<4xf32> {
+ %4 = mhlo.add %arg0, %arg0 : tensor<4xf32>
+ flow.return %4 : tensor<4xf32>
+ }
+ %2 = flow.variable.load @var2 : tensor<4xf32>
+ %3 = flow.dispatch.region[%c4 : index](%arg0 = %0 : tensor<4xf32>, %arg1 = %2 : tensor<4xf32>) -> tensor<4xf32> {
+ %4 = mhlo.subtract %arg1, %arg0 : tensor<4xf32>
+ flow.return %4 : tensor<4xf32>
+ }
+ return %3 : tensor<4xf32>
+ }
+}
+// CHECK-LABEL: func @notDominate
+// CHECK: flow.dispatch.region
+// CHECK: flow.dispatch.region
+
+// -----
+
+module {
+ flow.variable @var1 dense<1.000000e+00> : tensor<4xf32>
+ flow.variable @var2 dense<2.000000e+00> : tensor<4xf32>
+ func @dominate() -> tensor<4xf32> {
+ %c4 = constant 4 : index
+ %0 = flow.variable.load @var1 : tensor<4xf32>
+ %1 = flow.variable.load @var2 : tensor<4xf32>
+ %2 = flow.dispatch.region[%c4 : index](%arg0 = %0 : tensor<4xf32>) -> tensor<4xf32> {
+ %4 = mhlo.add %arg0, %arg0 : tensor<4xf32>
+ flow.return %4 : tensor<4xf32>
+ }
+ %3 = flow.dispatch.region[%c4 : index](%arg0 = %0 : tensor<4xf32>, %arg1 = %1 : tensor<4xf32>) -> tensor<4xf32> {
+ %4 = mhlo.subtract %arg1, %arg0 : tensor<4xf32>
+ flow.return %4 : tensor<4xf32>
+ }
+ return %3 : tensor<4xf32>
+ }
+}
+// CHECK-LABEL: func @dominate
+// CHECK: flow.dispatch.region
+// CHECK-NOT: flow.dispatch.region
+
+// -----
+
+// Test if the op that only can be a root op fuse with consumer but not
+// producer. This test use a dummy workload to test on root only op
+// functionality.
+module {
+ func @rootOnlyOp(%arg0: tensor<3x4xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %c0 = constant 0 : index
+ %0 = flow.dispatch.region[%c0 : index](%arg2 = %arg0 : tensor<3x4xi32>) -> tensor<3x4xi32> {
+ %3 = mhlo.add %arg2, %arg2 : tensor<3x4xi32>
+ flow.return %3 : tensor<3x4xi32>
+ }
+ %1 = flow.dispatch.region[%c0 : index](%arg2 = %0 : tensor<3x4xi32>) -> tensor<1x2xi32> {
+ %3 = "mhlo.slice"(%arg2) {limit_indices = dense<[2, 3]> : tensor<2xi64>, start_indices = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ flow.return %3 : tensor<1x2xi32>
+ }
+ %2 = flow.dispatch.region[%c0 : index](%arg2 = %1 : tensor<1x2xi32>, %arg3 = %arg1 : tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %3 = mhlo.multiply %arg2, %arg3 : tensor<1x2xi32>
+ flow.return %3 : tensor<1x2xi32>
+ }
+ return %2 : tensor<1x2xi32>
+ }
+}
+// CHECK-LABEL: func @rootOnlyOp
+// CHECK: flow.dispatch.region
+// CHECK-NEXT: mhlo.add
+// CHECK: flow.dispatch.region
+// CHECK-NEXT: mhlo.slice
+// CHECK-NEXT: mhlo.multiply
diff --git a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
index 506d2d2..eaf4516 100644
--- a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
@@ -17,10 +17,10 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "llvm/ADT/SetVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
index 3b11700..5e19ea9 100644
--- a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
@@ -20,12 +20,12 @@
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LLVM.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
index 195b7a0..c3086b9 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
@@ -43,8 +43,8 @@
auto uniqueName = (Twine("__") + variableOp.getName() + "_initializer").str();
auto initializerFuncOp =
rewriter.create<FuncOp>(variableOp.getLoc(), uniqueName, initializerType);
- auto *entryBlock = initializerFuncOp.addEntryBlock();
- rewriter.setInsertionPointToEnd(entryBlock);
+ rewriter.createBlock(&initializerFuncOp.getBody(), initializerFuncOp.begin(),
+ initializerType.getInputs());
// Create const and return ops.
auto constValue = rewriter.create<ConstantOp>(loc, immediateElements);
@@ -82,8 +82,7 @@
variableOp.getLoc(), variableOp.sym_name(), variableOp.is_mutable(),
converter.convertType(variableOp.type()), initializer, initialValue,
llvm::to_vector<4>(variableOp.getDialectAttrs()));
- SymbolTable::setSymbolVisibility(
- newOp, SymbolTable::getSymbolVisibility(variableOp));
+ newOp.setVisibility(variableOp.getVisibility());
rewriter.replaceOp(variableOp, {});
return success();
}
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertConstantOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertConstantOps.cpp
index e1e4460..ec5fafe 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertConstantOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertConstantOps.cpp
@@ -34,8 +34,7 @@
auto rodataName = (op.sym_name() + storageOp.sym_name()).str();
auto rodataOp = rewriter.create<IREE::VM::RodataOp>(
storageOp.getLoc(), rodataName, storageOp.value());
- SymbolTable::setSymbolVisibility(rodataOp,
- SymbolTable::Visibility::Private);
+ rodataOp.setPrivate();
}
rewriter.eraseOp(op);
return success();
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
index 6043159..8d0e010 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
@@ -79,8 +79,7 @@
IREE::HAL::stringifyExecutableFormat(binaryOp.format()).lower())
.str(),
binaryOp.data());
- SymbolTable::setSymbolVisibility(rodataOp,
- SymbolTable::Visibility::Private);
+ rodataOp.setPrivate();
rodataOps.push_back(rodataOp);
}
rewriter.restoreInsertionPoint(insertPoint);
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir
index 17dae9c..4692292 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir
@@ -13,8 +13,8 @@
// CHECK: vm.global.ref @pool_storage0_buffer init(@pool_storage0_buffer_initializer) : !vm.ref<!hal.buffer>
hal.variable @pool_storage0_buffer init(@pool_storage0_buffer_initializer) : !hal.buffer attributes {sym_visibility = "private"}
-// CHECK: vm.func @pool_storage0_buffer_initializer() -> !vm.ref<!hal.buffer>
-func @pool_storage0_buffer_initializer() -> !hal.buffer attributes {sym_visibility = "private"} {
+// CHECK: vm.func private @pool_storage0_buffer_initializer() -> !vm.ref<!hal.buffer>
+func private @pool_storage0_buffer_initializer() -> !hal.buffer {
%c0 = constant 0 : index
%c16 = constant 16 : index
%dev = hal.ex.shared_device : !hal.device
@@ -32,8 +32,8 @@
// CHECK: vm.global.ref @pool_splats init(@pool_splats_initializer) : !vm.ref<!hal.buffer>
hal.variable @pool_splats init(@pool_splats_initializer) : !hal.buffer attributes {sym_visibility = "private"}
-// CHECK: vm.func @pool_splats_initializer() -> !vm.ref<!hal.buffer>
-func @pool_splats_initializer() -> !hal.buffer attributes {sym_visibility = "private"} {
+// CHECK: vm.func private @pool_splats_initializer() -> !vm.ref<!hal.buffer>
+func private @pool_splats_initializer() -> !hal.buffer {
%c64 = constant 64 : index
%c0 = constant 0 : index
%c4 = constant 4 : index
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 66749ce..8dc89ba 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -57,8 +57,7 @@
auto newOp = rewriter.create<VariableOp>(op.getLoc(), op.sym_name(),
op.is_mutable(), op.type(),
constResult);
- SymbolTable::setSymbolVisibility(newOp,
- SymbolTable::getSymbolVisibility(op));
+ newOp.setVisibility(op.getVisibility());
rewriter.replaceOp(op, {});
return success();
}
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 8c337a0..6b7e466 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -1392,6 +1392,14 @@
data));
}
+void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state,
+ uint32_t format, DenseIntElementsAttr data) {
+ ensureTerminator(*state.addRegion(), builder, state.location);
+ state.addAttribute(
+ "format", builder.getIntegerAttr(builder.getIntegerType(32), format));
+ state.addAttribute("data", data);
+}
+
static ParseResult parseExecutableBinaryOp(OpAsmParser &parser,
OperationState *result) {
auto *body = result->addRegion();
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index e294c8d..529cb5f 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1931,6 +1931,7 @@
let skipDefaultBuilders = 1;
let builders = [
OpBuilderDAG<(ins "uint32_t":$format, "std::vector<uint8_t>":$data)>,
+ OpBuilderDAG<(ins "uint32_t":$format, "DenseIntElementsAttr":$data)>,
];
let extraClassDeclaration = [{
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
index 594a0be..51983ae 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
@@ -38,12 +38,14 @@
"LLVMAOTTarget.h",
],
deps = [
- ":LLVMAOTTargetLinker",
+ ":LinkerTool",
+ "//iree/base:flatcc",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMBaseTarget",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMIRPasses",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
- "//iree/schemas:dylib_executable_def_cc_fbs",
+ "//iree/compiler/Utils",
+ "//iree/schemas:dylib_executable_def_c_fbs",
"@llvm-project//llvm:AArch64AsmParser",
"@llvm-project//llvm:AArch64CodeGen",
"@llvm-project//llvm:ARMAsmParser",
@@ -52,21 +54,25 @@
"@llvm-project//llvm:Support",
"@llvm-project//llvm:X86AsmParser",
"@llvm-project//llvm:X86CodeGen",
+ "@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:TargetLLVMIR",
],
)
cc_library(
- name = "LLVMAOTTargetLinker",
- hdrs = ["LLVMAOTTargetLinker.h"],
- deps = platform_trampoline_deps("LLVMAOTTargetLinker", "compiler/Dialect/HAL/Target/LLVM/AOT"),
+ name = "LinkerTool",
+ srcs = ["LinkerTool.cpp"],
+ hdrs = ["LinkerTool.h"],
+ deps = platform_trampoline_deps("LinkerTools", "compiler/Dialect/HAL/Target/LLVM/AOT"),
)
cc_library(
- name = "LLVMAOTTargetLinker_hdrs",
- hdrs = ["LLVMAOTTargetLinker.h"],
+ name = "LinkerTool_hdrs",
+ hdrs = ["LinkerTool.h"],
deps = [
- "//iree/base:status",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Support",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
index ff46a62..0a97912 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
@@ -26,7 +26,7 @@
SRCS
"LLVMAOTTarget.cpp"
DEPS
- ::LLVMAOTTargetLinker
+ ::LinkerTool
LLVMAArch64AsmParser
LLVMAArch64CodeGen
LLVMARMAsmParser
@@ -35,32 +35,39 @@
LLVMSupport
LLVMX86AsmParser
LLVMX86CodeGen
+ MLIRLLVMIR
MLIRTargetLLVMIR
+ iree::base::flatcc
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Target::LLVM::LLVMBaseTarget
iree::compiler::Dialect::HAL::Target::LLVM::LLVMIRPasses
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
- iree::schemas::dylib_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::dylib_executable_def_c_fbs
PUBLIC
)
iree_cc_library(
NAME
- LLVMAOTTargetLinker
+ LinkerTool
HDRS
- "LLVMAOTTargetLinker.h"
+ "LinkerTool.h"
+ SRCS
+ "LinkerTool.cpp"
DEPS
- iree::compiler::Dialect::HAL::Target::LLVM::AOT::internal::LLVMAOTTargetLinker_internal
+ iree::compiler::Dialect::HAL::Target::LLVM::AOT::internal::LinkerTools_internal
PUBLIC
)
iree_cc_library(
NAME
- LLVMAOTTargetLinker_hdrs
+ LinkerTool_hdrs
HDRS
- "LLVMAOTTargetLinker.h"
+ "LinkerTool.h"
DEPS
- iree::base::status
+ LLVMCore
+ LLVMSupport
+ MLIRSupport
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
index 6060600..03224a8 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
@@ -16,14 +16,16 @@
#include <cstdlib>
-#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h"
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/schemas/dylib_executable_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/dylib_executable_def_builder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TargetSelect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Target/LLVMIR.h"
namespace mlir {
@@ -46,95 +48,153 @@
// multi-threading issues.
llvm::LLVMContext context;
- // Remove all private functions, e.g tile size calcuations.
- SmallVector<FuncOp, 4> nonPublicFn;
- for (auto func : targetOp.getInnerModule().getOps<FuncOp>()) {
- if (SymbolTable::getSymbolVisibility(func) !=
- SymbolTable::Visibility::Public) {
- nonPublicFn.push_back(func);
- }
+ // We name our files after the executable name so that they are easy to
+ // track both during compilation (logs/artifacts/etc), as outputs (final
+ // intermediate code/binary files), and at runtime (loaded
+ // libraries/symbols/etc).
+ auto libraryName =
+ targetOp.getParentOfType<IREE::HAL::ExecutableOp>().getName().str();
+
+ // TODO(#3737): don't add functions we don't want to serialize to the
+ // module. Right now workgroup count calculation functions end up in here
+ // as std.func ops and not just the llvm.func ops we expect.
+ auto illegalFuncOps =
+ llvm::to_vector<4>(targetOp.getInnerModule().getOps<FuncOp>());
+ for (auto funcOp : illegalFuncOps) {
+ funcOp.erase();
}
- for (auto func : nonPublicFn) {
- func.erase();
- }
+
+ llvm::Triple targetTriple(options_.targetTriple);
+ targetOp.getInnerModule().setAttr(
+ LLVM::LLVMDialect::getTargetTripleAttrName(),
+ executableBuilder.getStringAttr(targetTriple.str()));
// At this moment we are leaving MLIR LLVM dialect land translating module
// into target independent LLVMIR.
- auto llvmModule =
- mlir::translateModuleToLLVMIR(targetOp.getInnerModule(), context);
+ auto llvmModule = mlir::translateModuleToLLVMIR(targetOp.getInnerModule(),
+ context, libraryName);
if (!llvmModule) {
- return failure();
+ return targetOp.emitError() << "failed to translate the MLIR LLVM "
+ "dialect to the native llvm::Module";
}
- iree::DyLibExecutableDefT dyLibExecutableDef;
- // Create invocation function an populate entry_points.
- auto entryPointOps = targetOp.getBlock().getOps<ExecutableEntryPointOp>();
-
- for (auto entryPointOp : entryPointOps) {
- dyLibExecutableDef.entry_points.push_back(
- std::string(entryPointOp.sym_name()));
+ // Try to grab a linker tool based on the options (and target environment).
+ auto linkerTool = LinkerTool::getForTarget(targetTriple, options_);
+ if (!linkerTool) {
+ return mlir::emitError(targetOp.getLoc())
+ << "failed to find a target linker for the given target triple '"
+ << options_.targetTriple << "'";
}
- // LLVMIR opt passes.
+ // Configure the module with any code generation options required later by
+ // linking (such as initializer functions).
+ auto entryPointNames = llvm::to_vector<8>(
+ llvm::map_range(targetOp.getBlock().getOps<ExecutableEntryPointOp>(),
+ [&](auto op) { return op.getName(); }));
+ if (failed(
+ linkerTool->configureModule(llvmModule.get(), entryPointNames))) {
+ return targetOp.emitError()
+ << "failed to configure LLVM module for target linker";
+ }
+
+ // LLVM opt passes that perform code generation optimizations/transformation
+ // similar to what a frontend would do before passing to linking.
auto targetMachine = createTargetMachine(options_);
if (!targetMachine) {
- targetOp.emitError("Can't create target machine for target triple: " +
- options_.targetTriple);
- return failure();
+ return mlir::emitError(targetOp.getLoc())
+ << "failed to create target machine for target triple '"
+ << options_.targetTriple << "'";
}
-
llvmModule->setDataLayout(targetMachine->createDataLayout());
llvmModule->setTargetTriple(targetMachine->getTargetTriple().str());
-
if (failed(
runLLVMIRPasses(options_, targetMachine.get(), llvmModule.get()))) {
- return targetOp.emitError(
- "Can't build LLVMIR opt passes for ExecutableOp module");
+ return targetOp.emitError()
+ << "failed to run LLVM-IR opt passes for IREE::HAL::ExecutableOp "
+ "targeting '"
+ << options_.targetTriple << "'";
}
- std::string objData;
- if (failed(runEmitObjFilePasses(targetMachine.get(), llvmModule.get(),
- &objData))) {
- return targetOp.emitError("Can't compile LLVMIR module to an obj");
- }
-
- std::string sharedLibData;
- const char *linkerToolPath = std::getenv("IREE_LLVMAOT_LINKER_PATH");
- if (linkerToolPath != nullptr) {
- auto sharedLibDataStatus = linkLLVMAOTObjects(linkerToolPath, objData);
- if (!sharedLibDataStatus.ok()) {
- return targetOp.emitError(
- "Can't link executable and generate target dylib, using linker "
- "toolchain: '" +
- std::string(linkerToolPath) + "'");
+ // Emit object files.
+ SmallVector<Artifact, 4> objectFiles;
+ {
+ // NOTE: today we just use a single object file, however if we wanted to
+ // scale code generation and linking we'd want to generate one per
+ // function (or something like that).
+ std::string objectData;
+ if (failed(runEmitObjFilePasses(targetMachine.get(), llvmModule.get(),
+ &objectData))) {
+ return targetOp.emitError()
+ << "failed to compile LLVM-IR module to an object file";
}
- sharedLibData = sharedLibDataStatus.value();
- } else {
- auto sharedLibDataStatus = linkLLVMAOTObjectsWithLLDElf(objData);
- if (!sharedLibDataStatus.ok()) {
- return targetOp.emitError(
- "Can't link executable and generate target dylib using "
- "lld::elf::link");
- }
- sharedLibData = sharedLibDataStatus.value();
+ auto objectFile = Artifact::createTemporary(libraryName, "obj");
+ auto &os = objectFile.outputFile->os();
+ os << objectData;
+ os.flush();
+ os.close();
+ objectFiles.push_back(std::move(objectFile));
}
- dyLibExecutableDef.library_embedded = {sharedLibData.begin(),
- sharedLibData.end()};
- ::flatbuffers::FlatBufferBuilder fbb;
- auto executableOffset =
- iree::DyLibExecutableDef::Pack(fbb, &dyLibExecutableDef);
- iree::FinishDyLibExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ // Link the generated object files into a dylib.
+ auto linkArtifactsOr =
+ linkerTool->linkDynamicLibrary(libraryName, objectFiles);
+ if (!linkArtifactsOr.hasValue()) {
+ return mlir::emitError(targetOp.getLoc())
+ << "failed to link executable and generate target dylib using "
+ "linker toolchain "
+ << linkerTool->getToolPath();
+ }
+ auto &linkArtifacts = linkArtifactsOr.getValue();
+ if (options_.keepLinkerArtifacts) {
+ mlir::emitRemark(targetOp.getLoc())
+ << "Linker artifacts for " << targetOp.getName() << " preserved:\n"
+ << " " << linkArtifacts.libraryFile.path;
+ linkArtifacts.keepAllFiles();
+ }
+
+ // Embed debug symbols at the end of the flatbuffer by adding first in the
+ // bottoms-up builder.
+ FlatbufferBuilder builder;
+ flatbuffers_uint8_vec_ref_t debugDatabaseRef = 0;
+ flatbuffers_string_ref_t debugDatabaseFilenameRef = 0;
+ if (options_.debugSymbols && linkArtifacts.debugFile.outputFile) {
+ debugDatabaseRef = builder.streamUint8Vec([&](raw_ostream &stream) {
+ return linkArtifacts.debugFile.readInto(stream);
+ });
+ debugDatabaseFilenameRef = builder.createString(
+ llvm::sys::path::filename(linkArtifacts.debugFile.path));
+ }
+
+ // Embed entire dynamic library output.
+ flatbuffers_uint8_vec_ref_t libraryEmbeddedRef =
+ builder.streamUint8Vec([&](raw_ostream &stream) {
+ return linkArtifacts.libraryFile.readInto(stream);
+ });
+ if (!libraryEmbeddedRef) {
+ return targetOp.emitError() << "failed to read back dylib temp file at "
+ << linkArtifacts.libraryFile.path;
+ }
+
+ // Entry point names up from.
+ // TODO(#3580): these won't be needed in the executable_library world.
+ auto entryPointsRef = builder.createStringVec(llvm::map_range(
+ targetOp.getBlock().getOps<ExecutableEntryPointOp>(),
+ [&](ExecutableEntryPointOp op) { return op.getName(); }));
+
+ iree_DyLibExecutableDef_start_as_root(builder);
+ iree_DyLibExecutableDef_entry_points_add(builder, entryPointsRef);
+ iree_DyLibExecutableDef_library_embedded_add(builder, libraryEmbeddedRef);
+ iree_DyLibExecutableDef_debug_database_filename_add(
+ builder, debugDatabaseFilenameRef);
+ iree_DyLibExecutableDef_debug_database_embedded_add(builder,
+ debugDatabaseRef);
+ iree_DyLibExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::DyLib),
- std::move(bytes));
-
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
};
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h
deleted file mode 100644
index d6d5220..0000000
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h
+++ /dev/null
@@ -1,41 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-//
-
-#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LLVMAOTTARGETLINKER_H_
-#define IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LLVMAOTTARGETLINKER_H_
-
-#include <string>
-
-#include "iree/base/status.h"
-#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-
-// Calls linker tool to link objData and returns shared library blob.
-iree::StatusOr<std::string> linkLLVMAOTObjects(
- const std::string& linkerToolPath, const std::string& objData);
-// Use lld::elf::link for linking objData and returns shared library blob.
-iree::StatusOr<std::string> linkLLVMAOTObjectsWithLLDElf(
- const std::string& objData);
-
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LLVMAOTTARGETLINKER_H_
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp
new file mode 100644
index 0000000..16a8526
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp
@@ -0,0 +1,126 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
+
+#define DEBUG_TYPE "llvmaot-linker"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// static
+Artifact Artifact::createTemporary(StringRef prefix, StringRef suffix) {
+ llvm::SmallString<32> filePath;
+ if (std::error_code error =
+ llvm::sys::fs::createTemporaryFile(prefix, suffix, filePath)) {
+ llvm::errs() << "failed to generate temporary file: " << error.message();
+ return {};
+ }
+ std::error_code error;
+ auto file = std::make_unique<llvm::ToolOutputFile>(filePath, error,
+ llvm::sys::fs::OF_None);
+ if (error) {
+ llvm::errs() << "failed to open temporary file '" << filePath
+ << "': " << error.message();
+ return {};
+ }
+ return {filePath.str().str(), std::move(file)};
+}
+
+// static
+Artifact Artifact::createVariant(StringRef basePath, StringRef suffix) {
+ SmallString<32> filePath(basePath);
+ llvm::sys::path::replace_extension(filePath, suffix);
+ std::error_code error;
+ auto file = std::make_unique<llvm::ToolOutputFile>(filePath, error,
+ llvm::sys::fs::OF_Append);
+ if (error) {
+ llvm::errs() << "failed to open temporary file '" << filePath
+ << "': " << error.message();
+ return {};
+ }
+ return {filePath.str().str(), std::move(file)};
+}
+
+Optional<std::vector<int8_t>> Artifact::read() const {
+ auto fileData = llvm::MemoryBuffer::getFile(path);
+ if (!fileData) {
+ llvm::errs() << "failed to load library output file '" << path << "'";
+ return llvm::None;
+ }
+ auto sourceBuffer = fileData.get()->getBuffer();
+ std::vector<int8_t> resultBuffer(sourceBuffer.size());
+ std::memcpy(resultBuffer.data(), sourceBuffer.data(), sourceBuffer.size());
+ return resultBuffer;
+}
+
+bool Artifact::readInto(raw_ostream &targetStream) const {
+ // NOTE: we could make this much more efficient if we read in the file a
+ // chunk at a time and piped it along to targetStream. I couldn't find
+ // anything in LLVM that did this, for some crazy reason, but since we are
+ // dealing with binaries that can be 10+MB here it'd be nice if we could avoid
+ // reading them all into memory.
+ auto fileData = llvm::MemoryBuffer::getFile(path);
+ if (!fileData) {
+ llvm::errs() << "failed to load library output file '" << path << "'";
+ return false;
+ }
+ auto sourceBuffer = fileData.get()->getBuffer();
+ targetStream.write(sourceBuffer.data(), sourceBuffer.size());
+ return true;
+}
+
+void Artifact::close() { outputFile->os().close(); }
+
+void Artifacts::keepAllFiles() {
+ if (libraryFile.outputFile) libraryFile.outputFile->keep();
+ if (debugFile.outputFile) debugFile.outputFile->keep();
+ for (auto &file : otherFiles) {
+ file.outputFile->keep();
+ }
+}
+
+std::string LinkerTool::getToolPath() const {
+ char *linkerPath = std::getenv("IREE_LLVMAOT_LINKER_PATH");
+ if (linkerPath) {
+ return std::string(linkerPath);
+ } else {
+ return "";
+ }
+}
+
+LogicalResult LinkerTool::runLinkCommand(const std::string &commandLine) {
+ LLVM_DEBUG(llvm::dbgs() << "Running linker command:\n" << commandLine);
+#if defined(_MSC_VER)
+ // It's easy to run afoul of quoting rules on Windows (such as when using
+ // spaces in the linker environment variable). See:
+ // https://stackoverflow.com/a/9965141
+ auto quotedCommandLine = "\"" + commandLine + "\"";
+ int exitCode = system(quotedCommandLine.c_str());
+#else
+ int exitCode = system(commandLine.c_str());
+#endif // _MSC_VER
+ if (exitCode == 0) return success();
+ llvm::errs() << "Linking failed; command line returned exit code " << exitCode
+ << ":\n\n"
+ << commandLine << "\n\n";
+ return failure();
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h
new file mode 100644
index 0000000..8cf4484
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h
@@ -0,0 +1,120 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LINKERTOOL_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LINKERTOOL_H_
+
+#include <string>
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h"
+#include "llvm/ADT/Triple.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+struct Artifact {
+ // Creates an output file path/container pair.
+ // By default the file will be deleted when the link completes; callers must
+ // use llvm::ToolOutputFile::keep() to prevent deletion upon success (or if
+ // leaving artifacts for debugging).
+ static Artifact createTemporary(StringRef prefix, StringRef suffix);
+
+ // Creates an output file derived from the given file's path with a new
+ // suffix.
+ static Artifact createVariant(StringRef basePath, StringRef suffix);
+
+ Artifact() = default;
+ Artifact(std::string path, std::unique_ptr<llvm::ToolOutputFile> outputFile)
+ : path(std::move(path)), outputFile(std::move(outputFile)) {}
+
+ std::string path;
+ std::unique_ptr<llvm::ToolOutputFile> outputFile;
+
+ // Reads the artifact file contents as bytes.
+ Optional<std::vector<int8_t>> read() const;
+
+ // Reads the artifact file and writes it into the given |stream|.
+ bool readInto(raw_ostream& targetStream) const;
+
+ // Closes the ostream of the file while preserving the temporary entry on
+ // disk. Use this if files need to be modified by external tools that may
+ // require exclusive access.
+ void close();
+};
+
+struct Artifacts {
+ // File containing the linked library (DLL, ELF, etc).
+ Artifact libraryFile;
+
+ // Optional file containing associated debug information (if stored
+ // separately, such as PDB files).
+ Artifact debugFile;
+
+ // Other files associated with linking.
+ SmallVector<Artifact, 4> otherFiles;
+
+ // Keeps all of the artifacts around after linking completes. Useful for
+ // debugging.
+ void keepAllFiles();
+};
+
+// Base type for linker tools that can turn object files into shared objects.
+class LinkerTool {
+ public:
+ // Gets an instance of a linker tool for the given target options. This may
+ // be a completely different toolchain than that of the host.
+ static std::unique_ptr<LinkerTool> getForTarget(
+ llvm::Triple& targetTriple, LLVMTargetOptions& targetOptions);
+
+ explicit LinkerTool(llvm::Triple targetTriple,
+ LLVMTargetOptions targetOptions)
+ : targetTriple(std::move(targetTriple)),
+ targetOptions(std::move(targetOptions)) {}
+
+ virtual ~LinkerTool() = default;
+
+ // Returns the path to the linker tool binary.
+ virtual std::string getToolPath() const;
+
+ // Configures a module prior to compilation with any additional
+ // functions/exports it may need, such as shared object initializer functions.
+ virtual LogicalResult configureModule(
+ llvm::Module* llvmModule, ArrayRef<StringRef> entryPointNames) = 0;
+
+ // Links the given object files into a dynamically loadable library.
+ // The resulting library (and other associated artifacts) will be returned on
+ // success.
+ virtual Optional<Artifacts> linkDynamicLibrary(
+ StringRef libraryName, ArrayRef<Artifact> objectFiles) = 0;
+
+ protected:
+ // Runs the given command line on the shell, logging failures.
+ LogicalResult runLinkCommand(const std::string& commandLine);
+
+ llvm::Triple targetTriple;
+ LLVMTargetOptions targetOptions;
+};
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_TARGET_LLVM_AOT_LINKERTOOL_H_
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/BUILD
index 6dadc26..bc8a6e0 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/BUILD
@@ -19,11 +19,16 @@
)
cc_library(
- name = "LLVMAOTTargetLinker_internal",
- srcs = ["LLVMAOTTargetLinker.cpp"],
+ name = "LinkerTools_internal",
+ srcs = [
+ "LinkerTools.cpp",
+ "UnixLinkerTool.cpp",
+ "WindowsLinkerTool.cpp",
+ ],
deps = [
- "//iree/base:status",
- "//iree/compiler/Dialect/HAL/Target/LLVM/AOT:LLVMAOTTargetLinker_hdrs",
+ "//iree/compiler/Dialect/HAL/Target/LLVM/AOT:LinkerTool_hdrs",
+ "@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Support",
],
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/CMakeLists.txt
index 01cb459..aafc4a5 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/CMakeLists.txt
@@ -16,12 +16,15 @@
iree_cc_library(
NAME
- LLVMAOTTargetLinker_internal
+ LinkerTools_internal
SRCS
- "LLVMAOTTargetLinker.cpp"
+ "LinkerTools.cpp"
+ "UnixLinkerTool.cpp"
+ "WindowsLinkerTool.cpp"
DEPS
+ LLVMCore
LLVMSupport
- iree::base::status
- iree::compiler::Dialect::HAL::Target::LLVM::AOT::LLVMAOTTargetLinker_hdrs
+ MLIRSupport
+ iree::compiler::Dialect::HAL::Target::LLVM::AOT::LinkerTool_hdrs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LLVMAOTTargetLinker.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LLVMAOTTargetLinker.cpp
deleted file mode 100644
index 8acd2d6..0000000
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LLVMAOTTargetLinker.cpp
+++ /dev/null
@@ -1,79 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTargetLinker.h"
-
-#include "iree/base/status.h"
-#include "llvm/Support/ToolOutputFile.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-
-iree::StatusOr<std::string> linkLLVMAOTObjects(
- const std::string& linkerToolPath, const std::string& objData) {
- llvm::SmallString<32> objFilePath, dylibFilePath;
- if (std::error_code error = llvm::sys::fs::createTemporaryFile(
- "llvmaot_dylibs", "objfile", objFilePath)) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << "Failed to generate temporary file for objfile : '"
- << error.message() << "'";
- }
- if (std::error_code error = llvm::sys::fs::createTemporaryFile(
- "llvmaot_dylibs", "dylibfile", dylibFilePath)) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << "Failed to generate temporary file for dylib : '"
- << error.message() << "'";
- }
- std::error_code error;
- auto outputFile = std::make_unique<llvm::ToolOutputFile>(
- objFilePath, error, llvm::sys::fs::F_None);
- if (error) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << "Failed to open temporary objfile '" << objFilePath.c_str()
- << "' for dylib : '" << error.message() << "'";
- }
-
- outputFile->os() << objData;
- outputFile->os().flush();
-
- auto linkingCmd =
- (linkerToolPath + " -shared " + objFilePath + " -o " + dylibFilePath)
- .str();
- int systemRet = system(linkingCmd.c_str());
- if (systemRet != 0) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << linkingCmd << " failed with exit code " << systemRet;
- }
-
- auto dylibData = llvm::MemoryBuffer::getFile(dylibFilePath);
- if (!dylibData) {
- return iree::InternalErrorBuilder(IREE_LOC)
- << "Failed to read temporary dylib file '" << dylibFilePath.c_str()
- << "'";
- }
- return dylibData.get()->getBuffer().str();
-}
-
-iree::StatusOr<std::string> linkLLVMAOTObjectsWithLLDElf(
- const std::string& objData) {
- return iree::UnimplementedErrorBuilder(IREE_LOC)
- << "linkLLVMAOTObjectsWithLLD not implemented yet!";
-}
-
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LinkerTools.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LinkerTools.cpp
new file mode 100644
index 0000000..6fc4b93
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/LinkerTools.cpp
@@ -0,0 +1,43 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// TODO(benvanik): add other platforms:
+// createMacLinkerTool using ld64.lld
+// createWasmLinkerTool wasm-ld
+
+std::unique_ptr<LinkerTool> createUnixLinkerTool(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions);
+std::unique_ptr<LinkerTool> createWindowsLinkerTool(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions);
+
+// static
+std::unique_ptr<LinkerTool> LinkerTool::getForTarget(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions) {
+ if (targetTriple.isOSWindows() || targetTriple.isWindowsMSVCEnvironment()) {
+ return createWindowsLinkerTool(targetTriple, targetOptions);
+ }
+ return createUnixLinkerTool(targetTriple, targetOptions);
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp
new file mode 100644
index 0000000..5e14e8a
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp
@@ -0,0 +1,97 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#define DEBUG_TYPE "llvmaot-linker"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Unix linker (ld-like); for ELF files.
+class UnixLinkerTool : public LinkerTool {
+ public:
+ using LinkerTool::LinkerTool;
+
+ std::string getToolPath() const override {
+ auto toolPath = LinkerTool::getToolPath();
+ return toolPath.empty() ? "ld.lld" : toolPath;
+ }
+
+ LogicalResult configureModule(llvm::Module *llvmModule,
+ ArrayRef<StringRef> entryPointNames) override {
+ // Enable frame pointers to ensure that stack unwinding works, e.g. in
+ // Tracy. In principle this could also be achieved by enabling unwind
+ // tables, but we tried that and that didn't work in Tracy (which uses
+ // libbacktrace), while enabling frame pointers worked.
+ // https://github.com/google/iree/issues/3957
+ for (auto &func : *llvmModule) {
+ auto attrs = func.getAttributes();
+ attrs = attrs.addAttribute(llvmModule->getContext(),
+ llvm::AttributeList::FunctionIndex,
+ "frame-pointer", "all");
+ func.setAttributes(attrs);
+ }
+ return success();
+ }
+
+ Optional<Artifacts> linkDynamicLibrary(
+ StringRef libraryName, ArrayRef<Artifact> objectFiles) override {
+ Artifacts artifacts;
+
+ // Create the shared object name; if we only have a single input object we
+ // can just reuse that.
+ if (objectFiles.size() == 1) {
+ artifacts.libraryFile =
+ Artifact::createVariant(objectFiles.front().path, "so");
+ } else {
+ artifacts.libraryFile = Artifact::createTemporary(libraryName, "so");
+ }
+ artifacts.libraryFile.close();
+
+ SmallVector<std::string, 8> flags = {
+ getToolPath(),
+ "-shared",
+ "-o " + artifacts.libraryFile.path,
+ };
+
+ // TODO(ataei): add flags based on targetTriple.isAndroid(), like
+ // -static-libstdc++ (if this is needed, which it shouldn't be).
+
+ // Link all input objects. Note that we are not linking whole-archive as we
+ // want to allow dropping of unused codegen outputs.
+ for (auto &objectFile : objectFiles) {
+ flags.push_back(objectFile.path);
+ }
+
+ auto commandLine = llvm::join(flags, " ");
+ if (failed(runLinkCommand(commandLine))) return llvm::None;
+ return artifacts;
+ }
+};
+
+std::unique_ptr<LinkerTool> createUnixLinkerTool(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions) {
+ return std::make_unique<UnixLinkerTool>(targetTriple, targetOptions);
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/WindowsLinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/WindowsLinkerTool.cpp
new file mode 100644
index 0000000..40632e3
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/WindowsLinkerTool.cpp
@@ -0,0 +1,290 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#define DEBUG_TYPE "llvmaot-linker"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Windows linker (MSVC link.exe-like); for DLL files.
+class WindowsLinkerTool : public LinkerTool {
+ public:
+ using LinkerTool::LinkerTool;
+
+ std::string getToolPath() const override {
+ auto toolPath = LinkerTool::getToolPath();
+ return toolPath.empty() ? "lld-link" : toolPath;
+ }
+
+ LogicalResult configureModule(llvm::Module *llvmModule,
+ ArrayRef<StringRef> entryPointNames) override {
+ auto &ctx = llvmModule->getContext();
+
+ // Create a _DllMainCRTStartup replacement that does not initialize the CRT.
+ // This is required to prevent a bunch of CRT junk (locale, errno, TLS, etc)
+ // from getting emitted in such a way that it cannot be stripped by LTCG.
+ // Since we don't emit code using the CRT (beyond memset/memcpy) this is
+ // fine and can reduce binary sizes by 50-100KB.
+ //
+ // More info:
+ // https://docs.microsoft.com/en-us/cpp/build/run-time-library-behavior?view=vs-2019
+ {
+ auto dwordType = llvm::IntegerType::get(ctx, 32);
+ auto ptrType = llvm::PointerType::getUnqual(dwordType);
+ auto entry = cast<llvm::Function>(
+ llvmModule
+ ->getOrInsertFunction("iree_dll_main", dwordType, ptrType,
+ dwordType, ptrType)
+ .getCallee());
+ entry->setCallingConv(llvm::CallingConv::X86_StdCall);
+ entry->setDLLStorageClass(
+ llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+ entry->setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
+ auto *block = llvm::BasicBlock::Create(ctx, "entry", entry);
+ llvm::IRBuilder<> builder(block);
+ auto one = llvm::ConstantInt::get(dwordType, 1, false);
+ builder.CreateRet(one);
+ }
+
+ // For now we ensure that our entry points are exported (via linker
+ // directives embedded in the object file) and in a compatible calling
+ // convention.
+ // TODO(benvanik): switch to executable libraries w/ internal functions.
+ for (auto entryPointName : entryPointNames) {
+ auto *entryPointFn = llvmModule->getFunction(entryPointName);
+ entryPointFn->setCallingConv(llvm::CallingConv::X86_StdCall);
+ entryPointFn->setDLLStorageClass(
+ llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+ entryPointFn->setLinkage(
+ llvm::GlobalValue::LinkageTypes::ExternalLinkage);
+ entryPointFn->setVisibility(
+ llvm::GlobalValue::VisibilityTypes::DefaultVisibility);
+ entryPointFn->addFnAttr(llvm::Attribute::UWTable);
+ }
+
+ return success();
+ }
+
+ Optional<Artifacts> linkDynamicLibrary(
+ StringRef libraryName, ArrayRef<Artifact> objectFiles) override {
+ Artifacts artifacts;
+
+ // Create the shared object name; if we only have a single input object we
+ // can just reuse that.
+ if (objectFiles.size() == 1) {
+ artifacts.libraryFile =
+ Artifact::createVariant(objectFiles.front().path, "dll");
+ } else {
+ artifacts.libraryFile = Artifact::createTemporary(libraryName, "dll");
+ }
+
+ // link.exe doesn't like the files being opened. We don't use them as
+ // streams so close them all now before running the linker.
+ artifacts.libraryFile.close();
+
+ // We need a full path for the PDB and I hate strings in LLVM grumble.
+ SmallString<32> pdbPath(artifacts.libraryFile.path);
+ llvm::sys::path::replace_extension(pdbPath, "pdb");
+
+ SmallVector<std::string, 8> flags = {
+ getToolPath(),
+
+ // Useful when debugging linking/loading issues:
+ // "/verbose",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/dll-build-a-dll?view=vs-2019
+ // Builds a DLL and exports functions with the dllexport storage class.
+ "/dll",
+
+ // Forces a fixed timestamp to ensure files are reproducable across
+ // builds. Undocumented but accepted by both link and lld-link.
+ // https://blog.conan.io/2019/09/02/Deterministic-builds-with-C-C++.html
+ "/Brepro",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/nodefaultlib-ignore-libraries?view=vs-2019
+ // Ignore any libraries that are specified by the platform as we
+ // directly provide the ones we want.
+ "/nodefaultlib",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/incremental-link-incrementally?view=vs-2019
+ // Disable incremental linking as we are only ever linking in one-shot
+ // mode to temp files. This avoids additional file padding and ordering
+ // restrictions that enable incremental linking. Our other options will
+ // prevent incremental linking in most cases, but it doesn't hurt to be
+ // explicit.
+ "/incremental:no",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/guard-enable-guard-checks?view=vs-2019
+ // No control flow guard lookup (indirect branch verification).
+ "/guard:no",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/safeseh-image-has-safe-exception-handlers?view=vs-2019
+ // We don't want exception unwind tables in our output.
+ "/safeseh:no",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/entry-entry-point-symbol?view=vs-2019
+ // Use our entry point instead of the standard CRT one; ensures that we
+ // pull in no global state from the CRT.
+ "/entry:iree_dll_main",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/debug-generate-debug-info?view=vs-2019
+ // Copies all PDB information into the final PDB so that we can use the
+ // same PDB across multiple machines.
+ "/debug:full",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/pdb-use-program-database
+ // Generates the PDB file containing the debug information.
+ ("/pdb:" + pdbPath).str(),
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/pdbaltpath-use-alternate-pdb-path?view=vs-2019
+ // Forces the PDB we generate to be referenced in the DLL as just a
+ // relative path to the DLL itself. This allows us to move the PDBs
+ // along with the build DLLs across machines.
+ "/pdbaltpath:%_PDB%",
+
+ // https://docs.microsoft.com/en-us/cpp/build/reference/out-output-file-name?view=vs-2019
+ // Target for linker output. The base name of this path will be used for
+ // additional output files (like the map and pdb).
+ "/out:" + artifacts.libraryFile.path,
+ };
+
+ if (targetOptions.optLevel.getSpeedupLevel() >= 2 ||
+ targetOptions.optLevel.getSizeLevel() >= 2) {
+ // https://docs.microsoft.com/en-us/cpp/build/reference/opt-optimizations?view=vs-2019
+ // Enable all the fancy optimizations.
+ flags.push_back("/opt:ref,icf,lbr");
+ }
+
+ // SDK and MSVC paths.
+ // These rely on the environment variables provided by the
+ // vcvarsall or VsDevCmd ("Developer Command Prompt") scripts. They can also
+ // be manually be specified.
+ //
+ // We could also check to see if vswhere is installed and query that in the
+ // event of missing environment variables; that would eliminate the need for
+ // specifying things from for example IDEs that may not bring in the vcvars.
+ //
+ /* Example values:
+ UCRTVersion=10.0.18362.0
+ UniversalCRTSdkDir=C:\Program Files (x86)\Windows Kits\10\
+ VCToolsInstallDir=C:\Program Files (x86)\Microsoft Visual
+ Studio\2019\Preview\VC\Tools\MSVC\14.28.29304\
+ */
+ if (!getenv("VCToolsInstallDir") || !getenv("UniversalCRTSdkDir")) {
+ llvm::errs() << "required environment for lld-link/link not specified; "
+ "ensure you are building from a shell where "
+ "vcvarsall/VsDevCmd.bat/etc has been used";
+ return llvm::None;
+ }
+ const char *arch;
+ if (targetTriple.isARM() && targetTriple.isArch32Bit()) {
+ arch = "arm";
+ } else if (targetTriple.isARM()) {
+ arch = "arm64";
+ } else if (targetTriple.isX86() && targetTriple.isArch32Bit()) {
+ arch = "x86";
+ } else if (targetTriple.isX86()) {
+ arch = "x64";
+ } else {
+ llvm::errs() << "unsupported Windows target triple (no arch libs): "
+ << targetTriple.str();
+ return llvm::None;
+ }
+ flags.push_back(
+ llvm::formatv("/libpath:\"{0}\\lib\\{1}\"", "%VCToolsInstallDir%", arch)
+ .str());
+ flags.push_back(llvm::formatv("/libpath:\"{0}\\Lib\\{1}\\ucrt\\{2}\"",
+ "%UniversalCRTSdkDir%", "%UCRTVersion%", arch)
+ .str());
+ flags.push_back(llvm::formatv("/libpath:\"{0}\\Lib\\{1}\\um\\{2}\"",
+ "%UniversalCRTSdkDir%", "%UCRTVersion%", arch)
+ .str());
+
+ // We need to link against different libraries based on our configuration
+ // matrix (dynamic/static and debug/release).
+ int libIndex = 0;
+ if (targetOptions.optLevel.getSpeedupLevel() == 0) {
+ libIndex += 0; // debug
+ } else {
+ libIndex += 2; // release
+ }
+ libIndex += targetOptions.linkStatic ? 1 : 0;
+
+ // The required libraries for linking DLLs:
+ // https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-160
+ //
+ // NOTE: there are only static versions of msvcrt as it's the startup code.
+ static const char *kMSVCRTLibs[4] = {
+ /* debug/dynamic */ "msvcrtd.lib",
+ /* debug/static */ "msvcrtd.lib",
+ /* release/dynamic */ "msvcrt.lib",
+ /* release/static */ "msvcrt.lib",
+ };
+ static const char *kVCRuntimeLibs[4] = {
+ /* debug/dynamic */ "vcruntimed.lib",
+ /* debug/static */ "libvcruntimed.lib",
+ /* release/dynamic */ "vcruntime.lib",
+ /* release/static */ "libvcruntime.lib",
+ };
+ static const char *kUCRTLibs[4] = {
+ /* debug/dynamic */ "ucrtd.lib",
+ /* debug/static */ "libucrtd.lib",
+ /* release/dynamic */ "ucrt.lib",
+ /* release/static */ "libucrt.lib",
+ };
+ flags.push_back(kMSVCRTLibs[libIndex]);
+ flags.push_back(kVCRuntimeLibs[libIndex]);
+ flags.push_back(kUCRTLibs[libIndex]);
+ flags.push_back("kernel32.lib");
+
+ // Link all input objects. Note that we are not linking whole-archive as we
+ // want to allow dropping of unused codegen outputs.
+ for (auto &objectFile : objectFiles) {
+ flags.push_back(objectFile.path);
+ }
+
+ auto commandLine = llvm::join(flags, " ");
+ if (failed(runLinkCommand(commandLine))) return llvm::None;
+
+ // PDB file gets generated wtih the same path + .pdb.
+ artifacts.debugFile =
+ Artifact::createVariant(artifacts.libraryFile.path, "pdb");
+
+ // We currently discard some of the other file outputs (like the .exp
+ // listing the exported symbols) as we don't need them.
+ artifacts.otherFiles.push_back(
+ Artifact::createVariant(artifacts.libraryFile.path, "exp"));
+ artifacts.otherFiles.push_back(
+ Artifact::createVariant(artifacts.libraryFile.path, "lib"));
+
+ return artifacts;
+ }
+};
+
+std::unique_ptr<LinkerTool> createWindowsLinkerTool(
+ llvm::Triple &targetTriple, LLVMTargetOptions &targetOptions) {
+ return std::make_unique<WindowsLinkerTool>(targetTriple, targetOptions);
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
index 9c608c2..9040b96 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -80,6 +80,7 @@
"LLVMTargetOptions.h",
],
deps = [
+ "@llvm-project//llvm:MC",
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
index 1f0f77f..416f32e 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -66,6 +66,7 @@
SRCS
"LLVMTargetOptions.cpp"
DEPS
+ LLVMMC
LLVMPasses
LLVMSupport
LLVMTarget
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
index eed06b7..10b20b5 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
@@ -37,12 +37,14 @@
"LLVMIRTarget.h",
],
deps = [
+ "//iree/base:flatcc",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMBaseTarget",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMIRPasses",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
"//iree/compiler/Dialect/Shape/IR",
- "//iree/schemas:llvmir_executable_def_cc_fbs",
+ "//iree/compiler/Utils",
+ "//iree/schemas:llvmir_executable_def_c_fbs",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:TargetLLVMIR",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
index c490eb0..650ce95 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
@@ -29,11 +29,13 @@
LLVMCore
LLVMSupport
MLIRTargetLLVMIR
+ iree::base::flatcc
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Target::LLVM::LLVMBaseTarget
iree::compiler::Dialect::HAL::Target::LLVM::LLVMIRPasses
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
iree::compiler::Dialect::Shape::IR
- iree::schemas::llvmir_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::llvmir_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
index b18ecbf..9567eae 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
@@ -17,7 +17,8 @@
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/schemas/llvmir_executable_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/llvmir_executable_def_builder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TargetSelect.h"
@@ -63,15 +64,6 @@
return targetOp.emitError("Failed to translate executable to LLVM IR");
}
- // Create invocation function an populate entry_points.
- iree::LLVMIRExecutableDefT llvmIrExecutableDef;
- auto entryPointOps =
- targetOp.getBlock().getOps<IREE::HAL::ExecutableEntryPointOp>();
- for (auto entryPointOp : entryPointOps) {
- llvmIrExecutableDef.entry_points.push_back(
- std::string(entryPointOp.sym_name()));
- }
-
// LLVMIR opt passes.
auto targetMachine = createTargetMachine(options_);
if (!targetMachine) {
@@ -86,30 +78,27 @@
"Can't build LLVMIR opt passes for ExecutableOp module");
}
- // Serialize LLVM module.
- std::string bufferString;
- llvm::raw_string_ostream ostream(bufferString);
- llvmModule->print(ostream, nullptr);
- ostream.flush();
+ // Serialize LLVM module directly into flatbuffer.
+ FlatbufferBuilder builder;
+ auto bitcodeModuleRef = builder.streamUint8Vec([&](raw_ostream &stream) {
+ llvmModule->print(stream, nullptr);
+ return true;
+ });
- // Creates executable bytes.
- llvmIrExecutableDef.llvmir_module = {bufferString.begin(),
- bufferString.end()};
+ auto entryPointsRef = builder.createStringVec(llvm::map_range(
+ targetOp.getBlock().getOps<ExecutableEntryPointOp>(),
+ [&](ExecutableEntryPointOp op) { return op.sym_name(); }));
- ::flatbuffers::FlatBufferBuilder fbb;
- auto executableOffset =
- iree::LLVMIRExecutableDef::Pack(fbb, &llvmIrExecutableDef);
- iree::FinishLLVMIRExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ iree_LLVMIRExecutableDef_start_as_root(builder);
+ iree_LLVMIRExecutableDef_entry_points_add(builder, entryPointsRef);
+ iree_LLVMIRExecutableDef_bitcode_module_add(builder, bitcodeModuleRef);
+ iree_LLVMIRExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::LLVM),
- std::move(bytes));
-
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
};
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
index 23370b5..4861dd3 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.cpp
@@ -80,17 +80,40 @@
buildLLVMTransformPassPipeline(passManager);
}
+static FileLineColLoc findFirstFileLoc(Location baseLoc) {
+ if (auto loc = baseLoc.dyn_cast<FusedLoc>()) {
+ for (auto &childLoc : loc.getLocations()) {
+ auto childResult = findFirstFileLoc(childLoc);
+ if (childResult) return childResult;
+ }
+ } else if (auto loc = baseLoc.dyn_cast<FileLineColLoc>()) {
+ return loc;
+ }
+ return FileLineColLoc{};
+}
+
+static std::string guessModuleName(mlir::ModuleOp moduleOp) {
+ std::string moduleName =
+ moduleOp.getName().hasValue() ? moduleOp.getName().getValue().str() : "";
+ if (!moduleName.empty()) return moduleName;
+ FileLineColLoc loc = findFirstFileLoc(moduleOp.getLoc());
+ return llvm::sys::path::stem(loc.getFilename()).str();
+}
+
LogicalResult LLVMBaseTargetBackend::linkExecutables(mlir::ModuleOp moduleOp) {
OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
auto executableOps =
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
+ // Guess a module name, if needed, to make the output files readable.
+ auto moduleName = guessModuleName(moduleOp);
+
// Create our new "linked" hal.executable.
- std::string linkedExecutableName = llvm::formatv("linked_{0}", name());
+ std::string linkedExecutableName =
+ llvm::formatv("{0}_linked_{1}", moduleName, name());
auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
moduleOp.getLoc(), linkedExecutableName);
- SymbolTable::setSymbolVisibility(linkedExecutableOp,
- SymbolTable::Visibility::Private);
+ linkedExecutableOp.setPrivate();
// Add our hal.executable.target with an empty module.
builder.setInsertionPointToStart(linkedExecutableOp.getBody());
auto linkedTargetOp = builder.create<IREE::HAL::ExecutableTargetOp>(
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
index de66446..65cd442 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.cpp
@@ -31,16 +31,33 @@
namespace IREE {
namespace HAL {
+static llvm::CodeGenOpt::Level passBuilderOptLevelToCodeGenOptLevel(
+ const llvm::PassBuilder::OptimizationLevel &level) {
+ switch (level.getSpeedupLevel()) {
+ case 0:
+ return llvm::CodeGenOpt::None;
+ case 1:
+ return llvm::CodeGenOpt::Less;
+ case 2:
+ default:
+ return llvm::CodeGenOpt::Default;
+ case 3:
+ return llvm::CodeGenOpt::Aggressive;
+ }
+}
+
std::unique_ptr<llvm::TargetMachine> createTargetMachine(
const LLVMTargetOptions &targetOptions) {
std::string errorMessage;
auto target = llvm::TargetRegistry::lookupTarget(targetOptions.targetTriple,
errorMessage);
if (!target) return nullptr;
- // TODO(ataei): Once we have an AOT backend pass cpu and cpu-features
std::unique_ptr<llvm::TargetMachine> machine(target->createTargetMachine(
- targetOptions.targetTriple, "generic" /* cpu e.g k8*/,
- "" /* cpu features e.g avx512fma*/, targetOptions.options, {}));
+ targetOptions.targetTriple, targetOptions.targetCPU /* cpu e.g k8*/,
+ targetOptions.targetCPUFeatures /* cpu features e.g avx512fma*/,
+ targetOptions.options, {}, {},
+ passBuilderOptLevelToCodeGenOptLevel(targetOptions.optLevel),
+ /*JIT=*/false));
return machine;
}
@@ -68,10 +85,12 @@
passBuilder.registerLoopAnalyses(loopAnalysisManager);
passBuilder.crossRegisterProxies(loopAnalysisManager, functionAnalysisManager,
cGSCCAnalysisManager, moduleAnalysisManager);
- llvm::ModulePassManager modulePassManager;
- modulePassManager =
- passBuilder.buildPerModuleDefaultPipeline(options.optLevel);
- modulePassManager.run(*module, moduleAnalysisManager);
+ if (options.optLevel != llvm::PassBuilder::OptimizationLevel::O0) {
+ llvm::ModulePassManager modulePassManager;
+ modulePassManager =
+ passBuilder.buildPerModuleDefaultPipeline(options.optLevel);
+ modulePassManager.run(*module, moduleAnalysisManager);
+ }
if (llvm::verifyModule(*module)) return failure();
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp
index e71beac..649dbf0 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/MC/SubtargetFeature.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Host.h"
#include "llvm/Target/TargetOptions.h"
@@ -26,17 +27,34 @@
LLVMTargetOptions getDefaultLLVMTargetOptions() {
LLVMTargetOptions targetOptions;
+
// Host target triple.
targetOptions.targetTriple = llvm::sys::getDefaultTargetTriple();
+ targetOptions.targetCPU = llvm::sys::getHostCPUName().str();
+ {
+ llvm::SubtargetFeatures features;
+ llvm::StringMap<bool> hostFeatures;
+ if (llvm::sys::getHostCPUFeatures(hostFeatures)) {
+ for (auto &feature : hostFeatures) {
+ features.AddFeature(feature.first(), feature.second);
+ }
+ }
+ targetOptions.targetCPUFeatures = features.getString();
+ }
+
// LLVM loop optimization options.
targetOptions.pipelineTuningOptions.LoopInterleaving = true;
targetOptions.pipelineTuningOptions.LoopVectorization = true;
targetOptions.pipelineTuningOptions.LoopUnrolling = true;
+
// LLVM SLP Auto vectorizer.
targetOptions.pipelineTuningOptions.SLPVectorization = true;
+
// LLVM -O3.
+ // TODO(benvanik): add an option for this.
targetOptions.optLevel = llvm::PassBuilder::OptimizationLevel::O3;
targetOptions.options.FloatABIType = llvm::FloatABI::Hard;
+
return targetOptions;
}
@@ -46,16 +64,57 @@
static llvm::cl::opt<std::string> clTargetTriple(
"iree-llvm-target-triple", llvm::cl::desc("LLVM target machine triple"),
llvm::cl::init(llvmTargetOptions.targetTriple));
- static llvm::cl::opt<bool> clSoftFloat(
- "iree-llvm-enable-msoft-float-abi",
+ static llvm::cl::opt<std::string> clTargetCPU(
+ "iree-llvm-target-cpu",
+ llvm::cl::desc(
+ "LLVM target machine CPU; use 'host' for your host native CPU"),
+ llvm::cl::init("generic"));
+ static llvm::cl::opt<std::string> clTargetCPUFeatures(
+ "iree-llvm-target-cpu-features",
+ llvm::cl::desc("LLVM target machine CPU features; use 'host' for your "
+ "host native CPU"),
+ llvm::cl::init(""));
+ llvmTargetOptions.targetTriple = clTargetTriple;
+ if (clTargetCPU != "host") {
+ llvmTargetOptions.targetCPU = clTargetCPU;
+ }
+ if (clTargetCPUFeatures != "host") {
+ llvmTargetOptions.targetCPUFeatures = clTargetCPUFeatures;
+ }
+
+ static llvm::cl::opt<llvm::FloatABI::ABIType> clTargetFloatABI(
+ "iree-llvm-target-float-abi",
llvm::cl::desc("LLVM target codegen enables soft float abi e.g "
"-mfloat-abi=softfp"),
- llvm::cl::init(false));
+ llvm::cl::init(llvmTargetOptions.options.FloatABIType),
+ llvm::cl::values(
+ clEnumValN(llvm::FloatABI::Default, "default", "Default (softfp)"),
+ clEnumValN(llvm::FloatABI::Soft, "soft",
+ "Software floating-point emulation"),
+ clEnumValN(llvm::FloatABI::Hard, "hard",
+ "Hardware floating-point instructions")));
+ llvmTargetOptions.options.FloatABIType = clTargetFloatABI;
- llvmTargetOptions.targetTriple = clTargetTriple;
- if (clSoftFloat) {
- llvmTargetOptions.options.FloatABIType = llvm::FloatABI::Soft;
- }
+ static llvm::cl::opt<bool> clDebugSymbols(
+ "iree-llvm-debug-symbols",
+ llvm::cl::desc("Generate and embed debug information (DWARF, PDB, etc)"),
+ llvm::cl::init(llvmTargetOptions.debugSymbols));
+ llvmTargetOptions.debugSymbols = clDebugSymbols;
+
+ static llvm::cl::opt<bool> clLinkStatic(
+ "iree-llvm-link-static",
+ llvm::cl::desc(
+ "Links system libraries into binaries statically to isolate them "
+ "from platform dependencies needed at runtime"),
+ llvm::cl::init(llvmTargetOptions.linkStatic));
+ llvmTargetOptions.linkStatic = clLinkStatic;
+
+ static llvm::cl::opt<bool> clKeepLinkerArtifacts(
+ "iree-llvm-keep-linker-artifacts",
+ llvm::cl::desc("Keep LLVM linker target artifacts (.so/.dll/etc)"),
+ llvm::cl::init(llvmTargetOptions.keepLinkerArtifacts));
+ llvmTargetOptions.keepLinkerArtifacts = clKeepLinkerArtifacts;
+
return llvmTargetOptions;
}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h
index 4893566..ba454b6 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMTargetOptions.h
@@ -24,10 +24,28 @@
namespace HAL {
struct LLVMTargetOptions {
+ // Target machine configuration.
+ std::string targetTriple;
+ std::string targetCPU;
+ std::string targetCPUFeatures;
+
llvm::PipelineTuningOptions pipelineTuningOptions;
llvm::PassBuilder::OptimizationLevel optLevel;
llvm::TargetOptions options;
- std::string targetTriple;
+
+ // Include debug information in output files (PDB, DWARF, etc).
+ // Though this can be set independently from the optLevel (so -O3 with debug
+ // information is valid) it may significantly change the output program
+ // and benchmarking
+ bool debugSymbols = true;
+
+ // Link any required runtime libraries into the produced binaries statically.
+ // This increases resulting binary size but enables the binaries to be used on
+ // any machine without requiring matching system libraries to be installed.
+ bool linkStatic = false;
+
+ // True to keep linker artifacts for debugging.
+ bool keepLinkerArtifacts = false;
};
// Returns LLVMTargetOptions struct intialized with the
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir b/iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir
similarity index 90%
rename from iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir
rename to iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir
index 0a66d94..41f1939 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/test/binaryop_test.mlir
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir
@@ -11,7 +11,7 @@
}
}
-// CHECK-LABEL: hal.executable @linked_llvm_ir
+// CHECK-LABEL: hal.executable @binary_op_linked_llvm_ir
// CHECK-DAG: hal.executable.binary attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = 1280071245 : i32} {
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir b/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir
index 6744a70..1a48518 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir
@@ -11,7 +11,7 @@
}
}
-// CHECK-LABEL: hal.executable @linked_llvm_ir
+// CHECK-LABEL: hal.executable @matmul_op_linked_llvm_ir
// CHECK-DAG: hal.executable.binary attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = 1280071245 : i32} {
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt
index 60352f0..2934ce9 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt
@@ -32,10 +32,11 @@
MLIRSPIRV
MLIRSPIRVSerialization
MLIRSupport
- flatbuffers
+ iree::base::flatcc
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target::SPIRVCommon
- iree::schemas::metal_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::metal_executable_def_c_fbs
PUBLIC
)
@@ -48,8 +49,7 @@
"SPIRVToMSL.cpp"
DEPS
LLVMSupport
+ MLIRSupport
spirv-cross-msl
- INCLUDES
- ${PROJECT_SOURCE_DIR}/third_party
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
index 612d93b..58e2172 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -14,12 +14,12 @@
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/schemas/metal_executable_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/metal_executable_def_builder.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
@@ -50,16 +50,6 @@
spirv::getDefaultResourceLimits(context));
}
-// Returns a list of entry point names matching the expected export ordinals.
-static std::vector<std::string> populateEntryPointNames(
- spirv::ModuleOp spvModuleOp) {
- std::vector<std::string> entryPointNames;
- spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
- entryPointNames.push_back(std::string(entryPointOp.fn()));
- });
- return entryPointNames;
-}
-
class MetalSPIRVTargetBackend : public SPIRVTargetBackend {
public:
MetalSPIRVTargetBackend(MetalSPIRVTargetOptions options)
@@ -88,17 +78,18 @@
// The runtime use ordinals instead of names but Metal requires function
// names for constructing pipeline states. Get an ordered list of the entry
// point names.
- std::vector<std::string> entryPoints;
+ SmallVector<StringRef, 8> entryPointNames;
if (auto scheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
iree_compiler::getEntryPointScheduleAttrName())) {
// We have multiple entry points in this module. Make sure the order
// specified in the schedule attribute is respected.
for (Attribute entryPoint : scheduleAttr) {
- entryPoints.emplace_back(
- entryPoint.cast<StringAttr>().getValue().str());
+ entryPointNames.push_back(entryPoint.cast<StringAttr>().getValue());
}
} else {
- entryPoints = populateEntryPointNames(spvModuleOp);
+ spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
+ entryPointNames.push_back(entryPointOp.fn());
+ });
}
// 1. Serialize the spirv::ModuleOp into binary format.
@@ -109,7 +100,7 @@
// 2. Cross compile SPIR-V to MSL source code.
llvm::SmallVector<MetalShader, 2> mslShaders;
- for (const std::string &entryPoint : entryPoints) {
+ for (const auto &entryPoint : entryPointNames) {
llvm::Optional<MetalShader> mslShader = crossCompileSPIRVToMSL(
// We can use ArrayRef here given spvBinary reserves 0 bytes on stack.
llvm::makeArrayRef(spvBinary.data(), spvBinary.size()), entryPoint);
@@ -129,30 +120,32 @@
// to invoke them in C++.
// 4. Pack the MTLLibrary and metadata into a flatbuffer.
- iree::MetalExecutableDefT metalExecutableDef;
- metalExecutableDef.entry_points = entryPoints;
- for (auto &shader : mslShaders) {
- metalExecutableDef.shader_sources.push_back(std::move(shader.source));
- const auto &sizes = shader.threadgroupSize;
- metalExecutableDef.threadgroup_sizes.push_back(
- {sizes.x, sizes.y, sizes.z});
- }
+ FlatbufferBuilder builder;
- // Pack the executable definition and get the bytes with the proper header.
- // The header is used to verify the contents at runtime.
- ::flatbuffers::FlatBufferBuilder fbb;
- auto executableOffset =
- iree::MetalExecutableDef::Pack(fbb, &metalExecutableDef);
- iree::FinishMetalExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ auto shaderSourcesRef = builder.createStringVec(llvm::map_range(
+ mslShaders, [&](const MetalShader &shader) { return shader.source; }));
+
+ iree_MetalThreadgroupSize_vec_start(builder);
+ for (auto &shader : mslShaders) {
+ iree_MetalThreadgroupSize_vec_push_create(
+ builder, shader.threadgroupSize.x, shader.threadgroupSize.y,
+ shader.threadgroupSize.z);
+ }
+ auto threadgroupSizesRef = iree_MetalThreadgroupSize_vec_end(builder);
+
+ auto entryPointNamesRef = builder.createStringVec(entryPointNames);
+
+ iree_MetalExecutableDef_start_as_root(builder);
+ iree_MetalExecutableDef_entry_points_add(builder, entryPointNamesRef);
+ iree_MetalExecutableDef_threadgroup_sizes_add(builder, threadgroupSizesRef);
+ iree_MetalExecutableDef_shader_sources_add(builder, shaderSourcesRef);
+ iree_MetalExecutableDef_end_as_root(builder);
// 5. Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::Metal),
- std::move(bytes));
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
index 5d89aa3..0d23fd6 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
@@ -19,7 +19,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "spirv_cross/spirv_msl.hpp"
+#include "third_party/spirv_cross/spirv_msl.hpp"
#define DEBUG_TYPE "spirv-to-msl"
@@ -32,9 +32,9 @@
using CompilerMSL::CompilerMSL;
MetalShader::ThreadGroupSize getWorkgroupSizeForEntryPoint(
- const std::string& entryName) {
+ StringRef entryName) {
const auto& entryPoint = get_entry_point(
- entryName, spv::ExecutionModel::ExecutionModelGLCompute);
+ entryName.str(), spv::ExecutionModel::ExecutionModelGLCompute);
const auto& workgroupSize = entryPoint.workgroup_size;
// TODO(antiagainst): support specialization constant.
if (workgroupSize.constant != 0) return {0, 0, 0};
@@ -104,13 +104,13 @@
} // namespace
llvm::Optional<MetalShader> crossCompileSPIRVToMSL(
- llvm::ArrayRef<uint32_t> spvBinary, const std::string& entryPoint) {
+ llvm::ArrayRef<uint32_t> spvBinary, StringRef entryPoint) {
SPIRVToMSLCompiler spvCrossCompiler(spvBinary.data(), spvBinary.size());
// All spirv-cross operations work on the current entry point. It should be
// set right after the cross compiler construction.
spvCrossCompiler.set_entry_point(
- entryPoint, spv::ExecutionModel::ExecutionModelGLCompute);
+ entryPoint.str(), spv::ExecutionModel::ExecutionModelGLCompute);
// Explicitly set the argument buffer index for each SPIR-V resource variable.
auto descriptors = spvCrossCompiler.getBufferSetBindingPairs();
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h
index ab805d1..79cb609 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h
@@ -21,6 +21,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/StringRef.h"
+#include "mlir/Support/LLVM.h"
namespace mlir {
namespace iree_compiler {
@@ -37,7 +38,7 @@
// Cross compiles SPIR-V into Metal Shading Language source code for the
// compute shader with |entryPoint|. Returns llvm::None on failure.
llvm::Optional<MetalShader> crossCompileSPIRVToMSL(
- llvm::ArrayRef<uint32_t> spvBinary, const std::string& entryPoint);
+ llvm::ArrayRef<uint32_t> spvBinary, StringRef entryPoint);
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD b/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
index f3f0f08..7276df9 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
@@ -37,6 +37,7 @@
"VMLATarget.h",
],
deps = [
+ "//iree/base:flatcc",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/VM/Conversion",
@@ -45,8 +46,8 @@
"//iree/compiler/Dialect/VM/Transforms",
"//iree/compiler/Dialect/VMLA/IR:VMLADialect",
"//iree/compiler/Dialect/VMLA/Transforms",
- "//iree/schemas:vmla_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/compiler/Utils",
+ "//iree/schemas:vmla_executable_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
index 43de52b..8f6194d 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
@@ -30,7 +30,7 @@
MLIRIR
MLIRPass
MLIRSupport
- flatbuffers
+ iree::base::flatcc
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::VM::Conversion
@@ -39,6 +39,7 @@
iree::compiler::Dialect::VM::Transforms
iree::compiler::Dialect::VMLA::IR::VMLADialect
iree::compiler::Dialect::VMLA::Transforms
- iree::schemas::vmla_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::vmla_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index 564f03d..1f40ef5 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
@@ -23,11 +22,10 @@
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
-#include "iree/schemas/vmla_executable_def_generated.h"
-#include "llvm/ADT/ScopeExit.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/vmla_executable_def_builder.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
@@ -117,8 +115,7 @@
// Create our new "linked" hal.executable.
auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
moduleOp.getLoc(), "linked_vmla");
- SymbolTable::setSymbolVisibility(linkedExecutableOp,
- SymbolTable::Visibility::Private);
+ linkedExecutableOp.setPrivate();
// Add our VMLA hal.executable.target with an empty module.
builder.setInsertionPointToStart(linkedExecutableOp.getBody());
auto linkedTargetOp = builder.create<IREE::HAL::ExecutableTargetOp>(
@@ -217,35 +214,30 @@
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
- // Serialize the VM module to bytes.
- std::string byteStreamValue;
- llvm::raw_string_ostream byte_stream(byteStreamValue);
+ // Serialize the VM module to bytes directly into a flatbuffer.
+ FlatbufferBuilder builder;
IREE::VM::BytecodeTargetOptions bytecodeOptions;
- if (failed(translateModuleToBytecode(targetOp.getInnerModule(),
- bytecodeOptions, byte_stream))) {
+ auto dataRef = builder.streamUint8Vec([&](raw_ostream &stream) {
+ return succeeded(translateModuleToBytecode(targetOp.getInnerModule(),
+ bytecodeOptions, stream));
+ });
+ if (!dataRef) {
return targetOp.emitError() << "failed to serialize converted VM module";
}
// Pack the executable definition and get the bytes with the proper header.
// The header is used to verify the contents at runtime.
- ::flatbuffers::FlatBufferBuilder fbb;
- iree::VMLAExecutableDefT vmlaExecutableDef;
- vmlaExecutableDef.bytecode_module.resize(byteStreamValue.size());
- std::memcpy(vmlaExecutableDef.bytecode_module.data(),
- byteStreamValue.data(), byteStreamValue.size());
- auto executableOffset =
- iree::VMLAExecutableDef::Pack(fbb, &vmlaExecutableDef);
- iree::FinishVMLAExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ iree_VMLAExecutableDef_start_as_root(builder);
+ iree_VMLAExecutableDef_bytecode_module_add(builder, dataRef);
+ iree_VMLAExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
+ // NOTE: this snapshots the flatbuffer builder data at the time it is called
+ // and future changes will not be observed.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::VMLA),
- std::move(bytes));
-
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
index db09599..4062aa6 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
@@ -37,6 +37,7 @@
"VulkanSPIRVTarget.h",
],
deps = [
+ "//iree/base:flatcc",
"//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/LinalgToSPIRV",
"//iree/compiler/Conversion/LinalgToSPIRV:CodeGenOptionUtils",
@@ -45,8 +46,8 @@
"//iree/compiler/Dialect/HAL/Target/SPIRVCommon",
"//iree/compiler/Dialect/Vulkan/IR",
"//iree/compiler/Dialect/Vulkan/Utils",
- "//iree/schemas:spirv_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/compiler/Utils",
+ "//iree/schemas:spirv_executable_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:GPUDialect",
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
index 8fba6f5..c230f04 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
@@ -36,7 +36,7 @@
MLIRSPIRVSerialization
MLIRSupport
MLIRVector
- flatbuffers
+ iree::base::flatcc
iree::compiler::Conversion::Common
iree::compiler::Conversion::LinalgToSPIRV
iree::compiler::Conversion::LinalgToSPIRV::CodeGenOptionUtils
@@ -45,6 +45,7 @@
iree::compiler::Dialect::HAL::Target::SPIRVCommon
iree::compiler::Dialect::Vulkan::IR
iree::compiler::Dialect::Vulkan::Utils
- iree::schemas::spirv_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::spirv_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index b678f36..166f979 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
@@ -23,7 +22,8 @@
#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
#include "iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/spirv_executable_def_builder.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -86,16 +86,6 @@
return {};
}
-// Returns a list of entry point names matching the expected export ordinals.
-static std::vector<std::string> populateEntryPointNames(
- spirv::ModuleOp spvModuleOp) {
- std::vector<std::string> entryPointNames;
- spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
- entryPointNames.push_back(std::string(entryPointOp.fn()));
- });
- return entryPointNames;
-}
-
class VulkanSPIRVTargetBackend : public SPIRVTargetBackend {
public:
VulkanSPIRVTargetBackend(VulkanSPIRVTargetOptions options)
@@ -135,53 +125,47 @@
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
- iree::SpirVExecutableDefT spirvExecutableDef;
-
ModuleOp innerModuleOp = targetOp.getInnerModule();
auto spvModuleOp = *innerModuleOp.getOps<spirv::ModuleOp>().begin();
+ // Serialize the spirv::ModuleOp into the binary that we will embed in the
+ // final flatbuffer.
+ FlatbufferBuilder builder;
+ SmallVector<uint32_t, 256> spvBinary;
+ if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) {
+ return targetOp.emitError() << "failed to serialize spv.module";
+ }
+ auto spvCodeRef = flatbuffers_uint32_vec_create(builder, spvBinary.data(),
+ spvBinary.size());
+
// The sequencer and runtime use ordinals instead of names. We provide the
// list of entry point names here that are then passed in
// VkShaderModuleCreateInfo.
+ SmallVector<StringRef, 8> entryPointNames;
if (auto scheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
iree_compiler::getEntryPointScheduleAttrName())) {
// We have multiple entry points in this module. Make sure the order
// specified in the schedule attribute is respected.
for (Attribute entryPoint : scheduleAttr) {
- spirvExecutableDef.entry_points.emplace_back(
- entryPoint.cast<StringAttr>().getValue().str());
+ entryPointNames.push_back(entryPoint.cast<StringAttr>().getValue());
}
} else {
- spirvExecutableDef.entry_points = populateEntryPointNames(spvModuleOp);
+ spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
+ entryPointNames.push_back(entryPointOp.fn());
+ });
}
+ auto entryPointsRef = builder.createStringVec(entryPointNames);
- // Serialize the spirv::ModuleOp into the binary that we will embed in the
- // final flatbuffer.
- SmallVector<uint32_t, 256> spvBinary;
- if (failed(spirv::serialize(spvModuleOp, spvBinary))) {
- return targetOp.emitError() << "failed to serialize spv.module";
- }
- spirvExecutableDef.code = {spvBinary.begin(), spvBinary.end()};
- if (spirvExecutableDef.code.empty()) {
- return targetOp.emitError()
- << "failed to translate and serialize SPIR-V executable";
- }
-
- // Pack the executable definition and get the bytes with the proper header.
- // The header is used to verify the contents at runtime.
- ::flatbuffers::FlatBufferBuilder fbb;
- auto executableOffset =
- iree::SpirVExecutableDef::Pack(fbb, &spirvExecutableDef);
- iree::FinishSpirVExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ iree_SpirVExecutableDef_start_as_root(builder);
+ iree_SpirVExecutableDef_entry_points_add(builder, entryPointsRef);
+ iree_SpirVExecutableDef_code_add(builder, spvCodeRef);
+ iree_SpirVExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::SpirV),
- std::move(bytes));
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
diff --git a/iree/compiler/Dialect/HAL/Transforms/IdentifyConstantPools.cpp b/iree/compiler/Dialect/HAL/Transforms/IdentifyConstantPools.cpp
index 5969b6b..97a193e 100644
--- a/iree/compiler/Dialect/HAL/Transforms/IdentifyConstantPools.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/IdentifyConstantPools.cpp
@@ -169,7 +169,7 @@
.create<ConstantPoolOp>(moduleBuilder.getUnknownLoc(),
poolName, bufferConstraints);
moduleSymbolTable.insert(poolOp, moduleBuilder.getInsertionPoint());
- SymbolTable::setSymbolVisibility(poolOp, SymbolTable::Visibility::Private);
+ poolOp.setPrivate();
// Replace each variable and keep track of the mapping from variable->value.
// This allows us to do one run through the module to replace usages as a
@@ -187,8 +187,7 @@
// Create the constant in the pool.
auto valueOp = poolBuilder.create<ConstantPoolValueOp>(
variableOp.getLoc(), variableOp.getName(), value);
- SymbolTable::setSymbolVisibility(valueOp,
- SymbolTable::Visibility::Nested);
+ valueOp.setNested();
// If the variable is an immutable constant and used in compatible
// ways we can turn them into constant loads instead. These will avoid
@@ -245,8 +244,7 @@
auto initializerFunc = moduleBuilder.create<FuncOp>(
variableOp.getLoc(), initializerName,
moduleBuilder.getFunctionType({}, {variableOp.type()}));
- SymbolTable::setSymbolVisibility(initializerFunc,
- SymbolTable::Visibility::Private);
+ initializerFunc.setPrivate();
variableOp.removeAttr("initial_value");
variableOp.setAttr("initializer",
moduleBuilder.getSymbolRefAttr(initializerFunc));
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp
index d5e6b16..190fc9c 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp
@@ -90,8 +90,7 @@
storageOp.getLoc(), variableName, /*isMutable=*/false,
IREE::HAL::BufferType::get(context));
moduleSymbolTable.insert(variableOp, insertionPoint);
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Private);
+ variableOp.setPrivate();
// Find all the spans in the pool that map into this storage buffer so that
// we can update them with their runtime offsets. Note that since we are
@@ -123,8 +122,7 @@
auto initializerFunc = FuncOp::create(
storageOp.getLoc(), initializerName,
builder.getFunctionType({}, {IREE::HAL::BufferType::get(context)}));
- SymbolTable::setSymbolVisibility(initializerFunc,
- SymbolTable::Visibility::Private);
+ initializerFunc.setPrivate();
auto funcBuilder = OpBuilder::atBlockBegin(initializerFunc.addEntryBlock());
@@ -180,8 +178,7 @@
variableLoc, variableName, /*isMutable=*/false,
IREE::HAL::BufferType::get(context));
moduleSymbolTable.insert(variableOp, insertionPoint);
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Private);
+ variableOp.setPrivate();
// Compute the ranges for all the splats at runtime and the required buffer
// size based on the constraints provided.
@@ -232,8 +229,7 @@
auto initializerFunc = FuncOp::create(
variableLoc, initializerName,
builder.getFunctionType({}, {IREE::HAL::BufferType::get(context)}));
- SymbolTable::setSymbolVisibility(initializerFunc,
- SymbolTable::Visibility::Private);
+ initializerFunc.setPrivate();
auto funcBuilder = OpBuilder::atBlockBegin(initializerFunc.addEntryBlock());
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 3c08282..5f2747e 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -147,11 +147,8 @@
auto thunkFuncType = FunctionType::get({}, {}, clonedFuncOp.getContext());
auto thunkFuncOp = FuncOp::create(clonedFuncOp.getLoc(),
clonedFuncOp.getName(), thunkFuncType);
- SymbolTable::setSymbolVisibility(thunkFuncOp,
- SymbolTable::Visibility::Public);
clonedFuncOp.setName((clonedFuncOp.getName() + "_impl").str());
- SymbolTable::setSymbolVisibility(clonedFuncOp,
- SymbolTable::Visibility::Private);
+ clonedFuncOp.setPrivate();
clonedFuncOp.getParentRegion()->getBlocks().front().push_front(thunkFuncOp);
// For now we only support tensor types, so bindings are in order.
@@ -286,8 +283,7 @@
// Copy interface bindings into the target module so symbol references work.
auto inlinedInterfaceOp = interfaceOp.clone();
- SymbolTable::setSymbolVisibility(inlinedInterfaceOp,
- SymbolTable::Visibility::Private);
+ inlinedInterfaceOp.setPrivate();
targetOp.getInnerModule().push_back(inlinedInterfaceOp);
}
return success();
@@ -369,8 +365,7 @@
builder.setInsertionPointAfter(sourceOp);
auto exectuableOp = builder.create<IREE::HAL::ExecutableOp>(
sourceOp.getLoc(), sourceOp.getName());
- SymbolTable::setSymbolVisibility(exectuableOp,
- SymbolTable::Visibility::Private);
+ exectuableOp.setPrivate();
// Add IO ops to define the bindings and how parameters are passed.
auto interfaceOp = declareInterfaceIO(sourceOp, exectuableOp);
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index 1908bef..48331c0 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -99,8 +99,7 @@
auto executableType = ExecutableType::get(executableOp.getContext());
auto variableOp = moduleBuilder.create<VariableOp>(
executableOp.getLoc(), symbolName, /*isMutable=*/true, executableType);
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Private);
+ variableOp.setPrivate();
executableCache_.try_emplace(executableOp.sym_name(), variableOp);
return variableOp;
}
@@ -121,14 +120,12 @@
loc, symbolName,
/*isMutable=*/false, layoutType, StringRef(initializerName),
llvm::None);
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Private);
+ variableOp.setPrivate();
descriptorSetLayoutCache_.try_emplace(bindingsAttr, variableOp);
auto initializerOp = moduleBuilder.create<FuncOp>(
loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}));
- SymbolTable::setSymbolVisibility(initializerOp,
- SymbolTable::Visibility::Private);
+ initializerOp.setPrivate();
auto *block = initializerOp.addEntryBlock();
OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
auto deviceValue = blockBuilder.createOrFold<ExSharedDeviceOp>(loc);
@@ -165,14 +162,12 @@
auto variableOp = moduleBuilder.create<VariableOp>(
loc, symbolName, /*isMutable=*/false, layoutType,
StringRef(initializerName), llvm::None);
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Private);
+ variableOp.setPrivate();
executableLayoutCache_.try_emplace(setLayoutsArrayAttr, variableOp);
auto initializerOp = moduleBuilder.create<FuncOp>(
loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType}));
- SymbolTable::setSymbolVisibility(initializerOp,
- SymbolTable::Visibility::Private);
+ initializerOp.setPrivate();
auto *block = initializerOp.addEntryBlock();
OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
SmallVector<Value, 4> setLayoutValues;
@@ -208,14 +203,12 @@
// TODO(#1146): we define this as public right now to ensure it remains
// after DCE.
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Public);
+ // variableOp.setPublic();
auto initializerOp = moduleBuilder.create<FuncOp>(
loc, initializerName,
moduleBuilder.getFunctionType({}, {executableCacheType}));
- SymbolTable::setSymbolVisibility(initializerOp,
- SymbolTable::Visibility::Private);
+ initializerOp.setPrivate();
auto *block = initializerOp.addEntryBlock();
OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
auto deviceValue = blockBuilder.createOrFold<ExSharedDeviceOp>(loc);
diff --git a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
index 7a90486..16fe669 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -73,14 +73,12 @@
auto initializerOp = moduleBuilder.create<FuncOp>(
fusedLoc, variableName + "_initializer",
moduleBuilder.getFunctionType({}, {moduleBuilder.getI1Type()}));
- SymbolTable::setSymbolVisibility(initializerOp,
- SymbolTable::Visibility::Private);
+ initializerOp.setPrivate();
moduleBuilder.setInsertionPoint(initializerOp);
auto variableOp = moduleBuilder.create<IREE::HAL::VariableOp>(
fusedLoc, variableName,
/*isMutable=*/false, initializerOp);
- SymbolTable::setSymbolVisibility(variableOp,
- SymbolTable::Visibility::Private);
+ variableOp.setPrivate();
moduleBuilder.setInsertionPointAfter(initializerOp);
auto funcBuilder = OpBuilder::atBlockBegin(initializerOp.addEntryBlock());
diff --git a/iree/compiler/Dialect/HAL/Transforms/PackConstantPoolStorage.cpp b/iree/compiler/Dialect/HAL/Transforms/PackConstantPoolStorage.cpp
index ae2bc46..474f3f6 100644
--- a/iree/compiler/Dialect/HAL/Transforms/PackConstantPoolStorage.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/PackConstantPoolStorage.cpp
@@ -78,8 +78,7 @@
auto splatOp = builder.create<ConstantPoolSplatOp>(
splatValueOp.getLoc(), splatValueOp.getName(), splatValueOp.value(),
SymbolRefAttr{}, ByteRangeAttr{});
- SymbolTable::setSymbolVisibility(splatOp,
- SymbolTable::Visibility::Nested);
+ splatOp.setNested();
splatValueOp.erase();
}
@@ -100,8 +99,7 @@
.create<ConstantStorageOp>(storageBufferLoc, "_storage",
storageBuffer.data);
poolSymbolTable.insert(storageBufferOp);
- SymbolTable::setSymbolVisibility(storageBufferOp,
- SymbolTable::Visibility::Nested);
+ storageBufferOp.setNested();
// TODO(benvanik): specify alignment attribute for file serialization
// (minStorageBufferOffsetAlignment) and get vm.rodata handling it.
@@ -120,8 +118,7 @@
APInt(64, constantSpan.length),
poolOp.getContext()),
SymbolRefAttr{}, ByteRangeAttr{});
- SymbolTable::setSymbolVisibility(spanOp,
- SymbolTable::Visibility::Nested);
+ spanOp.setNested();
valueOp.erase();
}
}
diff --git a/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h b/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
index 15fc7b0..8cf1285 100644
--- a/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
+++ b/iree/compiler/Dialect/HAL/Utils/DeviceSwitchBuilder.h
@@ -171,8 +171,6 @@
resultTypes_(resultTypes),
device_(device),
rewriter_(rewriter) {
- // FIXME: Keep the same listener as the provided builder.
- rewriter_.setListener(nullptr);
}
// Pushes a new condition onto the stack and returns a builder that must have
@@ -213,11 +211,18 @@
loc_, resultTypes_, device_, conditionAttrs, conditionArgs);
for (int i = 0; i < caseOps_.size(); ++i) {
Region &targetRegion = switchOp.getRegion(i);
- Block *entryBlock = new Block;
- targetRegion.push_back(entryBlock);
- BlockAndValueMapping mapper;
+
+ SmallVector<Type, 4> entryTypes;
for (auto arg : conditionArgs[i]) {
- mapper.map(arg, entryBlock->addArgument(arg.getType()));
+ entryTypes.push_back(arg.getType());
+ }
+ Block *entryBlock =
+ rewriter_.createBlock(&targetRegion, targetRegion.end(), entryTypes);
+ rewriter_.setInsertionPointAfter(switchOp);
+
+ BlockAndValueMapping mapper;
+ for (auto arg : llvm::zip(conditionArgs[i], entryBlock->getArguments())) {
+ mapper.map(std::get<0>(arg), std::get<1>(arg));
}
Region &sourceRegion = caseOps_[i].getRegion(0);
@@ -229,7 +234,7 @@
rewriter_.mergeBlocks(secondBlock, entryBlock,
entryBlock->getArguments().take_front(
secondBlock->getNumArguments()));
- caseOps_[i].erase();
+ rewriter_.eraseOp(caseOps_[i]);
}
return switchOp;
}
diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
index 18dc06c..65e4663 100644
--- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
+++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
@@ -21,13 +21,13 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/Optional.h"
#include "llvm/Support/Debug.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/IREE/IR/BUILD b/iree/compiler/Dialect/IREE/IR/BUILD
index f0d67e1..e04be16 100644
--- a/iree/compiler/Dialect/IREE/IR/BUILD
+++ b/iree/compiler/Dialect/IREE/IR/BUILD
@@ -65,7 +65,7 @@
td_srcs = [
":td_files",
"@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
],
)
@@ -79,6 +79,6 @@
td_srcs = [
":td_files",
"@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
],
)
diff --git a/iree/compiler/Dialect/IREE/Tools/BUILD b/iree/compiler/Dialect/IREE/Tools/BUILD
index f046f8e..0411cff 100644
--- a/iree/compiler/Dialect/IREE/Tools/BUILD
+++ b/iree/compiler/Dialect/IREE/Tools/BUILD
@@ -18,16 +18,9 @@
licenses = ["notice"], # Apache 2.0
)
-cc_library(
- name = "Tools",
+filegroup(
+ name = "GenSrcs",
srcs = [
"StructAttrGen.cpp",
],
- deps = [
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:TableGen",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TableGen",
- ],
- alwayslink = 1,
)
diff --git a/iree/compiler/Dialect/IREE/Tools/CMakeLists.txt b/iree/compiler/Dialect/IREE/Tools/CMakeLists.txt
index 620c3a1..8b864e5 100644
--- a/iree/compiler/Dialect/IREE/Tools/CMakeLists.txt
+++ b/iree/compiler/Dialect/IREE/Tools/CMakeLists.txt
@@ -13,17 +13,3 @@
# limitations under the License.
iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- Tools
- SRCS
- "StructAttrGen.cpp"
- DEPS
- LLVMSupport
- LLVMTableGen
- MLIRSupport
- MLIRTableGen
- ALWAYSLINK
- PUBLIC
-)
diff --git a/iree/compiler/Dialect/Sequence/IR/BUILD b/iree/compiler/Dialect/Sequence/IR/BUILD
index 9f4d05c..430d094 100644
--- a/iree/compiler/Dialect/Sequence/IR/BUILD
+++ b/iree/compiler/Dialect/Sequence/IR/BUILD
@@ -99,7 +99,7 @@
"//iree/compiler/Dialect/IREE/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
],
)
@@ -115,6 +115,6 @@
"//iree/compiler/Dialect/IREE/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
],
)
diff --git a/iree/compiler/Dialect/Shape/IR/BUILD b/iree/compiler/Dialect/Shape/IR/BUILD
index c4ca5e5..96fe609 100644
--- a/iree/compiler/Dialect/Shape/IR/BUILD
+++ b/iree/compiler/Dialect/Shape/IR/BUILD
@@ -77,7 +77,7 @@
"//iree/compiler/Dialect/IREE/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/ViewLikeInterface.td",
"@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
],
@@ -95,7 +95,7 @@
"//iree/compiler/Dialect/IREE/IR:td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/ViewLikeInterface.td",
"@llvm-project//mlir:include/mlir/IR/OpAsmInterface.td",
],
diff --git a/iree/compiler/Dialect/Shape/IR/Folders.cpp b/iree/compiler/Dialect/Shape/IR/Folders.cpp
index 86db03f..fc15793 100644
--- a/iree/compiler/Dialect/Shape/IR/Folders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Folders.cpp
@@ -16,6 +16,7 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Utils/PatternUtils.h"
#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
@@ -306,10 +307,23 @@
return success();
}
+LogicalResult fromExtentTensorOfCastIndexBypass(
+ FromExtentTensorOp op, FromExtentTensorOp::Adaptor operands,
+ PatternRewriter &rewriter) {
+ auto toOp = dyn_cast_or_null<IndexCastOp>(op.extent_tensor().getDefiningOp());
+ if (!toOp) {
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<FromExtentTensorOp>(op, op.getType(), toOp.in());
+ return success();
+}
+
void FromExtentTensorOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
insertGreedyPattern(patterns, context,
fromExtentTensorOfToExtentTensorIsIdentity);
+ insertGreedyPattern(patterns, context, fromExtentTensorOfCastIndexBypass);
}
//===----------------------------------------------------------------------===//
@@ -351,6 +365,7 @@
insertConversionPattern(patterns, context, safeCastCompatibleShapePattern);
insertConversionPattern(patterns, context,
fromExtentTensorOfToExtentTensorIsIdentity);
+ insertConversionPattern(patterns, context, fromExtentTensorOfCastIndexBypass);
}
} // namespace Shape
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
index 44b905e..19d3812 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
@@ -19,9 +19,9 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Optional.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
using llvm::None;
using llvm::Optional;
@@ -34,8 +34,8 @@
namespace {
template <typename HloOp>
-Value rewriteXlaBinaryElementwiseOpShape(RankedShapeType resultShape, HloOp op,
- OpBuilder &builder) {
+Value rewriteXlaNaryElementwiseOpShape(RankedShapeType resultShape, HloOp op,
+ OpBuilder &builder) {
if (!op) return nullptr;
SmallVector<Value, 4> inputOperands(op.getOperands());
return buildCastInputsToResultShape(op.getLoc(), resultShape, inputOperands,
@@ -448,7 +448,7 @@
// NOTE: Most of these *should not* be "custom ops". They should be coming
// from declarative shape information, but that doesn't exist yet.
#define INSERT_EW_OP(OpTy) \
- b.insertOpRankedShapeBuilder<OpTy>(rewriteXlaBinaryElementwiseOpShape<OpTy>);
+ b.insertOpRankedShapeBuilder<OpTy>(rewriteXlaNaryElementwiseOpShape<OpTy>);
INSERT_EW_OP(AddOp);
INSERT_EW_OP(Atan2Op);
INSERT_EW_OP(DivOp);
@@ -462,6 +462,7 @@
INSERT_EW_OP(ShiftRightLogicalOp);
INSERT_EW_OP(SubOp);
INSERT_EW_OP(CompareOp);
+ INSERT_EW_OP(ClampOp);
b.insertOpRankedShapeBuilder<SelectOp>(rewriteSelectOp);
b.insertOpRankedShapeBuilder<DotOp>(rewriteXlaDotOpShape);
diff --git a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
index 1b06f2d..d4a8ebe 100644
--- a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
@@ -18,13 +18,13 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Shape/Transforms/Passes.h b/iree/compiler/Dialect/Shape/Transforms/Passes.h
index 21541ac..765a99e 100644
--- a/iree/compiler/Dialect/Shape/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Shape/Transforms/Passes.h
@@ -19,11 +19,9 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
namespace mlir {
-
-class OpPassManager;
-
namespace iree_compiler {
namespace Shape {
diff --git a/iree/compiler/Dialect/VM/Analysis/test/register_allocation.mlir b/iree/compiler/Dialect/VM/Analysis/test/register_allocation.mlir
index b9dfe09..6f0aea7 100644
--- a/iree/compiler/Dialect/VM/Analysis/test/register_allocation.mlir
+++ b/iree/compiler/Dialect/VM/Analysis/test/register_allocation.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(test-iree-vm-register-allocation)' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(vm.func(test-iree-vm-register-allocation))' %s | IreeFileCheck %s
// CHECK-LABEL: @module
vm.module @module {
diff --git a/iree/compiler/Dialect/VM/Analysis/test/value_liveness.mlir b/iree/compiler/Dialect/VM/Analysis/test/value_liveness.mlir
index 3997bab..755dc67 100644
--- a/iree/compiler/Dialect/VM/Analysis/test/value_liveness.mlir
+++ b/iree/compiler/Dialect/VM/Analysis/test/value_liveness.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(test-iree-vm-value-liveness)' %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(vm.func(test-iree-vm-value-liveness))' %s | IreeFileCheck %s
// CHECK-LABEL: @module
vm.module @module {
diff --git a/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
index 67c4111..65c77f9 100644
--- a/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
@@ -40,8 +40,7 @@
if (!existingOp) {
auto clonedOp = cast<IREE::VM::ImportOp>(targetBuilder.clone(*importOp));
clonedOp.setName(fullName);
- SymbolTable::setSymbolVisibility(clonedOp,
- SymbolTable::Visibility::Private);
+ clonedOp.setPrivate();
}
});
return success();
@@ -123,7 +122,7 @@
rewriter.setInsertionPoint(funcOp);
auto rodataOp =
rewriter.create<IREE::VM::RodataOp>(loc, safeIdentifier, utf8Bytes);
- SymbolTable::setSymbolVisibility(rodataOp, SymbolTable::Visibility::Private);
+ rodataOp.setPrivate();
rewriter.restoreInsertionPoint(insertPoint);
// Load the UTF-8 bytes to pass as a value.
diff --git a/iree/compiler/Dialect/VM/IR/BUILD b/iree/compiler/Dialect/VM/IR/BUILD
index fcf853d..78488ef 100644
--- a/iree/compiler/Dialect/VM/IR/BUILD
+++ b/iree/compiler/Dialect/VM/IR/BUILD
@@ -98,7 +98,7 @@
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
],
)
@@ -116,7 +116,7 @@
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
],
)
@@ -149,6 +149,6 @@
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
],
)
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BUILD b/iree/compiler/Dialect/VM/Target/Bytecode/BUILD
index 8a1f4db..c525f57 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/BUILD
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/BUILD
@@ -25,8 +25,8 @@
"//iree/compiler/Dialect/VM/Analysis",
"//iree/compiler/Dialect/VM/IR",
"//iree/compiler/Dialect/VM/Transforms",
- "//iree/schemas:bytecode_module_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/compiler/Utils",
+ "//iree/schemas:bytecode_module_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
index ecfee52f..d63d70f 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
@@ -16,8 +16,6 @@
#include <algorithm>
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/minireflect.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
@@ -28,7 +26,9 @@
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
-#include "iree/schemas/bytecode_module_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/bytecode_module_def_builder.h"
+#include "iree/schemas/bytecode_module_def_json_printer.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
@@ -48,10 +48,6 @@
namespace {
-using flatbuffers::FlatBufferBuilder;
-using flatbuffers::Offset;
-using flatbuffers::Vector;
-
struct ModuleCounts {
int importFuncs = 0;
int exportFuncs = 0;
@@ -204,20 +200,6 @@
return success();
}
-// Returns a vector of tables of type T or None if |contents| is empty.
-template <typename T>
-static Optional<Offset<Vector<Offset<T>>>> createOptionalVector(
- const std::vector<Offset<T>> &contents, FlatBufferBuilder &fbb) {
- if (contents.empty()) return llvm::None;
- return fbb.CreateVector(contents);
-}
-template <typename T>
-static Optional<Offset<Vector<T>>> createOptionalVector(
- const std::vector<T> &contents, FlatBufferBuilder &fbb) {
- if (contents.empty()) return llvm::None;
- return fbb.CreateVector(contents);
-}
-
// Encodes a type (or a tuple of nested types) to a calling convention string.
//
// Examples:
@@ -331,89 +313,103 @@
return std::string(s.data(), s.size());
}
-// Populates common fields for FunctionSignatureDefs of all function types.
-static void populateFunctionSignatureDef(FunctionType functionType,
- llvm::DenseMap<Type, int> &typeTable,
- iree::vm::FunctionSignatureDefT &fsd) {
- for (auto type : functionType.getInputs()) {
- if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) {
- type = refPtrType.getObjectType();
- }
- fsd.argument_types.push_back(typeTable.lookup(type));
+// Creates a FunctionSignatureDef based on the given function metadata.
+// Some fields are not used on all signature defs and added only when present on
+// the argument objects/attrs.
+static iree_vm_FunctionSignatureDef_ref_t createFunctionSignatureDef(
+ FunctionType functionType, llvm::DenseMap<Type, int> &typeTable,
+ StringRef callingConvention,
+ iree_vm_ReflectionAttrDef_vec_ref_t reflectionAttrsRef,
+ FlatbufferBuilder &fbb) {
+ auto resultTypesRef = fbb.createInt32Vec(
+ llvm::map_range(functionType.getResults(), [&](Type type) {
+ if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) {
+ type = refPtrType.getObjectType();
+ }
+ return typeTable.lookup(type);
+ }));
+ auto argumentTypesRef = fbb.createInt32Vec(
+ llvm::map_range(functionType.getInputs(), [&](Type type) {
+ if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) {
+ type = refPtrType.getObjectType();
+ }
+ return typeTable.lookup(type);
+ }));
+
+ auto callingConventionRef = fbb.createString(callingConvention);
+
+ // If the signature would be empty then let's avoid writing the empty table.
+ if (!argumentTypesRef && !resultTypesRef && !callingConventionRef &&
+ !reflectionAttrsRef) {
+ return 0;
}
- for (auto type : functionType.getResults()) {
- if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) {
- type = refPtrType.getObjectType();
- }
- fsd.result_types.push_back(typeTable.lookup(type));
- }
+
+ iree_vm_FunctionSignatureDef_start(fbb);
+ iree_vm_FunctionSignatureDef_argument_types_add(fbb, argumentTypesRef);
+ iree_vm_FunctionSignatureDef_result_types_add(fbb, resultTypesRef);
+ iree_vm_FunctionSignatureDef_calling_convention_add(fbb,
+ callingConventionRef);
+ iree_vm_FunctionSignatureDef_reflection_attrs_add(fbb, reflectionAttrsRef);
+ return iree_vm_FunctionSignatureDef_end(fbb);
}
// Returns a serialized function signature.
-static Offset<iree::vm::FunctionSignatureDef> makeImportFunctionSignatureDef(
+static iree_vm_FunctionSignatureDef_ref_t makeImportFunctionSignatureDef(
IREE::VM::ImportOp importOp, llvm::DenseMap<Type, int> &typeTable,
- FlatBufferBuilder &fbb) {
- // Common attributes.
- iree::vm::FunctionSignatureDefT fsd;
- populateFunctionSignatureDef(importOp.getType(), typeTable, fsd);
-
+ FlatbufferBuilder &fbb) {
// Generate the signature calling convention string based on types.
auto cconv = makeImportCallingConventionString(importOp);
if (!cconv.hasValue()) return {};
- fsd.calling_convention = cconv.getValue();
-
- return iree::vm::FunctionSignatureDef::Pack(fbb, &fsd);
+ return createFunctionSignatureDef(importOp.getType(), typeTable,
+ cconv.getValue(), /*reflectionAttrsRef=*/0,
+ fbb);
}
// Returns a serialized function signature.
-static Offset<iree::vm::FunctionSignatureDef> makeExportFunctionSignatureDef(
+static iree_vm_FunctionSignatureDef_ref_t makeExportFunctionSignatureDef(
IREE::VM::ExportOp exportOp, IREE::VM::FuncOp funcOp,
- llvm::DenseMap<Type, int> &typeTable, FlatBufferBuilder &fbb) {
- // Common attributes.
- iree::vm::FunctionSignatureDefT fsd;
- populateFunctionSignatureDef(funcOp.getType(), typeTable, fsd);
-
+ llvm::DenseMap<Type, int> &typeTable, FlatbufferBuilder &fbb) {
// Generate the signature calling convention string based on types.
auto cconv = makeCallingConventionString(funcOp);
if (!cconv.hasValue()) return {};
- fsd.calling_convention = cconv.getValue();
-
- return iree::vm::FunctionSignatureDef::Pack(fbb, &fsd);
+ return createFunctionSignatureDef(funcOp.getType(), typeTable,
+ cconv.getValue(), /*reflectionAttrsRef=*/0,
+ fbb);
}
// Returns a serialized function signature.
-static Offset<iree::vm::FunctionSignatureDef> makeInternalFunctionSignatureDef(
+static iree_vm_FunctionSignatureDef_ref_t makeInternalFunctionSignatureDef(
IREE::VM::FuncOp funcOp, llvm::DenseMap<Type, int> &typeTable,
- FlatBufferBuilder &fbb) {
- // Common attributes.
- iree::vm::FunctionSignatureDefT fsd;
- populateFunctionSignatureDef(funcOp.getType(), typeTable, fsd);
-
+ FlatbufferBuilder &fbb) {
// Generate the signature calling convention string based on types.
// TODO(benvanik): only do this on exports. The runtime currently looks on
// internal functions, though, so we have to have it here.
auto cconv = makeCallingConventionString(funcOp);
if (!cconv.hasValue()) return {};
- fsd.calling_convention = cconv.getValue();
// Reflection attributes.
// TODO(benvanik): move these to exports (or remove entirely).
+ iree_vm_ReflectionAttrDef_vec_ref_t reflectionAttrsRef = 0;
if (auto reflectionAttrs =
funcOp.getAttrOfType<DictionaryAttr>("iree.reflection")) {
- llvm::SmallVector<Offset<iree::vm::ReflectionAttrDef>, 4>
- reflectionAttrItems;
+ SmallVector<iree_vm_ReflectionAttrDef_ref_t, 4> reflectionAttrRefs;
for (auto reflectionAttr : reflectionAttrs) {
auto key = reflectionAttr.first.strref();
auto value = reflectionAttr.second.dyn_cast<StringAttr>();
if (!value || key.empty()) continue;
- auto rattr = std::make_unique<iree::vm::ReflectionAttrDefT>();
- rattr->key = key.str();
- rattr->value = value.getValue().str();
- fsd.reflection_attrs.push_back(std::move(rattr));
+ // NOTE: if we actually want to keep these we should dedupe them (as the
+ // keys and likely several of the values are shared across all functions).
+ auto valueRef = fbb.createString(value.getValue());
+ auto keyRef = fbb.createString(key);
+ reflectionAttrRefs.push_back(
+ iree_vm_ReflectionAttrDef_create(fbb, keyRef, valueRef));
}
+ reflectionAttrsRef = iree_vm_ReflectionAttrDef_vec_create(
+ fbb, reflectionAttrRefs.data(), reflectionAttrRefs.size());
}
- return iree::vm::FunctionSignatureDef::Pack(fbb, &fsd);
+ return createFunctionSignatureDef(funcOp.getType(), typeTable,
+ cconv.getValue(), reflectionAttrsRef, fbb);
}
// Builds a complete BytecodeModuleDef FlatBuffer object in |fbb|.
@@ -426,9 +422,9 @@
// has been packed into the top-level table. This results in a messier function
// here during serialization but a much more trivial (and cache-friendly)
// representation at runtime.
-static Offset<iree::vm::BytecodeModuleDef> buildFlatBufferModule(
- BytecodeTargetOptions targetOptions, IREE::VM::ModuleOp moduleOp,
- FlatBufferBuilder &fbb) {
+static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions,
+ IREE::VM::ModuleOp moduleOp,
+ FlatbufferBuilder &fbb) {
SymbolTable symbolTable(moduleOp);
auto symbolCounts = computeModuleSymbolCounts(moduleOp);
@@ -456,17 +452,24 @@
// Serialize read-only data first so that it ends up at the end of the file.
// This is where large things like parameters live and we don't want that to
// get paged in until it is needed.
- std::vector<Offset<Vector<uint8_t>>> rodataContentOffsets;
- rodataContentOffsets.reserve(rodataOps.size());
- for (auto rodataOp : rodataOps) {
- auto dataOffset =
+ //
+ // NOTE: flatbuffers are built bottom-up; after each rodata we serialize we
+ // move *backward* in the file and prepend the next, meaning that if we
+ // were to serialize all rodata we'd have it in the opposite order as we do
+ // in the IR. Though this it isn't required for correctness, enabling file
+ // layout planning by preserving the order in the IR is useful.
+ SmallVector<flatbuffers_uint8_vec_ref_t, 8> rodataContentRefs;
+ rodataContentRefs.reserve(rodataOps.size());
+ for (auto rodataOp : llvm::reverse(rodataOps)) {
+ auto rodataRef =
serializeConstant(rodataOp.getLoc(), rodataOp.value(), fbb);
- if (dataOffset.IsNull()) {
- rodataOp.emitOpError() << "failed to encode";
- return {};
+ if (!rodataRef) {
+ return rodataOp.emitOpError() << "failed to encode";
}
- rodataContentOffsets.push_back(dataOffset);
+ rodataContentRefs.push_back(rodataRef);
}
+ // List of references needs to be swapped forward (we wrote backward).
+ std::reverse(rodataContentRefs.begin(), rodataContentRefs.end());
// Find all types in the module to build the type table.
// Note that we don't emit it yet as we want to keep it near the top of the
@@ -478,28 +481,32 @@
}
// Serialize function bytecode one at a time and then merge at the end.
- std::vector<std::vector<uint8_t>> bytecodeDataParts;
- std::vector<iree::vm::FunctionDescriptor> functionDescriptors;
- bytecodeDataParts.reserve(internalFuncOps.size());
- functionDescriptors.reserve(internalFuncOps.size());
+ SmallVector<std::vector<uint8_t>, 8> bytecodeDataParts;
+ SmallVector<iree_vm_FunctionDescriptor_t, 8> functionDescriptors;
+ bytecodeDataParts.resize(internalFuncOps.size());
+ functionDescriptors.resize(internalFuncOps.size());
size_t totalBytecodeLength = 0;
- for (auto funcOp : internalFuncOps) {
- auto encodedFunction =
- BytecodeEncoder::encodeFunction(funcOp, typeOrdinalMap, symbolTable);
+ for (auto funcOp : llvm::enumerate(internalFuncOps)) {
+ auto encodedFunction = BytecodeEncoder::encodeFunction(
+ funcOp.value(), typeOrdinalMap, symbolTable);
if (!encodedFunction) {
- funcOp.emitError() << "failed to encode function bytecode";
- return {};
+ return funcOp.value().emitError() << "failed to encode function bytecode";
}
- functionDescriptors.push_back(iree::vm::FunctionDescriptor(
- totalBytecodeLength, encodedFunction->bytecodeData.size(),
- encodedFunction->i32RegisterCount, encodedFunction->refRegisterCount));
+ iree_vm_FunctionDescriptor_assign(
+ &functionDescriptors[funcOp.index()], totalBytecodeLength,
+ encodedFunction->bytecodeData.size(), encodedFunction->i32RegisterCount,
+ encodedFunction->refRegisterCount);
totalBytecodeLength += encodedFunction->bytecodeData.size();
- bytecodeDataParts.push_back(std::move(encodedFunction->bytecodeData));
+ bytecodeDataParts[funcOp.index()] =
+ std::move(encodedFunction->bytecodeData);
}
- // TODO(benvanik): compression? deduping?
- uint8_t *bytecodeDataPtr = nullptr;
- auto bytecodeDataOffset = fbb.CreateUninitializedVector<uint8_t>(
- totalBytecodeLength, &bytecodeDataPtr);
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytecodeDataPtr =
+ flatbuffers_uint8_vec_extend(fbb, totalBytecodeLength);
+ // NOTE: we need to ensure we clear the output data in case we have gaps for
+ // alignment (where otherwise uninitialized memory might sneak in and be bad
+ // for both security and determinism).
+ memset(bytecodeDataPtr, 0, totalBytecodeLength);
size_t currentBytecodeOffset = 0;
for (const auto &it : llvm::enumerate(internalFuncOps)) {
int ordinal = it.index();
@@ -508,102 +515,109 @@
data.size());
currentBytecodeOffset += data.size();
}
+ auto bytecodeDataRef = flatbuffers_uint8_vec_end(fbb);
+
+ // Encode the function descriptors adjacent to the bytcode data; they are
+ // always accessed together. Descriptor 0 is likely within a few hundred bytes
+ // of the referenced bytecode data offset 0, and from there we are at least
+ // able to hope sequential readahead caching helps; if not, at least we
+ // hopefully don't fault on the first function call every time.
+ auto functionDescriptorsRef = iree_vm_FunctionDescriptor_vec_create(
+ fbb, functionDescriptors.data(), functionDescriptors.size());
// Serialize metadata that should be near the front of the file.
- std::vector<Offset<iree::vm::RodataSegmentDef>> rodataSegmentOffsets;
- rodataSegmentOffsets.reserve(rodataOps.size());
- for (auto rodataContentOffset : rodataContentOffsets) {
- iree::vm::RodataSegmentDefBuilder rsd(fbb);
- rsd.add_data(rodataContentOffset);
- rodataSegmentOffsets.push_back(rsd.Finish());
- }
- std::vector<Offset<iree::vm::RwdataSegmentDef>> rwdataSegmentOffsets;
- std::vector<Offset<iree::vm::TypeDef>> typeOffsets;
- typeOffsets.reserve(typeTable.size());
- for (auto &typeDef : typeTable) {
- auto nameOffset = fbb.CreateString(typeDef.full_name);
- iree::vm::TypeDefBuilder tdb(fbb);
- tdb.add_full_name(nameOffset);
- typeOffsets.push_back(tdb.Finish());
- }
- std::vector<Offset<iree::vm::ImportFunctionDef>> importFuncOffsets;
- importFuncOffsets.reserve(importFuncOps.size());
- for (auto importOp : importFuncOps) {
- auto nameOffset = fbb.CreateString(importOp.getName().str());
- auto signatureOffset =
- makeImportFunctionSignatureDef(importOp, typeOrdinalMap, fbb);
- iree::vm::ImportFunctionDefBuilder ifd(fbb);
- ifd.add_full_name(nameOffset);
- ifd.add_signature(signatureOffset);
- importFuncOffsets.push_back(ifd.Finish());
- }
- std::vector<Offset<iree::vm::ExportFunctionDef>> exportFuncOffsets;
- exportFuncOffsets.reserve(exportFuncOps.size());
- for (auto exportOp : exportFuncOps) {
- auto nameOffset = fbb.CreateString(exportOp.export_name().str());
- auto funcOp = symbolTable.lookup<IREE::VM::FuncOp>(exportOp.function_ref());
- auto signatureOffset =
- makeExportFunctionSignatureDef(exportOp, funcOp, typeOrdinalMap, fbb);
- iree::vm::ExportFunctionDefBuilder efd(fbb);
- efd.add_local_name(nameOffset);
- efd.add_signature(signatureOffset);
- efd.add_internal_ordinal(funcOp.ordinal().getValue().getLimitedValue());
- exportFuncOffsets.push_back(efd.Finish());
- }
- std::vector<Offset<iree::vm::InternalFunctionDef>> internalFuncOffsets;
+ auto rodataSegmentRefs = llvm::to_vector<8>(
+ llvm::map_range(rodataContentRefs, [&](auto rodataContentRef) {
+ iree_vm_RodataSegmentDef_start(fbb);
+ iree_vm_RodataSegmentDef_data_add(fbb, rodataContentRef);
+ return iree_vm_RodataSegmentDef_end(fbb);
+ }));
+ SmallVector<iree_vm_RwdataSegmentDef_ref_t, 8> rwdataSegmentRefs;
+ // NOTE: rwdata current unused.
+ auto typeRefs =
+ llvm::to_vector<8>(llvm::map_range(typeTable, [&](auto typeDef) {
+ auto fullNameRef = fbb.createString(typeDef.full_name);
+ iree_vm_TypeDef_start(fbb);
+ iree_vm_TypeDef_full_name_add(fbb, fullNameRef);
+ return iree_vm_TypeDef_end(fbb);
+ }));
+ auto importFuncRefs =
+ llvm::to_vector<8>(llvm::map_range(importFuncOps, [&](auto importOp) {
+ auto fullNameRef = fbb.createString(importOp.getName());
+ auto signatureRef =
+ makeImportFunctionSignatureDef(importOp, typeOrdinalMap, fbb);
+ iree_vm_ImportFunctionDef_start(fbb);
+ iree_vm_ImportFunctionDef_full_name_add(fbb, fullNameRef);
+ iree_vm_ImportFunctionDef_signature_add(fbb, signatureRef);
+ return iree_vm_ImportFunctionDef_end(fbb);
+ }));
+ auto exportFuncRefs =
+ llvm::to_vector<8>(llvm::map_range(exportFuncOps, [&](auto exportOp) {
+ auto localNameRef = fbb.createString(exportOp.export_name());
+ auto funcOp =
+ symbolTable.lookup<IREE::VM::FuncOp>(exportOp.function_ref());
+ auto signatureRef = makeExportFunctionSignatureDef(exportOp, funcOp,
+ typeOrdinalMap, fbb);
+ iree_vm_ExportFunctionDef_start(fbb);
+ iree_vm_ExportFunctionDef_local_name_add(fbb, localNameRef);
+ iree_vm_ExportFunctionDef_signature_add(fbb, signatureRef);
+ iree_vm_ExportFunctionDef_internal_ordinal_add(
+ fbb, funcOp.ordinal().getValue().getLimitedValue());
+ return iree_vm_ExportFunctionDef_end(fbb);
+ }));
+ SmallVector<iree_vm_InternalFunctionDef_ref_t, 8> internalFuncRefs;
if (!targetOptions.stripSymbols) {
- internalFuncOffsets.reserve(internalFuncOps.size());
+ internalFuncRefs.reserve(internalFuncOps.size());
for (auto funcOp : internalFuncOps) {
- auto nameOffset = fbb.CreateString(funcOp.getName().str());
- auto signatureOffset =
+ auto localNameRef = fbb.createString(funcOp.getName());
+ auto signatureRef =
makeInternalFunctionSignatureDef(funcOp, typeOrdinalMap, fbb);
- iree::vm::InternalFunctionDefBuilder ifd(fbb);
- ifd.add_local_name(nameOffset);
- ifd.add_signature(signatureOffset);
- internalFuncOffsets.push_back(ifd.Finish());
+ iree_vm_InternalFunctionDef_start(fbb);
+ iree_vm_InternalFunctionDef_local_name_add(fbb, localNameRef);
+ iree_vm_InternalFunctionDef_signature_add(fbb, signatureRef);
+ internalFuncRefs.push_back(iree_vm_InternalFunctionDef_end(fbb));
}
}
- auto functionDescriptorsOffset =
- fbb.CreateVectorOfStructs(functionDescriptors);
- auto rodataSegmentsOffset = createOptionalVector(rodataSegmentOffsets, fbb);
- auto rwdataSegmentsOffset = createOptionalVector(rwdataSegmentOffsets, fbb);
- auto internalFuncsOffset = fbb.CreateVector(internalFuncOffsets);
- auto exportFuncsOffset = fbb.CreateVector(exportFuncOffsets);
- auto importFuncsOffset = createOptionalVector(importFuncOffsets, fbb);
- auto typesOffset = fbb.CreateVector(typeOffsets);
+ // NOTE: we keep the vectors clustered here so that we can hopefully keep the
+ // pages mapped at runtime; vector dereferences in flatbuffers require
+ // touching these structs to get length/etc and as such we don't want to be
+ // gathering from all over the file (with giant rodata chunks and such
+ // inbetween) just to perform a bounds check and deference into another part
+ // of the file.
+ auto rodataSegmentsRef = fbb.createOffsetVecDestructive(rodataSegmentRefs);
+ auto rwdataSegmentsRef = fbb.createOffsetVecDestructive(rwdataSegmentRefs);
+ auto internalFuncsRef = fbb.createOffsetVecDestructive(internalFuncRefs);
+ auto exportFuncsOffset = fbb.createOffsetVecDestructive(exportFuncRefs);
+ auto importFuncsRef = fbb.createOffsetVecDestructive(importFuncRefs);
+ auto typesRef = fbb.createOffsetVecDestructive(typeRefs);
- Optional<Offset<iree::vm::ModuleStateDef>> moduleStateDef;
+ iree_vm_ModuleStateDef_ref_t moduleStateDef = 0;
if (symbolCounts.globalBytes || symbolCounts.globalRefs) {
- iree::vm::ModuleStateDefBuilder msd(fbb);
- msd.add_global_bytes_capacity(symbolCounts.globalBytes);
- msd.add_global_ref_count(symbolCounts.globalRefs);
- moduleStateDef = msd.Finish();
+ iree_vm_ModuleStateDef_start(fbb);
+ iree_vm_ModuleStateDef_global_bytes_capacity_add(fbb,
+ symbolCounts.globalBytes);
+ iree_vm_ModuleStateDef_global_ref_count_add(fbb, symbolCounts.globalRefs);
+ moduleStateDef = iree_vm_ModuleStateDef_end(fbb);
}
- auto nameOffset = fbb.CreateString(
- moduleOp.sym_name().empty() ? "module" : moduleOp.sym_name().str());
+ auto moduleNameRef = fbb.createString(
+ moduleOp.sym_name().empty() ? "module" : moduleOp.sym_name());
- iree::vm::BytecodeModuleDefBuilder bmd(fbb);
- bmd.add_name(nameOffset);
- bmd.add_types(typesOffset);
- if (importFuncsOffset) {
- bmd.add_imported_functions(importFuncsOffset.getValue());
- }
- bmd.add_exported_functions(exportFuncsOffset);
- bmd.add_internal_functions(internalFuncsOffset);
- if (moduleStateDef) {
- bmd.add_module_state(moduleStateDef.getValue());
- }
- if (rwdataSegmentsOffset) {
- bmd.add_rwdata_segments(rwdataSegmentsOffset.getValue());
- }
- if (rodataSegmentsOffset) {
- bmd.add_rodata_segments(rodataSegmentsOffset.getValue());
- }
- bmd.add_function_descriptors(functionDescriptorsOffset);
- bmd.add_bytecode_data(bytecodeDataOffset);
- return bmd.Finish();
+ iree_vm_BytecodeModuleDef_start_as_root(fbb);
+ iree_vm_BytecodeModuleDef_name_add(fbb, moduleNameRef);
+ iree_vm_BytecodeModuleDef_types_add(fbb, typesRef);
+ iree_vm_BytecodeModuleDef_imported_functions_add(fbb, importFuncsRef);
+ iree_vm_BytecodeModuleDef_exported_functions_add(fbb, exportFuncsOffset);
+ iree_vm_BytecodeModuleDef_internal_functions_add(fbb, internalFuncsRef);
+ iree_vm_BytecodeModuleDef_module_state_add(fbb, moduleStateDef);
+ iree_vm_BytecodeModuleDef_rodata_segments_add(fbb, rodataSegmentsRef);
+ iree_vm_BytecodeModuleDef_rwdata_segments_add(fbb, rwdataSegmentsRef);
+ iree_vm_BytecodeModuleDef_function_descriptors_add(fbb,
+ functionDescriptorsRef);
+ iree_vm_BytecodeModuleDef_bytecode_data_add(fbb, bytecodeDataRef);
+ iree_vm_BytecodeModuleDef_end_as_root(fbb);
+ return success();
}
LogicalResult translateModuleToBytecode(IREE::VM::ModuleOp moduleOp,
@@ -639,28 +653,30 @@
// the module header in memory. This ensures that when we map the file only
// the first few pages need to be accessed to get the metadata and the rest
// can be large bulk data.
- FlatBufferBuilder fbb;
- auto moduleDef = buildFlatBufferModule(targetOptions, moduleOp, fbb);
- if (moduleDef.IsNull()) {
+ FlatbufferBuilder fbb;
+ if (failed(buildFlatBufferModule(targetOptions, moduleOp, fbb))) {
return moduleOp.emitError()
<< "failed to build FlatBuffer BytecodeModuleDef";
}
- iree::vm::FinishBytecodeModuleDefBuffer(fbb, moduleDef);
- const uint8_t *flatbufferBytes = fbb.GetBufferPointer();
- size_t flatbufferByteSize = fbb.GetSize();
-
switch (targetOptions.outputFormat) {
case BytecodeOutputFormat::kFlatBufferBinary:
- output.write(reinterpret_cast<const char *>(flatbufferBytes),
- flatbufferByteSize);
+ if (failed(fbb.copyToStream(output))) {
+ return moduleOp.emitError()
+ << "failed to copy flatbuffer emitter contents to output stream "
+ "- possibly out of memory";
+ }
break;
case BytecodeOutputFormat::kFlatBufferText: {
- flatbuffers::ToStringVisitor toStringVisitor("\n", false, " ", false);
- flatbuffers::IterateFlatBuffer(flatbufferBytes,
- iree::vm::BytecodeModuleDefTypeTable(),
- &toStringVisitor);
- output << toStringVisitor.s << "\n";
+ if (failed(fbb.printJsonToStream(/*pretty=*/true,
+ /*includeDefaults=*/false,
+ bytecode_module_def_print_json,
+ output))) {
+ return moduleOp.emitError()
+ << "failed to print flatbuffer emitter contents to output "
+ "stream - possibly out of memory, possibly unprintable "
+ "structure";
+ }
break;
}
default:
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt
index 73aeb14..24a13c5 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt
@@ -35,12 +35,12 @@
MLIRSupport
MLIRTransforms
MLIRTranslation
- flatbuffers
iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::IREE::Transforms
iree::compiler::Dialect::VM::Analysis
iree::compiler::Dialect::VM::IR
iree::compiler::Dialect::VM::Transforms
- iree::schemas::bytecode_module_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::bytecode_module_def_c_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
index 180535a..64e1079 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h"
-#include "flatbuffers/flatbuffers.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/StandardTypes.h"
@@ -24,24 +23,16 @@
namespace IREE {
namespace VM {
-namespace {
-
-using flatbuffers::FlatBufferBuilder;
-using flatbuffers::Offset;
-using flatbuffers::Vector;
-
-} // namespace
-
// TODO(benvanik): switch to LLVM's BinaryStreamWriter to handle endianness.
-static Offset<Vector<uint8_t>> serializeConstantI8Array(
- DenseIntElementsAttr attr, FlatBufferBuilder &fbb) {
+static flatbuffers_uint8_vec_ref_t serializeConstantI8Array(
+ DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
// vm.rodata and other very large constants end up as this; since i8 is i8
// everywhere (endianness doesn't matter when you have one byte :) we can
// directly access the data and memcpy.
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 1, &bytePtr);
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr =
+ flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(int8_t));
if (attr.isSplat()) {
// NOTE: this is a slow path and we should have eliminated it earlier on
// during constant op conversion.
@@ -52,72 +43,72 @@
auto rawData = attr.getRawData();
std::memcpy(bytePtr, rawData.data(), rawData.size());
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantI16Array(
- DenseIntElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 2, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantI16Array(
+ DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
+ fbb, attr.getNumElements() * sizeof(int16_t));
uint16_t *nativePtr = reinterpret_cast<uint16_t *>(bytePtr);
for (const APInt &value : attr.getIntValues()) {
*(nativePtr++) = value.extractBitsAsZExtValue(16, 0) & UINT16_MAX;
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantI32Array(
- DenseIntElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 4, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantI32Array(
+ DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
+ fbb, attr.getNumElements() * sizeof(int32_t));
uint32_t *nativePtr = reinterpret_cast<uint32_t *>(bytePtr);
for (const APInt &value : attr.getIntValues()) {
*(nativePtr++) = value.extractBitsAsZExtValue(32, 0) & UINT32_MAX;
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantI64Array(
- DenseIntElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 8, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantI64Array(
+ DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
+ fbb, attr.getNumElements() * sizeof(int64_t));
uint64_t *nativePtr = reinterpret_cast<uint64_t *>(bytePtr);
for (const APInt &value : attr.getIntValues()) {
*(nativePtr++) = value.extractBitsAsZExtValue(64, 0) & UINT64_MAX;
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantF32Array(
- DenseFPElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 4, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantF32Array(
+ DenseFPElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr =
+ flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(float));
float *nativePtr = reinterpret_cast<float *>(bytePtr);
for (const APFloat &value : attr.getFloatValues()) {
*(nativePtr++) = value.convertToFloat();
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantF64Array(
- DenseFPElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 8, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantF64Array(
+ DenseFPElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr =
+ flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(double));
double *nativePtr = reinterpret_cast<double *>(bytePtr);
for (const APFloat &value : attr.getFloatValues()) {
*(nativePtr++) = value.convertToDouble();
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-Offset<Vector<uint8_t>> serializeConstant(Location loc,
- ElementsAttr elementsAttr,
- FlatBufferBuilder &fbb) {
+flatbuffers_uint8_vec_ref_t serializeConstant(Location loc,
+ ElementsAttr elementsAttr,
+ FlatbufferBuilder &fbb) {
if (auto attr = elementsAttr.dyn_cast<DenseIntElementsAttr>()) {
switch (attr.getType().getElementTypeBitWidth()) {
case 8:
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h
index 04d62a3..56471a6 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h
@@ -15,7 +15,8 @@
#ifndef IREE_COMPILER_DIALECT_VM_TARGET_BYTECODE_CONSTANTENCODER_H_
#define IREE_COMPILER_DIALECT_VM_TARGET_BYTECODE_CONSTANTENCODER_H_
-#include "flatbuffers/flatbuffers.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/bytecode_module_def_builder.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Location.h"
@@ -25,9 +26,9 @@
namespace VM {
// Serializes a constant attribute to the FlatBuffer as a binary blob.
-flatbuffers::Offset<flatbuffers::Vector<uint8_t>> serializeConstant(
- Location loc, ElementsAttr elementsAttr,
- flatbuffers::FlatBufferBuilder &fbb);
+flatbuffers_uint8_vec_ref_t serializeConstant(Location loc,
+ ElementsAttr elementsAttr,
+ FlatbufferBuilder &fbb);
} // namespace VM
} // namespace IREE
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir b/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir
index 4389809..fc4d1a2 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir
@@ -1,20 +1,50 @@
// RUN: iree-translate -split-input-file -iree-vm-ir-to-bytecode-module -iree-vm-bytecode-module-output-format=flatbuffer-text %s | IreeFileCheck %s
-// CHECK: name: "constants"
+// CHECK: "name": "constants"
vm.module @constants {
vm.export @func
vm.func @func() {
vm.return
}
- // CHECK: rodata_segments: [ {
+ // CHECK: "rodata_segments": [{
- // CHECK: data: [ 1, 2, 3 ]
+ // CHECK: "data": [
+ // CHECK-NEXT: 1,
+ // CHECK-NEXT: 2,
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: ]
vm.rodata @dense_i8s dense<[1, 2, 3]> : tensor<3xi8>
- // CHECK: data: [ 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64 ]
+ // CHECK: "data": [
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 128,
+ // CHECK-NEXT: 63,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 64,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 64,
+ // CHECK-NEXT: 64
+ // CHECK-NEXT: ]
vm.rodata @dense_float32s dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>
- // CHECK: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
+ // CHECK: "data": [
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 128,
+ // CHECK-NEXT: 63,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 128,
+ // CHECK-NEXT: 63,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 128,
+ // CHECK-NEXT: 63
+ // CHECK-NEXT: ]
vm.rodata @splat_float32s dense<1.000000e+00> : tensor<3xf32>
}
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir b/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir
index 18caab9..6c5e100 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir
@@ -1,24 +1,32 @@
// RUN: iree-translate -split-input-file -iree-vm-ir-to-bytecode-module -iree-vm-bytecode-module-output-format=flatbuffer-text %s | IreeFileCheck %s
-// CHECK: name: "simple_module"
+// CHECK: "name": "simple_module"
vm.module @simple_module {
- // CHECK: types: [ {
- // CHECK: full_name: "i32"
+ // CHECK: "types": [{
+ // CHECK: "full_name": "i32"
- // CHECK: exported_functions:
- // CHECK: local_name: "func"
+ // CHECK: "exported_functions":
+ // CHECK: "local_name": "func"
vm.export @func
- // CHECK: internal_functions:
- // CHECK: local_name: "func"
+ // CHECK: "internal_functions":
+ // CHECK: "local_name": "func"
vm.func @func(%arg0 : i32) -> i32 {
vm.return %arg0 : i32
}
- // CHECK: function_descriptors:
- // CHECK-NEXT: bytecode_offset: 0
- // CHECK-NEXT: bytecode_length: 5
- // CHECK-NEXT: i32_register_count: 1
- // CHECK-NEXT: ref_register_count: 0
- // CHECK: bytecode_data: [ 84, 1, 0, 0, 0 ]
+ // CHECK: "function_descriptors":
+ // CHECK-NEXT: {
+ // CHECK-NEXT: "bytecode_offset": 0
+ // CHECK-NEXT: "bytecode_length": 5
+ // CHECK-NEXT: "i32_register_count": 1
+ // CHECK-NEXT: "ref_register_count": 0
+ // CHECK-NEXT: }
+ // CHECK: "bytecode_data": [
+ // CHECK-NEXT: 84,
+ // CHECK-NEXT: 1,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: ]
}
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir b/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
index 30c0d8b..0c69043 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
@@ -3,12 +3,12 @@
// CHECK-LABEL: simple_module
vm.module @simple_module {
vm.export @func
- // CHECK: internal_functions:
- // CHECK: reflection_attrs:
- // CHECK: key: "f"
- // CHECK: value: "FOOBAR"
- // CHECK: key: "fv"
- // CHECK: value: "INFINITY"
+ // CHECK: "internal_functions":
+ // CHECK: "reflection_attrs":
+ // CHECK: "key": "f"
+ // CHECK: "value": "FOOBAR"
+ // CHECK: "key": "fv"
+ // CHECK: "value": "INFINITY"
vm.func @func(%arg0 : i32) -> i32
attributes { iree.reflection = { f = "FOOBAR", fv = "INFINITY" } }
{
diff --git a/iree/compiler/Dialect/VM/Tools/BUILD b/iree/compiler/Dialect/VM/Tools/BUILD
index b31b9f9..1b0c741 100644
--- a/iree/compiler/Dialect/VM/Tools/BUILD
+++ b/iree/compiler/Dialect/VM/Tools/BUILD
@@ -18,17 +18,10 @@
licenses = ["notice"], # Apache 2.0
)
-cc_library(
- name = "Tools",
+filegroup(
+ name = "GenSrcs",
srcs = [
"VMOpEncoderGen.cpp",
"VMOpTableGen.cpp",
],
- deps = [
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:TableGen",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TableGen",
- ],
- alwayslink = 1,
)
diff --git a/iree/compiler/Dialect/VM/Tools/CMakeLists.txt b/iree/compiler/Dialect/VM/Tools/CMakeLists.txt
index d13afd5..8b864e5 100644
--- a/iree/compiler/Dialect/VM/Tools/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Tools/CMakeLists.txt
@@ -13,18 +13,3 @@
# limitations under the License.
iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- Tools
- SRCS
- "VMOpEncoderGen.cpp"
- "VMOpTableGen.cpp"
- DEPS
- LLVMSupport
- LLVMTableGen
- MLIRSupport
- MLIRTableGen
- ALWAYSLINK
- PUBLIC
-)
diff --git a/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp
index 1c8204f..3261758 100644
--- a/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp
@@ -60,8 +60,7 @@
(funcOp.getName() + "_const").str(),
inlineOp.value());
moduleSymbolTable.insert(rodataOp, moduleBuilder.getInsertionPoint());
- SymbolTable::setSymbolVisibility(rodataOp,
- SymbolTable::Visibility::Private);
+ rodataOp.setPrivate();
replaceInlineOpWithRodataRef(inlineOp, rodataOp);
}
}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp
index f5bc7c1..3640212 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp
@@ -21,6 +21,7 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -31,7 +32,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 5b842a5..c3d2a91 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -24,6 +24,8 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "llvm/ADT/STLExtras.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -34,8 +36,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
@@ -843,6 +843,8 @@
typeConverter, context);
patterns.insert<VMLAOpConversion<mhlo::LogOp, IREE::VMLA::LogOp>>(
typeConverter, context);
+ patterns.insert<VMLAOpConversion<mhlo::CeilOp, IREE::VMLA::CeilOp>>(
+ typeConverter, context);
patterns.insert<VMLAOpConversion<mhlo::FloorOp, IREE::VMLA::FloorOp>>(
typeConverter, context);
patterns.insert<VMLAOpConversion<mhlo::RoundOp, IREE::VMLA::RoundOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
index 1a2ddd3..60fb6ff 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
@@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -29,7 +30,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
@@ -186,6 +186,16 @@
srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
rewriter.getI32IntegerAttr(dimension), dst, dstShape,
TypeAttr::get(elementType));
+ } else if (isa<mhlo::AndOp>(computeOp)) {
+ rewriter.create<IREE::VMLA::ReduceAndOp>(
+ srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
+ rewriter.getI32IntegerAttr(dimension), dst, dstShape,
+ TypeAttr::get(elementType));
+ } else if (isa<mhlo::OrOp>(computeOp)) {
+ rewriter.create<IREE::VMLA::ReduceOrOp>(
+ srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
+ rewriter.getI32IntegerAttr(dimension), dst, dstShape,
+ TypeAttr::get(elementType));
} else {
computeOp.emitRemark() << "unsupported builtin reduction operation";
return failure();
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index a2facbc..b2adbd9 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -359,6 +359,8 @@
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceSumOp, "vmla.reduce.sum");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMinOp, "vmla.reduce.min");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMaxOp, "vmla.reduce.max");
+ VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceAndOp, "vmla.reduce.and");
+ VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceOrOp, "vmla.reduce.or");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingSumOp, "vmla.pooling.sum");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingMinOp, "vmla.pooling.min");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index fe6185c..ae878d8 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -601,6 +601,8 @@
def VMLA_ReduceSumOp : VMLA_ReduceOp<"reduce.sum">;
def VMLA_ReduceMinOp : VMLA_ReduceOp<"reduce.min">;
def VMLA_ReduceMaxOp : VMLA_ReduceOp<"reduce.max">;
+def VMLA_ReduceAndOp : VMLA_ReduceOp<"reduce.and">;
+def VMLA_ReduceOrOp : VMLA_ReduceOp<"reduce.or">;
class VMLA_PoolingOp<string mnemonic, list<OpTrait> traits = []> :
VMLA_ElementTypeOp<mnemonic, !listconcat(traits, [VMLA_IncludeShapes])> {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
index 5c5ce94..45d36f7 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
@@ -25,13 +25,13 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
index d29b463..8061fca 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
@@ -18,9 +18,9 @@
#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
index fef7f26..3097f96 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -21,6 +21,8 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -34,8 +36,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp b/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp
index 289ce23..2305342 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index 2302fe1..5277b53 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -475,6 +475,20 @@
%dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
)
+vm.import @reduce.and.i8(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+ %dimension : i32,
+ %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
+)
+
+vm.import @reduce.or.i8(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
+ %dimension : i32,
+ %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
+)
+
vm.import @pooling.sum.i8(
%src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
%init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
diff --git a/iree/compiler/Translation/test/smoketest.mlir b/iree/compiler/Translation/test/smoketest.mlir
index 69a8d97..3b19456 100644
--- a/iree/compiler/Translation/test/smoketest.mlir
+++ b/iree/compiler/Translation/test/smoketest.mlir
@@ -1,30 +1,36 @@
// RUN: iree-translate -split-input-file -iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module -iree-vm-bytecode-module-output-format=flatbuffer-text %s | IreeFileCheck %s
-// CHECK-LABEL: name: "simple_module"
+// CHECK-LABEL: "name": "simple_module"
module @simple_module {
-// CHECK: exported_functions:
-// CHECK: local_name: "func"
+// CHECK: "exported_functions":
+// CHECK: "local_name": "func"
-// CHECK: internal_functions:
-// CHECK: local_name: "func"
+// CHECK: "internal_functions":
+// CHECK: "local_name": "func"
func @func(%arg0 : i32) -> i32 attributes { iree.module.export } {
return %arg0 : i32
}
-// CHECK: function_descriptors:
-// CHECK-NEXT: bytecode_offset: 0
-// CHECK-NEXT: bytecode_length: 5
-// CHECK-NEXT: i32_register_count: 1
-// CHECK-NEXT: ref_register_count: 0
-// CHECK: bytecode_data: [ 84, 1, 0, 0,
+// CHECK: "function_descriptors":
+// CHECK-NEXT: {
+// CHECK-NEXT: "bytecode_offset": 0
+// CHECK-NEXT: "bytecode_length": 5
+// CHECK-NEXT: "i32_register_count": 1
+// CHECK-NEXT: "ref_register_count": 0
+// CHECK-NEXT: }
+// CHECK: "bytecode_data": [
+// CHECK-NEXT: 84,
+// CHECK-NEXT: 1,
+// CHECK-NEXT: 0,
+// CHECK-NEXT: 0,
}
// -----
-// CHECK-LABEL: name: "do_not_optimize_module"
+// CHECK-LABEL: "name": "do_not_optimize_module"
module @do_not_optimize_module {
-// CHECK: exported_functions:
-// CHECK: local_name: "add"
+// CHECK: "exported_functions":
+// CHECK: "local_name": "add"
func @add() -> i32 attributes { iree.module.export } {
%c1 = constant 1 : i32
%unf_c1 = iree.do_not_optimize(%c1) : i32
@@ -36,12 +42,12 @@
// -----
-// CHECK-LABEL: name: "hal_usage"
+// CHECK-LABEL: "name": "hal_usage"
module @hal_usage {
-// CHECK: imported_functions:
-// CHECK: full_name: "hal.command_buffer.dispatch"
-// CHECK: exported_functions:
-// CHECK: local_name: "hloElementwiseOps"
+// CHECK: "imported_functions":
+// CHECK: "full_name": "hal.command_buffer.dispatch"
+// CHECK: "exported_functions":
+// CHECK: "local_name": "hloElementwiseOps"
func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes {iree.module.export} {
%0 = mhlo.add %arg0, %arg0 : tensor<4xf32>
%1 = mhlo.subtract %0, %arg0 : tensor<4xf32>
diff --git a/iree/compiler/Utils/BUILD b/iree/compiler/Utils/BUILD
index ee02d94..3f6aa07 100644
--- a/iree/compiler/Utils/BUILD
+++ b/iree/compiler/Utils/BUILD
@@ -23,13 +23,16 @@
cc_library(
name = "Utils",
srcs = [
+ "FlatbufferUtils.cpp",
"GraphUtils.cpp",
],
hdrs = [
+ "FlatbufferUtils.h",
"GraphUtils.h",
"PatternUtils.h",
],
deps = [
+ "//iree/base:flatcc",
"//iree/compiler/Dialect/IREE/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Utils/CMakeLists.txt b/iree/compiler/Utils/CMakeLists.txt
index 6e7286e..6db54f5 100644
--- a/iree/compiler/Utils/CMakeLists.txt
+++ b/iree/compiler/Utils/CMakeLists.txt
@@ -18,9 +18,11 @@
NAME
Utils
HDRS
+ "FlatbufferUtils.h"
"GraphUtils.h"
"PatternUtils.h"
SRCS
+ "FlatbufferUtils.cpp"
"GraphUtils.cpp"
DEPS
LLVMSupport
@@ -30,6 +32,7 @@
MLIRSupport
MLIRTransformUtils
MLIRTransforms
+ iree::base::flatcc
iree::compiler::Dialect::IREE::IR
tensorflow::mlir_hlo
PUBLIC
diff --git a/iree/compiler/Utils/FlatbufferUtils.cpp b/iree/compiler/Utils/FlatbufferUtils.cpp
new file mode 100644
index 0000000..487a053
--- /dev/null
+++ b/iree/compiler/Utils/FlatbufferUtils.cpp
@@ -0,0 +1,127 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+
+#include <vector>
+
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Combines all pages of the flatbuffer builder into a single contiguous byte
+// buffer and returns the result.
+//
+// NOTE: this is a alloc/copy. We need to have a single contiguous buffer to
+// pass into the elements factory function and the data we have in the
+// builder is paged. If we end up with a custom attribute type for this that
+// does not support storage uniquing then we can directly allocate and copy
+// the pages into the buffer without the extra copy.
+static SmallVector<uint8_t, 32> cloneBufferIntoContiguousBytes(
+ FlatbufferBuilder &fbb) {
+ size_t packedSize = flatcc_builder_get_buffer_size(fbb);
+ SmallVector<uint8_t, 32> packedData(packedSize);
+ void *result =
+ flatcc_builder_copy_buffer(fbb, packedData.data(), packedData.size());
+ assert(result && "flatcc_emitter_t impl failed (non-default?)");
+ return packedData;
+}
+
+FlatbufferBuilder::FlatbufferBuilder() { flatcc_builder_init(&builder); }
+
+FlatbufferBuilder::~FlatbufferBuilder() { flatcc_builder_clear(&builder); }
+
+flatbuffers_uint8_vec_ref_t FlatbufferBuilder::streamUint8Vec(
+ std::function<bool(raw_ostream &stream)> fn) {
+ flatbuffers_uint8_vec_start(*this);
+ raw_flatbuffer_uint8_vec_ostream stream(*this);
+ if (!fn(stream)) {
+ return 0;
+ }
+ stream.flush();
+ return flatbuffers_uint8_vec_end(*this);
+}
+
+DenseIntElementsAttr FlatbufferBuilder::getBufferAttr(MLIRContext *context) {
+ // We require direct access to the flatbuffer bytes so we can pass them to
+ // the attribute constructor (which needs to inspect them all for uniquing).
+ auto bufferData = cloneBufferIntoContiguousBytes(*this);
+
+ // NOTE: ew. OpaqueAttr may be better? It does equality checks but won't try
+ // to unique and would let us get a mutable buffer out.
+ return DenseIntElementsAttr::get(
+ VectorType::get({static_cast<int64_t>(bufferData.size())},
+ IntegerType::get(8, context)),
+ std::move(bufferData));
+}
+
+LogicalResult FlatbufferBuilder::copyToStream(llvm::raw_ostream &output) {
+ // NOTE: expected to be the default emitter.
+ auto *E = reinterpret_cast<flatcc_emitter_t *>(
+ flatcc_builder_get_emit_context(*this));
+
+ if (!E->front) {
+ return failure();
+ }
+ if (E->front == E->back) {
+ output.write(reinterpret_cast<char *>(E->front_cursor), E->used);
+ return success();
+ }
+ size_t len = FLATCC_EMITTER_PAGE_SIZE - E->front_left;
+ output.write(reinterpret_cast<char *>(E->front_cursor), len);
+ flatcc_emitter_page_t *p = E->front->next;
+ while (p != E->back) {
+ output.write(reinterpret_cast<char *>(p->page), FLATCC_EMITTER_PAGE_SIZE);
+ p = p->next;
+ }
+ output.write(reinterpret_cast<char *>(p->page),
+ FLATCC_EMITTER_PAGE_SIZE - E->back_left);
+ return success();
+}
+
+LogicalResult FlatbufferBuilder::printJsonToStream(
+ bool pretty, bool includeDefaults, print_json_fn_t print_json_fn,
+ llvm::raw_ostream &output) {
+ // The printer requires direct access to the flatbuffer bytes so clone here.
+ auto bufferData = cloneBufferIntoContiguousBytes(*this);
+
+ flatcc_json_printer_t printer;
+ flatcc_json_printer_init_dynamic_buffer(&printer, /*buffer_size=*/0);
+ flatcc_json_printer_set_indent(&printer, pretty ? 2 : 0);
+ flatcc_json_printer_set_skip_default(&printer, !includeDefaults);
+ flatcc_json_printer_set_force_default(&printer, includeDefaults);
+
+ // Print into the dynamically-resizing buffer. May fail if OOM.
+ int rv =
+ print_json_fn(&printer, reinterpret_cast<const char *>(bufferData.data()),
+ bufferData.size());
+ if (rv == -1) {
+ flatcc_json_printer_clear(&printer);
+ return failure();
+ }
+
+ // Take the buffer from the printer; note that it is 0 terminated and can be
+ // used directly as a cstr if needed.
+ size_t outputSize = 0;
+ char *outputBytes = reinterpret_cast<char *>(
+ flatcc_json_printer_finalize_dynamic_buffer(&printer, &outputSize));
+ output.write(outputBytes, outputSize);
+ free(outputBytes);
+
+ return success();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Utils/FlatbufferUtils.h b/iree/compiler/Utils/FlatbufferUtils.h
new file mode 100644
index 0000000..301634d
--- /dev/null
+++ b/iree/compiler/Utils/FlatbufferUtils.h
@@ -0,0 +1,184 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_UTILS_FLATBUFFERUTILS_H_
+#define IREE_COMPILER_UTILS_FLATBUFFERUTILS_H_
+
+#include <functional>
+
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+
+// NOTE: order matters here as some of the LLVM includes conflict.
+#include "iree/base/flatcc.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// RAII wrapper for flatcc_builder_t; pass to functions requiring a builder.
+//
+// Usage:
+// FlatbufferBuilder builder;
+// // NOTE: flatbuffers are built bottoms-up so we first generate our [uint8]:
+// auto dataRef = builder.streamUint8Vec(...);
+// // ... and then start the table that references it:
+// my_type_start_as_root(builder);
+// my_type_uint8_vec_field_add(builder, dataRef);
+// my_type_end_as_root(builder);
+// // ... and finally capture the results as an mlir::Attribute.
+// auto attr = builder.getBufferAttr(mlirContext);
+class FlatbufferBuilder {
+ public:
+ FlatbufferBuilder();
+ ~FlatbufferBuilder();
+
+ operator flatcc_builder_t *() { return &builder; }
+
+ // Creates a string with the given string contents (including zeros).
+ flatbuffers_string_ref_t createString(StringRef value) {
+ if (value.empty()) return 0;
+ return flatbuffers_string_create(*this, value.data(), value.size());
+ }
+
+ // Creates a string vector containing all strings in the given range.
+ template <typename RangeTy>
+ flatbuffers_string_vec_ref_t createStringVec(RangeTy &&Range) {
+ auto stringRefs =
+ llvm::to_vector<8>(llvm::map_range(Range, [&](StringRef value) {
+ return flatbuffers_string_create(*this, value.data(), value.size());
+ }));
+ if (stringRefs.empty()) return 0;
+ return flatbuffers_string_vec_create(*this, stringRefs.data(),
+ stringRefs.size());
+ }
+
+ // Creates an offset vector with the given values. The source values will not
+ // be modified.
+ flatbuffers_vec_ref_t createOffsetVec(ArrayRef<flatcc_builder_ref_t> values) {
+ if (values.empty()) return 0;
+ return flatcc_builder_create_offset_vector(*this, values.data(),
+ values.size());
+ }
+
+ // Creates an offset vector with the given values.
+ // Unlike createOffsetVec this will destroy the input values array during
+ // serialization but be much faster.
+ flatbuffers_vec_ref_t createOffsetVecDestructive(
+ SmallVectorImpl<flatcc_builder_ref_t> &values) {
+ if (values.empty()) return 0;
+ return flatcc_builder_create_offset_vector_direct(*this, values.data(),
+ values.size());
+ }
+
+ // Creates an [int32] vec with the contents of the given range.
+ template <typename RangeTy>
+ flatbuffers_int32_vec_ref_t createInt32Vec(RangeTy &&Range) {
+ if (llvm::empty(Range)) return 0;
+ flatbuffers_int32_vec_start(*this);
+ for (int32_t v : Range) {
+ flatbuffers_int32_vec_push_create(*this, v);
+ }
+ return flatbuffers_int32_vec_end(*this);
+ }
+
+ // Provides a raw_ostream that |fn| can use to directly stream into a [uint8]
+ // in the flatbuffer builder.
+ //
+ // Usage:
+ // auto ref = builder.streamUint8Vec([&](llvm::raw_ostream &stream) {
+ // stream << "foo";
+ // return true;
+ // });
+ // ...
+ // my_type_uint8_vec_field_add(builder, ref); // use vec reference
+ // ...
+ flatbuffers_uint8_vec_ref_t streamUint8Vec(
+ std::function<bool(raw_ostream &stream)> fn);
+
+ // Captures the current contents of the flatbuffer builder and returns them
+ // as a shaped `vector<SIZExi8>` dense attr. The builder is left unmodified.
+ DenseIntElementsAttr getBufferAttr(MLIRContext *context);
+
+ // Copies the current contents of the flatbuffer builder to the target output
+ // stream. The builder is left unmodified.
+ //
+ // This is reduces a significant large allocation that can happen when trying
+ // to stitch together all of the pages that were allocated in the emitter as
+ // the flatbuffer was constructed; here we can just walk over each page and
+ // write it out in order without any allocations.
+ LogicalResult copyToStream(llvm::raw_ostream &output);
+
+ using print_json_fn_t = int (*)(flatcc_json_printer_t *ctx, const char *buf,
+ size_t bufsiz);
+
+ // Prints the flatbuffer in its canonical JSON format to the given stream.
+ // The builder is left unmodified.
+ //
+ // |pretty| enables newlines and indentation; somewhat useful for lit testing
+ // (as large byte buffers end up with a byte per line!).
+ //
+ // |includeDefaults| will force all values, including those that would not
+ // be serialized to the binary format due to the default value (0, etc) being
+ // omitted.
+ //
+ // NOTE: JSON representations will also differ structurally from the binary
+ // format as reused tables are printed wherever they are used as opposed to
+ // referencing the same bytes; meaning that this can't be used to verify that
+ // we are correctly memoizing strings/structures/etc.
+ LogicalResult printJsonToStream(bool pretty, bool includeDefaults,
+ print_json_fn_t print_json_fn,
+ llvm::raw_ostream &output);
+
+ private:
+ flatcc_builder_t builder;
+};
+
+// Allows streaming bytes directly into a flatbuffer `[uint8]` field.
+// The ostream runs in buffered mode and routes all writes into pages
+// allocated by the flatbuffer builder as we grow the output.
+//
+// Usage:
+// flatbuffers_uint8_vec_start(builder);
+// raw_flatbuffer_uint8_vec_ostream stream(builder);
+// stream << "foo";
+// stream.flush(); // *********** IMPORTANT ***********
+// flatbuffers_uint8_vec_ref_t ref = flatbuffers_uint8_vec_end(builder);
+class raw_flatbuffer_uint8_vec_ostream : public llvm::raw_ostream {
+ public:
+ explicit raw_flatbuffer_uint8_vec_ostream(flatcc_builder_t *builder)
+ : raw_ostream(/*unbuffered=*/true), builder(builder) {}
+
+ ~raw_flatbuffer_uint8_vec_ostream() override { flush(); }
+
+ private:
+ void write_impl(const char *Ptr, size_t Size) override {
+ flatbuffers_uint8_vec_append(builder,
+ reinterpret_cast<const uint8_t *>(Ptr), Size);
+ }
+
+ uint64_t current_pos() const override {
+ return tell() - GetNumBytesInBuffer();
+ }
+
+ flatcc_builder_t *builder;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_UTILS_FLATBUFFERUTILS_H_
diff --git a/iree/hal/BUILD b/iree/hal/BUILD
index 154a8df..fbe7237 100644
--- a/iree/hal/BUILD
+++ b/iree/hal/BUILD
@@ -22,18 +22,10 @@
licenses = ["notice"], # Apache 2.0
)
-cc_library(
- name = "allocator",
- srcs = ["allocator.cc"],
- hdrs = ["allocator.h"],
- deps = [
- ":buffer",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:tracing",
- "@com_google_absl//absl/types:span",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Public API
+#===------------------------------------------------------------------------===#
+# TODO(benvanik): rename to :hal
cc_library(
name = "api",
@@ -44,30 +36,18 @@
],
visibility = ["//visibility:public"],
deps = [
- ":api_hdrs",
- ":buffer",
- ":command_buffer",
- ":device",
- ":driver",
":driver_registry",
+ ":hal",
":heap_buffer",
- ":semaphore",
"//iree/base:api",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:ref_ptr",
"//iree/base:tracing",
"//iree/hal/host:host_local_allocator",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "api_hdrs",
- hdrs = ["api.h"],
- deps = [
- "//iree/base:api_hdrs",
+ "@half//:includes",
],
)
@@ -76,7 +56,7 @@
srcs = ["api_string_util_test.cc"],
deps = [
":api",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:status",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
@@ -85,18 +65,52 @@
],
)
+#===------------------------------------------------------------------------===#
+# Implementation
+#===------------------------------------------------------------------------===#
+# TODO(benvanik): rename to :cc and expose via an api_cc.h.
+
cc_library(
- name = "buffer",
- srcs = ["buffer.cc"],
- hdrs = ["buffer.h"],
+ name = "hal",
+ srcs = [
+ "allocator.cc",
+ "buffer.cc",
+ "command_buffer.cc",
+ "deferred_buffer.cc",
+ "executable_cache.cc",
+ ],
+ hdrs = [
+ "allocator.h",
+ "buffer.h",
+ "command_buffer.h",
+ "command_queue.h",
+ "debug_capture_manager.h",
+ "deferred_buffer.h",
+ "descriptor_set.h",
+ "descriptor_set_layout.h",
+ "device.h",
+ "device_info.h",
+ "device_placement.h",
+ "driver.h",
+ "event.h",
+ "executable.h",
+ "executable_cache.h",
+ "executable_format.h",
+ "executable_layout.h",
+ "executable_spec.h",
+ "resource.h",
+ "semaphore.h",
+ "stack_trace.h",
+ ],
deps = [
- ":resource",
- "//iree/base:bitfield",
+ "//iree/base:core_headers",
"//iree/base:logging",
+ "//iree/base:ref_ptr",
"//iree/base:status",
+ "//iree/base:time",
+ "//iree/base:tracing",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
- "@com_google_absl//absl/types:variant",
],
)
@@ -107,7 +121,7 @@
"buffer_test.cc",
],
deps = [
- ":buffer",
+ ":hal",
":heap_buffer",
"//iree/base:status",
"//iree/testing:gtest",
@@ -116,72 +130,11 @@
],
)
-cc_library(
- name = "command_buffer",
- srcs = ["command_buffer.cc"],
- hdrs = ["command_buffer.h"],
- deps = [
- ":buffer",
- ":descriptor_set",
- ":event",
- ":executable",
- ":executable_layout",
- ":resource",
- "//iree/base:bitfield",
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "command_buffer_validation",
- srcs = ["command_buffer_validation.cc"],
- hdrs = ["command_buffer_validation.h"],
- deps = [
- ":allocator",
- ":command_buffer",
- "//iree/base:logging",
- "//iree/base:status",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "command_queue",
- hdrs = ["command_queue.h"],
- deps = [
- ":command_buffer",
- ":semaphore",
- "//iree/base:bitfield",
- "//iree/base:status",
- "//iree/base:time",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "debug_capture_manager",
- hdrs = ["debug_capture_manager.h"],
- deps = [
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "deferred_buffer",
- srcs = ["deferred_buffer.cc"],
- hdrs = ["deferred_buffer.h"],
- deps = [
- ":allocator",
- ":buffer",
- "//iree/base:status",
- ],
-)
-
cc_test(
name = "deferred_buffer_test",
srcs = ["deferred_buffer_test.cc"],
deps = [
- ":deferred_buffer",
+ ":hal",
":heap_buffer",
"//iree/hal/testing:mock_allocator",
"//iree/testing:gtest",
@@ -190,68 +143,34 @@
],
)
-cc_library(
- name = "descriptor_set",
- hdrs = ["descriptor_set.h"],
- deps = [
- ":buffer",
- ":resource",
- "@com_google_absl//absl/strings",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Debugging utilities and tools
+#===------------------------------------------------------------------------===#
cc_library(
- name = "descriptor_set_layout",
- hdrs = ["descriptor_set_layout.h"],
+ name = "command_buffer_validation",
+ srcs = ["command_buffer_validation.cc"],
+ hdrs = ["command_buffer_validation.h"],
deps = [
- ":buffer",
- ":resource",
- ],
-)
-
-cc_library(
- name = "device",
- hdrs = ["device.h"],
- deps = [
- ":allocator",
- ":buffer",
- ":command_queue",
- ":descriptor_set",
- ":descriptor_set_layout",
- ":device_info",
- ":event",
- ":executable_cache",
- ":executable_layout",
- ":semaphore",
- "//iree/base:ref_ptr",
+ ":hal",
+ "//iree/base:logging",
"//iree/base:status",
- "//iree/base:target_platform",
- "//iree/base:time",
- ],
-)
-
-cc_library(
- name = "device_info",
- hdrs = ["device_info.h"],
- deps = [
- "//iree/base:bitfield",
"@com_google_absl//absl/strings",
],
)
+#===------------------------------------------------------------------------===#
+# Internal device management and driver registry
+#===------------------------------------------------------------------------===#
+# TODO(benvanik): port these to C and merge into main API.
+
cc_library(
name = "device_manager",
srcs = ["device_manager.cc"],
hdrs = ["device_manager.h"],
deps = [
- ":allocator",
- ":buffer",
- ":command_queue",
- ":device",
- ":device_placement",
- ":executable_format",
+ ":hal",
":heap_buffer",
- ":semaphore",
"//iree/base:status",
"//iree/base:time",
"//iree/base:tracing",
@@ -261,29 +180,11 @@
)
cc_library(
- name = "device_placement",
- hdrs = ["device_placement.h"],
-)
-
-cc_library(
- name = "driver",
- hdrs = ["driver.h"],
- deps = [
- ":debug_capture_manager",
- ":device",
- ":device_info",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- ],
-)
-
-cc_library(
name = "driver_registry",
srcs = ["driver_registry.cc"],
hdrs = ["driver_registry.h"],
deps = [
- ":driver",
- "//iree/base:initializer",
+ ":hal",
"//iree/base:ref_ptr",
"//iree/base:status",
"@com_google_absl//absl/base:core_headers",
@@ -291,87 +192,18 @@
],
)
-cc_library(
- name = "event",
- hdrs = ["event.h"],
- deps = [
- ":resource",
- ],
-)
-
-cc_library(
- name = "executable",
- hdrs = ["executable.h"],
- deps = [":resource"],
-)
-
-cc_library(
- name = "executable_cache",
- srcs = ["executable_cache.cc"],
- hdrs = ["executable_cache.h"],
- deps = [
- ":executable",
- ":executable_format",
- ":executable_layout",
- ":executable_spec",
- "//iree/base:bitfield",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "executable_format",
- hdrs = ["executable_format.h"],
-)
-
-cc_library(
- name = "executable_layout",
- hdrs = ["executable_layout.h"],
- deps = [":resource"],
-)
-
-cc_library(
- name = "executable_spec",
- hdrs = ["executable_spec.h"],
- deps = [
- ":executable_format",
- "@com_google_absl//absl/types:span",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Internal implementation details
+#===------------------------------------------------------------------------===#
cc_library(
name = "heap_buffer",
srcs = ["heap_buffer.cc"],
hdrs = ["heap_buffer.h"],
deps = [
- ":allocator",
- ":buffer",
+ ":hal",
"//iree/base:status",
"//iree/base:tracing",
"//iree/hal/host:host_buffer",
],
)
-
-cc_library(
- name = "resource",
- hdrs = ["resource.h"],
- deps = [
- "//iree/base:ref_ptr",
- ],
-)
-
-cc_library(
- name = "semaphore",
- hdrs = ["semaphore.h"],
- deps = [
- ":resource",
- "//iree/base:status",
- "//iree/base:time",
- ],
-)
-
-cc_library(
- name = "stack_trace",
- hdrs = ["stack_trace.h"],
-)
diff --git a/iree/hal/CMakeLists.txt b/iree/hal/CMakeLists.txt
index 3804321..9d88801 100644
--- a/iree/hal/CMakeLists.txt
+++ b/iree/hal/CMakeLists.txt
@@ -16,22 +16,6 @@
iree_cc_library(
NAME
- allocator
- HDRS
- "allocator.h"
- SRCS
- "allocator.cc"
- DEPS
- ::buffer
- absl::span
- iree::base::ref_ptr
- iree::base::status
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
api
HDRS
"api.h"
@@ -39,35 +23,20 @@
SRCS
"api.cc"
DEPS
- ::api_hdrs
- ::buffer
- ::command_buffer
- ::device
- ::driver
::driver_registry
+ ::hal
::heap_buffer
- ::semaphore
absl::inlined_vector
absl::span
absl::strings
iree::base::api
- iree::base::memory
+ iree::base::core_headers
iree::base::ref_ptr
iree::base::tracing
iree::hal::host::host_local_allocator
PUBLIC
)
-iree_cc_library(
- NAME
- api_hdrs
- HDRS
- "api.h"
- DEPS
- iree::base::api_hdrs
- PUBLIC
-)
-
iree_cc_test(
NAME
api_string_util_test
@@ -77,7 +46,7 @@
::api
absl::inlined_vector
absl::strings
- iree::base::memory
+ iree::base::core_headers
iree::base::status
iree::testing::gtest
iree::testing::gtest_main
@@ -85,19 +54,44 @@
iree_cc_library(
NAME
- buffer
+ hal
HDRS
+ "allocator.h"
"buffer.h"
+ "command_buffer.h"
+ "command_queue.h"
+ "debug_capture_manager.h"
+ "deferred_buffer.h"
+ "descriptor_set.h"
+ "descriptor_set_layout.h"
+ "device.h"
+ "device_info.h"
+ "device_placement.h"
+ "driver.h"
+ "event.h"
+ "executable.h"
+ "executable_cache.h"
+ "executable_format.h"
+ "executable_layout.h"
+ "executable_spec.h"
+ "resource.h"
+ "semaphore.h"
+ "stack_trace.h"
SRCS
+ "allocator.cc"
"buffer.cc"
+ "command_buffer.cc"
+ "deferred_buffer.cc"
+ "executable_cache.cc"
DEPS
- ::resource
absl::span
absl::strings
- absl::variant
- iree::base::bitfield
+ iree::base::core_headers
iree::base::logging
+ iree::base::ref_ptr
iree::base::status
+ iree::base::time
+ iree::base::tracing
PUBLIC
)
@@ -108,7 +102,7 @@
"buffer_mapping_test.cc"
"buffer_test.cc"
DEPS
- ::buffer
+ ::hal
::heap_buffer
absl::span
iree::base::status
@@ -116,23 +110,18 @@
iree::testing::gtest_main
)
-iree_cc_library(
+iree_cc_test(
NAME
- command_buffer
- HDRS
- "command_buffer.h"
+ deferred_buffer_test
SRCS
- "command_buffer.cc"
+ "deferred_buffer_test.cc"
DEPS
- ::buffer
- ::descriptor_set
- ::event
- ::executable
- ::executable_layout
- ::resource
- iree::base::bitfield
- iree::base::status
- PUBLIC
+ ::hal
+ ::heap_buffer
+ absl::memory
+ iree::hal::testing::mock_allocator
+ iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_library(
@@ -143,8 +132,7 @@
SRCS
"command_buffer_validation.cc"
DEPS
- ::allocator
- ::command_buffer
+ ::hal
absl::strings
iree::base::logging
iree::base::status
@@ -153,130 +141,14 @@
iree_cc_library(
NAME
- command_queue
- HDRS
- "command_queue.h"
- DEPS
- ::command_buffer
- ::semaphore
- absl::span
- iree::base::bitfield
- iree::base::status
- iree::base::time
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- debug_capture_manager
- HDRS
- "debug_capture_manager.h"
- DEPS
- iree::base::status
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- deferred_buffer
- HDRS
- "deferred_buffer.h"
- SRCS
- "deferred_buffer.cc"
- DEPS
- ::allocator
- ::buffer
- iree::base::status
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- deferred_buffer_test
- SRCS
- "deferred_buffer_test.cc"
- DEPS
- ::deferred_buffer
- ::heap_buffer
- absl::memory
- iree::hal::testing::mock_allocator
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- descriptor_set
- HDRS
- "descriptor_set.h"
- DEPS
- ::buffer
- ::resource
- absl::strings
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- descriptor_set_layout
- HDRS
- "descriptor_set_layout.h"
- DEPS
- ::buffer
- ::resource
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- device
- HDRS
- "device.h"
- DEPS
- ::allocator
- ::buffer
- ::command_queue
- ::descriptor_set
- ::descriptor_set_layout
- ::device_info
- ::event
- ::executable_cache
- ::executable_layout
- ::semaphore
- iree::base::ref_ptr
- iree::base::status
- iree::base::target_platform
- iree::base::time
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- device_info
- HDRS
- "device_info.h"
- DEPS
- absl::strings
- iree::base::bitfield
- PUBLIC
-)
-
-iree_cc_library(
- NAME
device_manager
HDRS
"device_manager.h"
SRCS
"device_manager.cc"
DEPS
- ::allocator
- ::buffer
- ::command_queue
- ::device
- ::device_placement
- ::executable_format
+ ::hal
::heap_buffer
- ::semaphore
absl::span
absl::synchronization
iree::base::status
@@ -287,38 +159,15 @@
iree_cc_library(
NAME
- device_placement
- HDRS
- "device_placement.h"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- driver
- HDRS
- "driver.h"
- DEPS
- ::debug_capture_manager
- ::device
- ::device_info
- iree::base::ref_ptr
- iree::base::status
- PUBLIC
-)
-
-iree_cc_library(
- NAME
driver_registry
HDRS
"driver_registry.h"
SRCS
"driver_registry.cc"
DEPS
- ::driver
+ ::hal
absl::core_headers
absl::synchronization
- iree::base::initializer
iree::base::ref_ptr
iree::base::status
PUBLIC
@@ -326,113 +175,15 @@
iree_cc_library(
NAME
- event
- HDRS
- "event.h"
- DEPS
- ::resource
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable
- HDRS
- "executable.h"
- DEPS
- ::resource
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable_cache
- HDRS
- "executable_cache.h"
- SRCS
- "executable_cache.cc"
- DEPS
- ::executable
- ::executable_format
- ::executable_layout
- ::executable_spec
- iree::base::bitfield
- iree::base::ref_ptr
- iree::base::status
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable_format
- HDRS
- "executable_format.h"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable_layout
- HDRS
- "executable_layout.h"
- DEPS
- ::resource
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable_spec
- HDRS
- "executable_spec.h"
- DEPS
- ::executable_format
- absl::span
- PUBLIC
-)
-
-iree_cc_library(
- NAME
heap_buffer
HDRS
"heap_buffer.h"
SRCS
"heap_buffer.cc"
DEPS
- ::allocator
- ::buffer
+ ::hal
iree::base::status
iree::base::tracing
iree::hal::host::host_buffer
PUBLIC
)
-
-iree_cc_library(
- NAME
- resource
- HDRS
- "resource.h"
- DEPS
- iree::base::ref_ptr
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- semaphore
- HDRS
- "semaphore.h"
- DEPS
- ::resource
- iree::base::status
- iree::base::time
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- stack_trace
- HDRS
- "stack_trace.h"
- PUBLIC
-)
diff --git a/iree/hal/api.cc b/iree/hal/api.cc
index 1adc1ff..956ac7b 100644
--- a/iree/hal/api.cc
+++ b/iree/hal/api.cc
@@ -39,6 +39,7 @@
#include "iree/hal/heap_buffer.h"
#include "iree/hal/host/host_local_allocator.h"
#include "iree/hal/semaphore.h"
+#include "third_party/half/half.hpp"
namespace iree {
namespace hal {
@@ -294,9 +295,16 @@
reinterpret_cast<uint64_t*>(out_data))
? iree_ok_status()
: iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
- case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "float16 parsing not implemented");
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_16: {
+ float temp = 0;
+ if (!absl::SimpleAtof(absl::string_view(data_str.data, data_str.size),
+ &temp)) {
+ return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT);
+ }
+ *reinterpret_cast<uint16_t*>(out_data) =
+ half_float::detail::float2half<std::round_to_nearest>(temp);
+ return iree_ok_status();
+ }
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
return absl::SimpleAtof(absl::string_view(data_str.data, data_str.size),
reinterpret_cast<float*>(out_data))
@@ -407,8 +415,10 @@
*reinterpret_cast<const uint64_t*>(data.data));
break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "parser for float16 not yet implemented");
+ n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
+ half_float::detail::half2float<float>(
+ *reinterpret_cast<const uint16_t*>(data.data)));
+ break;
case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G",
*reinterpret_cast<const float*>(data.data));
@@ -419,7 +429,7 @@
break;
default: {
// Treat any unknown format as binary.
- n = 2 * element_size;
+ n = 2 * (int)element_size;
if (buffer && buffer_capacity > n) {
iree_hal_bytes_to_hex_string(data.data, buffer, element_size);
buffer[n] = 0;
@@ -870,6 +880,19 @@
return handle->WriteData(target_offset, source_buffer, data_length);
}
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_copy_data(
+ iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+ iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+ iree_device_size_t data_length) {
+ IREE_TRACE_SCOPE0("iree_hal_buffer_copy_data");
+ IREE_ASSERT_ARGUMENT(source_buffer);
+ IREE_ASSERT_ARGUMENT(target_buffer);
+ auto* handle = reinterpret_cast<Buffer*>(target_buffer);
+ return handle->CopyData(target_offset,
+ reinterpret_cast<Buffer*>(source_buffer),
+ source_offset, data_length);
+}
+
IREE_API_EXPORT iree_status_t iree_hal_buffer_map(
iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access,
iree_device_size_t byte_offset, iree_device_size_t byte_length,
@@ -1684,19 +1707,19 @@
// We need to allocate storage to marshal in the semaphores. Ideally we'd
// change the C++ API to make this 1:1 with a reinterpret_cast, however that
// makes the C API more difficult. Bleh.
- int total_semaphore_count = 0;
- for (int i = 0; i < batch_count; ++i) {
+ iree_host_size_t total_semaphore_count = 0;
+ for (iree_host_size_t i = 0; i < batch_count; ++i) {
total_semaphore_count += batches[i].wait_semaphores.count;
total_semaphore_count += batches[i].signal_semaphores.count;
}
absl::InlinedVector<SemaphoreValue, 4> semaphore_values(
total_semaphore_count);
absl::InlinedVector<SubmissionBatch, 2> dst_batches(batch_count);
- int base_semaphore_index = 0;
- for (int i = 0; i < batch_count; ++i) {
+ iree_host_size_t base_semaphore_index = 0;
+ for (iree_host_size_t i = 0; i < batch_count; ++i) {
const auto& src_batch = batches[i];
auto& dst_batch = dst_batches[i];
- for (int j = 0; j < src_batch.wait_semaphores.count; ++j) {
+ for (iree_host_size_t j = 0; j < src_batch.wait_semaphores.count; ++j) {
semaphore_values[base_semaphore_index + j] = {
reinterpret_cast<Semaphore*>(src_batch.wait_semaphores.semaphores[j]),
src_batch.wait_semaphores.payload_values[j]};
@@ -1708,7 +1731,7 @@
dst_batch.command_buffers =
iree::ReinterpretSpan<CommandBuffer*>(absl::MakeConstSpan(
src_batch.command_buffers, src_batch.command_buffer_count));
- for (int j = 0; j < src_batch.signal_semaphores.count; ++j) {
+ for (iree_host_size_t j = 0; j < src_batch.signal_semaphores.count; ++j) {
semaphore_values[base_semaphore_index + j] = {
reinterpret_cast<Semaphore*>(
src_batch.signal_semaphores.semaphores[j]),
@@ -1878,6 +1901,11 @@
}
*out_driver_count = available_drivers.size();
+ *out_driver_names = NULL;
+ if (available_drivers.empty()) {
+ return iree_ok_status();
+ }
+
iree_string_view_t* driver_name_storage = nullptr;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
allocator,
diff --git a/iree/hal/api.h b/iree/hal/api.h
index fa1c11f..4e0295f 100644
--- a/iree/hal/api.h
+++ b/iree/hal/api.h
@@ -598,7 +598,7 @@
// with real device allocators and will likely incur a copy if used.
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_allocator_create_host_local(iree_allocator_t allocator,
- iree_hal_allocator** out_allocator);
+ iree_hal_allocator_t** out_allocator);
// Retains the given |allocator| for the caller.
IREE_API_EXPORT void IREE_API_CALL
@@ -717,6 +717,12 @@
iree_hal_buffer_t* buffer, iree_device_size_t target_offset,
const void* source_buffer, iree_device_size_t data_length);
+// Copies data from the provided |source_buffer| into the |target_buffer|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_copy_data(
+ iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
+ iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
+ iree_device_size_t data_length);
+
// Maps the buffer to be accessed as a host pointer into |out_mapped_memory|.
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_map(
iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access,
diff --git a/iree/hal/api_string_util_test.cc b/iree/hal/api_string_util_test.cc
index 44a5711..a148af1 100644
--- a/iree/hal/api_string_util_test.cc
+++ b/iree/hal/api_string_util_test.cc
@@ -554,6 +554,8 @@
IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_UINT_16)));
EXPECT_THAT(ParseElementType("f32"),
IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_32)));
+ EXPECT_THAT(ParseElementType("f16"),
+ IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_FLOAT_16)));
EXPECT_THAT(ParseElementType("x64"),
IsOkAndHolds(Eq(IREE_HAL_ELEMENT_TYPE_OPAQUE_64)));
EXPECT_THAT(ParseElementType("*64"),
@@ -1000,6 +1002,7 @@
expect_round_trip("4xi16=0 -1 2 3");
expect_round_trip("4xu16=0 1 2 3");
expect_round_trip("2x2xi32=[0 1][2 3]");
+ expect_round_trip("4xf16=0 0.5 2 3");
expect_round_trip("4xf32=0 1.1 2 3");
expect_round_trip("4xf64=0 1.1 2 3");
expect_round_trip("1x2x3xi8=[[0 1 2][3 4 5]]");
diff --git a/iree/hal/buffer.cc b/iree/hal/buffer.cc
index 13c756f..f798952 100644
--- a/iree/hal/buffer.cc
+++ b/iree/hal/buffer.cc
@@ -22,7 +22,6 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
-#include "absl/types/variant.h"
#include "iree/base/status.h"
namespace iree {
diff --git a/iree/hal/buffer.h b/iree/hal/buffer.h
index 3db906b..a7e8c01 100644
--- a/iree/hal/buffer.h
+++ b/iree/hal/buffer.h
@@ -59,7 +59,6 @@
#include <utility>
#include "absl/types/span.h"
-#include "absl/types/variant.h"
#include "iree/base/bitfield.h"
#include "iree/base/logging.h"
#include "iree/base/status.h"
diff --git a/iree/hal/command_buffer.h b/iree/hal/command_buffer.h
index 293e2d5..a94b532 100644
--- a/iree/hal/command_buffer.h
+++ b/iree/hal/command_buffer.h
@@ -32,7 +32,7 @@
// A bitfield specifying the mode of operation for a command buffer.
enum class CommandBufferMode : uint32_t {
// Command buffer will be submitted once and never used again.
- // This may enable in-place patching of command buffers that reduce overhead
+ // This may enable in-place patching of command buffers that reduces overhead
// when it's known that command buffers will not be reused.
kOneShot = 1 << 0,
};
diff --git a/iree/hal/cts/BUILD b/iree/hal/cts/BUILD
index cbee0cf..b915ee1 100644
--- a/iree/hal/cts/BUILD
+++ b/iree/hal/cts/BUILD
@@ -14,8 +14,6 @@
# Conformance Test Suite (CTS) for HAL implementations.
-load("//iree:build_defs.oss.bzl", "PLATFORM_VULKAN_TEST_DEPS")
-
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
@@ -33,14 +31,9 @@
deps = [
"//iree/base:status",
"//iree/hal:driver_registry",
+ "//iree/hal/drivers",
"//iree/testing:gtest",
-
- # HAL driver modules.
- "//iree/hal/dylib:dylib_driver_module", # build-cleaner: keep
- "//iree/hal/llvmjit:llvmjit_driver_module", # build-cleaner: keep
- "//iree/hal/vmla:vmla_driver_module", # build-cleaner: keep
- "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep
- ] + PLATFORM_VULKAN_TEST_DEPS,
+ ],
)
cc_test(
diff --git a/iree/hal/cts/CMakeLists.txt b/iree/hal/cts/CMakeLists.txt
index 1271a0c..f48aee9 100644
--- a/iree/hal/cts/CMakeLists.txt
+++ b/iree/hal/cts/CMakeLists.txt
@@ -14,8 +14,6 @@
iree_add_all_subdirs()
-# bazel_to_cmake: DO NOT EDIT, IREE_HAL_DRIVER_MODULES is custom logic
-
iree_cc_library(
NAME
cts_test_base
@@ -26,9 +24,8 @@
DEPS
iree::base::status
iree::hal::driver_registry
+ iree::hal::drivers
iree::testing::gtest
- iree::testing::gtest_main
- ${IREE_HAL_DRIVER_MODULES}
TESTONLY
PUBLIC
)
@@ -43,6 +40,7 @@
iree::base::status
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -55,6 +53,7 @@
iree::base::status
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -67,6 +66,7 @@
iree::base::status
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -79,6 +79,7 @@
iree::base::status
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -90,6 +91,7 @@
::cts_test_base
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -101,4 +103,5 @@
::cts_test_base
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
diff --git a/iree/hal/cts/allocator_test.cc b/iree/hal/cts/allocator_test.cc
index 1608c62..f8e20b3 100644
--- a/iree/hal/cts/allocator_test.cc
+++ b/iree/hal/cts/allocator_test.cc
@@ -77,10 +77,10 @@
allocator_->CanUseBufferLike(allocator_, memory_type, usage, usage));
}
-INSTANTIATE_TEST_SUITE_P(AllDrivers, AllocatorTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, AllocatorTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/cts/buffer_test.cc b/iree/hal/cts/buffer_test.cc
index 26eecc7..dc63dd0 100644
--- a/iree/hal/cts/buffer_test.cc
+++ b/iree/hal/cts/buffer_test.cc
@@ -375,10 +375,10 @@
EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xCC, 0xDD, 4, 5, 6));
}
-INSTANTIATE_TEST_SUITE_P(AllDrivers, BufferTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, BufferTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/cts/command_buffer_test.cc b/iree/hal/cts/command_buffer_test.cc
index 5b6e4d6..234e755 100644
--- a/iree/hal/cts/command_buffer_test.cc
+++ b/iree/hal/cts/command_buffer_test.cc
@@ -237,10 +237,10 @@
// TODO(scotttodd): UpdateBuffer, Dispatch, Sync, etc.
-INSTANTIATE_TEST_SUITE_P(AllDrivers, CommandBufferTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, CommandBufferTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/cts/command_queue_test.cc b/iree/hal/cts/command_queue_test.cc
index 411cd58..84feaae 100644
--- a/iree/hal/cts/command_queue_test.cc
+++ b/iree/hal/cts/command_queue_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <mutex>
+
#include "iree/base/status.h"
#include "iree/hal/cts/cts_test_base.h"
#include "iree/hal/driver_registry.h"
@@ -142,6 +144,10 @@
}
std::vector<std::string> GetSupportedDrivers() {
+ static std::once_flag register_once;
+ std::call_once(register_once, [] {
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
+ });
auto drivers = DriverRegistry::shared_registry()->EnumerateAvailableDrivers();
auto it = drivers.begin();
while (it != drivers.end()) {
diff --git a/iree/hal/cts/cts_test_base.h b/iree/hal/cts/cts_test_base.h
index 5d1fde9..fc2a94a 100644
--- a/iree/hal/cts/cts_test_base.h
+++ b/iree/hal/cts/cts_test_base.h
@@ -16,10 +16,12 @@
#define IREE_HAL_CTS_CTS_TEST_BASE_H_
#include <map>
+#include <mutex>
#include <set>
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
+#include "iree/hal/drivers/init.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
@@ -29,6 +31,15 @@
// Common setup for tests parameterized across all registered drivers.
class CtsTestBase : public ::testing::TestWithParam<std::string> {
+ public:
+ static std::vector<std::string> EnumerateAvailableDrivers() {
+ static std::once_flag register_once;
+ std::call_once(register_once, [] {
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
+ });
+ return DriverRegistry::shared_registry()->EnumerateAvailableDrivers();
+ }
+
protected:
// Per-test-suite set-up. This is called before the first test in this test
// suite. We use it to set up drivers that must be reused between test cases
diff --git a/iree/hal/cts/driver_test.cc b/iree/hal/cts/driver_test.cc
index 9b85ccb..2fea7d8 100644
--- a/iree/hal/cts/driver_test.cc
+++ b/iree/hal/cts/driver_test.cc
@@ -30,16 +30,16 @@
TEST_P(DriverTest, EnumerateAndCreateAvailableDevices) {
IREE_ASSERT_OK_AND_ASSIGN(auto devices, driver_->EnumerateAvailableDevices());
- for (int i = 0; i < devices.size(); ++i) {
+ for (iree_host_size_t i = 0; i < devices.size(); ++i) {
IREE_ASSERT_OK_AND_ASSIGN(auto device, driver_->CreateDevice(devices[i]));
IREE_LOG(INFO) << "Device #" << i << " details:\n" << device->DebugString();
}
}
-INSTANTIATE_TEST_SUITE_P(AllDrivers, DriverTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, DriverTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/cts/semaphore_test.cc b/iree/hal/cts/semaphore_test.cc
index c4299ca..727596e 100644
--- a/iree/hal/cts/semaphore_test.cc
+++ b/iree/hal/cts/semaphore_test.cc
@@ -141,10 +141,10 @@
thread.join();
}
-INSTANTIATE_TEST_SUITE_P(AllDrivers, SemaphoreTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, SemaphoreTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/driver_registry.cc b/iree/hal/driver_registry.cc
index 81322ea..79fd907 100644
--- a/iree/hal/driver_registry.cc
+++ b/iree/hal/driver_registry.cc
@@ -82,6 +82,3 @@
} // namespace hal
} // namespace iree
-
-IREE_REGISTER_MODULE_INITIALIZER(
- iree_hal, ::iree::hal::DriverRegistry::shared_registry());
diff --git a/iree/hal/driver_registry.h b/iree/hal/driver_registry.h
index 8d182ec..f0675c2 100644
--- a/iree/hal/driver_registry.h
+++ b/iree/hal/driver_registry.h
@@ -20,7 +20,6 @@
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
-#include "iree/base/initializer.h"
#include "iree/base/ref_ptr.h"
#include "iree/base/status.h"
#include "iree/hal/driver.h"
@@ -78,7 +77,4 @@
} // namespace hal
} // namespace iree
-IREE_DECLARE_MODULE_INITIALIZER(iree_hal);
-IREE_REQUIRE_MODULE_LINKED(iree_hal);
-
#endif // IREE_HAL_DRIVER_REGISTRY_H_
diff --git a/iree/hal/drivers/BUILD b/iree/hal/drivers/BUILD
new file mode 100644
index 0000000..8c04a1d
--- /dev/null
+++ b/iree/hal/drivers/BUILD
@@ -0,0 +1,35 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["-layering_check"], # buildozer: disable=no-layering-check, allow indirect headers
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "drivers",
+ srcs = ["init.c"],
+ hdrs = ["init.h"],
+ deps = [
+ "//iree/base:api",
+ "//iree/base:tracing",
+ ] + [
+ # TODO(*): select() and only pull in based on build configuration.
+ "//iree/hal/dylib/registration",
+ "//iree/hal/llvmjit/registration",
+ "//iree/hal/vmla/registration",
+ "//iree/hal/vulkan/registration/google_internal:registration",
+ ],
+)
diff --git a/iree/hal/drivers/CMakeLists.txt b/iree/hal/drivers/CMakeLists.txt
new file mode 100644
index 0000000..388f199
--- /dev/null
+++ b/iree/hal/drivers/CMakeLists.txt
@@ -0,0 +1,46 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# bazel_to_cmake: DO NOT EDIT (custom configuration vars)
+
+set(IREE_HAL_DRIVER_MODULES)
+if(${IREE_HAL_DRIVER_DYLIB})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::registration)
+endif()
+if(${IREE_HAL_DRIVER_LLVM})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::llvmjit::registration)
+endif()
+if(${IREE_HAL_DRIVER_METAL})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::metal::registration)
+endif()
+if(${IREE_HAL_DRIVER_VMLA})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vmla::registration)
+endif()
+if(${IREE_HAL_DRIVER_VULKAN})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vulkan::registration)
+endif()
+
+iree_cc_library(
+ NAME
+ drivers
+ HDRS
+ "init.h"
+ SRCS
+ "init.c"
+ DEPS
+ iree::base::api
+ iree::base::tracing
+ ${IREE_HAL_DRIVER_MODULES}
+ PUBLIC
+)
diff --git a/iree/hal/drivers/init.c b/iree/hal/drivers/init.c
new file mode 100644
index 0000000..38482f4
--- /dev/null
+++ b/iree/hal/drivers/init.c
@@ -0,0 +1,69 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/drivers/init.h"
+
+#include "iree/base/tracing.h"
+
+#if defined(IREE_HAL_HAVE_DYLIB_DRIVER_MODULE)
+#include "iree/hal/dylib/registration/driver_module.h"
+#endif // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE)
+#include "iree/hal/llvmjit/registration/driver_module.h"
+#endif // IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_METAL_DRIVER_MODULE)
+#include "iree/hal/metal/registration/driver_module.h"
+#endif // IREE_HAL_HAVE_METAL_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE)
+#include "iree/hal/vmla/registration/driver_module.h"
+#endif // IREE_HAL_HAVE_VMLA_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_VULKAN_DRIVER_MODULE)
+#include "iree/hal/vulkan/registration/google_internal/driver_module.h"
+#endif // IREE_HAL_HAVE_VULKAN_DRIVER_MODULE
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_register_all_available_drivers() {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+#if defined(IREE_HAL_HAVE_DYLIB_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+ iree_hal_dylib_driver_module_register());
+#endif // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+ iree_hal_llvmjit_driver_module_register());
+#endif // IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_METAL_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+ iree_hal_metal_driver_module_register());
+#endif // IREE_HAL_HAVE_METAL_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_vmla_driver_module_register());
+#endif // IREE_HAL_HAVE_VMLA_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_VULKAN_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+ iree_hal_vulkan_driver_module_register());
+#endif // IREE_HAL_HAVE_VULKAN_DRIVER_MODULE
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
diff --git a/iree/hal/drivers/init.h b/iree/hal/drivers/init.h
new file mode 100644
index 0000000..6ab9c40
--- /dev/null
+++ b/iree/hal/drivers/init.h
@@ -0,0 +1,38 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DRIVERS_INIT_H_
+#define IREE_HAL_DRIVERS_INIT_H_
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+#include "iree/base/api.h"
+
+// Registers all drivers that were linked into the current binary based on the
+// build configuration. Note that there may be no drivers available.
+//
+// This only registers IREE core drivers (those under iree/hal/). User-provided
+// drivers must be directly registered or directly created, though a user could
+// create their own user_register_all_available_drivers() that calls this as
+// well as registering their drivers.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_register_all_available_drivers();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_DRIVERS_INIT_H_
diff --git a/iree/hal/dylib/BUILD b/iree/hal/dylib/BUILD
index 9c19b61..49a5711 100644
--- a/iree/hal/dylib/BUILD
+++ b/iree/hal/dylib/BUILD
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# HAL implementation for executing functions provided by dynamic libraries.
-
load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
package(
@@ -31,70 +29,32 @@
)
cc_library(
- name = "dylib_device",
- srcs = ["dylib_device.cc"],
- hdrs = ["dylib_device.h"],
- deps = [
- ":dylib_executable_cache",
- "//iree/base:tracing",
- "//iree/hal/host:host_local_device",
+ name = "dylib",
+ srcs = [
+ "dylib_device.cc",
+ "dylib_driver.cc",
+ "dylib_executable.cc",
+ "dylib_executable_cache.cc",
],
-)
-
-cc_library(
- name = "dylib_driver",
- srcs = ["dylib_driver.cc"],
- hdrs = ["dylib_driver.h"],
- deps = [
- ":dylib_device",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "//iree/hal/host/serial:serial_scheduling_model",
+ hdrs = [
+ "dylib_device.h",
+ "dylib_driver.h",
+ "dylib_executable.h",
+ "dylib_executable_cache.h",
],
-)
-
-cc_library(
- name = "dylib_driver_module",
- srcs = ["dylib_driver_module.cc"],
- deps = [
- ":dylib_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/hal:driver_registry",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "dylib_executable",
- srcs = ["dylib_executable.cc"],
- hdrs = ["dylib_executable.h"],
deps = [
"//iree/base:dynamic_library",
"//iree/base:file_io",
"//iree/base:file_path",
+ "//iree/base:flatcc",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_spec",
+ "//iree/hal",
"//iree/hal/host:host_executable",
- "//iree/schemas:dylib_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/hal/host:host_local_device",
+ "//iree/hal/host/serial:serial_scheduling_model",
+ "//iree/schemas:dylib_executable_def_c_fbs",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:span",
],
)
-
-cc_library(
- name = "dylib_executable_cache",
- srcs = ["dylib_executable_cache.cc"],
- hdrs = ["dylib_executable_cache.h"],
- deps = [
- ":dylib_executable",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- ],
-)
diff --git a/iree/hal/dylib/CMakeLists.txt b/iree/hal/dylib/CMakeLists.txt
index b057837..03d278d 100644
--- a/iree/hal/dylib/CMakeLists.txt
+++ b/iree/hal/dylib/CMakeLists.txt
@@ -20,83 +20,30 @@
iree_cc_library(
NAME
- dylib_device
+ dylib
HDRS
"dylib_device.h"
+ "dylib_driver.h"
+ "dylib_executable.h"
+ "dylib_executable_cache.h"
SRCS
"dylib_device.cc"
- DEPS
- ::dylib_executable_cache
- iree::base::tracing
- iree::hal::host::host_local_device
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dylib_driver
- HDRS
- "dylib_driver.h"
- SRCS
"dylib_driver.cc"
- DEPS
- ::dylib_device
- iree::hal::device_info
- iree::hal::driver
- iree::hal::host::serial::serial_scheduling_model
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dylib_driver_module
- SRCS
- "dylib_driver_module.cc"
- DEPS
- ::dylib_driver
- iree::base::init
- iree::base::status
- iree::hal::driver_registry
- ALWAYSLINK
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dylib_executable
- HDRS
- "dylib_executable.h"
- SRCS
"dylib_executable.cc"
+ "dylib_executable_cache.cc"
DEPS
absl::inlined_vector
absl::span
- flatbuffers
iree::base::dynamic_library
iree::base::file_io
iree::base::file_path
+ iree::base::flatcc
iree::base::status
iree::base::tracing
- iree::hal::executable
- iree::hal::executable_spec
+ iree::hal
iree::hal::host::host_executable
- iree::schemas::dylib_executable_def_cc_fbs
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dylib_executable_cache
- HDRS
- "dylib_executable_cache.h"
- SRCS
- "dylib_executable_cache.cc"
- DEPS
- ::dylib_executable
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
+ iree::hal::host::host_local_device
+ iree::hal::host::serial::serial_scheduling_model
+ iree::schemas::dylib_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/hal/dylib/dylib_executable.cc b/iree/hal/dylib/dylib_executable.cc
index b15f99f..92a8ea7 100644
--- a/iree/hal/dylib/dylib_executable.cc
+++ b/iree/hal/dylib/dylib_executable.cc
@@ -14,10 +14,62 @@
#include "iree/hal/dylib/dylib_executable.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/base/file_io.h"
#include "iree/base/file_path.h"
-#include "iree/schemas/dylib_executable_def_generated.h"
+
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/dylib_executable_def_reader.h"
+#include "iree/schemas/dylib_executable_def_verifier.h"
+
+// NOTE: starting to port this to C.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_dylib_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret = iree_DyLibExecutableDef_verify_as_root(
+ flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_DyLibExecutableDef_table_t executable_def =
+ iree_DyLibExecutableDef_as_root(flatbuffer_data.data);
+
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_DyLibExecutableDef_entry_points_get(executable_def);
+ size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
+ for (size_t i = 0; i < entry_point_count; ++i) {
+ if (!flatbuffers_string_len(
+ flatbuffers_string_vec_at(entry_points_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable entry point %zu has no name", i);
+ }
+ }
+
+ if (!flatbuffers_uint8_vec_len(
+ iree_DyLibExecutableDef_library_embedded_get(executable_def))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable library_embedded is missing/empty");
+ }
+
+ return iree_ok_status();
+}
namespace iree {
namespace hal {
@@ -34,9 +86,13 @@
DyLibExecutable::~DyLibExecutable() {
IREE_TRACE_SCOPE0("DyLibExecutable::dtor");
- // TODO(benvanik): move to an atexit handler when tracing is enabled.
- // executable_library_.release();
+#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
+ // Leak the library when tracing, since the profiler may still be reading it.
+ // TODO(benvanik): move to an atexit handler instead, verify with ASAN/MSAN
+ executable_library_.release();
+#else
executable_library_.reset();
+#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
for (const auto& file_path : temp_file_paths_) {
file_io::DeleteFile(file_path).IgnoreError();
}
@@ -45,21 +101,22 @@
Status DyLibExecutable::Initialize(ExecutableSpec spec) {
IREE_TRACE_SCOPE0("DyLibExecutable::Initialize");
- auto dylib_executable_def =
- ::flatbuffers::GetRoot<DyLibExecutableDef>(spec.executable_data.data());
-
- if (!dylib_executable_def->entry_points() ||
- dylib_executable_def->entry_points()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
- }
- if (!dylib_executable_def->library_embedded() ||
- dylib_executable_def->library_embedded()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No embedded library";
- }
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data = iree_make_const_byte_span(
+ spec.executable_data.data(), spec.executable_data.size());
+ IREE_RETURN_IF_ERROR(
+ iree_hal_dylib_executable_flatbuffer_verify(executable_data));
+ iree_DyLibExecutableDef_table_t executable_def =
+ iree_DyLibExecutableDef_as_root(executable_data.data);
// Write the embedded library out to a temp file, since all of the dynamic
// library APIs work with files. We could instead use in-memory files on
// platforms where that is convenient.
+ //
+ // TODO(#3845): use dlopen on an fd with either dlopen(/proc/self/fd/NN),
+ // fdlopen, or android_dlopen_ext to avoid needing to write the file to disk.
+ // Can fallback to memfd_create + dlopen where available, and fallback from
+ // that to disk (maybe just windows/mac).
std::string base_name = "dylib_executable";
IREE_ASSIGN_OR_RETURN(auto library_temp_path,
file_io::GetTempFile(base_name));
@@ -73,50 +130,51 @@
library_temp_path += ".so";
#endif
- absl::string_view embedded_library_data(
- reinterpret_cast<const char*>(
- dylib_executable_def->library_embedded()->data()),
- dylib_executable_def->library_embedded()->size());
- IREE_RETURN_IF_ERROR(
- file_io::SetFileContents(library_temp_path, embedded_library_data));
+ flatbuffers_uint8_vec_t embedded_library_vec =
+ iree_DyLibExecutableDef_library_embedded_get(executable_def);
+ IREE_RETURN_IF_ERROR(file_io::SetFileContents(
+ library_temp_path,
+ absl::string_view(reinterpret_cast<const char*>(embedded_library_vec),
+ flatbuffers_uint8_vec_len(embedded_library_vec))));
IREE_ASSIGN_OR_RETURN(executable_library_,
DynamicLibrary::Load(library_temp_path.c_str()));
- if (dylib_executable_def->debug_database_filename() &&
- dylib_executable_def->debug_database_embedded()) {
+ flatbuffers_string_t debug_database_filename =
+ iree_DyLibExecutableDef_debug_database_filename_get(executable_def);
+ flatbuffers_uint8_vec_t debug_database_embedded_vec =
+ iree_DyLibExecutableDef_debug_database_embedded_get(executable_def);
+ if (flatbuffers_string_len(debug_database_filename) &&
+ flatbuffers_uint8_vec_len(debug_database_embedded_vec)) {
IREE_TRACE_SCOPE0("DyLibExecutable::AttachDebugDatabase");
- absl::string_view debug_database_filename(
- dylib_executable_def->debug_database_filename()->data(),
- dylib_executable_def->debug_database_filename()->size());
- absl::string_view debug_database_data(
- reinterpret_cast<const char*>(
- dylib_executable_def->debug_database_embedded()->data()),
- dylib_executable_def->debug_database_embedded()->size());
auto debug_database_path = file_path::JoinPaths(
- file_path::DirectoryName(library_temp_path), debug_database_filename);
+ file_path::DirectoryName(library_temp_path),
+ absl::string_view(debug_database_filename,
+ flatbuffers_string_len(debug_database_filename)));
temp_file_paths_.push_back(debug_database_path);
- IREE_IGNORE_ERROR(
- file_io::SetFileContents(debug_database_path, debug_database_data));
+ IREE_IGNORE_ERROR(file_io::SetFileContents(
+ debug_database_path,
+ absl::string_view(
+ reinterpret_cast<const char*>(debug_database_embedded_vec),
+ flatbuffers_uint8_vec_len(debug_database_embedded_vec))));
executable_library_->AttachDebugDatabase(debug_database_path.c_str());
}
- const auto& entry_points = *dylib_executable_def->entry_points();
- entry_functions_.resize(entry_points.size());
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- entry_names_.resize(entry_points.size());
-#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- for (int i = 0; i < entry_functions_.size(); ++i) {
- void* symbol = executable_library_->GetSymbol(entry_points[i]->c_str());
+ flatbuffers_string_vec_t entry_points =
+ iree_DyLibExecutableDef_entry_points_get(executable_def);
+ entry_functions_.resize(flatbuffers_string_vec_len(entry_points));
+ IREE_TRACE(entry_names_.resize(flatbuffers_string_vec_len(entry_points)));
+ for (size_t i = 0; i < entry_functions_.size(); ++i) {
+ flatbuffers_string_t entry_point =
+ flatbuffers_string_vec_at(entry_points, i);
+ void* symbol = executable_library_->GetSymbol(entry_point);
if (!symbol) {
return NotFoundErrorBuilder(IREE_LOC)
- << "Could not find symbol: " << entry_points[i];
+ << "Could not find symbol: " << entry_point;
}
entry_functions_[i] = symbol;
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- entry_names_[i] = entry_points[i]->c_str();
-#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
+ IREE_TRACE(entry_names_[i] = entry_point);
}
return OkStatus();
@@ -125,9 +183,7 @@
struct DyLibDispatchState : public HostExecutable::DispatchState {
DyLibDispatchState() = default;
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- const char* entry_name = nullptr;
-#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
+ IREE_TRACE(const char* entry_name = nullptr);
void* entry_function = nullptr;
std::array<void*, 32> args;
@@ -144,9 +200,7 @@
}
auto dispatch_state = make_ref<DyLibDispatchState>();
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- dispatch_state->entry_name = entry_names_[params.entry_point];
-#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
+ IREE_TRACE(dispatch_state->entry_name = entry_names_[params.entry_point]);
dispatch_state->entry_function = entry_functions_[params.entry_point];
int binding_count = 0;
diff --git a/iree/hal/dylib/dylib_executable.h b/iree/hal/dylib/dylib_executable.h
index f27151b..424c096 100644
--- a/iree/hal/dylib/dylib_executable.h
+++ b/iree/hal/dylib/dylib_executable.h
@@ -52,9 +52,7 @@
std::unique_ptr<DynamicLibrary> executable_library_;
std::vector<void*> entry_functions_;
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- std::vector<const char*> entry_names_;
-#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
+ IREE_TRACE(std::vector<const char*> entry_names_);
};
} // namespace dylib
diff --git a/iree/hal/dylib/registration/BUILD b/iree/hal/dylib/registration/BUILD
new file mode 100644
index 0000000..ec475fa
--- /dev/null
+++ b/iree/hal/dylib/registration/BUILD
@@ -0,0 +1,46 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(NOT ${IREE_HAL_DRIVER_DYLIB})
+ return()
+endif()
+""",
+ inline = True,
+)
+
+cc_library(
+ name = "registration",
+ srcs = ["driver_module.cc"],
+ hdrs = ["driver_module.h"],
+ defines = [
+ "IREE_HAL_HAVE_DYLIB_DRIVER_MODULE=1",
+ ],
+ deps = [
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/hal:api",
+ "//iree/hal:driver_registry",
+ "//iree/hal/dylib",
+ ],
+)
diff --git a/iree/hal/dylib/registration/CMakeLists.txt b/iree/hal/dylib/registration/CMakeLists.txt
new file mode 100644
index 0000000..07fd286
--- /dev/null
+++ b/iree/hal/dylib/registration/CMakeLists.txt
@@ -0,0 +1,37 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(NOT ${IREE_HAL_DRIVER_DYLIB})
+ return()
+endif()
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ iree::base::flags
+ iree::base::status
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::dylib
+ DEFINES
+ "IREE_HAL_HAVE_DYLIB_DRIVER_MODULE=1"
+ PUBLIC
+)
diff --git a/iree/hal/dylib/dylib_driver_module.cc b/iree/hal/dylib/registration/driver_module.cc
similarity index 74%
rename from iree/hal/dylib/dylib_driver_module.cc
rename to iree/hal/dylib/registration/driver_module.cc
index d71d3fa..b9a920f 100644
--- a/iree/hal/dylib/dylib_driver_module.cc
+++ b/iree/hal/dylib/registration/driver_module.cc
@@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/dylib/registration/driver_module.h"
-#include "iree/base/init.h"
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/dylib/dylib_driver.h"
@@ -31,8 +30,8 @@
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_dylib_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "dylib", ::iree::hal::dylib::CreateDyLibDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_dylib_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_dylib_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "dylib", ::iree::hal::dylib::CreateDyLibDriver);
+}
diff --git a/iree/hal/dylib/registration/driver_module.h b/iree/hal/dylib/registration/driver_module.h
new file mode 100644
index 0000000..230a69d
--- /dev/null
+++ b/iree/hal/dylib/registration/driver_module.h
@@ -0,0 +1,33 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DYLIB_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_DYLIB_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// DEPRECATED: this entire driver will be removed soon.
+// TODO(#3580): remove this entire driver w/ iree_hal_executable_library_t.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_dylib_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_DYLIB_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/host/BUILD b/iree/hal/host/BUILD
index a26fa0f..a968eed 100644
--- a/iree/hal/host/BUILD
+++ b/iree/hal/host/BUILD
@@ -28,7 +28,7 @@
deps = [
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:semaphore",
+ "//iree/hal",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/synchronization",
@@ -56,7 +56,7 @@
"//iree/base:logging",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:buffer",
+ "//iree/hal",
],
)
@@ -65,8 +65,7 @@
srcs = ["host_descriptor_set.cc"],
hdrs = ["host_descriptor_set.h"],
deps = [
- "//iree/hal:descriptor_set",
- "//iree/hal:descriptor_set_layout",
+ "//iree/hal",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -76,8 +75,7 @@
hdrs = ["host_executable.h"],
deps = [
"//iree/base:status",
- "//iree/hal:descriptor_set",
- "//iree/hal:executable",
+ "//iree/hal",
],
)
@@ -86,9 +84,8 @@
srcs = ["host_executable_layout.cc"],
hdrs = ["host_executable_layout.h"],
deps = [
- "//iree/base:memory",
- "//iree/hal:descriptor_set_layout",
- "//iree/hal:executable_layout",
+ "//iree/base:core_headers",
+ "//iree/hal",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -101,8 +98,7 @@
":host_buffer",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:buffer",
+ "//iree/hal",
],
)
@@ -115,12 +111,11 @@
":host_executable_layout",
":host_local_allocator",
":scheduling_model",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:allocator",
+ "//iree/hal",
"//iree/hal:command_buffer_validation",
- "//iree/hal:device",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
@@ -136,7 +131,7 @@
"//iree/base:intrusive_list",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:command_buffer",
+ "//iree/hal",
],
)
@@ -145,7 +140,7 @@
srcs = ["nop_event.cc"],
hdrs = ["nop_event.h"],
deps = [
- "//iree/hal:event",
+ "//iree/hal",
],
)
@@ -153,6 +148,6 @@
name = "scheduling_model",
hdrs = ["scheduling_model.h"],
deps = [
- "//iree/hal:command_queue",
+ "//iree/hal",
],
)
diff --git a/iree/hal/host/CMakeLists.txt b/iree/hal/host/CMakeLists.txt
index 953a777..5c926c6 100644
--- a/iree/hal/host/CMakeLists.txt
+++ b/iree/hal/host/CMakeLists.txt
@@ -28,7 +28,7 @@
absl::synchronization
iree::base::status
iree::base::tracing
- iree::hal::semaphore
+ iree::hal
PUBLIC
)
@@ -56,7 +56,7 @@
iree::base::logging
iree::base::status
iree::base::tracing
- iree::hal::buffer
+ iree::hal
PUBLIC
)
@@ -69,8 +69,7 @@
"host_descriptor_set.cc"
DEPS
absl::inlined_vector
- iree::hal::descriptor_set
- iree::hal::descriptor_set_layout
+ iree::hal
PUBLIC
)
@@ -81,8 +80,7 @@
"host_executable.h"
DEPS
iree::base::status
- iree::hal::descriptor_set
- iree::hal::executable
+ iree::hal
PUBLIC
)
@@ -95,9 +93,8 @@
"host_executable_layout.cc"
DEPS
absl::inlined_vector
- iree::base::memory
- iree::hal::descriptor_set_layout
- iree::hal::executable_layout
+ iree::base::core_headers
+ iree::hal
PUBLIC
)
@@ -112,8 +109,7 @@
::host_buffer
iree::base::status
iree::base::tracing
- iree::hal::allocator
- iree::hal::buffer
+ iree::hal
PUBLIC
)
@@ -132,12 +128,11 @@
absl::core_headers
absl::memory
absl::span
- iree::base::memory
+ iree::base::core_headers
iree::base::status
iree::base::tracing
- iree::hal::allocator
+ iree::hal
iree::hal::command_buffer_validation
- iree::hal::device
PUBLIC
)
@@ -153,7 +148,7 @@
iree::base::intrusive_list
iree::base::status
iree::base::tracing
- iree::hal::command_buffer
+ iree::hal
PUBLIC
)
@@ -165,7 +160,7 @@
SRCS
"nop_event.cc"
DEPS
- iree::hal::event
+ iree::hal
PUBLIC
)
@@ -175,6 +170,6 @@
HDRS
"scheduling_model.h"
DEPS
- iree::hal::command_queue
+ iree::hal
PUBLIC
)
diff --git a/iree/hal/host/host_executable.h b/iree/hal/host/host_executable.h
index 4885885..cb2e6c9 100644
--- a/iree/hal/host/host_executable.h
+++ b/iree/hal/host/host_executable.h
@@ -41,7 +41,7 @@
// Grid parameters shared for all tiles within a dispatch.
struct DispatchParams {
// Entry point within the executable.
- int32_t entry_point = 0;
+ size_t entry_point = 0;
// Total workgroup XYZ count for the grid.
std::array<uint32_t, 3> workgroup_count;
diff --git a/iree/hal/host/serial/BUILD b/iree/hal/host/serial/BUILD
index 6bba64a..9228d98 100644
--- a/iree/hal/host/serial/BUILD
+++ b/iree/hal/host/serial/BUILD
@@ -28,8 +28,7 @@
deps = [
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:command_queue",
- "//iree/hal:semaphore",
+ "//iree/hal",
"//iree/hal/host/serial:serial_submission_queue",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
@@ -43,7 +42,7 @@
":async_command_queue",
"//iree/base:status",
"//iree/base:time",
- "//iree/hal:command_queue",
+ "//iree/hal",
"//iree/hal/host/serial:serial_submission_queue",
"//iree/hal/testing:mock_command_buffer",
"//iree/hal/testing:mock_command_queue",
@@ -60,7 +59,7 @@
deps = [
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:command_buffer",
+ "//iree/hal",
"//iree/hal/host:host_descriptor_set",
"//iree/hal/host:host_executable",
"//iree/hal/host:host_executable_layout",
@@ -76,7 +75,7 @@
":async_command_queue",
":serial_command_processor",
":serial_submission_queue",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:status",
"//iree/base:tracing",
"//iree/hal/host:condvar_semaphore",
@@ -95,7 +94,7 @@
"//iree/base:intrusive_list",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:command_queue",
+ "//iree/hal",
"//iree/hal/host:condvar_semaphore",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
diff --git a/iree/hal/host/serial/CMakeLists.txt b/iree/hal/host/serial/CMakeLists.txt
index 1077854..86b7da5 100644
--- a/iree/hal/host/serial/CMakeLists.txt
+++ b/iree/hal/host/serial/CMakeLists.txt
@@ -26,9 +26,8 @@
absl::synchronization
iree::base::status
iree::base::tracing
- iree::hal::command_queue
+ iree::hal
iree::hal::host::serial::serial_submission_queue
- iree::hal::semaphore
PUBLIC
)
@@ -42,7 +41,7 @@
absl::memory
iree::base::status
iree::base::time
- iree::hal::command_queue
+ iree::hal
iree::hal::host::serial::serial_submission_queue
iree::hal::testing::mock_command_buffer
iree::hal::testing::mock_command_queue
@@ -61,7 +60,7 @@
absl::inlined_vector
iree::base::status
iree::base::tracing
- iree::hal::command_buffer
+ iree::hal
iree::hal::host::host_descriptor_set
iree::hal::host::host_executable
iree::hal::host::host_executable_layout
@@ -80,7 +79,7 @@
::serial_command_processor
::serial_submission_queue
absl::inlined_vector
- iree::base::memory
+ iree::base::core_headers
iree::base::status
iree::base::tracing
iree::hal::host::condvar_semaphore
@@ -104,7 +103,7 @@
iree::base::intrusive_list
iree::base::status
iree::base::tracing
- iree::hal::command_queue
+ iree::hal
iree::hal::host::condvar_semaphore
PUBLIC
)
diff --git a/iree/hal/llvmjit/BUILD b/iree/hal/llvmjit/BUILD
index 2e9d275..d7f5535 100644
--- a/iree/hal/llvmjit/BUILD
+++ b/iree/hal/llvmjit/BUILD
@@ -31,75 +31,33 @@
)
cc_library(
- name = "llvmjit_device",
- srcs = ["llvmjit_device.cc"],
- hdrs = ["llvmjit_device.h"],
- deps = [
- ":llvmjit_executable_cache",
- "//iree/base:tracing",
- "//iree/hal/host:host_local_device",
+ name = "llvmjit",
+ srcs = [
+ "llvmjit_device.cc",
+ "llvmjit_driver.cc",
+ "llvmjit_executable.cc",
+ "llvmjit_executable_cache.cc",
],
-)
-
-cc_library(
- name = "llvmjit_driver",
- srcs = ["llvmjit_driver.cc"],
- hdrs = ["llvmjit_driver.h"],
- deps = [
- ":llvmjit_device",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "//iree/hal/host/serial:serial_scheduling_model",
- "@llvm-project//llvm:ExecutionEngine",
+ hdrs = [
+ "llvmjit_device.h",
+ "llvmjit_driver.h",
+ "llvmjit_executable.h",
+ "llvmjit_executable_cache.h",
],
-)
-
-cc_library(
- name = "llvmjit_driver_module",
- srcs = ["llvmjit_driver_module.cc"],
deps = [
- ":llvmjit_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/hal:driver_registry",
- "@llvm-project//llvm:Support",
- #TODO(ataei): Link with native target dep.
- "@llvm-project//llvm:X86CodeGen",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "llvmjit_executable",
- srcs = ["llvmjit_executable.cc"],
- hdrs = ["llvmjit_executable.h"],
- deps = [
+ "//iree/base:flatcc",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:buffer",
- "//iree/hal:executable",
- "//iree/hal:executable_spec",
+ "//iree/hal",
"//iree/hal/host:host_executable",
- "//iree/schemas:llvmir_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/hal/host:host_local_device",
+ "//iree/hal/host/serial:serial_scheduling_model",
+ "//iree/schemas:llvmir_executable_def_c_fbs",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:Core",
+ "@llvm-project//llvm:ExecutionEngine",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:Support",
],
)
-
-cc_library(
- name = "llvmjit_executable_cache",
- srcs = ["llvmjit_executable_cache.cc"],
- hdrs = ["llvmjit_executable_cache.h"],
- deps = [
- ":llvmjit_executable",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- ],
-)
diff --git a/iree/hal/llvmjit/CMakeLists.txt b/iree/hal/llvmjit/CMakeLists.txt
index 6faf869..ea303d1 100644
--- a/iree/hal/llvmjit/CMakeLists.txt
+++ b/iree/hal/llvmjit/CMakeLists.txt
@@ -20,87 +20,31 @@
iree_cc_library(
NAME
- llvmjit_device
+ llvmjit
HDRS
"llvmjit_device.h"
+ "llvmjit_driver.h"
+ "llvmjit_executable.h"
+ "llvmjit_executable_cache.h"
SRCS
"llvmjit_device.cc"
- DEPS
- ::llvmjit_executable_cache
- iree::base::tracing
- iree::hal::host::host_local_device
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- llvmjit_driver
- HDRS
- "llvmjit_driver.h"
- SRCS
"llvmjit_driver.cc"
- DEPS
- ::llvmjit_device
- LLVMExecutionEngine
- iree::hal::device_info
- iree::hal::driver
- iree::hal::host::serial::serial_scheduling_model
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- llvmjit_driver_module
- SRCS
- "llvmjit_driver_module.cc"
- DEPS
- ::llvmjit_driver
- LLVMSupport
- LLVMX86CodeGen
- iree::base::init
- iree::base::status
- iree::hal::driver_registry
- ALWAYSLINK
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- llvmjit_executable
- HDRS
- "llvmjit_executable.h"
- SRCS
"llvmjit_executable.cc"
+ "llvmjit_executable_cache.cc"
DEPS
LLVMAsmParser
LLVMCore
+ LLVMExecutionEngine
LLVMOrcJIT
LLVMSupport
absl::span
- flatbuffers
+ iree::base::flatcc
iree::base::status
iree::base::tracing
- iree::hal::buffer
- iree::hal::executable
- iree::hal::executable_spec
+ iree::hal
iree::hal::host::host_executable
- iree::schemas::llvmir_executable_def_cc_fbs
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- llvmjit_executable_cache
- HDRS
- "llvmjit_executable_cache.h"
- SRCS
- "llvmjit_executable_cache.cc"
- DEPS
- ::llvmjit_executable
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
+ iree::hal::host::host_local_device
+ iree::hal::host::serial::serial_scheduling_model
+ iree::schemas::llvmir_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/hal/llvmjit/llvmjit_executable.cc b/iree/hal/llvmjit/llvmjit_executable.cc
index 391a217..5dd1a82 100644
--- a/iree/hal/llvmjit/llvmjit_executable.cc
+++ b/iree/hal/llvmjit/llvmjit_executable.cc
@@ -17,11 +17,9 @@
#include <iostream>
#include <memory>
-#include "flatbuffers/flatbuffers.h"
#include "iree/base/tracing.h"
#include "iree/hal/buffer.h"
#include "iree/hal/executable.h"
-#include "iree/schemas/llvmir_executable_def_generated.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/AsmParser/Parser.h"
@@ -31,6 +29,60 @@
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/llvmir_executable_def_reader.h"
+#include "iree/schemas/llvmir_executable_def_verifier.h"
+
+// NOTE: starting to port this to C.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_llvmir_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret = iree_LLVMIRExecutableDef_verify_as_root(
+ flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_LLVMIRExecutableDef_table_t executable_def =
+ iree_LLVMIRExecutableDef_as_root(flatbuffer_data.data);
+
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_LLVMIRExecutableDef_entry_points_get(executable_def);
+ size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
+ for (size_t i = 0; i < entry_point_count; ++i) {
+ if (!flatbuffers_string_len(
+ flatbuffers_string_vec_at(entry_points_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable entry point %zu has no name", i);
+ }
+ }
+
+ if (!flatbuffers_uint8_vec_len(
+ iree_LLVMIRExecutableDef_bitcode_module_get(executable_def))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable bitcode_module is missing/empty");
+ }
+
+ return iree_ok_status();
+}
+
namespace iree {
namespace hal {
namespace llvmjit {
@@ -40,13 +92,20 @@
ExecutableSpec spec, bool allow_aliasing_data) {
IREE_TRACE_SCOPE0("LLVMJITExecutable::Load");
- auto module_def =
- ::flatbuffers::GetRoot<LLVMIRExecutableDef>(spec.executable_data.data());
- auto data =
- reinterpret_cast<const char*>(module_def->llvmir_module()->data());
- const int size = module_def->llvmir_module()->size();
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data = iree_make_const_byte_span(
+ spec.executable_data.data(), spec.executable_data.size());
+ IREE_RETURN_IF_ERROR(
+ iree_hal_llvmir_executable_flatbuffer_verify(executable_data));
+ iree_LLVMIRExecutableDef_table_t executable_def =
+ iree_LLVMIRExecutableDef_as_root(executable_data.data);
+
+ flatbuffers_uint8_vec_t bitcode_module_vec =
+ iree_LLVMIRExecutableDef_bitcode_module_get(executable_def);
auto mem_buffer = llvm::MemoryBuffer::getMemBufferCopy(
- llvm::StringRef(data, size), "llvm-ir");
+ llvm::StringRef(reinterpret_cast<const char*>(bitcode_module_vec),
+ flatbuffers_uint8_vec_len(bitcode_module_vec)),
+ "llvm-ir");
auto llvm_context = std::make_unique<llvm::LLVMContext>();
llvm::SMDiagnostic sm_diagnostic;
auto module = llvm::parseAssembly(*mem_buffer, sm_diagnostic, *llvm_context);
@@ -55,7 +114,6 @@
<< "Can't parse LLVMIR Module: " << sm_diagnostic.getMessage().str();
}
auto dataLayout = module->getDataLayout();
- const auto entry_points = module_def->entry_points();
llvm::orc::ThreadSafeModule thread_safe_module(std::move(module),
std::move(llvm_context));
auto ll_jit = llvm::cantFail(llvm::orc::LLJITBuilder().create());
@@ -67,29 +125,35 @@
<< llvm::toString(std::move(err));
}
- auto dylib_serarch_generator =
+ auto llvmjit_serarch_generator =
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
dataLayout.getGlobalPrefix());
- if (!dylib_serarch_generator) {
+ if (!llvmjit_serarch_generator) {
return UnavailableErrorBuilder(IREE_LOC)
<< "Can't resolve symbols in current process: "
- << llvm::toString(dylib_serarch_generator.takeError());
+ << llvm::toString(llvmjit_serarch_generator.takeError());
}
- auto& main_jitdylib = ll_jit->getMainJITDylib();
- main_jitdylib.addGenerator(std::move(dylib_serarch_generator.get()));
+ auto& main_jitllvmjit = ll_jit->getMainJITDylib();
+ main_jitllvmjit.addGenerator(std::move(llvmjit_serarch_generator.get()));
auto executable =
make_ref<LLVMJITExecutable>(spec, std::move(ll_jit), allow_aliasing_data);
- for (const auto func_name : *entry_points) {
- auto func_symbol = executable->ll_jit_->lookup(func_name->str());
+ flatbuffers_string_vec_t entry_points =
+ iree_LLVMIRExecutableDef_entry_points_get(executable_def);
+ executable->symbols_.resize(flatbuffers_string_vec_len(entry_points));
+ for (size_t i = 0; i < flatbuffers_string_vec_len(entry_points); ++i) {
+ flatbuffers_string_t entry_point =
+ flatbuffers_string_vec_at(entry_points, i);
+ auto func_symbol = executable->ll_jit_->lookup(
+ llvm::StringRef(entry_point, flatbuffers_string_len(entry_point)));
if (!func_symbol) {
return NotFoundErrorBuilder(IREE_LOC)
- << "Can't JIT compile function '" << func_name
+ << "Can't JIT compile function '" << entry_point
<< "': " << llvm::toString(func_symbol.takeError());
}
- executable->symbols_.push_back(func_symbol.get());
+ executable->symbols_[i] = func_symbol.get();
}
return executable;
diff --git a/iree/hal/llvmjit/llvmjit_executable.h b/iree/hal/llvmjit/llvmjit_executable.h
index f672381..6c5b3bc 100644
--- a/iree/hal/llvmjit/llvmjit_executable.h
+++ b/iree/hal/llvmjit/llvmjit_executable.h
@@ -20,7 +20,6 @@
#include "iree/base/status.h"
#include "iree/hal/executable_spec.h"
#include "iree/hal/host/host_executable.h"
-#include "iree/schemas/llvmir_executable_def_generated.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
diff --git a/iree/hal/llvmjit/registration/BUILD b/iree/hal/llvmjit/registration/BUILD
new file mode 100644
index 0000000..b484138
--- /dev/null
+++ b/iree/hal/llvmjit/registration/BUILD
@@ -0,0 +1,54 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(${IREE_HAL_DRIVER_LLVM})
+""",
+ inline = True,
+)
+
+cc_library(
+ name = "registration",
+ srcs = ["driver_module.cc"],
+ hdrs = ["driver_module.h"],
+ defines = [
+ "IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE=1",
+ ],
+ deps = [
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/hal:api",
+ "//iree/hal:driver_registry",
+ "//iree/hal/llvmjit",
+ "@llvm-project//llvm:Support",
+ # TODO(ataei): Link with native target dep.
+ "@llvm-project//llvm:X86CodeGen",
+ ],
+)
+
+iree_cmake_extra_content(
+ content = """
+endif()
+""",
+ inline = True,
+)
diff --git a/iree/hal/llvmjit/registration/CMakeLists.txt b/iree/hal/llvmjit/registration/CMakeLists.txt
new file mode 100644
index 0000000..af7d688
--- /dev/null
+++ b/iree/hal/llvmjit/registration/CMakeLists.txt
@@ -0,0 +1,39 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_LLVM})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ LLVMSupport
+ LLVMX86CodeGen
+ iree::base::flags
+ iree::base::status
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::llvmjit
+ DEFINES
+ "IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE=1"
+ PUBLIC
+)
+
+endif()
diff --git a/iree/hal/llvmjit/llvmjit_driver_module.cc b/iree/hal/llvmjit/registration/driver_module.cc
similarity index 76%
rename from iree/hal/llvmjit/llvmjit_driver_module.cc
rename to iree/hal/llvmjit/registration/driver_module.cc
index db1837f..544467c 100644
--- a/iree/hal/llvmjit/llvmjit_driver_module.cc
+++ b/iree/hal/llvmjit/registration/driver_module.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/llvmjit/registration/driver_module.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/llvmjit/llvmjit_driver.h"
@@ -34,8 +34,8 @@
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_llvm_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "llvm", ::iree::hal::llvmjit::CreateLLVMJITDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_llvm_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_llvmjit_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "llvm", ::iree::hal::llvmjit::CreateLLVMJITDriver);
+}
diff --git a/iree/hal/llvmjit/registration/driver_module.h b/iree/hal/llvmjit/registration/driver_module.h
new file mode 100644
index 0000000..f7a7c85
--- /dev/null
+++ b/iree/hal/llvmjit/registration/driver_module.h
@@ -0,0 +1,31 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_LLVMJIT_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_LLVMJIT_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_llvmjit_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_LLVMJIT_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/metal/CMakeLists.txt b/iree/hal/metal/CMakeLists.txt
index 7cdb1f5..b8b5f0e 100644
--- a/iree/hal/metal/CMakeLists.txt
+++ b/iree/hal/metal/CMakeLists.txt
@@ -16,215 +16,51 @@
return()
endif()
+iree_add_all_subdirs()
+
iree_cc_library(
NAME
- metal_capture_manager
+ metal
HDRS
+ "metal_buffer.h"
"metal_capture_manager.h"
+ "metal_command_buffer.h"
+ "metal_command_queue.h"
+ "metal_device.h"
+ "metal_direct_allocator.h"
+ "metal_driver.h"
+ "metal_kernel_library.h"
+ "metal_pipeline_argument_buffer.h"
+ "metal_pipeline_cache.h"
+ "metal_shared_event.h"
SRCS
+ "metal_buffer.mm"
"metal_capture_manager.mm"
+ "metal_command_buffer.mm"
+ "metal_command_queue.mm"
+ "metal_device.mm"
+ "metal_direct_allocator.mm"
+ "metal_driver.mm"
+ "metal_kernel_library.mm"
+ "metal_pipeline_argument_buffer.cc"
+ "metal_pipeline_cache.mm"
+ "metal_shared_event.mm"
DEPS
+ absl::flat_hash_map
+ absl::inlined_vector
+ absl::memory
+ absl::span
+ absl::strings
+ iree::base::flatcc
iree::base::file_io
iree::base::logging
iree::base::status
- iree::base::tracing
- iree::hal::debug_capture_manager
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_command_buffer
- HDRS
- "metal_command_buffer.h"
- SRCS
- "metal_command_buffer.mm"
- DEPS
- ::metal_kernel_library
- ::metal_pipeline_argument_buffer
- absl::flat_hash_map
- absl::inlined_vector
- iree::base::status
- iree::base::tracing
- iree::hal::command_buffer
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_command_queue
- HDRS
- "metal_command_queue.h"
- SRCS
- "metal_command_queue.mm"
- DEPS
- ::metal_command_buffer
- ::metal_shared_event
- iree::base::status
iree::base::time
iree::base::tracing
- iree::hal::command_queue
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_device
- HDRS
- "metal_device.h"
- SRCS
- "metal_device.mm"
- DEPS
- ::metal_capture_manager
- ::metal_command_buffer
- ::metal_command_queue
- ::metal_direct_allocator
- ::metal_pipeline_argument_buffer
- ::metal_pipeline_cache
- ::metal_shared_event
- absl::strings
- absl::span
- iree::base::status
- iree::base::time
- iree::base::tracing
- iree::hal::allocator
- iree::hal::command_queue
- iree::hal::device
- iree::hal::driver
- iree::hal::semaphore
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_direct_allocator
- HDRS
- "metal_buffer.h"
- "metal_direct_allocator.h"
- SRCS
- "metal_buffer.mm"
- "metal_direct_allocator.mm"
- DEPS
- absl::memory
- iree::base::logging
- iree::base::status
- iree::base::tracing
- iree::hal::allocator
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_driver
- HDRS
- "metal_driver.h"
- SRCS
- "metal_driver.mm"
- DEPS
- ::metal_capture_manager
- ::metal_device
- iree::base::status
- iree::base::tracing
- iree::hal::device_info
- iree::hal::driver
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_driver_module
- SRCS
- "metal_driver_module.cc"
- DEPS
- ::metal_driver
- absl::flags
- iree::base::init
- iree::base::status
- iree::hal::driver_registry
- ALWAYSLINK
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_kernel_library
- HDRS
- "metal_kernel_library.h"
- SRCS
- "metal_kernel_library.mm"
- DEPS
- absl::inlined_vector
- iree::base::memory
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_spec
- iree::schemas::metal_executable_def_cc_fbs
+ iree::hal
+ iree::schemas::metal_executable_def_c_fbs
LINKOPTS
"-framework Foundation"
"-framework Metal"
PUBLIC
)
-
-iree_cc_library(
- NAME
- metal_pipeline_argument_buffer
- HDRS
- "metal_pipeline_argument_buffer.h"
- SRCS
- "metal_pipeline_argument_buffer.cc"
- DEPS
- absl::inlined_vector
- absl::span
- iree::hal::descriptor_set_layout
- iree::hal::executable_layout
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_pipeline_cache
- HDRS
- "metal_pipeline_cache.h"
- SRCS
- "metal_pipeline_cache.mm"
- DEPS
- ::metal_kernel_library
- flatbuffers
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
- iree::schemas::metal_executable_def_cc_fbs
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_shared_event
- HDRS
- "metal_shared_event.h"
- SRCS
- "metal_shared_event.mm"
- DEPS
- iree::base::tracing
- iree::hal::semaphore
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
diff --git a/iree/hal/metal/metal_command_buffer.mm b/iree/hal/metal/metal_command_buffer.mm
index 34ab86f..b40369f 100644
--- a/iree/hal/metal/metal_command_buffer.mm
+++ b/iree/hal/metal/metal_command_buffer.mm
@@ -368,11 +368,11 @@
}
IREE_DVLOG(2) << "Dispatch workgroup count: (" << workgroups[0] << ", " << workgroups[1] << ", "
- << workgroups[2] << "), workgroup size: (" << workgroup_size.x() << ", "
- << workgroup_size.y() << ", " << workgroup_size.z() << ")";
- [compute_encoder dispatchThreadgroups:MTLSizeMake(workgroups[0], workgroups[1], workgroups[2])
- threadsPerThreadgroup:MTLSizeMake(workgroup_size.x(), workgroup_size.y(),
- workgroup_size.z())];
+ << workgroups[2] << "), workgroup size: (" << workgroup_size.x << ", "
+ << workgroup_size.y << ", " << workgroup_size.z << ")";
+ [compute_encoder
+ dispatchThreadgroups:MTLSizeMake(workgroups[0], workgroups[1], workgroups[2])
+ threadsPerThreadgroup:MTLSizeMake(workgroup_size.x, workgroup_size.y, workgroup_size.z)];
return OkStatus();
}
diff --git a/iree/hal/metal/metal_kernel_library.h b/iree/hal/metal/metal_kernel_library.h
index 2fbf2f4..d68bc4e 100644
--- a/iree/hal/metal/metal_kernel_library.h
+++ b/iree/hal/metal/metal_kernel_library.h
@@ -24,7 +24,11 @@
#include "iree/hal/executable.h"
#include "iree/hal/executable_cache.h"
#include "iree/hal/executable_spec.h"
-#include "iree/schemas/metal_executable_def_generated.h"
+
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/metal_executable_def_reader.h"
+#include "iree/schemas/metal_executable_def_verifier.h"
namespace iree {
namespace hal {
@@ -41,7 +45,7 @@
public:
static StatusOr<ref_ptr<MetalKernelLibrary>> Create(
id<MTLDevice> device, ExecutableCachingModeBitfield mode,
- const MetalExecutableDef& metal_executable_def);
+ const ExecutableSpec& spec);
~MetalKernelLibrary() override;
bool supports_debugging() const override { return false; }
@@ -50,7 +54,7 @@
StatusOr<id<MTLFunction>> GetKernelForEntryPoint(int ordinal) const;
// Returns the threadgroup size for the entry point with the given |ordinal|.
- StatusOr<MetalThreadgroupSize> GetThreadgroupSizeForEntryPoint(
+ StatusOr<iree_MetalThreadgroupSize_t> GetThreadgroupSizeForEntryPoint(
int ordinal) const;
// Returns the pipeline state object for the entry point with the given
@@ -61,7 +65,7 @@
private:
struct KernelObjects {
id<MTLFunction> function;
- MetalThreadgroupSize threadgroup_size;
+ iree_MetalThreadgroupSize_t threadgroup_size;
// Baked pipeline state object.
id<MTLComputePipelineState> pipeline_state;
};
@@ -69,17 +73,13 @@
// Creates a MetalKernelLibrary assuming all Metal objects are already
// retained before passing in.
MetalKernelLibrary(id<MTLDevice> device,
- absl::InlinedVector<id<MTLLibrary>, 1> libraries,
- absl::InlinedVector<KernelObjects, 1> kernel_objects,
- std::string tag);
-
- // Tag coming from Metal executable FlatBuffer.
- std::string tag_;
+ absl::InlinedVector<id<MTLLibrary>, 4> libraries,
+ absl::InlinedVector<KernelObjects, 4> kernel_objects);
id<MTLDevice> device_;
- absl::InlinedVector<id<MTLLibrary>, 1> libraries_;
- absl::InlinedVector<KernelObjects, 1> kernel_objects_;
+ absl::InlinedVector<id<MTLLibrary>, 4> libraries_;
+ absl::InlinedVector<KernelObjects, 4> kernel_objects_;
};
} // namespace metal
diff --git a/iree/hal/metal/metal_kernel_library.mm b/iree/hal/metal/metal_kernel_library.mm
index 5e60ebe..8ea7e6a 100644
--- a/iree/hal/metal/metal_kernel_library.mm
+++ b/iree/hal/metal/metal_kernel_library.mm
@@ -18,42 +18,99 @@
#include "iree/base/status.h"
#include "iree/base/tracing.h"
+// NOTE: starting to port this to ObjC.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_metal_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret =
+ iree_MetalExecutableDef_verify_as_root(flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_MetalExecutableDef_table_t executable_def =
+ iree_MetalExecutableDef_as_root(flatbuffer_data.data);
+
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_MetalExecutableDef_entry_points_get(executable_def);
+ size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
+ for (size_t i = 0; i < entry_point_count; ++i) {
+ if (!flatbuffers_string_len(flatbuffers_string_vec_at(entry_points_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable entry point %zu has no name", i);
+ }
+ }
+
+ iree_MetalThreadgroupSize_vec_t threadgroup_sizes_vec =
+ iree_MetalExecutableDef_threadgroup_sizes(executable_def);
+ size_t threadgroup_size_count = iree_MetalThreadgroupSize_vec_len(threadgroup_sizes_vec);
+ if (!threadgroup_size_count) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "No threadgroup sizes present";
+ }
+
+ flatbuffers_string_vec_t shader_sources_vec =
+ iree_MetalExecutableDef_shader_sources_get(executable_def);
+ size_t shader_source_count = flatbuffers_string_vec_len(shader_sources_vec);
+ for (size_t i = 0; i < shader_source_count; ++i) {
+ if (!flatbuffers_string_len(flatbuffers_string_vec_at(shader_sources_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "executable shader source %zu is empty",
+ i);
+ }
+ }
+
+ if (entry_point_count != threadgroup_size_count || entry_point_count != shader_source_count) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "mismatch among the numbers of entry points (%zu), thread group sizes "
+ "(%zu), and source strings (%zu)",
+ entry_point_count, threadgroup_size_count, shader_source_count);
+ }
+
+ return iree_ok_status();
+}
+
namespace iree {
namespace hal {
namespace metal {
// static
-StatusOr<ref_ptr<MetalKernelLibrary>> MetalKernelLibrary::Create(
- id<MTLDevice> device, ExecutableCachingModeBitfield mode,
- const MetalExecutableDef& metal_executable_def) {
+StatusOr<ref_ptr<MetalKernelLibrary>> MetalKernelLibrary::Create(id<MTLDevice> device,
+ ExecutableCachingModeBitfield mode,
+ const ExecutableSpec& spec) {
IREE_TRACE_SCOPE0("MetalKernelLibrary::Create");
- if (!metal_executable_def.entry_points() || metal_executable_def.entry_points()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
- }
- if (!metal_executable_def.threadgroup_sizes() ||
- metal_executable_def.threadgroup_sizes()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No threadgroup sizes present";
- }
- if (!metal_executable_def.shader_sources() ||
- metal_executable_def.shader_sources()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No MSL source string present";
- }
- const auto& entry_points = *metal_executable_def.entry_points();
- const auto& threadgroup_sizes = *metal_executable_def.threadgroup_sizes();
- const auto& msl_sources = *metal_executable_def.shader_sources();
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data =
+ iree_make_const_byte_span(spec.executable_data.data(), spec.executable_data.size());
+ IREE_RETURN_IF_ERROR(iree_hal_metal_executable_flatbuffer_verify(executable_data));
+ iree_MetalExecutableDef_table_t executable_def =
+ iree_MetalExecutableDef_as_root(executable_data.data);
- if (entry_points.size() != threadgroup_sizes.size() ||
- entry_points.size() != msl_sources.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Mismatch among the numbers of entry points, thread group sizes, and source strings";
- }
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_MetalExecutableDef_entry_points_get(executable_def);
+ iree_MetalThreadgroupSize_vec_t threadgroup_sizes_vec =
+ iree_MetalExecutableDef_threadgroup_sizes(executable_def);
+ flatbuffers_string_vec_t shader_sources_vec =
+ iree_MetalExecutableDef_shader_sources_get(executable_def);
// Compile each MSL source string into a MTLLibrary and get the MTLFunction for the entry point to
// build the pipeline state object.
- absl::InlinedVector<id<MTLLibrary>, 1> libraries;
- absl::InlinedVector<KernelObjects, 1> kernel_objects;
+ absl::InlinedVector<id<MTLLibrary>, 4> libraries;
+ absl::InlinedVector<KernelObjects, 4> kernel_objects;
MTLCompileOptions* msl_compile_options = [MTLCompileOptions new];
msl_compile_options.languageVersion = MTLLanguageVersion2_0;
@@ -71,12 +128,15 @@
// debugging purposes but bad for performance. Enable offline compilation and make that as the
// default.
- for (int i = 0; i < msl_sources.size(); ++i) {
+ for (size_t entry_ordinal = 0; entry_ordinal < flatbuffers_string_vec_len(shader_sources_vec);
+ ++entry_ordinal) {
+ flatbuffers_string_t entry_point = flatbuffers_string_vec_at(entry_points_vec, entry_ordinal);
@autoreleasepool {
NSError* error = nil;
- NSString* shader_source = [NSString stringWithCString:msl_sources[i]->c_str()
- encoding:[NSString defaultCStringEncoding]];
+ NSString* shader_source =
+ [NSString stringWithCString:flatbuffers_string_vec_at(shader_sources_vec, entry_ordinal)
+ encoding:[NSString defaultCStringEncoding]];
id<MTLLibrary> library = [device newLibraryWithSource:shader_source
options:msl_compile_options
error:&error];
@@ -89,17 +149,17 @@
}
libraries.push_back(library);
- NSString* entry_point = [NSString stringWithCString:entry_points[i]->c_str()
- encoding:[NSString defaultCStringEncoding]];
- id<MTLFunction> function = [library newFunctionWithName:entry_point];
+ id<MTLFunction> function = [library
+ newFunctionWithName:[NSString stringWithCString:entry_point
+ encoding:[NSString defaultCStringEncoding]]];
if (!function) {
NSLog(@"Failed to create MTLFunction");
#ifndef NDEBUG
NSLog(@"Original MSL source: %@", shader_source);
#endif
return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Cannot find entry point '" << entry_points[i] << "' in shader source index "
- << i;
+ << "Cannot find entry point '" << entry_point << "' in shader source index "
+ << entry_ordinal;
}
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:function
@@ -112,21 +172,19 @@
return InvalidArgumentErrorBuilder(IREE_LOC) << "Invalid MSL source";
}
- kernel_objects.push_back(KernelObjects{function, *threadgroup_sizes[i], pso});
+ kernel_objects.push_back(
+ KernelObjects{function, {static_cast<uint32_t>(iree_MetalThreadgroupSize__size())}, pso});
}
}
- std::string tag = metal_executable_def.tag() ? metal_executable_def.tag()->str() : "";
- return assign_ref(new MetalKernelLibrary([device retain], std::move(libraries),
- std::move(kernel_objects), std::move(tag)));
+ return assign_ref(
+ new MetalKernelLibrary([device retain], std::move(libraries), std::move(kernel_objects)));
}
MetalKernelLibrary::MetalKernelLibrary(id<MTLDevice> device,
- absl::InlinedVector<id<MTLLibrary>, 1> libraries,
- absl::InlinedVector<KernelObjects, 1> kernel_objects,
- std::string tag)
- : tag_(std::move(tag)),
- device_(device),
+ absl::InlinedVector<id<MTLLibrary>, 4> libraries,
+ absl::InlinedVector<KernelObjects, 4> kernel_objects)
+ : device_(device),
libraries_(std::move(libraries)),
kernel_objects_(std::move(kernel_objects)) {}
@@ -146,7 +204,7 @@
return kernel_objects_[ordinal].function;
}
-StatusOr<MetalThreadgroupSize> MetalKernelLibrary::GetThreadgroupSizeForEntryPoint(
+StatusOr<iree_MetalThreadgroupSize_t> MetalKernelLibrary::GetThreadgroupSizeForEntryPoint(
int ordinal) const {
if (ordinal < 0 || ordinal >= kernel_objects_.size()) {
return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal: " << ordinal;
diff --git a/iree/hal/metal/metal_pipeline_cache.mm b/iree/hal/metal/metal_pipeline_cache.mm
index 2c44b0b..1e98770 100644
--- a/iree/hal/metal/metal_pipeline_cache.mm
+++ b/iree/hal/metal/metal_pipeline_cache.mm
@@ -14,12 +14,10 @@
#include "iree/hal/metal/metal_pipeline_cache.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/executable_format.h"
#include "iree/hal/metal/metal_kernel_library.h"
-#include "iree/schemas/metal_executable_def_generated.h"
namespace iree {
namespace hal {
@@ -37,19 +35,9 @@
ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode,
const ExecutableSpec& spec) {
IREE_TRACE_SCOPE0("MetalPipelineCache::PrepareExecutable");
- if (spec.executable_data.size() <= 4 ||
- !MetalExecutableDefBufferHasIdentifier(spec.executable_data.data())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Supplied executable data does not contain a MetalExecutableDef";
- }
-
- // Get the Metal executable def flatbuffer.
- const auto& metal_executable_def =
- *::flatbuffers::GetRoot<MetalExecutableDef>(spec.executable_data.data());
// Create the Metal library (which may itself own many pipeline states).
- IREE_ASSIGN_OR_RETURN(auto executable,
- MetalKernelLibrary::Create(metal_device_, mode, metal_executable_def));
+ IREE_ASSIGN_OR_RETURN(auto executable, MetalKernelLibrary::Create(metal_device_, mode, spec));
return executable;
}
diff --git a/iree/hal/metal/registration/CMakeLists.txt b/iree/hal/metal/registration/CMakeLists.txt
new file mode 100644
index 0000000..9a5e57c
--- /dev/null
+++ b/iree/hal/metal/registration/CMakeLists.txt
@@ -0,0 +1,38 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_METAL})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ absl::flags
+ iree::base::flags
+ iree::base::status
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::metal
+ DEFINES
+ "IREE_HAL_HAVE_METAL_DRIVER_MODULE=1"
+ PUBLIC
+)
+
+endif()
diff --git a/iree/hal/metal/metal_driver_module.cc b/iree/hal/metal/registration/driver_module.cc
similarity index 78%
rename from iree/hal/metal/metal_driver_module.cc
rename to iree/hal/metal/registration/driver_module.cc
index 48308ae..2b445f3 100644
--- a/iree/hal/metal/metal_driver_module.cc
+++ b/iree/hal/metal/registration/driver_module.cc
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/metal/registration/driver_module.h"
#include "absl/flags/flag.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/metal/metal_driver.h"
@@ -28,7 +28,6 @@
namespace iree {
namespace hal {
namespace metal {
-namespace {
StatusOr<ref_ptr<Driver>> CreateMetalDriver() {
MetalDriverOptions options;
@@ -37,13 +36,12 @@
return MetalDriver::Create(options);
}
-} // namespace
} // namespace metal
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_metal_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "metal", ::iree::hal::metal::CreateMetalDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_metal_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_metal_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "metal", ::iree::hal::metal::CreateMetalDriver);
+}
diff --git a/iree/hal/metal/registration/driver_module.h b/iree/hal/metal/registration/driver_module.h
new file mode 100644
index 0000000..1b5c9c0
--- /dev/null
+++ b/iree/hal/metal/registration/driver_module.h
@@ -0,0 +1,31 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_METAL_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_metal_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_METAL_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/testing/BUILD b/iree/hal/testing/BUILD
index d69e68c..e66768f 100644
--- a/iree/hal/testing/BUILD
+++ b/iree/hal/testing/BUILD
@@ -25,7 +25,7 @@
testonly = True,
hdrs = ["mock_allocator.h"],
deps = [
- "//iree/hal:allocator",
+ "//iree/hal",
"//iree/testing:gtest",
],
)
@@ -35,7 +35,7 @@
testonly = True,
hdrs = ["mock_command_buffer.h"],
deps = [
- "//iree/hal:command_buffer",
+ "//iree/hal",
"//iree/testing:gtest",
],
)
@@ -45,7 +45,7 @@
testonly = True,
hdrs = ["mock_command_queue.h"],
deps = [
- "//iree/hal:command_queue",
+ "//iree/hal",
"//iree/testing:gtest",
],
)
diff --git a/iree/hal/testing/CMakeLists.txt b/iree/hal/testing/CMakeLists.txt
index e925403..53fa5a8 100644
--- a/iree/hal/testing/CMakeLists.txt
+++ b/iree/hal/testing/CMakeLists.txt
@@ -20,7 +20,7 @@
HDRS
"mock_allocator.h"
DEPS
- iree::hal::allocator
+ iree::hal
iree::testing::gtest
TESTONLY
PUBLIC
@@ -32,7 +32,7 @@
HDRS
"mock_command_buffer.h"
DEPS
- iree::hal::command_buffer
+ iree::hal
iree::testing::gtest
TESTONLY
PUBLIC
@@ -44,7 +44,7 @@
HDRS
"mock_command_queue.h"
DEPS
- iree::hal::command_queue
+ iree::hal
iree::testing::gtest
TESTONLY
PUBLIC
diff --git a/iree/hal/vmla/BUILD b/iree/hal/vmla/BUILD
index d029bbf..cc7204d 100644
--- a/iree/hal/vmla/BUILD
+++ b/iree/hal/vmla/BUILD
@@ -57,7 +57,7 @@
srcs = ["op_kernels_test.cc"],
deps = [
":op_kernels",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"@com_google_absl//absl/container:inlined_vector",
@@ -65,108 +65,55 @@
)
cc_library(
- name = "vmla_cache",
- srcs = ["vmla_cache.cc"],
- hdrs = ["vmla_cache.h"],
- deps = [
- ":vmla_executable",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- "//iree/vm:instance",
- "//iree/vm:module",
- ],
-)
-
-cc_library(
- name = "vmla_device",
- srcs = ["vmla_device.cc"],
- hdrs = ["vmla_device.h"],
- deps = [
- ":vmla_cache",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_queue",
- "//iree/hal:device",
- "//iree/hal/host:host_local_device",
- "//iree/vm:instance",
- "//iree/vm:module",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "vmla_driver",
- srcs = ["vmla_driver.cc"],
- hdrs = ["vmla_driver.h"],
- deps = [
- ":vmla_device",
- ":vmla_module",
- "//iree/base:tracing",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "//iree/hal/host/serial:serial_scheduling_model",
- "//iree/vm:instance",
- "//iree/vm:module",
- ],
-)
-
-cc_library(
- name = "vmla_driver_module",
- srcs = ["vmla_driver_module.cc"],
- deps = [
- ":vmla_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/hal:driver_registry",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "vmla_executable",
- srcs = ["vmla_executable.cc"],
- hdrs = ["vmla_executable.h"],
- deps = [
- ":vmla_module",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_spec",
- "//iree/hal/host:host_buffer",
- "//iree/hal/host:host_executable",
- "//iree/schemas:vmla_executable_def_cc_fbs",
- "//iree/vm:bytecode_module",
- "//iree/vm:context",
- "//iree/vm:instance",
- "//iree/vm:invocation",
- "//iree/vm:list",
- "//iree/vm:module",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "vmla_module",
- srcs = ["vmla_module.cc"],
- hdrs = ["vmla_module.h"],
+ name = "op_module",
+ srcs = ["op_module.cc"],
+ hdrs = ["op_module.h"],
deps = [
":op_kernels",
"//iree/base:api",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:ref_ptr",
"//iree/base:status",
"//iree/base:tracing",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "vmla",
+ srcs = [
+ "vmla_cache.cc",
+ "vmla_device.cc",
+ "vmla_driver.cc",
+ "vmla_executable.cc",
+ ],
+ hdrs = [
+ "vmla_cache.h",
+ "vmla_device.h",
+ "vmla_driver.h",
+ "vmla_executable.h",
+ ],
+ deps = [
+ ":op_module",
+ "//iree/base:api",
+ "//iree/base:core_headers",
+ "//iree/base:flatcc",
+ "//iree/base:ref_ptr",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/hal",
+ "//iree/hal/host:host_buffer",
+ "//iree/hal/host:host_executable",
+ "//iree/hal/host:host_local_device",
+ "//iree/hal/host/serial:serial_scheduling_model",
+ "//iree/schemas:vmla_executable_def_c_fbs",
+ "//iree/vm",
+ "//iree/vm:bytecode_module",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
diff --git a/iree/hal/vmla/CMakeLists.txt b/iree/hal/vmla/CMakeLists.txt
index eb22436..b25a0dc 100644
--- a/iree/hal/vmla/CMakeLists.txt
+++ b/iree/hal/vmla/CMakeLists.txt
@@ -47,131 +47,63 @@
DEPS
::op_kernels
absl::inlined_vector
- iree::base::memory
+ iree::base::core_headers
iree::testing::gtest
iree::testing::gtest_main
)
iree_cc_library(
NAME
- vmla_cache
+ op_module
HDRS
- "vmla_cache.h"
+ "op_module.h"
SRCS
- "vmla_cache.cc"
- DEPS
- ::vmla_executable
- iree::base::status
- iree::base::tracing
- iree::hal::allocator
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
- iree::vm::instance
- iree::vm::module
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_device
- HDRS
- "vmla_device.h"
- SRCS
- "vmla_device.cc"
- DEPS
- ::vmla_cache
- absl::inlined_vector
- absl::memory
- absl::span
- absl::strings
- iree::base::memory
- iree::base::status
- iree::base::tracing
- iree::hal::command_queue
- iree::hal::device
- iree::hal::host::host_local_device
- iree::vm::instance
- iree::vm::module
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_driver
- HDRS
- "vmla_driver.h"
- SRCS
- "vmla_driver.cc"
- DEPS
- ::vmla_device
- ::vmla_module
- iree::base::tracing
- iree::hal::device_info
- iree::hal::driver
- iree::hal::host::serial::serial_scheduling_model
- iree::vm::instance
- iree::vm::module
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_driver_module
- SRCS
- "vmla_driver_module.cc"
- DEPS
- ::vmla_driver
- iree::base::init
- iree::base::status
- iree::hal::driver_registry
- ALWAYSLINK
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_executable
- HDRS
- "vmla_executable.h"
- SRCS
- "vmla_executable.cc"
- DEPS
- ::vmla_module
- absl::inlined_vector
- absl::span
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_spec
- iree::hal::host::host_buffer
- iree::hal::host::host_executable
- iree::schemas::vmla_executable_def_cc_fbs
- iree::vm::bytecode_module
- iree::vm::context
- iree::vm::instance
- iree::vm::invocation
- iree::vm::list
- iree::vm::module
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_module
- HDRS
- "vmla_module.h"
- SRCS
- "vmla_module.cc"
+ "op_module.cc"
DEPS
::op_kernels
absl::span
iree::base::api
- iree::base::memory
+ iree::base::core_headers
iree::base::ref_ptr
iree::base::status
iree::base::tracing
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ vmla
+ HDRS
+ "vmla_cache.h"
+ "vmla_device.h"
+ "vmla_driver.h"
+ "vmla_executable.h"
+ SRCS
+ "vmla_cache.cc"
+ "vmla_device.cc"
+ "vmla_driver.cc"
+ "vmla_executable.cc"
+ DEPS
+ ::op_module
+ absl::inlined_vector
+ absl::memory
+ absl::span
+ absl::strings
+ iree::base::api
+ iree::base::core_headers
+ iree::base::flatcc
+ iree::base::ref_ptr
+ iree::base::status
+ iree::base::tracing
+ iree::hal
+ iree::hal::host::host_buffer
+ iree::hal::host::host_executable
+ iree::hal::host::host_local_device
+ iree::hal::host::serial::serial_scheduling_model
+ iree::schemas::vmla_executable_def_c_fbs
+ iree::vm
+ iree::vm::bytecode_module
PUBLIC
)
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index b2b9cda..a7a5cae 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -455,6 +455,22 @@
ShapeSpan src_shape, ShapeSpan dst_shape);
};
+struct ReduceAnd {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ ShapeSpan src_shape, ShapeSpan dst_shape);
+};
+
+struct ReduceOr {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ ShapeSpan src_shape, ShapeSpan dst_shape);
+};
+
struct PoolingSum {
template <typename T>
static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index bcee0df..f3ec435 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -96,7 +96,7 @@
size_t element_size) {
absl::InlinedVector<size_t, 6> strides(shape.size());
strides.back() = element_size;
- for (int i = shape.size() - 2; i >= 0; --i) {
+ for (int i = static_cast<int>(shape.size()) - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
@@ -110,7 +110,7 @@
absl::Span<const int32_t> dst_indices,
absl::Span<const int32_t> lengths) {
if (lengths.size() > 1) {
- for (int i = 0; i < lengths[0]; ++i) {
+ for (int32_t i = 0; i < lengths[0]; ++i) {
size_t src_offset = src_strides[0] * (src_indices[0] + i);
size_t dst_offset = dst_strides[0] * (dst_indices[0] + i);
CopyRegion(src_buffer.subspan(src_offset), src_strides.subspan(1),
@@ -904,6 +904,20 @@
}
};
+struct AndKernel {
+ template <typename T>
+ inline void operator()(T* value0, const T value1) {
+ *value0 = *value0 && value1;
+ }
+};
+
+struct OrKernel {
+ template <typename T>
+ inline void operator()(T* value0, const T value1) {
+ *value0 = *value0 || value1;
+ }
+};
+
template <typename T, typename KernelImpl>
inline void ReduceDimension(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer, ShapeSpan src_shape,
@@ -1028,6 +1042,24 @@
src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
}
+template <typename T>
+Status ReduceAnd::Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ ShapeSpan src_shape, ShapeSpan dst_shape) {
+ return impl::GenericReduce<T, impl::AndKernel>(
+ src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
+}
+
+template <typename T>
+Status ReduceOr::Execute(absl::Span<const T> src_buffer,
+ absl::Span<const T> init_buffer,
+ absl::Span<T> dst_buffer, int32_t dimension,
+ ShapeSpan src_shape, ShapeSpan dst_shape) {
+ return impl::GenericReduce<T, impl::OrKernel>(
+ src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape);
+}
+
namespace impl {
template <typename T, typename KernelImpl>
@@ -1035,12 +1067,12 @@
absl::Span<const int> src_indices,
ShapeSpan src_shape, T init_value,
ShapeSpan window_dimensions, T* dst_value) {
- int rank = src_shape.size();
+ size_t rank = src_shape.size();
absl::InlinedVector<int, 8> window_indices(rank, 0);
auto getSrcValue = [&]() -> T {
- int flat_idx = 0;
- for (int i = 0; i < rank; ++i) {
- int idx = src_indices[i] + window_indices[i];
+ size_t flat_idx = 0;
+ for (size_t i = 0; i < rank; ++i) {
+ size_t idx = src_indices[i] + window_indices[i];
if (idx < 0 || idx >= src_shape[i]) return init_value;
flat_idx = flat_idx * src_shape[i] + idx;
}
@@ -1048,7 +1080,7 @@
};
*dst_value = init_value;
- for (int i = 0, e = GetElementCount(window_dimensions); i < e; ++i) {
+ for (size_t i = 0, e = GetElementCount(window_dimensions); i < e; ++i) {
KernelImpl()(dst_value, getSrcValue());
IncrementShapeIndex(absl::MakeSpan(window_indices), window_dimensions);
}
@@ -1060,11 +1092,11 @@
ShapeSpan src_shape, ShapeSpan dst_shape,
ShapeSpan window_dimensions, ShapeSpan strides,
ShapeSpan pad_low) {
- int rank = src_shape.size();
+ size_t rank = src_shape.size();
absl::InlinedVector<int, 8> src_indices(rank, 0);
absl::InlinedVector<int, 8> dst_indices(rank, 0);
- for (int i = 0, e = GetElementCount(dst_shape); i < e; ++i) {
- for (int j = 0; j < rank; ++j) {
+ for (size_t i = 0, e = GetElementCount(dst_shape); i < e; ++i) {
+ for (size_t j = 0; j < rank; ++j) {
src_indices[j] = dst_indices[j] * strides[j] - pad_low[j];
}
ComputePoolingWindow<T, KernelImpl>(src_buffer, src_indices, src_shape,
diff --git a/iree/hal/vmla/op_kernels_test.cc b/iree/hal/vmla/op_kernels_test.cc
index 877457b..feb1dee 100644
--- a/iree/hal/vmla/op_kernels_test.cc
+++ b/iree/hal/vmla/op_kernels_test.cc
@@ -39,7 +39,7 @@
}
template <typename T>
-std::vector<T> MakeIota(int size) {
+std::vector<T> MakeIota(size_t size) {
std::vector<T> v(size);
std::iota(v.begin(), v.end(), static_cast<T>(1));
return v;
@@ -390,7 +390,7 @@
absl::MakeSpan(dst_buffer),
dimension, src_shape, dst_shape));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
@@ -409,7 +409,7 @@
absl::MakeSpan(dst_buffer),
dimension, src_shape, dst_shape));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
@@ -421,8 +421,8 @@
Shape strides = {1, 2, 3, 1};
Shape pad_low = {0, 0, 0, 0};
std::vector<int> src_buffer = MakeIota<int>(GetShapeElementCount(src_shape));
- std::vector<int> init_buffer(1, 0.0f);
- std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
+ std::vector<int> init_buffer(1, 0);
+ std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0);
std::vector<int> expected_dst = {9, 12, 21, 24};
IREE_EXPECT_OK(PoolingMax::Execute<int>(
@@ -442,8 +442,8 @@
Shape strides = {1, 1};
Shape pad_low = {1, 1};
std::vector<int> src_buffer = MakeIota<int>(GetShapeElementCount(src_shape));
- std::vector<int> init_buffer(1, 100.0);
- std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
+ std::vector<int> init_buffer(1, 100);
+ std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0);
std::vector<int> expected_dst = {1, 1, 2, 1, 1, 2};
IREE_EXPECT_OK(PoolingMin::Execute<int>(
@@ -467,7 +467,7 @@
IREE_EXPECT_OK(PoolingSum::Execute<float>(
src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape,
window_sizes, strides, pad_low));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
@@ -485,10 +485,10 @@
std::vector<float> filter_buffer(GetShapeElementCount(filter_shape));
std::vector<float> expected_dst = {1310, 1466, 1622, 1778,
2090, 2246, 2402, 2558};
- for (int i = 0; i < GetShapeElementCount(input_shape); ++i) {
- input_buffer[i] = i + 1;
+ for (size_t i = 0; i < GetShapeElementCount(input_shape); ++i) {
+ input_buffer[i] = static_cast<float>(i + 1);
if (i < GetShapeElementCount(filter_shape)) {
- filter_buffer[i] = i + 1;
+ filter_buffer[i] = static_cast<float>(i + 1);
}
}
std::vector<float> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
@@ -498,7 +498,7 @@
absl::MakeSpan(dst_buffer), dst_shape, strides, pad_h, pad_w,
lhs_dilation, rhs_dilation, 1));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
@@ -518,7 +518,7 @@
1124, 1196, 1346, 1424, 1256, 1340, 1502, 1592, 1388, 1484, 1658,
1760, 1520, 1628, 1814, 1928, 1784, 1916, 2126, 2264, 1916, 2060,
2282, 2432, 2048, 2204, 2438, 2600, 2180, 2348, 2594, 2768};
- for (int i = 0; i < GetShapeElementCount(input_shape); ++i) {
+ for (size_t i = 0; i < GetShapeElementCount(input_shape); ++i) {
input_buffer[i] = i + 1;
if (i < GetShapeElementCount(filter_shape)) {
filter_buffer[i] = i + 1;
@@ -531,7 +531,7 @@
absl::MakeSpan(dst_buffer), dst_shape, strides, pad_h, pad_w,
lhs_dilation, rhs_dilation, 2));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/op_module.cc
similarity index 98%
rename from iree/hal/vmla/vmla_module.cc
rename to iree/hal/vmla/op_module.cc
index 571ed63..ff7ead3 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/op_module.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/hal/vmla/vmla_module.h"
+#include "iree/hal/vmla/op_module.h"
#include <cstdint>
@@ -129,8 +129,8 @@
constexpr int Interface::kMaxBindings;
void Interface::Reset() {
- for (int i = 0; i < bindings_.size(); ++i) {
- for (int j = 0; j < bindings_[i].size(); ++j) {
+ for (size_t i = 0; i < bindings_.size(); ++i) {
+ for (size_t j = 0; j < bindings_[i].size(); ++j) {
bindings_[i][j] = {};
}
}
@@ -150,7 +150,7 @@
<< "Constant value overflow; have " << values.size()
<< " but max is " << kMaxConstants;
}
- for (int i = 0; i < values.size(); ++i) {
+ for (size_t i = 0; i < values.size(); ++i) {
constants_[i] = values[i];
}
return OkStatus();
@@ -814,6 +814,8 @@
IREE_VMLA_REDUCTION_OP(ReduceMaxI16, kernels::ReduceMax, int16_t);
IREE_VMLA_REDUCTION_OP(ReduceMaxI32, kernels::ReduceMax, int32_t);
IREE_VMLA_REDUCTION_OP(ReduceMaxF32, kernels::ReduceMax, float);
+ IREE_VMLA_REDUCTION_OP(ReduceAndI8, kernels::ReduceAnd, int8_t);
+ IREE_VMLA_REDUCTION_OP(ReduceOrI8, kernels::ReduceOr, int8_t);
#define IREE_VMLA_POOLING_OP(name, kernel, type) \
Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape, \
@@ -1028,6 +1030,8 @@
vm::MakeNativeFunction("reduce.max.i16", &VMLAModuleState::ReduceMaxI16),
vm::MakeNativeFunction("reduce.max.i32", &VMLAModuleState::ReduceMaxI32),
vm::MakeNativeFunction("reduce.max.f32", &VMLAModuleState::ReduceMaxF32),
+ vm::MakeNativeFunction("reduce.and.i8", &VMLAModuleState::ReduceAndI8),
+ vm::MakeNativeFunction("reduce.or.i8", &VMLAModuleState::ReduceOrI8),
vm::MakeNativeFunction("pooling.sum.i8", &VMLAModuleState::PoolingSumI8),
vm::MakeNativeFunction("pooling.sum.i16", &VMLAModuleState::PoolingSumI16),
diff --git a/iree/hal/vmla/vmla_module.h b/iree/hal/vmla/op_module.h
similarity index 97%
rename from iree/hal/vmla/vmla_module.h
rename to iree/hal/vmla/op_module.h
index 9bb3860..f94a671 100644
--- a/iree/hal/vmla/vmla_module.h
+++ b/iree/hal/vmla/op_module.h
@@ -16,8 +16,8 @@
// linked into the same library, because of this we can avoid the C shims and
// directly use C++ types.
-#ifndef IREE_HAL_VMLA_VMLA_MODULE_H_
-#define IREE_HAL_VMLA_VMLA_MODULE_H_
+#ifndef IREE_HAL_VMLA_OP_MODULE_H_
+#define IREE_HAL_VMLA_OP_MODULE_H_
#include <cstdint>
@@ -136,4 +136,4 @@
IREE_VM_DECLARE_TYPE_ADAPTERS(Buffer, iree::hal::vmla::Buffer);
IREE_VM_DECLARE_TYPE_ADAPTERS(Interface, iree::hal::vmla::Interface);
-#endif // IREE_HAL_VMLA_VMLA_MODULE_H_
+#endif // IREE_HAL_VMLA_OP_MODULE_H_
diff --git a/iree/hal/vmla/registration/BUILD b/iree/hal/vmla/registration/BUILD
new file mode 100644
index 0000000..c37ffce
--- /dev/null
+++ b/iree/hal/vmla/registration/BUILD
@@ -0,0 +1,51 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(${IREE_HAL_DRIVER_VMLA})
+""",
+ inline = True,
+)
+
+cc_library(
+ name = "registration",
+ srcs = ["driver_module.cc"],
+ hdrs = ["driver_module.h"],
+ defines = [
+ "IREE_HAL_HAVE_VMLA_DRIVER_MODULE=1",
+ ],
+ deps = [
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/hal:api",
+ "//iree/hal:driver_registry",
+ "//iree/hal/vmla",
+ ],
+)
+
+iree_cmake_extra_content(
+ content = """
+endif()
+""",
+ inline = True,
+)
diff --git a/iree/hal/vmla/registration/CMakeLists.txt b/iree/hal/vmla/registration/CMakeLists.txt
new file mode 100644
index 0000000..04aee07
--- /dev/null
+++ b/iree/hal/vmla/registration/CMakeLists.txt
@@ -0,0 +1,37 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_VMLA})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ iree::base::flags
+ iree::base::status
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::vmla
+ DEFINES
+ "IREE_HAL_HAVE_VMLA_DRIVER_MODULE=1"
+ PUBLIC
+)
+
+endif()
diff --git a/iree/hal/vmla/vmla_driver_module.cc b/iree/hal/vmla/registration/driver_module.cc
similarity index 63%
rename from iree/hal/vmla/vmla_driver_module.cc
rename to iree/hal/vmla/registration/driver_module.cc
index b051db6..94877a1 100644
--- a/iree/hal/vmla/vmla_driver_module.cc
+++ b/iree/hal/vmla/registration/driver_module.cc
@@ -1,4 +1,4 @@
-// Copyright 2019 Google LLC
+// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/vmla/registration/driver_module.h"
-#include "iree/base/init.h"
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/vmla/vmla_driver.h"
@@ -22,17 +21,17 @@
namespace iree {
namespace hal {
namespace vmla {
-namespace {
-StatusOr<ref_ptr<Driver>> CreateVMLADriver() { return VMLADriver::Create(); }
+static StatusOr<ref_ptr<Driver>> CreateVMLADriver() {
+ return VMLADriver::Create();
+}
-} // namespace
} // namespace vmla
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_vmla_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "vmla", ::iree::hal::vmla::CreateVMLADriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_vmla_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_vmla_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "vmla", ::iree::hal::vmla::CreateVMLADriver);
+}
diff --git a/iree/hal/vmla/registration/driver_module.h b/iree/hal/vmla/registration/driver_module.h
new file mode 100644
index 0000000..ec7ff1a
--- /dev/null
+++ b/iree/hal/vmla/registration/driver_module.h
@@ -0,0 +1,33 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VMLA_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_VMLA_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// DEPRECATED: this entire driver will be removed soon.
+// TODO(#3580): remove this entire driver w/ iree_hal_executable_library_t.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_vmla_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_VMLA_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/vmla/vmla_cache.h b/iree/hal/vmla/vmla_cache.h
index 62650ab..560d4bd 100644
--- a/iree/hal/vmla/vmla_cache.h
+++ b/iree/hal/vmla/vmla_cache.h
@@ -18,8 +18,7 @@
#include "iree/hal/allocator.h"
#include "iree/hal/executable.h"
#include "iree/hal/executable_cache.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vmla/vmla_device.h b/iree/hal/vmla/vmla_device.h
index 78f23e6..3d3f3aa 100644
--- a/iree/hal/vmla/vmla_device.h
+++ b/iree/hal/vmla/vmla_device.h
@@ -17,8 +17,7 @@
#include "iree/base/memory.h"
#include "iree/hal/host/host_local_device.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vmla/vmla_driver.cc b/iree/hal/vmla/vmla_driver.cc
index c518efa..54696b1 100644
--- a/iree/hal/vmla/vmla_driver.cc
+++ b/iree/hal/vmla/vmla_driver.cc
@@ -19,9 +19,8 @@
#include "iree/base/tracing.h"
#include "iree/hal/device_info.h"
#include "iree/hal/host/serial/serial_scheduling_model.h"
+#include "iree/hal/vmla/op_module.h"
#include "iree/hal/vmla/vmla_device.h"
-#include "iree/hal/vmla/vmla_module.h"
-#include "iree/vm/module.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vmla/vmla_driver.h b/iree/hal/vmla/vmla_driver.h
index 3d8cb37..c701f4d 100644
--- a/iree/hal/vmla/vmla_driver.h
+++ b/iree/hal/vmla/vmla_driver.h
@@ -16,8 +16,7 @@
#define IREE_HAL_VMLA_VMLA_DRIVER_H_
#include "iree/hal/driver.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vmla/vmla_executable.cc b/iree/hal/vmla/vmla_executable.cc
index 4cc447a..da0cff2 100644
--- a/iree/hal/vmla/vmla_executable.cc
+++ b/iree/hal/vmla/vmla_executable.cc
@@ -17,12 +17,53 @@
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/host/host_buffer.h"
-#include "iree/hal/vmla/vmla_module.h"
-#include "iree/schemas/vmla_executable_def_generated.h"
+#include "iree/hal/vmla/op_module.h"
#include "iree/vm/bytecode_module.h"
-#include "iree/vm/invocation.h"
-#include "iree/vm/list.h"
-#include "iree/vm/module.h"
+
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/vmla_executable_def_reader.h"
+#include "iree/schemas/vmla_executable_def_verifier.h"
+
+// NOTE: starting to port this to C.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_vmla_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret = iree_VMLAExecutableDef_verify_as_root(
+ flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_VMLAExecutableDef_table_t executable_def =
+ iree_VMLAExecutableDef_as_root(flatbuffer_data.data);
+
+ if (flatbuffers_uint8_vec_len(
+ iree_VMLAExecutableDef_bytecode_module_get(executable_def)) < 0) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable bytecode_module is missing/empty");
+ }
+
+ // NOTE: we don't check the actual bytecode module contents here; it's opaque
+ // to us and passed on to the VM.
+ return iree_ok_status();
+}
namespace iree {
namespace hal {
@@ -61,34 +102,28 @@
iree_vm_module_t* vmla_module) {
IREE_TRACE_SCOPE0("VMLAExecutable::Initialize");
- if (spec_.executable_data.size() < 16) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Flatbuffer data is not present or less than 16 bytes";
- } else if (!iree::VMLAExecutableDefBufferHasIdentifier(
- spec_.executable_data.data())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Flatbuffer data does not have bytecode module identifier";
- }
-
- const auto* executable_def = ::flatbuffers::GetRoot<iree::VMLAExecutableDef>(
- spec_.executable_data.data());
- if (!executable_def || !executable_def->bytecode_module()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Failed getting root from flatbuffer data";
- }
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data = iree_make_const_byte_span(
+ spec_.executable_data.data(), spec_.executable_data.size());
+ IREE_RETURN_IF_ERROR(
+ iree_hal_vmla_executable_flatbuffer_verify(executable_data));
+ iree_VMLAExecutableDef_table_t executable_def =
+ iree_VMLAExecutableDef_as_root(executable_data.data);
// Load bytecode module from the executable spec.
+ flatbuffers_uint8_vec_t bytecode_module_vec =
+ iree_VMLAExecutableDef_bytecode_module_get(executable_def);
+ iree_const_byte_span_t bytecode_module_data = iree_make_const_byte_span(
+ bytecode_module_vec, flatbuffers_uint8_vec_len(bytecode_module_vec));
iree_vm_module_t* bytecode_module = nullptr;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
- iree_const_byte_span_t{reinterpret_cast<const uint8_t*>(
- executable_def->bytecode_module()->data()),
- executable_def->bytecode_module()->size()},
- iree_allocator_null(), iree_allocator_system(), &bytecode_module))
+ bytecode_module_data, iree_allocator_null(), iree_allocator_system(),
+ &bytecode_module))
<< "Failed to load executable bytecode module";
entry_functions_.resize(
iree_vm_module_signature(bytecode_module).export_function_count);
- for (int i = 0; i < entry_functions_.size(); ++i) {
+ for (size_t i = 0; i < entry_functions_.size(); ++i) {
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i,
&entry_functions_[i], nullptr));
@@ -134,7 +169,7 @@
auto* interface = &dispatch_state->interface;
IREE_RETURN_IF_ERROR(interface->SetConstants(params.push_constants->values));
- for (int set_ordinal = 0; set_ordinal < params.set_bindings.size();
+ for (size_t set_ordinal = 0; set_ordinal < params.set_bindings.size();
++set_ordinal) {
for (const auto& binding : params.set_bindings[set_ordinal]) {
// TODO(benvanik): plumb binding directly into VMLA to avoid this.
@@ -167,7 +202,7 @@
/*element_type=*/nullptr,
/*interface*/ 1 + /*workgroup_xyz[3]*/ 3, &input_list));
iree_vm_list_push_ref_retain(input_list, &dispatch_state->interface_ref);
- for (int i = 0; i < workgroup_xyz.size(); ++i) {
+ for (size_t i = 0; i < workgroup_xyz.size(); ++i) {
iree_vm_value_t value = iree_vm_value_make_i32(workgroup_xyz[i]);
iree_vm_list_push_value(input_list, &value);
}
diff --git a/iree/hal/vmla/vmla_executable.h b/iree/hal/vmla/vmla_executable.h
index 0e79b36..7eb3c27 100644
--- a/iree/hal/vmla/vmla_executable.h
+++ b/iree/hal/vmla/vmla_executable.h
@@ -22,9 +22,7 @@
#include "iree/base/status.h"
#include "iree/hal/executable_spec.h"
#include "iree/hal/host/host_executable.h"
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vulkan/BUILD b/iree/hal/vulkan/BUILD
index a00ccef..7c44f47 100644
--- a/iree/hal/vulkan/BUILD
+++ b/iree/hal/vulkan/BUILD
@@ -14,7 +14,7 @@
# HAL implementation using Vulkan and (likely) SPIR-V executables.
-load("//iree:build_defs.oss.bzl", "PLATFORM_VULKAN_TEST_DEPS")
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
package(
default_visibility = ["//visibility:public"],
@@ -22,20 +22,12 @@
licenses = ["notice"], # Apache 2.0
)
-# --define=IREE_VK=native to use the native Vulkan drivers (and real hardware).
-config_setting(
- name = "native_vk",
- values = {
- "define": "IREE_VK=native",
- },
-)
-
-# --define=IREE_VK=swiftshader to use SwiftShader.
-config_setting(
- name = "swiftshader_vk",
- values = {
- "define": "IREE_VK=swiftshader",
- },
+iree_cmake_extra_content(
+ content = """
+if(NOT ${IREE_HAL_DRIVER_VULKAN})
+ return()
+endif()
+""",
)
cc_library(
@@ -44,360 +36,66 @@
hdrs = ["api.h"],
visibility = ["//visibility:public"],
deps = [
- ":dynamic_symbols",
- ":extensibility_util",
- ":vulkan_device",
- ":vulkan_driver",
+ ":utils",
+ ":vulkan",
"//iree/base:api",
"//iree/base:tracing",
"//iree/hal:api",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
],
)
cc_library(
- name = "debug_reporter",
- srcs = ["debug_reporter.cc"],
- hdrs = ["debug_reporter.h"],
- deps = [
- ":dynamic_symbols",
- ":status_util",
- "//iree/base:status",
- "//iree/base:tracing",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
+ name = "utils",
+ srcs = [
+ "debug_reporter.cc",
+ "dynamic_symbols.cc",
+ "extensibility_util.cc",
+ "renderdoc_capture_manager.cc",
+ "status_util.cc",
+ "timepoint_util.cc",
],
-)
-
-cc_library(
- name = "descriptor_pool_cache",
- srcs = ["descriptor_pool_cache.cc"],
- hdrs = ["descriptor_pool_cache.h"],
- deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
- "//iree/base:ref_ptr",
- "//iree/base:tracing",
- "@com_google_absl//absl/container:inlined_vector",
- ],
-)
-
-cc_library(
- name = "descriptor_set_arena",
- srcs = ["descriptor_set_arena.cc"],
- hdrs = ["descriptor_set_arena.h"],
- deps = [
- ":descriptor_pool_cache",
- ":pipeline_executable",
- ":status_util",
- ":vma_allocator",
- "//iree/base:alignment",
- "//iree/base:arena",
- "//iree/base:math",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- ],
-)
-
-cc_library(
- name = "direct_command_buffer",
- srcs = ["direct_command_buffer.cc"],
- hdrs = ["direct_command_buffer.h"],
- deps = [
- ":descriptor_pool_cache",
- ":descriptor_set_arena",
- ":dynamic_symbols",
- ":handle_util",
- ":native_descriptor_set",
- ":native_event",
- ":pipeline_executable",
- ":pipeline_executable_layout",
- ":status_util",
- ":vma_allocator",
- "//iree/base:math",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "direct_command_queue",
- srcs = ["direct_command_queue.cc"],
- hdrs = ["direct_command_queue.h"],
- deps = [
- ":direct_command_buffer",
- ":dynamic_symbols",
- ":handle_util",
- ":native_timeline_semaphore",
- ":status_util",
- "//iree/base:api",
- "//iree/base:arena",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_queue",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "dynamic_symbols",
- srcs = ["dynamic_symbols.cc"],
hdrs = [
+ "debug_reporter.h",
"dynamic_symbol_tables.h",
"dynamic_symbols.h",
+ "extensibility_util.h",
+ "handle_util.h",
+ "renderdoc_capture_manager.h",
+ "status_util.h",
+ "timepoint_util.h",
+ "vulkan_headers.h",
],
deps = [
+ "//iree/base:core_headers",
"//iree/base:dynamic_library",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:target_platform",
- "//iree/base:tracing",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_test(
- name = "dynamic_symbols_test",
- srcs = ["dynamic_symbols_test.cc"],
- tags = ["driver=vulkan"],
- deps = [
- ":status_util",
- ":dynamic_symbols",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ] + PLATFORM_VULKAN_TEST_DEPS,
-)
-
-cc_library(
- name = "emulated_timeline_semaphore",
- srcs = ["emulated_timeline_semaphore.cc"],
- hdrs = ["emulated_timeline_semaphore.h"],
- deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
- ":timepoint_util",
- "//iree/base:api",
"//iree/base:intrusive_list",
+ "//iree/base:logging",
"//iree/base:ref_ptr",
"//iree/base:status",
"//iree/base:time",
"//iree/base:tracing",
- "//iree/hal:semaphore",
+ "//iree/hal",
"@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/utility",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "extensibility_util",
- srcs = ["extensibility_util.cc"],
- hdrs = ["extensibility_util.h"],
- deps = [
- ":dynamic_symbols",
- ":status_util",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
"@com_google_absl//absl/types:span",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "handle_util",
- hdrs = ["handle_util.h"],
- deps = [
- ":dynamic_symbols",
- ":extensibility_util",
- "//iree/base:ref_ptr",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/utility",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "native_event",
- srcs = ["native_event.cc"],
- hdrs = ["native_event.h"],
- deps = [
- ":handle_util",
- "//iree/hal:event",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "native_descriptor_set",
- srcs = ["native_descriptor_set.cc"],
- hdrs = ["native_descriptor_set.h"],
- deps = [
- ":handle_util",
- "//iree/hal:descriptor_set",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "native_timeline_semaphore",
- srcs = ["native_timeline_semaphore.cc"],
- hdrs = ["native_timeline_semaphore.h"],
- deps = [
- ":handle_util",
- ":status_util",
- "//iree/base:tracing",
- "//iree/hal:semaphore",
- "@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "pipeline_cache",
- srcs = ["pipeline_cache.cc"],
- hdrs = ["pipeline_cache.h"],
- deps = [
- ":handle_util",
- ":pipeline_executable",
- ":status_util",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- "//iree/schemas:spirv_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "pipeline_executable",
- srcs = ["pipeline_executable.cc"],
- hdrs = ["pipeline_executable.h"],
- deps = [
- ":handle_util",
- ":native_descriptor_set",
- ":pipeline_executable_layout",
- ":status_util",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_layout",
- "//iree/hal:executable_spec",
- "//iree/schemas:spirv_executable_def_cc_fbs",
- "@com_google_absl//absl/container:inlined_vector",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "pipeline_executable_layout",
- srcs = ["pipeline_executable_layout.cc"],
- hdrs = ["pipeline_executable_layout.h"],
- deps = [
- ":handle_util",
- "//iree/hal:descriptor_set_layout",
- "//iree/hal:executable_layout",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/types:span",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "renderdoc_capture_manager",
- srcs = ["renderdoc_capture_manager.cc"],
- hdrs = ["renderdoc_capture_manager.h"],
- deps = [
- "//iree/base:dynamic_library",
- "//iree/base:logging",
- "//iree/base:status",
- "//iree/base:target_platform",
- "//iree/base:tracing",
- "//iree/hal:debug_capture_manager",
- "@com_google_absl//absl/types:span",
+ "@iree_vulkan_headers//:vulkan_headers",
"@renderdoc_api//:renderdoc_app",
],
)
-cc_library(
- name = "serializing_command_queue",
- srcs = ["serializing_command_queue.cc"],
- hdrs = ["serializing_command_queue.h"],
- deps = [
- ":direct_command_buffer",
- ":dynamic_symbols",
- ":emulated_timeline_semaphore",
- ":handle_util",
- ":status_util",
- ":timepoint_util",
- "//iree/base:api",
- "//iree/base:intrusive_list",
- "//iree/base:memory",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- "//iree/hal:command_queue",
- "//iree/hal:semaphore",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "status_util",
- srcs = ["status_util.cc"],
- hdrs = ["status_util.h"],
- deps = [
- "//iree/base:status",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "timepoint_util",
- srcs = ["timepoint_util.cc"],
- hdrs = ["timepoint_util.h"],
- deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
- "//iree/base:intrusive_list",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:time",
- "//iree/base:tracing",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/utility",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
+# TODO(suderman): Re-enable once testing solution is found.
+# cc_test(
+# name = "dynamic_symbols_test",
+# srcs = ["dynamic_symbols_test.cc"],
+# tags = ["driver=vulkan"],
+# deps = [
+# ":utils",
+# "//iree/testing:gtest",
+# "//iree/testing:gtest_main",
+# ],
+# )
cc_library(
name = "vma_allocator",
@@ -411,106 +109,74 @@
"vma_allocator.h",
"vma_buffer.h",
],
- copts = [
- # Only needed in the implementation cc and not by external users.
- "-DVMA_STATIC_VULKAN_FUNCTIONS=0",
- ] + select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- "-Wno-thread-safety-attributes", # External code.
- ],
- }),
deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
+ ":utils",
+ "//iree/base:core_headers",
"//iree/base:logging",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:buffer",
+ "//iree/hal",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
"@vulkan_memory_allocator//:impl_header_only",
],
)
cc_library(
- name = "vulkan_device",
- srcs = ["vulkan_device.cc"],
- hdrs = ["vulkan_device.h"],
+ name = "vulkan",
+ srcs = [
+ "descriptor_pool_cache.cc",
+ "descriptor_set_arena.cc",
+ "direct_command_buffer.cc",
+ "direct_command_queue.cc",
+ "emulated_timeline_semaphore.cc",
+ "native_descriptor_set.cc",
+ "native_event.cc",
+ "native_timeline_semaphore.cc",
+ "pipeline_cache.cc",
+ "pipeline_executable.cc",
+ "pipeline_executable_layout.cc",
+ "serializing_command_queue.cc",
+ "vulkan_device.cc",
+ "vulkan_driver.cc",
+ ],
+ hdrs = [
+ "descriptor_pool_cache.h",
+ "descriptor_set_arena.h",
+ "direct_command_buffer.h",
+ "direct_command_queue.h",
+ "emulated_timeline_semaphore.h",
+ "native_descriptor_set.h",
+ "native_event.h",
+ "native_timeline_semaphore.h",
+ "pipeline_cache.h",
+ "pipeline_executable.h",
+ "pipeline_executable_layout.h",
+ "serializing_command_queue.h",
+ "vulkan_device.h",
+ "vulkan_driver.h",
+ ],
deps = [
- ":descriptor_pool_cache",
- ":direct_command_buffer",
- ":direct_command_queue",
- ":dynamic_symbols",
- ":emulated_timeline_semaphore",
- ":extensibility_util",
- ":handle_util",
- ":native_descriptor_set",
- ":native_event",
- ":native_timeline_semaphore",
- ":pipeline_cache",
- ":pipeline_executable_layout",
- ":serializing_command_queue",
- ":status_util",
+ ":utils",
":vma_allocator",
- "//iree/base:math",
- "//iree/base:memory",
+ "//iree/base:api",
+ "//iree/base:arena",
+ "//iree/base:core_headers",
+ "//iree/base:flatcc",
+ "//iree/base:intrusive_list",
+ "//iree/base:ref_ptr",
"//iree/base:status",
"//iree/base:time",
"//iree/base:tracing",
- "//iree/hal:allocator",
+ "//iree/hal",
"//iree/hal:command_buffer_validation",
- "//iree/hal:command_queue",
- "//iree/hal:debug_capture_manager",
- "//iree/hal:device",
- "//iree/hal:driver",
- "//iree/hal:semaphore",
+ "//iree/schemas:spirv_executable_def_c_fbs",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
],
)
-
-cc_library(
- name = "vulkan_driver",
- srcs = ["vulkan_driver.cc"],
- hdrs = ["vulkan_driver.h"],
- deps = [
- ":debug_reporter",
- ":dynamic_symbols",
- ":extensibility_util",
- ":renderdoc_capture_manager",
- ":status_util",
- ":vulkan_device",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:target_platform",
- "//iree/base:tracing",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "@com_google_absl//absl/container:inlined_vector",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "vulkan_driver_module",
- srcs = ["vulkan_driver_module.cc"],
- deps = [
- ":dynamic_symbols",
- ":vulkan_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:driver_registry",
- "@com_google_absl//absl/flags:flag",
- ],
- alwayslink = 1,
-)
diff --git a/iree/hal/vulkan/CMakeLists.txt b/iree/hal/vulkan/CMakeLists.txt
index 19e3685..3bdeeac 100644
--- a/iree/hal/vulkan/CMakeLists.txt
+++ b/iree/hal/vulkan/CMakeLists.txt
@@ -16,9 +16,7 @@
return()
endif()
-set(VMA_SRC_ROOT
- "${IREE_ROOT_DIR}/third_party/vulkan_memory_allocator/src/"
-)
+iree_add_all_subdirs()
iree_cc_library(
NAME
@@ -27,159 +25,52 @@
"api.h"
SRCS
"api.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
DEPS
+ ::utils
+ ::vulkan
iree::base::api
iree::base::tracing
iree::hal::api
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::extensibility_util
- iree::hal::vulkan::vulkan_device
- iree::hal::vulkan::vulkan_driver
- Vulkan::Headers
PUBLIC
)
iree_cc_library(
NAME
- debug_reporter
+ utils
HDRS
"debug_reporter.h"
- SRCS
- "debug_reporter.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- iree::base::status
- iree::base::tracing
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::status_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- descriptor_pool_cache
- HDRS
- "descriptor_pool_cache.h"
- SRCS
- "descriptor_pool_cache.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- iree::base::ref_ptr
- iree::base::tracing
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::status_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- descriptor_set_arena
- HDRS
- "descriptor_set_arena.h"
- SRCS
- "descriptor_set_arena.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- iree::base::alignment
- iree::base::arena
- iree::base::math
- iree::base::status
- iree::base::tracing
- iree::hal::command_buffer
- iree::hal::vulkan::descriptor_pool_cache
- iree::hal::vulkan::pipeline_executable
- iree::hal::vulkan::status_util
- iree::hal::vulkan::vma_allocator
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- direct_command_buffer
- HDRS
- "direct_command_buffer.h"
- SRCS
- "direct_command_buffer.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::core_headers
- absl::inlined_vector
- absl::synchronization
- iree::base::math
- iree::base::status
- iree::base::tracing
- iree::hal::command_buffer
- iree::hal::vulkan::descriptor_pool_cache
- iree::hal::vulkan::descriptor_set_arena
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::native_event
- iree::hal::vulkan::pipeline_executable
- iree::hal::vulkan::status_util
- iree::hal::vulkan::vma_allocator
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- direct_command_queue
- HDRS
- "direct_command_queue.h"
- SRCS
- "direct_command_queue.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::core_headers
- absl::synchronization
- iree::base::arena
- iree::base::memory
- iree::base::status
- iree::base::time
- iree::base::tracing
- iree::hal::command_queue
- iree::hal::vulkan::direct_command_buffer
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::native_timeline_semaphore
- iree::hal::vulkan::status_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dynamic_symbols
- HDRS
"dynamic_symbol_tables.h"
"dynamic_symbols.h"
+ "extensibility_util.h"
+ "handle_util.h"
+ "renderdoc_capture_manager.h"
+ "status_util.h"
+ "timepoint_util.h"
+ "vulkan_headers.h"
SRCS
+ "debug_reporter.cc"
"dynamic_symbols.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
+ "extensibility_util.cc"
+ "renderdoc_capture_manager.cc"
+ "status_util.cc"
+ "timepoint_util.cc"
DEPS
+ Vulkan::Headers
absl::core_headers
absl::memory
absl::span
absl::strings
+ absl::synchronization
+ iree::base::core_headers
iree::base::dynamic_library
+ iree::base::intrusive_list
+ iree::base::logging
iree::base::ref_ptr
iree::base::status
- iree::base::target_platform
+ iree::base::time
iree::base::tracing
- Vulkan::Headers
+ iree::hal
+ renderdoc_api::renderdoc_app
PUBLIC
)
@@ -188,278 +79,12 @@
dynamic_symbols_test
SRCS
"dynamic_symbols_test.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
DEPS
- iree::hal::vulkan::status_util
- iree::hal::vulkan::dynamic_symbols
+ ::utils
iree::testing::gtest
iree::testing::gtest_main
LABELS
- "nokokoro"
-)
-
-iree_cc_library(
- NAME
- emulated_timeline_semaphore
- HDRS
- "emulated_timeline_semaphore.h"
- SRCS
- "emulated_timeline_semaphore.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- ::status_util
- ::timepoint_util
- absl::inlined_vector
- absl::synchronization
- iree::base::intrusive_list
- iree::base::status
- iree::base::time
- iree::base::tracing
- iree::hal::semaphore
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- extensibility_util
- HDRS
- "extensibility_util.h"
- SRCS
- "extensibility_util.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::span
- iree::base::memory
- iree::base::status
- iree::base::tracing
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::status_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- handle_util
- HDRS
- "handle_util.h"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- absl::synchronization
- absl::utility
- iree::base::ref_ptr
- iree::hal::command_queue
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::extensibility_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_descriptor_set
- HDRS
- native_descriptor_set.h
- SRCS
- native_descriptor_set.cc
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- iree::hal::descriptor_set
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_event
- HDRS
- "native_event.h"
- SRCS
- "native_event.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- iree::hal::event
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_timeline_semaphore
- HDRS
- "native_timeline_semaphore.h"
- SRCS
- "native_timeline_semaphore.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- ::status_util
- iree::base::tracing
- iree::hal::semaphore
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- pipeline_cache
- HDRS
- "pipeline_cache.h"
- SRCS
- "pipeline_cache.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- absl::synchronization
- flatbuffers
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::pipeline_executable
- iree::hal::vulkan::status_util
- iree::schemas::spirv_executable_def_cc_fbs
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- pipeline_executable
- HDRS
- "pipeline_executable.h"
- SRCS
- "pipeline_executable.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- ::native_descriptor_set
- ::pipeline_executable_layout
- ::status_util
- absl::inlined_vector
- iree::base::memory
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_layout
- iree::hal::executable_spec
- iree::schemas::spirv_executable_def_cc_fbs
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- pipeline_executable_layout
- HDRS
- "pipeline_executable_layout.h"
- SRCS
- "pipeline_executable_layout.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- absl::span
- iree::hal::descriptor_set_layout
- iree::hal::executable_layout
- iree::hal::vulkan::handle_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- renderdoc_capture_manager
- HDRS
- "renderdoc_capture_manager.h"
- SRCS
- "renderdoc_capture_manager.cc"
- DEPS
- iree::base::dynamic_library
- iree::base::logging
- iree::base::status
- iree::base::target_platform
- iree::base::tracing
- iree::hal::debug_capture_manager
- renderdoc_api::renderdoc_app
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- serializing_command_queue
- HDRS
- "serializing_command_queue.h"
- SRCS
- "serializing_command_queue.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::direct_command_buffer
- ::emulated_timeline_semaphore
- ::handle_util
- ::status_util
- ::timepoint_util
- absl::inlined_vector
- absl::synchronization
- iree::base::status
- iree::base::tracing
- iree::hal::command_queue
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- status_util
- HDRS
- "status_util.h"
- SRCS
- "status_util.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- iree::base::status
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- timepoint_util
- HDRS
- "timepoint_util.h"
- SRCS
- "timepoint_util.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- absl::synchronization
- iree::base::intrusive_list
- iree::base::ref_ptr
- iree::base::status
- iree::base::time
- iree::base::tracing
- Vulkan::Headers
- PUBLIC
+ "driver=vulkan"
)
iree_cc_library(
@@ -473,120 +98,73 @@
"internal_vk_mem_alloc.h"
"vma_allocator.cc"
"vma_buffer.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- # Only needed in the implementation cc and not by external users.
- "-DVMA_STATIC_VULKAN_FUNCTIONS=0"
- # Silence some warnings.
- "-Wno-nullability-completeness"
- # Note: IREE_DEFAULT_COPTS sets -Wthread-safety-analysis, so we can't
- # disable here without switching off of iree_cc_library or refactoring
- # "-Wno-thread-safety-analysis"
- INCLUDES
- ${VMA_SRC_ROOT}
DEPS
+ ::utils
absl::flat_hash_map
absl::memory
absl::synchronization
+ iree::base::core_headers
iree::base::logging
iree::base::status
iree::base::tracing
- iree::hal::allocator
- iree::hal::buffer
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::status_util
- Vulkan::Headers
+ iree::hal
+ vulkan_memory_allocator
PUBLIC
)
iree_cc_library(
NAME
- vulkan_device
+ vulkan
HDRS
+ "descriptor_pool_cache.h"
+ "descriptor_set_arena.h"
+ "direct_command_buffer.h"
+ "direct_command_queue.h"
+ "emulated_timeline_semaphore.h"
+ "native_descriptor_set.h"
+ "native_event.h"
+ "native_timeline_semaphore.h"
+ "pipeline_cache.h"
+ "pipeline_executable.h"
+ "pipeline_executable_layout.h"
+ "serializing_command_queue.h"
"vulkan_device.h"
+ "vulkan_driver.h"
SRCS
+ "descriptor_pool_cache.cc"
+ "descriptor_set_arena.cc"
+ "direct_command_buffer.cc"
+ "direct_command_queue.cc"
+ "emulated_timeline_semaphore.cc"
+ "native_descriptor_set.cc"
+ "native_event.cc"
+ "native_timeline_semaphore.cc"
+ "pipeline_cache.cc"
+ "pipeline_executable.cc"
+ "pipeline_executable_layout.cc"
+ "serializing_command_queue.cc"
"vulkan_device.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
+ "vulkan_driver.cc"
DEPS
- ::descriptor_pool_cache
- ::direct_command_buffer
- ::direct_command_queue
- ::dynamic_symbols
- ::emulated_timeline_semaphore
- ::extensibility_util
- ::handle_util
- ::native_descriptor_set
- ::native_timeline_semaphore
- ::pipeline_cache
- ::pipeline_executable_layout
- ::serializing_command_queue
- ::status_util
+ ::utils
::vma_allocator
+ absl::core_headers
absl::inlined_vector
absl::memory
+ absl::span
absl::strings
absl::synchronization
- absl::span
- iree::base::math
- iree::base::memory
+ iree::base::api
+ iree::base::arena
+ iree::base::core_headers
+ iree::base::flatcc
+ iree::base::intrusive_list
+ iree::base::ref_ptr
iree::base::status
iree::base::time
iree::base::tracing
- iree::hal::allocator
+ iree::hal
iree::hal::command_buffer_validation
- iree::hal::command_queue
- iree::hal::debug_capture_manager
- iree::hal::device
- iree::hal::driver
- iree::hal::semaphore
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vulkan_driver
- HDRS
- "vulkan_driver.h"
- SRCS
- "vulkan_driver.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- iree::base::memory
- iree::base::status
- iree::base::target_platform
- iree::base::tracing
- iree::hal::device_info
- iree::hal::driver
- iree::hal::vulkan::debug_reporter
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::extensibility_util
- iree::hal::vulkan::renderdoc_capture_manager
- iree::hal::vulkan::status_util
- iree::hal::vulkan::vulkan_device
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vulkan_driver_module
- SRCS
- "vulkan_driver_module.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::flags
- iree::base::init
- iree::base::status
- iree::base::tracing
- iree::hal::driver_registry
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::vulkan_driver
- ALWAYSLINK
+ iree::schemas::spirv_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/hal/vulkan/api.h b/iree/hal/vulkan/api.h
index c3abe0c..6ad56a4 100644
--- a/iree/hal/vulkan/api.h
+++ b/iree/hal/vulkan/api.h
@@ -17,7 +17,7 @@
#ifndef IREE_HAL_VULKAN_API_H_
#define IREE_HAL_VULKAN_API_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
diff --git a/iree/hal/vulkan/debug_reporter.h b/iree/hal/vulkan/debug_reporter.h
index ffc594b..edcdbb3 100644
--- a/iree/hal/vulkan/debug_reporter.h
+++ b/iree/hal/vulkan/debug_reporter.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_DEBUG_REPORTER_H_
#define IREE_HAL_VULKAN_DEBUG_REPORTER_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/base/status.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
diff --git a/iree/hal/vulkan/descriptor_pool_cache.cc b/iree/hal/vulkan/descriptor_pool_cache.cc
index 32723f5..6feea16 100644
--- a/iree/hal/vulkan/descriptor_pool_cache.cc
+++ b/iree/hal/vulkan/descriptor_pool_cache.cc
@@ -65,7 +65,7 @@
std::array<VkDescriptorPoolSize, 1> pool_sizes;
pool_sizes[0].type = descriptor_type;
pool_sizes[0].descriptorCount = max_descriptor_count;
- create_info.poolSizeCount = pool_sizes.size();
+ create_info.poolSizeCount = static_cast<uint32_t>(pool_sizes.size());
create_info.pPoolSizes = pool_sizes.data();
DescriptorPool descriptor_pool;
diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc
index 97b8b86..238384f 100644
--- a/iree/hal/vulkan/descriptor_set_arena.cc
+++ b/iree/hal/vulkan/descriptor_set_arena.cc
@@ -131,7 +131,7 @@
// Pick a bucket based on the number of descriptors required.
// NOTE: right now we are 1:1 with bindings.
- int required_descriptor_count = bindings.size() * 1;
+ int required_descriptor_count = static_cast<int>(bindings.size() * 1);
int max_descriptor_count =
std::max(8, RoundUpToNearestPow2(required_descriptor_count));
int bucket = TrailingZeros(max_descriptor_count >> 3);
@@ -194,7 +194,8 @@
// descriptor sets we will need and what buffers they will point to (without
// doing just as much work as actually recording the buffer to try to find
// out).
- syms().vkUpdateDescriptorSets(*logical_device_, write_infos.size(),
+ syms().vkUpdateDescriptorSets(*logical_device_,
+ static_cast<uint32_t>(write_infos.size()),
write_infos.data(), 0, nullptr);
// Bind the descriptor set.
@@ -219,7 +220,8 @@
// command buffer and prevent the need for our own pooling mechanisms.
syms().vkCmdPushDescriptorSetKHR(
command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE,
- executable_layout->handle(), set, write_infos.size(), write_infos.data());
+ executable_layout->handle(), set,
+ static_cast<uint32_t>(write_infos.size()), write_infos.data());
return OkStatus();
}
diff --git a/iree/hal/vulkan/direct_command_buffer.cc b/iree/hal/vulkan/direct_command_buffer.cc
index 0222c73..b5a9aab 100644
--- a/iree/hal/vulkan/direct_command_buffer.cc
+++ b/iree/hal/vulkan/direct_command_buffer.cc
@@ -230,8 +230,10 @@
syms()->vkCmdPipelineBarrier(
command_buffer_, ConvertPipelineStageFlags(source_stage_mask),
ConvertPipelineStageFlags(target_stage_mask), /*dependencyFlags=*/0,
- memory_barrier_infos.size(), memory_barrier_infos.data(),
- buffer_barrier_infos.size(), buffer_barrier_infos.data(), 0, nullptr);
+ static_cast<uint32_t>(memory_barrier_infos.size()),
+ memory_barrier_infos.data(),
+ static_cast<uint32_t>(buffer_barrier_infos.size()),
+ buffer_barrier_infos.data(), 0, nullptr);
return OkStatus();
}
@@ -296,12 +298,14 @@
info.size = buffer_barrier.length;
}
- syms()->vkCmdWaitEvents(
- command_buffer_, event_handles.size(), event_handles.data(),
- ConvertPipelineStageFlags(source_stage_mask),
- ConvertPipelineStageFlags(target_stage_mask), memory_barrier_infos.size(),
- memory_barrier_infos.data(), buffer_barrier_infos.size(),
- buffer_barrier_infos.data(), 0, nullptr);
+ syms()->vkCmdWaitEvents(command_buffer_, event_handles.size(),
+ event_handles.data(),
+ ConvertPipelineStageFlags(source_stage_mask),
+ ConvertPipelineStageFlags(target_stage_mask),
+ static_cast<uint32_t>(memory_barrier_infos.size()),
+ memory_barrier_infos.data(),
+ static_cast<uint32_t>(buffer_barrier_infos.size()),
+ buffer_barrier_infos.data(), 0, nullptr);
return OkStatus();
}
@@ -385,8 +389,9 @@
syms()->vkCmdPushConstants(
command_buffer_, device_executable_layout->handle(),
- VK_SHADER_STAGE_COMPUTE_BIT, offset * sizeof(uint32_t),
- values.size() * sizeof(uint32_t), values.data());
+ VK_SHADER_STAGE_COMPUTE_BIT,
+ static_cast<uint32_t>(offset * sizeof(uint32_t)),
+ static_cast<uint32_t>(values.size() * sizeof(uint32_t)), values.data());
return OkStatus();
}
@@ -424,8 +429,9 @@
device_descriptor_set->handle()};
syms()->vkCmdBindDescriptorSets(
command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
- device_executable_layout->handle(), set, descriptor_sets.size(),
- descriptor_sets.data(), dynamic_offsets_i32.size(),
+ device_executable_layout->handle(), set,
+ static_cast<uint32_t>(descriptor_sets.size()), descriptor_sets.data(),
+ static_cast<uint32_t>(dynamic_offsets_i32.size()),
dynamic_offsets_i32.data());
return OkStatus();
diff --git a/iree/hal/vulkan/direct_command_buffer.h b/iree/hal/vulkan/direct_command_buffer.h
index 20f637d..8808c6e 100644
--- a/iree/hal/vulkan/direct_command_buffer.h
+++ b/iree/hal/vulkan/direct_command_buffer.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
#define IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/hal/command_buffer.h"
#include "iree/hal/vulkan/descriptor_pool_cache.h"
diff --git a/iree/hal/vulkan/direct_command_queue.cc b/iree/hal/vulkan/direct_command_queue.cc
index 7523b09..71e8e33 100644
--- a/iree/hal/vulkan/direct_command_queue.cc
+++ b/iree/hal/vulkan/direct_command_queue.cc
@@ -87,21 +87,25 @@
submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
submit_info->pNext = timeline_submit_info;
- submit_info->waitSemaphoreCount = wait_semaphore_handles.size();
+ submit_info->waitSemaphoreCount =
+ static_cast<uint32_t>(wait_semaphore_handles.size());
submit_info->pWaitSemaphores = wait_semaphore_handles.data();
submit_info->pWaitDstStageMask = wait_dst_stage_masks.data();
- submit_info->commandBufferCount = command_buffer_handles.size();
+ submit_info->commandBufferCount =
+ static_cast<uint32_t>(command_buffer_handles.size());
submit_info->pCommandBuffers = command_buffer_handles.data();
- submit_info->signalSemaphoreCount = signal_semaphore_handles.size();
+ submit_info->signalSemaphoreCount =
+ static_cast<uint32_t>(signal_semaphore_handles.size());
submit_info->pSignalSemaphores = signal_semaphore_handles.data();
timeline_submit_info->sType =
VK_STRUCTURE_TYPE_TIMELINE_SEMAPHORE_SUBMIT_INFO;
timeline_submit_info->pNext = nullptr;
- timeline_submit_info->waitSemaphoreValueCount = wait_semaphore_values.size();
+ timeline_submit_info->waitSemaphoreValueCount =
+ static_cast<uint32_t>(wait_semaphore_values.size());
timeline_submit_info->pWaitSemaphoreValues = wait_semaphore_values.data();
timeline_submit_info->signalSemaphoreValueCount =
- signal_semaphore_values.size();
+ static_cast<uint32_t>(signal_semaphore_values.size());
timeline_submit_info->pSignalSemaphoreValues = signal_semaphore_values.data();
return OkStatus();
@@ -125,7 +129,8 @@
{
absl::MutexLock lock(&queue_mutex_);
VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(
- queue_, submit_infos.size(), submit_infos.data(), VK_NULL_HANDLE));
+ queue_, static_cast<uint32_t>(submit_infos.size()), submit_infos.data(),
+ VK_NULL_HANDLE));
}
return OkStatus();
diff --git a/iree/hal/vulkan/direct_command_queue.h b/iree/hal/vulkan/direct_command_queue.h
index 1921bfa..b0d3611 100644
--- a/iree/hal/vulkan/direct_command_queue.h
+++ b/iree/hal/vulkan/direct_command_queue.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
#define IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <cstdint>
#include <string>
diff --git a/iree/hal/vulkan/dynamic_symbols.h b/iree/hal/vulkan/dynamic_symbols.h
index c7651d2..8e74875 100644
--- a/iree/hal/vulkan/dynamic_symbols.h
+++ b/iree/hal/vulkan/dynamic_symbols.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
#define IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <cstdint>
#include <functional>
diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.cc b/iree/hal/vulkan/emulated_timeline_semaphore.cc
index fb34bcc..475aa33 100644
--- a/iree/hal/vulkan/emulated_timeline_semaphore.cc
+++ b/iree/hal/vulkan/emulated_timeline_semaphore.cc
@@ -16,7 +16,6 @@
#include "absl/container/inlined_vector.h"
#include "absl/synchronization/mutex.h"
-#include "absl/utility/utility.h"
#include "iree/base/time.h"
#include "iree/base/tracing.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.h b/iree/hal/vulkan/emulated_timeline_semaphore.h
index 4280d33..ccf0029 100644
--- a/iree/hal/vulkan/emulated_timeline_semaphore.h
+++ b/iree/hal/vulkan/emulated_timeline_semaphore.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_
#define IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <atomic>
#include <vector>
diff --git a/iree/hal/vulkan/extensibility_util.h b/iree/hal/vulkan/extensibility_util.h
index b39b2fc..5aa6998 100644
--- a/iree/hal/vulkan/extensibility_util.h
+++ b/iree/hal/vulkan/extensibility_util.h
@@ -17,7 +17,7 @@
#ifndef IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
#define IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <vector>
diff --git a/iree/hal/vulkan/handle_util.h b/iree/hal/vulkan/handle_util.h
index 147acaa..efe463af 100644
--- a/iree/hal/vulkan/handle_util.h
+++ b/iree/hal/vulkan/handle_util.h
@@ -24,11 +24,11 @@
#ifndef IREE_HAL_VULKAN_HANDLE_UTIL_H_
#define IREE_HAL_VULKAN_HANDLE_UTIL_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "absl/synchronization/mutex.h"
-#include "absl/utility/utility.h"
#include "iree/base/ref_ptr.h"
+#include "iree/base/status.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
#include "iree/hal/vulkan/extensibility_util.h"
@@ -50,7 +50,7 @@
VkDeviceHandle(const VkDeviceHandle&) = delete;
VkDeviceHandle& operator=(const VkDeviceHandle&) = delete;
VkDeviceHandle(VkDeviceHandle&& other) noexcept
- : value_(absl::exchange(other.value_,
+ : value_(iree::exchange(other.value_,
static_cast<VkDevice>(VK_NULL_HANDLE))),
syms_(std::move(other.syms_)),
enabled_extensions_(other.enabled_extensions_),
@@ -94,7 +94,7 @@
VkCommandPoolHandle& operator=(const VkCommandPoolHandle&) = delete;
VkCommandPoolHandle(VkCommandPoolHandle&& other) noexcept
: logical_device_(std::move(other.logical_device_)),
- value_(absl::exchange(other.value_,
+ value_(iree::exchange(other.value_,
static_cast<VkCommandPool>(VK_NULL_HANDLE))) {}
VkCommandPoolHandle& operator=(VkCommandPoolHandle&& other) {
std::swap(logical_device_, other.logical_device_);
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.cc b/iree/hal/vulkan/internal_vk_mem_alloc.cc
index 95747a7..3899f7f 100644
--- a/iree/hal/vulkan/internal_vk_mem_alloc.cc
+++ b/iree/hal/vulkan/internal_vk_mem_alloc.cc
@@ -68,9 +68,7 @@
};
#define VMA_RW_MUTEX AbslVmaRWMutex
-#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0
-
#define VMA_IMPLEMENTATION
-#include "vk_mem_alloc.h"
+#include "iree/hal/vulkan/internal_vk_mem_alloc.h"
-#endif
+#endif // !VULKAN_MEMORY_ALLOCATOR_EXTERNAL_IMPL
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.h b/iree/hal/vulkan/internal_vk_mem_alloc.h
index 591ce7e..4541e9f 100644
--- a/iree/hal/vulkan/internal_vk_mem_alloc.h
+++ b/iree/hal/vulkan/internal_vk_mem_alloc.h
@@ -15,6 +15,17 @@
#ifndef IREE_HAL_VULKAN_INTERNAL_VK_MEM_ALLOC_H_
#define IREE_HAL_VULKAN_INTERNAL_VK_MEM_ALLOC_H_
-#include "vk_mem_alloc.h"
+// NOTE: ensure our vulkan headers are used (as we define VK_NO_PROTOTYPES).
+#include "iree/hal/vulkan/vulkan_headers.h"
+// Force all Vulkan calls to go through an indirect pVulkanFunctions interface.
+// https://gpuopen-librariesandsdks.github.io/VulkanMemoryAllocator/html/configuration.html
+#define VMA_STATIC_VULKAN_FUNCTIONS 0
+
+// Prevent VMA from querying for dynamic functions we may not have provided.
+// We want to be able to print nice errors or decide whether something is ok
+// to be omitted and not have VMA poking around where it shouldn't.
+#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0
+
+#include "vk_mem_alloc.h"
#endif // IREE_HAL_VULKAN_INTERNAL_VK_MEM_ALLOC_H_
diff --git a/iree/hal/vulkan/native_descriptor_set.h b/iree/hal/vulkan/native_descriptor_set.h
index 9bcb6f1..b251c3f 100644
--- a/iree/hal/vulkan/native_descriptor_set.h
+++ b/iree/hal/vulkan/native_descriptor_set.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_
#define IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/hal/descriptor_set.h"
#include "iree/hal/vulkan/handle_util.h"
diff --git a/iree/hal/vulkan/native_event.h b/iree/hal/vulkan/native_event.h
index 691ef6d..5db3f5d 100644
--- a/iree/hal/vulkan/native_event.h
+++ b/iree/hal/vulkan/native_event.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_NATIVE_EVENT_H_
#define IREE_HAL_VULKAN_NATIVE_EVENT_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/hal/event.h"
#include "iree/hal/vulkan/handle_util.h"
diff --git a/iree/hal/vulkan/native_timeline_semaphore.h b/iree/hal/vulkan/native_timeline_semaphore.h
index a5d3a93..0a22510 100644
--- a/iree/hal/vulkan/native_timeline_semaphore.h
+++ b/iree/hal/vulkan/native_timeline_semaphore.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_
#define IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "absl/synchronization/mutex.h"
#include "iree/hal/semaphore.h"
diff --git a/iree/hal/vulkan/pipeline_cache.cc b/iree/hal/vulkan/pipeline_cache.cc
index dd59678..5404cf9 100644
--- a/iree/hal/vulkan/pipeline_cache.cc
+++ b/iree/hal/vulkan/pipeline_cache.cc
@@ -14,13 +14,10 @@
#include "iree/hal/vulkan/pipeline_cache.h"
-#include "absl/synchronization/mutex.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/executable_format.h"
#include "iree/hal/vulkan/status_util.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
namespace iree {
namespace hal {
@@ -39,15 +36,6 @@
ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode,
const ExecutableSpec& spec) {
IREE_TRACE_SCOPE0("PipelineCache::PrepareExecutable");
- if (spec.executable_data.size() <= 4 ||
- !SpirVExecutableDefBufferHasIdentifier(spec.executable_data.data())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Supplied executable data does not contain a SpirVExecutableDef";
- }
-
- // Get the SPIR-V executable def flatbuffer.
- const auto& spirv_executable_def =
- *::flatbuffers::GetRoot<SpirVExecutableDef>(spec.executable_data.data());
// Create the executable (which may itself own many pipelines).
IREE_ASSIGN_OR_RETURN(
@@ -56,7 +44,7 @@
add_ref(logical_device_),
/*pipeline_cache=*/VK_NULL_HANDLE,
static_cast<PipelineExecutableLayout*>(executable_layout), mode,
- spirv_executable_def));
+ spec));
return executable;
}
diff --git a/iree/hal/vulkan/pipeline_cache.h b/iree/hal/vulkan/pipeline_cache.h
index 143d8b6..75b9353 100644
--- a/iree/hal/vulkan/pipeline_cache.h
+++ b/iree/hal/vulkan/pipeline_cache.h
@@ -15,14 +15,13 @@
#ifndef IREE_HAL_VULKAN_PIPELINE_CACHE_H_
#define IREE_HAL_VULKAN_PIPELINE_CACHE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "absl/container/inlined_vector.h"
#include "iree/hal/executable.h"
#include "iree/hal/executable_cache.h"
#include "iree/hal/vulkan/handle_util.h"
#include "iree/hal/vulkan/pipeline_executable.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vulkan/pipeline_executable.cc b/iree/hal/vulkan/pipeline_executable.cc
index 17a3d9b..b954f24 100644
--- a/iree/hal/vulkan/pipeline_executable.cc
+++ b/iree/hal/vulkan/pipeline_executable.cc
@@ -20,45 +20,74 @@
#include "iree/base/tracing.h"
#include "iree/hal/vulkan/status_util.h"
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/spirv_executable_def_reader.h"
+#include "iree/schemas/spirv_executable_def_verifier.h"
+
+// NOTE: starting to port this to C.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_spirv_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret = iree_SpirVExecutableDef_verify_as_root(
+ flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_SpirVExecutableDef_table_t executable_def =
+ iree_SpirVExecutableDef_as_root(flatbuffer_data.data);
+
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_SpirVExecutableDef_entry_points_get(executable_def);
+ size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
+ for (size_t i = 0; i < entry_point_count; ++i) {
+ if (!flatbuffers_string_len(
+ flatbuffers_string_vec_at(entry_points_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable entry point %zu has no name", i);
+ }
+ }
+
+ if (flatbuffers_uint32_vec_len(
+ iree_SpirVExecutableDef_code_get(executable_def)) < 0) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable SPIR-V code is missing/empty");
+ }
+
+ // TODO(benvanik): pull PopulateSpecializationInfo from history and update.
+ // For now the compiler isn't generating them, and we don't use them.
+ if (iree_SpirVExecutableDef_specialization_info_is_present(executable_def)) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "executable uses SPIR-V specialization constants; "
+ "they need to be revived");
+ }
+
+ return iree_ok_status();
+}
+
namespace iree {
namespace hal {
namespace vulkan {
namespace {
-// Generates the baked specialization constant data based on the flatbuffer.
-// We only support uint32_t right now so this is easy.
-// Note that the returned vectors are referenced by pointers in |out_info| and
-// must remain valid until the info is no longer in use.
-std::pair<std::vector<VkSpecializationMapEntry>, std::vector<uint8_t>>
-PopulateSpecializationInfo(const VkSpecializationInfoDef* info_def) {
- int entry_count =
- info_def && info_def->map_entries() ? info_def->map_entries()->size() : 0;
- if (!entry_count) {
- return {};
- }
-
- std::vector<VkSpecializationMapEntry> entries;
- entries.reserve(entry_count);
- std::vector<uint8_t> data;
- data.resize(entry_count * sizeof(uint32_t));
-
- uint32_t offset = 0;
- for (const auto* entry_def : *info_def->map_entries()) {
- if (!entry_def) continue;
- entries.push_back({});
- auto& entry = entries.back();
- entry.constantID = entry_def->constant_id();
- entry.offset = offset;
- entry.size = sizeof(uint32_t);
- uint32_t value = entry_def->uint32_value();
- std::memcpy(data.data() + offset, &value, sizeof(value));
- offset += entry.size;
- }
-
- return {std::move(entries), std::move(data)};
-}
-
class VkShaderModuleHandle : public RefObject<VkShaderModuleHandle> {
public:
explicit VkShaderModuleHandle(const ref_ptr<VkDeviceHandle>& logical_device)
@@ -99,47 +128,40 @@
StatusOr<ref_ptr<PipelineExecutable>> PipelineExecutable::Create(
ref_ptr<VkDeviceHandle> logical_device, VkPipelineCache pipeline_cache,
PipelineExecutableLayout* executable_layout,
- ExecutableCachingModeBitfield mode,
- const SpirVExecutableDef& spirv_executable_def) {
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) {
IREE_TRACE_SCOPE0("PipelineExecutable::Create");
const auto& syms = logical_device->syms();
- if (!spirv_executable_def.entry_points() ||
- spirv_executable_def.entry_points()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
- }
- if (!spirv_executable_def.code()) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No SPIR-V code present";
- }
- const auto& code = *spirv_executable_def.code();
+
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data = iree_make_const_byte_span(
+ spec.executable_data.data(), spec.executable_data.size());
+ IREE_RETURN_IF_ERROR(
+ iree_hal_spirv_executable_flatbuffer_verify(executable_data));
+ iree_SpirVExecutableDef_table_t executable_def =
+ iree_SpirVExecutableDef_as_root(executable_data.data);
// Create the shader module.
VkShaderModuleCreateInfo shader_module_create_info;
shader_module_create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
shader_module_create_info.pNext = nullptr;
shader_module_create_info.flags = 0;
- shader_module_create_info.codeSize = code.size() * sizeof(uint32_t);
- shader_module_create_info.pCode = code.data();
+ flatbuffers_uint32_vec_t code_vec =
+ iree_SpirVExecutableDef_code_get(executable_def);
+ shader_module_create_info.codeSize =
+ flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t);
+ shader_module_create_info.pCode = code_vec;
VkShaderModuleHandle shader_module(add_ref(logical_device));
VK_RETURN_IF_ERROR(syms->vkCreateShaderModule(
*logical_device, &shader_module_create_info, logical_device->allocator(),
shader_module.mutable_value()));
- // Specialization info is currently constant against all entry points.
- std::vector<VkSpecializationMapEntry> spec_entries;
- std::vector<uint8_t> spec_data;
- std::tie(spec_entries, spec_data) =
- PopulateSpecializationInfo(spirv_executable_def.specialization_info());
- VkSpecializationInfo specialization_info;
- specialization_info.mapEntryCount = spec_entries.size();
- specialization_info.pMapEntries = spec_entries.data();
- specialization_info.dataSize = spec_data.size();
- specialization_info.pData = spec_data.data();
-
// Create pipelines for each entry point.
- const auto& entry_points = *spirv_executable_def.entry_points();
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_SpirVExecutableDef_entry_points_get(executable_def);
absl::InlinedVector<VkComputePipelineCreateInfo, 1> pipeline_create_infos;
- pipeline_create_infos.resize(entry_points.size());
- for (int entry_ordinal = 0; entry_ordinal < entry_points.size();
+ pipeline_create_infos.resize(flatbuffers_string_vec_len(entry_points_vec));
+ for (size_t entry_ordinal = 0;
+ entry_ordinal < flatbuffers_string_vec_len(entry_points_vec);
++entry_ordinal) {
auto& pipeline_create_info = pipeline_create_infos[entry_ordinal];
pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
@@ -163,26 +185,25 @@
stage_create_info.flags = 0;
stage_create_info.stage = VK_SHADER_STAGE_COMPUTE_BIT;
stage_create_info.module = shader_module;
- stage_create_info.pName = entry_points[entry_ordinal]->c_str();
- stage_create_info.pSpecializationInfo = &specialization_info;
+ stage_create_info.pName =
+ flatbuffers_string_vec_at(entry_points_vec, entry_ordinal);
+ stage_create_info.pSpecializationInfo = NULL;
}
absl::InlinedVector<VkPipeline, 1> pipelines;
- pipelines.resize(entry_points.size());
+ pipelines.resize(flatbuffers_string_vec_len(entry_points_vec));
// Some ICDs appear to leak in here, out of our control.
// Warning: leak checks remain disabled if an error is returned.
IREE_DISABLE_LEAK_CHECKS();
VK_RETURN_IF_ERROR(syms->vkCreateComputePipelines(
- *logical_device, pipeline_cache, pipeline_create_infos.size(),
+ *logical_device, pipeline_cache,
+ static_cast<uint32_t>(pipeline_create_infos.size()),
pipeline_create_infos.data(), logical_device->allocator(),
pipelines.data()));
IREE_ENABLE_LEAK_CHECKS();
- auto executable = make_ref<PipelineExecutable>(std::move(logical_device),
- std::move(pipelines));
- executable->tag_ =
- spirv_executable_def.tag() ? spirv_executable_def.tag()->str() : "";
- return executable;
+ return make_ref<PipelineExecutable>(std::move(logical_device),
+ std::move(pipelines));
}
PipelineExecutable::PipelineExecutable(
diff --git a/iree/hal/vulkan/pipeline_executable.h b/iree/hal/vulkan/pipeline_executable.h
index 839b3dd..25f5955 100644
--- a/iree/hal/vulkan/pipeline_executable.h
+++ b/iree/hal/vulkan/pipeline_executable.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <vector>
@@ -28,7 +28,6 @@
#include "iree/hal/vulkan/handle_util.h"
#include "iree/hal/vulkan/native_descriptor_set.h"
#include "iree/hal/vulkan/pipeline_executable_layout.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
namespace iree {
namespace hal {
@@ -39,8 +38,7 @@
static StatusOr<ref_ptr<PipelineExecutable>> Create(
ref_ptr<VkDeviceHandle> logical_device, VkPipelineCache pipeline_cache,
PipelineExecutableLayout* executable_layout,
- ExecutableCachingModeBitfield mode,
- const SpirVExecutableDef& spirv_executable_def);
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec);
PipelineExecutable(ref_ptr<VkDeviceHandle> logical_device,
absl::InlinedVector<VkPipeline, 1> pipelines);
@@ -56,7 +54,6 @@
private:
ref_ptr<VkDeviceHandle> logical_device_;
- std::string tag_;
// One pipeline per entry point.
absl::InlinedVector<VkPipeline, 1> pipelines_;
diff --git a/iree/hal/vulkan/pipeline_executable_layout.h b/iree/hal/vulkan/pipeline_executable_layout.h
index 00b7ec7..fd52f2b 100644
--- a/iree/hal/vulkan/pipeline_executable_layout.h
+++ b/iree/hal/vulkan/pipeline_executable_layout.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_
#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/span.h"
diff --git a/iree/hal/vulkan/registration/BUILD b/iree/hal/vulkan/registration/BUILD
new file mode 100644
index 0000000..c5b81f4
--- /dev/null
+++ b/iree/hal/vulkan/registration/BUILD
@@ -0,0 +1,54 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(${IREE_HAL_DRIVER_VULKAN})
+""",
+ inline = True,
+)
+
+cc_library(
+ name = "registration",
+ srcs = ["driver_module.cc"],
+ hdrs = ["driver_module.h"],
+ defines = [
+ "IREE_HAL_HAVE_VULKAN_DRIVER_MODULE=1",
+ ],
+ deps = [
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/hal:api",
+ "//iree/hal:driver_registry",
+ "//iree/hal/vulkan",
+ "//iree/hal/vulkan:utils",
+ "@com_google_absl//absl/flags:flag",
+ ],
+)
+
+iree_cmake_extra_content(
+ content = """
+endif()
+""",
+ inline = True,
+)
diff --git a/iree/hal/vulkan/registration/CMakeLists.txt b/iree/hal/vulkan/registration/CMakeLists.txt
new file mode 100644
index 0000000..4cc0fa1
--- /dev/null
+++ b/iree/hal/vulkan/registration/CMakeLists.txt
@@ -0,0 +1,40 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_VULKAN})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ absl::flags
+ iree::base::flags
+ iree::base::status
+ iree::base::tracing
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::vulkan
+ iree::hal::vulkan::utils
+ DEFINES
+ "IREE_HAL_HAVE_VULKAN_DRIVER_MODULE=1"
+ PUBLIC
+)
+
+endif()
diff --git a/iree/hal/vulkan/vulkan_driver_module.cc b/iree/hal/vulkan/registration/driver_module.cc
similarity index 93%
rename from iree/hal/vulkan/vulkan_driver_module.cc
rename to iree/hal/vulkan/registration/driver_module.cc
index 9bd1cd0..9495f28 100644
--- a/iree/hal/vulkan/vulkan_driver_module.cc
+++ b/iree/hal/vulkan/registration/driver_module.cc
@@ -1,4 +1,4 @@
-// Copyright 2019 Google LLC
+// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/vulkan/registration/driver_module.h"
#include "absl/flags/flag.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/driver_registry.h"
@@ -132,8 +132,8 @@
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_vulkan_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "vulkan", ::iree::hal::vulkan::CreateVulkanDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_vulkan_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_vulkan_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "vulkan", ::iree::hal::vulkan::CreateVulkanDriver);
+}
diff --git a/iree/hal/vulkan/registration/driver_module.h b/iree/hal/vulkan/registration/driver_module.h
new file mode 100644
index 0000000..42f4d04
--- /dev/null
+++ b/iree/hal/vulkan/registration/driver_module.h
@@ -0,0 +1,31 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_VULKAN_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_vulkan_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_VULKAN_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/vulkan/renderdoc_capture_manager.cc b/iree/hal/vulkan/renderdoc_capture_manager.cc
index 039e9c0..b8b06ce 100644
--- a/iree/hal/vulkan/renderdoc_capture_manager.cc
+++ b/iree/hal/vulkan/renderdoc_capture_manager.cc
@@ -19,8 +19,7 @@
#include "iree/base/target_platform.h"
#include "iree/base/tracing.h"
-#if defined(IREE_PLATFORM_WINDOWS)
-#else
+#if !defined(IREE_PLATFORM_WINDOWS)
#include <dlfcn.h>
#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/hal/vulkan/serializing_command_queue.cc b/iree/hal/vulkan/serializing_command_queue.cc
index 1fc19fa..f4c3ba8 100644
--- a/iree/hal/vulkan/serializing_command_queue.cc
+++ b/iree/hal/vulkan/serializing_command_queue.cc
@@ -134,20 +134,20 @@
arena->AllocateSpan<VkSemaphore>(wait_semaphores.size());
auto wait_dst_stage_masks =
arena->AllocateSpan<VkPipelineStageFlags>(wait_semaphores.size());
- for (int i = 0, e = wait_semaphores.size(); i < e; ++i) {
+ for (size_t i = 0, e = wait_semaphores.size(); i < e; ++i) {
wait_semaphore_handles[i] = wait_semaphores[i];
wait_dst_stage_masks[i] = dst_stage_mask;
}
auto signal_semaphore_handles =
arena->AllocateSpan<VkSemaphore>(signal_semaphores.size());
- for (int i = 0, e = signal_semaphores.size(); i < e; ++i) {
+ for (size_t i = 0, e = signal_semaphores.size(); i < e; ++i) {
signal_semaphore_handles[i] = signal_semaphores[i];
}
auto command_buffer_handles =
arena->AllocateSpan<VkCommandBuffer>(command_buffers.size());
- for (int i = 0, e = command_buffers.size(); i < e; ++i) {
+ for (size_t i = 0, e = command_buffers.size(); i < e; ++i) {
const auto& command_buffer = command_buffers[i];
auto* direct_command_buffer =
static_cast<DirectCommandBuffer*>(command_buffer->impl());
@@ -156,12 +156,15 @@
submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
submit_info->pNext = nullptr;
- submit_info->waitSemaphoreCount = wait_semaphore_handles.size();
+ submit_info->waitSemaphoreCount =
+ static_cast<uint32_t>(wait_semaphore_handles.size());
submit_info->pWaitSemaphores = wait_semaphore_handles.data();
submit_info->pWaitDstStageMask = wait_dst_stage_masks.data();
- submit_info->commandBufferCount = command_buffer_handles.size();
+ submit_info->commandBufferCount =
+ static_cast<uint32_t>(command_buffer_handles.size());
submit_info->pCommandBuffers = command_buffer_handles.data();
- submit_info->signalSemaphoreCount = signal_semaphore_handles.size();
+ submit_info->signalSemaphoreCount =
+ static_cast<uint32_t>(signal_semaphore_handles.size());
submit_info->pSignalSemaphores = signal_semaphore_handles.data();
}
@@ -188,7 +191,7 @@
IREE_DVLOG(2) << "SerializingCommandQueue::Submit";
absl::MutexLock lock(&mutex_);
- for (int i = 0; i < batches.size(); ++i) {
+ for (size_t i = 0; i < batches.size(); ++i) {
// Grab a fence for this submission first. This will be used to check the
// progress of emulated timeline semaphores later.
IREE_ASSIGN_OR_RETURN(auto fence, fence_pool_->Acquire());
@@ -282,13 +285,13 @@
if (submit_infos.empty()) return false;
auto infos = arena.AllocateSpan<VkSubmitInfo>(submit_infos.size());
- for (int i = 0, e = submit_infos.size(); i < e; ++i) {
+ for (size_t i = 0, e = submit_infos.size(); i < e; ++i) {
infos[i] = submit_infos[i];
}
// Note: We might be able to batch the submission but it involves non-trivial
// fence handling. We can handle that if really needed.
- for (int i = 0, e = submit_infos.size(); i < e; ++i) {
+ for (size_t i = 0, e = submit_infos.size(); i < e; ++i) {
VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(
queue_, /*submitCount=*/1, &submit_infos[i], submit_fences[i]));
}
@@ -353,9 +356,9 @@
fences.reserve(pending_fences_.size());
for (const auto& fence : pending_fences_) fences.push_back(fence->value());
- VkResult result =
- syms()->vkWaitForFences(*logical_device_, fences.size(), fences.data(),
- /*waitAll=*/VK_TRUE, timeout_ns);
+ VkResult result = syms()->vkWaitForFences(
+ *logical_device_, static_cast<uint32_t>(fences.size()), fences.data(),
+ /*waitAll=*/VK_TRUE, timeout_ns);
switch (result) {
case VK_SUCCESS:
@@ -394,7 +397,8 @@
fences.reserve(pending_fences_.size());
for (const auto& fence : pending_fences_) fences.push_back(fence->value());
- syms()->vkWaitForFences(*logical_device_, fences.size(), fences.data(),
+ syms()->vkWaitForFences(*logical_device_,
+ static_cast<uint32_t>(fences.size()), fences.data(),
/*waitAll=*/VK_TRUE, /*timeout=*/UINT64_MAX);
// Clear the list. Fences will be automatically returned back to the queue
// after refcount reaches 0.
diff --git a/iree/hal/vulkan/serializing_command_queue.h b/iree/hal/vulkan/serializing_command_queue.h
index 3437eaf..a955a3c 100644
--- a/iree/hal/vulkan/serializing_command_queue.h
+++ b/iree/hal/vulkan/serializing_command_queue.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_SERIALIZING_COMMAND_QUEUE_H_
#define IREE_HAL_VULKAN_SERIALIZING_COMMAND_QUEUE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <memory>
#include <string>
diff --git a/iree/hal/vulkan/status_util.h b/iree/hal/vulkan/status_util.h
index 2a866b5..798111c 100644
--- a/iree/hal/vulkan/status_util.h
+++ b/iree/hal/vulkan/status_util.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_STATUS_UTIL_H_
#define IREE_HAL_VULKAN_STATUS_UTIL_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/base/status.h"
diff --git a/iree/hal/vulkan/timepoint_util.cc b/iree/hal/vulkan/timepoint_util.cc
index a096d4d..c14ea26 100644
--- a/iree/hal/vulkan/timepoint_util.cc
+++ b/iree/hal/vulkan/timepoint_util.cc
@@ -17,7 +17,6 @@
#include <memory>
#include "absl/synchronization/mutex.h"
-#include "absl/utility/utility.h"
#include "iree/base/time.h"
#include "iree/base/tracing.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
diff --git a/iree/hal/vulkan/timepoint_util.h b/iree/hal/vulkan/timepoint_util.h
index f162753..5fe35e4 100644
--- a/iree/hal/vulkan/timepoint_util.h
+++ b/iree/hal/vulkan/timepoint_util.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_TIMEPOINT_UTIL_H_
#define IREE_HAL_VULKAN_TIMEPOINT_UTIL_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <atomic>
#include <vector>
diff --git a/iree/hal/vulkan/vma_allocator.h b/iree/hal/vulkan/vma_allocator.h
index ccd677e..6acf935 100644
--- a/iree/hal/vulkan/vma_allocator.h
+++ b/iree/hal/vulkan/vma_allocator.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
#define IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <memory>
diff --git a/iree/hal/vulkan/vma_buffer.h b/iree/hal/vulkan/vma_buffer.h
index b768f71..dddb55b 100644
--- a/iree/hal/vulkan/vma_buffer.h
+++ b/iree/hal/vulkan/vma_buffer.h
@@ -15,10 +15,10 @@
#ifndef IREE_HAL_VULKAN_VMA_BUFFER_H_
#define IREE_HAL_VULKAN_VMA_BUFFER_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/hal/buffer.h"
-#include "vk_mem_alloc.h"
+#include "iree/hal/vulkan/internal_vk_mem_alloc.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc
index a975b3a..a5e2cd0 100644
--- a/iree/hal/vulkan/vulkan_device.cc
+++ b/iree/hal/vulkan/vulkan_device.cc
@@ -173,7 +173,7 @@
uint64_t compute_queue_count = CountOnes64(compute_queue_set.queue_indices);
for (uint32_t i = 0; i < compute_queue_count; ++i) {
- if (!(compute_queue_set.queue_indices & (1 << i))) continue;
+ if (!(compute_queue_set.queue_indices & (1ull << i))) continue;
VkQueue queue = VK_NULL_HANDLE;
syms->vkGetDeviceQueue(*logical_device,
@@ -195,7 +195,7 @@
uint64_t transfer_queue_count = CountOnes64(transfer_queue_set.queue_indices);
for (uint32_t i = 0; i < transfer_queue_count; ++i) {
- if (!(transfer_queue_set.queue_indices & (1 << i))) continue;
+ if (!(transfer_queue_set.queue_indices & (1ull << i))) continue;
VkQueue queue = VK_NULL_HANDLE;
syms->vkGetDeviceQueue(*logical_device,
@@ -367,7 +367,7 @@
QueueSet compute_queue_set = {};
compute_queue_set.queue_family_index = queue_family_info.dispatch_index;
for (uint32_t i = 0; i < queue_family_info.dispatch_queue_count; ++i) {
- compute_queue_set.queue_indices |= 1 << i;
+ compute_queue_set.queue_indices |= 1ull << i;
}
QueueSet transfer_queue_set = {};
transfer_queue_set.queue_family_index = queue_family_info.transfer_index;
@@ -377,7 +377,7 @@
base_queue_index = queue_family_info.dispatch_index;
}
for (uint32_t i = 0; i < queue_family_info.transfer_queue_count; ++i) {
- transfer_queue_set.queue_indices |= 1 << (i + base_queue_index);
+ transfer_queue_set.queue_indices |= 1ull << (i + base_queue_index);
}
// Emulate timeline semaphores if associated functions are not defined.
diff --git a/iree/hal/vulkan/vulkan_device.h b/iree/hal/vulkan/vulkan_device.h
index 456bc19..fbc10d3 100644
--- a/iree/hal/vulkan/vulkan_device.h
+++ b/iree/hal/vulkan/vulkan_device.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_VULKAN_DEVICE_H_
#define IREE_HAL_VULKAN_VULKAN_DEVICE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <functional>
#include <memory>
diff --git a/iree/hal/vulkan/vulkan_driver.cc b/iree/hal/vulkan/vulkan_driver.cc
index 9f76455..0e1c7b5 100644
--- a/iree/hal/vulkan/vulkan_driver.cc
+++ b/iree/hal/vulkan/vulkan_driver.cc
@@ -114,9 +114,11 @@
create_info.pNext = nullptr;
create_info.flags = 0;
create_info.pApplicationInfo = &app_info;
- create_info.enabledLayerCount = enabled_layer_names.size();
+ create_info.enabledLayerCount =
+ static_cast<uint32_t>(enabled_layer_names.size());
create_info.ppEnabledLayerNames = enabled_layer_names.data();
- create_info.enabledExtensionCount = enabled_extension_names.size();
+ create_info.enabledExtensionCount =
+ static_cast<uint32_t>(enabled_extension_names.size());
create_info.ppEnabledExtensionNames = enabled_extension_names.data();
// If we have the debug_utils extension then we can chain a one-shot messenger
diff --git a/iree/hal/vulkan/vulkan_driver.h b/iree/hal/vulkan/vulkan_driver.h
index 9c8bcce..ed3e462 100644
--- a/iree/hal/vulkan/vulkan_driver.h
+++ b/iree/hal/vulkan/vulkan_driver.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_VULKAN_DRIVER_H_
#define IREE_HAL_VULKAN_VULKAN_DRIVER_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <memory>
#include <vector>
diff --git a/iree/hal/vulkan/vulkan_headers.h b/iree/hal/vulkan/vulkan_headers.h
new file mode 100644
index 0000000..9acafd1
--- /dev/null
+++ b/iree/hal/vulkan/vulkan_headers.h
@@ -0,0 +1,46 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_VULKAN_HEADERS_H_
+#define IREE_HAL_VULKAN_VULKAN_HEADERS_H_
+
+// We exclusively use Vulkan via queried function pointers. To ensure that there
+// are no accidental calls to the linker-loaded implicit functions we just
+// compile them all out.
+//
+// Code under iree/hal/vulkan/ *MUST NOT* directly include vulkan.h or any
+// header that includes it without this first being set. This means that this
+// iree/hal/vulkan/vulkan_headers.h file must usually be included first in all
+// files using it.
+//
+// From there, use iree/hal/vulkan/dynamic_symbols.h to plumb the dynamically
+// resolved symbols to any code that may need to make Vulkan calls. See that
+// header for more information: in general we try to keep our required set of
+// symbols minimal to avoid binary size/runtime memory/linker time so symbols
+// are only added as needed.
+//
+// Other non-core code can choose not to disable the prototypes if they want.
+// I don't suggest it though for anything beyond samples.
+//
+// There's a bunch of reasons to dynamically link against Vulkan like supporting
+// platforms without Vulkan or with differing Vulkan versions where all symbols
+// may not be available.
+//
+// See this article for more information:
+// https://djang86.blogspot.com/2019/01/what-is-vknoprototypes.html
+#define VK_NO_PROTOTYPES 1
+
+#include <vulkan/vulkan.h>
+
+#endif // IREE_HAL_VULKAN_VULKAN_HEADERS_H_
diff --git a/iree/modules/check/BUILD b/iree/modules/check/BUILD
index cfaedff..4686574 100644
--- a/iree/modules/check/BUILD
+++ b/iree/modules/check/BUILD
@@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load(
- "//iree:build_defs.oss.bzl",
- "IREE_DRIVER_MODULES",
- "PLATFORM_VULKAN_DEPS",
- iree_cc_binary = "cc_binary",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
@@ -34,38 +27,39 @@
"//iree/base:logging",
"//iree/base:status",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
-iree_cc_binary(
+cc_binary(
name = "iree-check-module",
testonly = True,
srcs = ["iree-check-module-main.cc"],
deps = [
":native_module",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
"//iree/base:api",
+ "//iree/base:core_headers",
"//iree/base:file_io",
- "//iree/base:init",
+ "//iree/base:flags",
"//iree/base:status",
- "//iree/base:target_platform",
"//iree/base:tracing",
+ "//iree/hal/drivers",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/tools/utils:vm_util",
"//iree/vm:bytecode_module",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ ],
)
cc_library(
@@ -80,7 +74,7 @@
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
],
diff --git a/iree/modules/check/CMakeLists.txt b/iree/modules/check/CMakeLists.txt
index 56dea0e..3d8dc84 100644
--- a/iree/modules/check/CMakeLists.txt
+++ b/iree/modules/check/CMakeLists.txt
@@ -29,13 +29,13 @@
iree::base::logging
iree::base::status
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
)
endif()
@@ -51,16 +51,16 @@
absl::flags
absl::strings
iree::base::api
+ iree::base::core_headers
iree::base::file_io
- iree::base::init
+ iree::base::flags
iree::base::status
- iree::base::target_platform
iree::base::tracing
+ iree::hal::drivers
iree::modules::hal
iree::testing::gtest
iree::tools::utils::vm_util
iree::vm::bytecode_module
- ${IREE_HAL_DRIVER_MODULES}
TESTONLY
)
@@ -80,7 +80,7 @@
iree::modules::hal
iree::testing::gtest
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
TESTONLY
PUBLIC
)
diff --git a/iree/modules/check/check_test.cc b/iree/modules/check/check_test.cc
index 4c30063..7086085 100644
--- a/iree/modules/check/check_test.cc
+++ b/iree/modules/check/check_test.cc
@@ -21,6 +21,7 @@
#include "iree/base/logging.h"
#include "iree/base/status.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/check/native_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
@@ -35,6 +36,7 @@
class CheckTest : public ::testing::Test {
protected:
static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
// TODO(benvanik): move to instance-based registration.
IREE_ASSERT_OK(iree_hal_module_register_types());
@@ -414,8 +416,8 @@
TEST_F(CheckTest, ExpectAlmostEqNearIdenticalBufferSuccess) {
vm::ref<iree_hal_buffer_view_t> lhs;
vm::ref<iree_hal_buffer_view_t> rhs;
- float lhs_contents[] = {1.0, 1.99999, 0.00001, 4};
- float rhs_contents[] = {1.00001, 2.0, 0, 4};
+ float lhs_contents[] = {1.0f, 1.99999f, 0.00001f, 4.0f};
+ float rhs_contents[] = {1.00001f, 2.0f, 0.0f, 4.0f};
int32_t shape[] = {4};
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(lhs_contents, shape, &lhs));
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(rhs_contents, shape, &rhs));
diff --git a/iree/modules/check/iree-check-module-main.cc b/iree/modules/check/iree-check-module-main.cc
index 7e3687f..35f57dc 100644
--- a/iree/modules/check/iree-check-module-main.cc
+++ b/iree/modules/check/iree-check-module-main.cc
@@ -19,10 +19,11 @@
#include "absl/strings/string_view.h"
#include "iree/base/api.h"
#include "iree/base/file_io.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/target_platform.h"
#include "iree/base/tracing.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/check/native_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
@@ -35,7 +36,7 @@
#if defined(IREE_PLATFORM_WINDOWS)
#include <fcntl.h>
#include <io.h>
-#define IREE_FORCE_BINARY_STDIN() setmode(_fileno(stdin), O_BINARY)
+#define IREE_FORCE_BINARY_STDIN() _setmode(_fileno(stdin), O_BINARY)
#else
#define IREE_FORCE_BINARY_STDIN()
#endif // IREE_PLATFORM_WINDOWS
@@ -111,8 +112,8 @@
std::array<iree_vm_module_t*, 3> modules = {hal_module, check_module,
input_module};
auto module_signature = iree_vm_module_signature(input_module);
- for (int ordinal = 0; ordinal < module_signature.export_function_count;
- ++ordinal) {
+ for (iree_host_size_t ordinal = 0;
+ ordinal < module_signature.export_function_count; ++ordinal) {
iree_vm_function_t function;
iree_string_view_t export_name_sv;
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
@@ -173,7 +174,8 @@
} // namespace
extern "C" int main(int argc, char** argv) {
- InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
::testing::InitGoogleTest(&argc, argv);
IREE_FORCE_BINARY_STDIN();
diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc
index f35968f..5686a1f 100644
--- a/iree/modules/check/native_module.cc
+++ b/iree/modules/check/native_module.cc
@@ -61,7 +61,7 @@
template <typename T>
Status ExpectAllTrue(iree_byte_span_t bytes) {
- EXPECT_THAT(AbslSpan<T>(bytes), Each(Not(0)));
+ EXPECT_THAT(AbslSpan<T>(bytes), Each(Not(T(0))));
return OkStatus();
}
diff --git a/iree/modules/hal/BUILD b/iree/modules/hal/BUILD
index bdea2da..6780425 100644
--- a/iree/modules/hal/BUILD
+++ b/iree/modules/hal/BUILD
@@ -25,10 +25,10 @@
deps = [
"//iree/base:api",
"//iree/base:tracing",
+ "//iree/hal",
"//iree/hal:api",
- "//iree/hal:device",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
diff --git a/iree/modules/hal/CMakeLists.txt b/iree/modules/hal/CMakeLists.txt
index 21fc736..636a96f 100644
--- a/iree/modules/hal/CMakeLists.txt
+++ b/iree/modules/hal/CMakeLists.txt
@@ -28,9 +28,9 @@
absl::span
iree::base::api
iree::base::tracing
+ iree::hal
iree::hal::api
- iree::hal::device
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
PUBLIC
)
diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc
index be8a730..9725e35 100644
--- a/iree/modules/hal/hal_module.cc
+++ b/iree/modules/hal/hal_module.cc
@@ -193,7 +193,7 @@
size_t buffer_length = source->data.data_length;
if (length == -1) {
- length = buffer_length;
+ length = static_cast<size_t>(buffer_length);
}
if (length < 0 || offset < 0 || offset > buffer_length ||
offset + length > buffer_length) {
diff --git a/iree/modules/strings/BUILD b/iree/modules/strings/BUILD
index 3a8dcf3..fcb1a40 100644
--- a/iree/modules/strings/BUILD
+++ b/iree/modules/strings/BUILD
@@ -23,10 +23,7 @@
"//iree/hal:api",
"//iree/modules/hal",
"//iree/vm",
- "//iree/vm:module",
- "//iree/vm:native_module_cc",
- "//iree/vm:ref",
- "//iree/vm:stack",
+ "//iree/vm:cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@@ -43,13 +40,13 @@
"//iree/base:api",
"//iree/base:logging",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_benchmark//:benchmark",
diff --git a/iree/modules/strings/CMakeLists.txt b/iree/modules/strings/CMakeLists.txt
index b65fa0b..5aac9ba 100644
--- a/iree/modules/strings/CMakeLists.txt
+++ b/iree/modules/strings/CMakeLists.txt
@@ -34,10 +34,7 @@
iree::hal::api
iree::modules::hal
iree::vm
- iree::vm::module
- iree::vm::native_module_cc
- iree::vm::ref
- iree::vm::stack
+ iree::vm::cc
PUBLIC
)
@@ -55,13 +52,13 @@
iree::base::api
iree::base::logging
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
)
iree_bytecode_module(
diff --git a/iree/modules/strings/api.cc b/iree/modules/strings/api.cc
index 34c3fb2..ee364ec 100644
--- a/iree/modules/strings/api.cc
+++ b/iree/modules/strings/api.cc
@@ -25,10 +25,8 @@
#include "iree/modules/strings/api.h"
#include "iree/modules/strings/api_detail.h"
#include "iree/modules/strings/strings_module.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
#include "iree/vm/native_module_cc.h"
-#include "iree/vm/ref.h"
-#include "iree/vm/stack.h"
extern "C" iree_status_t strings_string_create(iree_string_view_t value,
iree_allocator_t allocator,
@@ -141,7 +139,7 @@
if (!tensor || !rank) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
}
- *rank = tensor->rank;
+ *rank = static_cast<int32_t>(tensor->rank);
return iree_ok_status();
}
diff --git a/iree/modules/strings/strings_module.cc b/iree/modules/strings/strings_module.cc
index 2d43cef..c8ea45f 100644
--- a/iree/modules/strings/strings_module.cc
+++ b/iree/modules/strings/strings_module.cc
@@ -112,7 +112,7 @@
std::string str;
str.reserve(string_length);
StringTensorToStringHelper(str_tensor->values, str_tensor->shape,
- str_tensor->rank, &str);
+ static_cast<int32_t>(str_tensor->rank), &str);
IREE_RETURN_IF_ERROR(strings_string_create(
iree_make_cstring_view(str.c_str()), allocator_, &new_string));
diff --git a/iree/modules/strings/strings_module_test.cc b/iree/modules/strings/strings_module_test.cc
index 63cb269..68414ad 100644
--- a/iree/modules/strings/strings_module_test.cc
+++ b/iree/modules/strings/strings_module_test.cc
@@ -21,6 +21,7 @@
#include "iree/base/api.h"
#include "iree/base/logging.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/modules/strings/api.h"
#include "iree/modules/strings/api_detail.h"
@@ -40,6 +41,10 @@
class StringsModuleTest : public ::testing::Test {
protected:
+ static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
+ }
+
virtual void SetUp() {
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
@@ -301,7 +306,7 @@
std::vector<iree_string_view_t> out_strings(expected.size());
IREE_ASSERT_OK(strings_string_tensor_get_elements(
output_tensor, out_strings.data(), out_strings.size(), 0));
- for (int i = 0; i < expected.size(); i++) {
+ for (iree_host_size_t i = 0; i < expected.size(); i++) {
EXPECT_EQ(iree_string_view_compare(out_strings[i], expected[i]), 0)
<< "Expected: " << expected[i].data << " found "
<< out_strings[i].data;
diff --git a/iree/modules/tensorlist/BUILD b/iree/modules/tensorlist/BUILD
index 4cdfb5d..74deecd 100644
--- a/iree/modules/tensorlist/BUILD
+++ b/iree/modules/tensorlist/BUILD
@@ -36,13 +36,13 @@
"//iree/base:api",
"//iree/base:logging",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
@@ -61,7 +61,7 @@
"//iree/hal:api",
"//iree/modules/hal",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:span",
],
diff --git a/iree/modules/tensorlist/CMakeLists.txt b/iree/modules/tensorlist/CMakeLists.txt
index 2eb6691..1092601 100644
--- a/iree/modules/tensorlist/CMakeLists.txt
+++ b/iree/modules/tensorlist/CMakeLists.txt
@@ -41,13 +41,13 @@
iree::base::api
iree::base::logging
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
)
iree_cc_library(
@@ -66,6 +66,6 @@
iree::hal::api
iree::modules::hal
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
PUBLIC
)
diff --git a/iree/modules/tensorlist/tensorlist_test.cc b/iree/modules/tensorlist/tensorlist_test.cc
index 6eab915..3293d57 100644
--- a/iree/modules/tensorlist/tensorlist_test.cc
+++ b/iree/modules/tensorlist/tensorlist_test.cc
@@ -21,6 +21,7 @@
#include "iree/base/api.h"
#include "iree/base/logging.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/modules/tensorlist/native_module.h"
#include "iree/modules/tensorlist/tensorlist_test_module.h"
@@ -36,6 +37,10 @@
class TensorListModulesTest : public ::testing::Test {
protected:
+ static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
+ }
+
virtual void SetUp() {
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
diff --git a/iree/samples/custom_modules/BUILD b/iree/samples/custom_modules/BUILD
index b7c91e6..c980e20 100644
--- a/iree/samples/custom_modules/BUILD
+++ b/iree/samples/custom_modules/BUILD
@@ -46,13 +46,13 @@
"//iree/base:api",
"//iree/base:logging",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
@@ -67,6 +67,6 @@
"//iree/hal:api",
"//iree/modules/hal",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
],
)
diff --git a/iree/samples/custom_modules/CMakeLists.txt b/iree/samples/custom_modules/CMakeLists.txt
index 4a28aec..f7942f7 100644
--- a/iree/samples/custom_modules/CMakeLists.txt
+++ b/iree/samples/custom_modules/CMakeLists.txt
@@ -45,13 +45,13 @@
iree::base::api
iree::base::logging
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
)
iree_cc_library(
@@ -66,6 +66,6 @@
iree::hal::api
iree::modules::hal
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
PUBLIC
)
diff --git a/iree/samples/custom_modules/custom_modules_test.cc b/iree/samples/custom_modules/custom_modules_test.cc
index e975a01..18ebd30 100644
--- a/iree/samples/custom_modules/custom_modules_test.cc
+++ b/iree/samples/custom_modules/custom_modules_test.cc
@@ -19,6 +19,7 @@
#include "iree/base/api.h"
#include "iree/base/logging.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/samples/custom_modules/custom_modules_test_module.h"
#include "iree/samples/custom_modules/native_module.h"
@@ -32,6 +33,10 @@
class CustomModulesTest : public ::testing::Test {
protected:
+ static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
+ }
+
virtual void SetUp() {
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
diff --git a/iree/samples/custom_modules/dialect/BUILD b/iree/samples/custom_modules/dialect/BUILD
index c5a49ea..fa0bc78 100644
--- a/iree/samples/custom_modules/dialect/BUILD
+++ b/iree/samples/custom_modules/dialect/BUILD
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load("//iree:build_defs.oss.bzl", iree_cc_binary = "cc_binary")
load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
load("//build_tools/bazel:tblgen.bzl", "gentbl")
@@ -89,7 +88,7 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "custom-opt",
srcs = ["custom-opt-main.cc"],
deps = [
@@ -103,7 +102,7 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "custom-translate",
srcs = ["custom-translate-main.cc"],
deps = [
diff --git a/iree/samples/emitc_modules/CMakeLists.txt b/iree/samples/emitc_modules/CMakeLists.txt
index be48f3f..c699032 100644
--- a/iree/samples/emitc_modules/CMakeLists.txt
+++ b/iree/samples/emitc_modules/CMakeLists.txt
@@ -36,12 +36,7 @@
DEPS
::add_module_cc
iree::base::api
- iree::vm::c_funcs
- iree::vm::context
- iree::vm::instance
- iree::vm::native_module
- iree::vm::ref
- iree::vm::stack
+ iree::vm
PUBLIC
)
@@ -56,10 +51,7 @@
iree::base::status
iree::testing::gtest
iree::testing::gtest_main
- iree::vm::context
- iree::vm::instance
- iree::vm::invocation
- iree::vm::list
- iree::vm::ref_cc
+ iree::vm
+ iree::vm::cc
)
endif()
diff --git a/iree/samples/emitc_modules/add_module_test.cc b/iree/samples/emitc_modules/add_module_test.cc
index 619fefa..bb739ae 100644
--- a/iree/samples/emitc_modules/add_module_test.cc
+++ b/iree/samples/emitc_modules/add_module_test.cc
@@ -17,10 +17,7 @@
#include "iree/base/status.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/invocation.h"
-#include "iree/vm/list.h"
+#include "iree/vm/api.h"
#include "iree/vm/ref_cc.h"
namespace iree {
diff --git a/iree/samples/emitc_modules/add_module_test.h b/iree/samples/emitc_modules/add_module_test.h
index 07f61c6..c6223d2 100644
--- a/iree/samples/emitc_modules/add_module_test.h
+++ b/iree/samples/emitc_modules/add_module_test.h
@@ -12,11 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/native_module.h"
-#include "iree/vm/ref.h"
-#include "iree/vm/stack.h"
+#include "iree/vm/api.h"
// This would be generated together with the functions in the header
#include "iree/samples/emitc_modules/add_module.module"
diff --git a/iree/samples/simple_embedding/BUILD b/iree/samples/simple_embedding/BUILD
index ad2e027..db314ba 100644
--- a/iree/samples/simple_embedding/BUILD
+++ b/iree/samples/simple_embedding/BUILD
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load("//iree:build_defs.oss.bzl", "PLATFORM_VULKAN_TEST_DEPS", "iree_cmake_extra_content")
load("//iree/tools:compilation.bzl", "iree_bytecode_module")
package(
@@ -21,15 +20,6 @@
licenses = ["notice"], # Apache 2.0
)
-iree_cmake_extra_content(
- content = """
-if(NOT ${IREE_TARGET_BACKEND_VMLA} OR NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}
- OR NOT ${IREE_HAL_DRIVER_VMLA} OR NOT ${IREE_HAL_DRIVER_VULKAN})
- return()
-endif()
-""",
-)
-
iree_bytecode_module(
name = "simple_embedding_test_bytecode_module",
src = "simple_embedding_test.mlir",
@@ -48,25 +38,20 @@
# For AddressSanitizer when using Vulkan + a local Nvidia GPU
"//iree/tools:sanitizer_suppressions.txt",
],
- tags = ["driver=vulkan"],
deps = [
":simple_embedding_test_bytecode_module_cc",
"//iree/base:api",
"//iree/base:logging",
"//iree/hal:api",
+ "//iree/hal/drivers",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
-
- # These are the drivers we support running with and can produce
- # executables for from the source MLIR.
- "//iree/hal/vmla:vmla_driver_module", # build-cleaner: keep
- "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep
- ] + PLATFORM_VULKAN_TEST_DEPS,
+ ],
)
diff --git a/iree/samples/simple_embedding/CMakeLists.txt b/iree/samples/simple_embedding/CMakeLists.txt
index 1ceb8a0..fb0b717 100644
--- a/iree/samples/simple_embedding/CMakeLists.txt
+++ b/iree/samples/simple_embedding/CMakeLists.txt
@@ -12,11 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-if(NOT ${IREE_TARGET_BACKEND_VMLA} OR NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}
- OR NOT ${IREE_HAL_DRIVER_VMLA} OR NOT ${IREE_HAL_DRIVER_VULKAN})
- return()
-endif()
-
iree_add_all_subdirs()
iree_bytecode_module(
@@ -48,14 +43,11 @@
iree::base::api
iree::base::logging
iree::hal::api
- iree::hal::vmla::vmla_driver_module
- iree::hal::vulkan::vulkan_driver_module
+ iree::hal::drivers
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
- LABELS
- "driver=vulkan"
+ iree::vm::cc
)
diff --git a/iree/samples/simple_embedding/simple_embedding_test.cc b/iree/samples/simple_embedding/simple_embedding_test.cc
index aaea632..bc2c8cd 100644
--- a/iree/samples/simple_embedding/simple_embedding_test.cc
+++ b/iree/samples/simple_embedding/simple_embedding_test.cc
@@ -19,6 +19,7 @@
#include "iree/base/api.h"
#include "iree/base/logging.h"
#include "iree/hal/api.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
@@ -44,15 +45,25 @@
// Builds a list of tests to run based on the linked in driver modules.
std::vector<TestParams> GetAvailableDriverTestParams() {
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
+
std::vector<TestParams> all_test_params;
iree_string_view_t* driver_names = nullptr;
iree_host_size_t driver_count = 0;
IREE_CHECK_OK(iree_hal_driver_registry_query_available_drivers(
iree_allocator_system(), &driver_names, &driver_count));
- for (int i = 0; i < driver_count; ++i) {
+ for (iree_host_size_t i = 0; i < driver_count; ++i) {
TestParams test_params;
test_params.driver_name =
std::string(driver_names[i].data, driver_names[i].size);
+
+ // TODO(#3843): this whole file stopped being useful a long time ago as a
+ // "simple" embedded test. This is a hack to work around its bustedness.
+ if (test_params.driver_name == "dylib" ||
+ test_params.driver_name == "llvm") {
+ continue;
+ }
+
all_test_params.push_back(std::move(test_params));
}
iree_allocator_free(iree_allocator_system(), driver_names);
diff --git a/iree/samples/vulkan/BUILD b/iree/samples/vulkan/BUILD
index 89c927f..f005ede 100644
--- a/iree/samples/vulkan/BUILD
+++ b/iree/samples/vulkan/BUILD
@@ -48,11 +48,12 @@
deps = [
":simple_mul_bytecode_module_cc",
"//iree/base:main",
+ "//iree/hal/vulkan/registration",
"//iree/modules/hal",
"//iree/testing/vulkan:vulkan_gui_util",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/types:span",
],
diff --git a/iree/samples/vulkan/CMakeLists.txt b/iree/samples/vulkan/CMakeLists.txt
index c3de407..4c7f665 100644
--- a/iree/samples/vulkan/CMakeLists.txt
+++ b/iree/samples/vulkan/CMakeLists.txt
@@ -39,12 +39,6 @@
return()
endif()
-if(${CMAKE_HOST_SYSTEM_NAME} STREQUAL "Windows")
- set(_GUI_LINKOPTS "-SUBSYSTEM:WINDOWS")
-else()
- set(_GUI_LINKOPTS "")
-endif()
-
iree_cc_binary(
NAME
vulkan_inference_gui
@@ -53,12 +47,13 @@
DEPS
absl::base
iree::base::main
+ iree::hal::vulkan::registration
iree::modules::hal
iree::samples::vulkan::simple_mul_bytecode_module_cc
iree::testing::vulkan::vulkan_gui_util
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
LINKOPTS
- "${_GUI_LINKOPTS}"
+ "${IREE_TARGET_GUI_LINKOPTS}"
)
diff --git a/iree/samples/vulkan/vulkan_inference_gui.cc b/iree/samples/vulkan/vulkan_inference_gui.cc
index 6183580..c1fb686 100644
--- a/iree/samples/vulkan/vulkan_inference_gui.cc
+++ b/iree/samples/vulkan/vulkan_inference_gui.cc
@@ -23,6 +23,7 @@
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/vulkan/api.h"
+#include "iree/hal/vulkan/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
@@ -93,10 +94,11 @@
std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
std::vector<const char*> extensions =
GetInstanceExtensions(window, iree_vulkan_features);
- SetupVulkan(iree_vulkan_features, layers.data(), layers.size(),
- extensions.data(), extensions.size(), g_Allocator, &g_Instance,
- &g_QueueFamily, &g_PhysicalDevice, &g_Queue, &g_Device,
- &g_DescriptorPool);
+ SetupVulkan(iree_vulkan_features, layers.data(),
+ static_cast<uint32_t>(layers.size()), extensions.data(),
+ static_cast<uint32_t>(extensions.size()), g_Allocator,
+ &g_Instance, &g_QueueFamily, &g_PhysicalDevice, &g_Queue,
+ &g_Device, &g_DescriptorPool);
// Create Window Surface
VkSurfaceKHR surface;
@@ -174,9 +176,6 @@
// --------------------------------------------------------------------------
// Setup IREE.
- // This call to |iree_api_init| is not technically required, but it is
- // included for completeness.
- IREE_CHECK_OK(iree_api_init(&argc, &argv));
// Check API version.
iree_api_version_t actual_version;
@@ -188,7 +187,8 @@
IREE_LOG(FATAL) << "Unsupported runtime API version " << actual_version;
}
- // Register HAL module types.
+ // Register HAL drivers and VM module types.
+ IREE_CHECK_OK(iree_hal_vulkan_driver_module_register());
IREE_CHECK_OK(iree_hal_module_register_types());
// Create a runtime Instance.
diff --git a/iree/schemas/BUILD b/iree/schemas/BUILD
index abe8fa7..98fd357 100644
--- a/iree/schemas/BUILD
+++ b/iree/schemas/BUILD
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load("//iree:build_defs.oss.bzl", "FLATBUFFER_SUPPORTS_REFLECTIONS", "iree_build_test", "iree_flatbuffer_cc_library")
+load("//iree:build_defs.oss.bzl", "iree_build_test")
load("//build_tools/bazel:iree_flatcc.bzl", "iree_flatbuffer_c_library")
-load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
package(
default_visibility = ["//visibility:public"],
@@ -22,89 +21,57 @@
licenses = ["notice"], # Apache 2.0
)
-FLATC_ARGS = [
- # Preserve workspace-relative include paths in generated code.
- "--keep-prefix",
- # Use C++11 'enum class' for enums.
- "--scoped-enums",
- # Include reflection tables used for dumping debug representations.
- "--reflect-names",
- # Generate FooT types for unpack/pack support. Note that this should only
- # be used in tooling as the code size/runtime overhead is non-trivial.
- "--gen-object-api",
+FLATCC_ARGS = [
+ "--reader",
+ "--builder",
+ "--verifier",
+ "--json",
]
iree_flatbuffer_c_library(
name = "bytecode_module_def_c_fbs",
srcs = ["bytecode_module_def.fbs"],
- flatcc_args = [
- "--reader",
- "--builder",
- "--verifier",
- ],
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "bytecode_module_def_cc_fbs",
- srcs = ["bytecode_module_def.fbs"],
- flatc_args = FLATC_ARGS,
-)
-
-iree_flatbuffer_cc_library(
- name = "dylib_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "dylib_executable_def_c_fbs",
srcs = ["dylib_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "llvmir_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "llvmir_executable_def_c_fbs",
srcs = ["llvmir_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "metal_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "metal_executable_def_c_fbs",
srcs = ["metal_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "spirv_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "spirv_executable_def_c_fbs",
srcs = ["spirv_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "vmla_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "vmla_executable_def_c_fbs",
srcs = ["vmla_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
iree_build_test(
name = "schema_build_test",
targets = [
- ":bytecode_module_def_cc_fbs",
- ":dylib_executable_def_cc_fbs",
- ":llvmir_executable_def_cc_fbs",
- ":metal_executable_def_cc_fbs",
- ":spirv_executable_def_cc_fbs",
- ":vmla_executable_def_cc_fbs",
+ ":bytecode_module_def_c_fbs",
+ ":dylib_executable_def_c_fbs",
+ ":llvmir_executable_def_c_fbs",
+ ":metal_executable_def_c_fbs",
+ ":spirv_executable_def_c_fbs",
+ ":vmla_executable_def_c_fbs",
],
)
-
-REFLECTION_SRCS = [] if not FLATBUFFER_SUPPORTS_REFLECTIONS else [
- "bytecode_module_def.bfbs",
- "dylib_executable_def.bfbs",
- "llvmir_executable_def.bfbs",
- "metal_executable_def.bfbs",
- "spirv_executable_def.bfbs",
- "vmla_executable_def.bfbs",
-]
-
-cc_embed_data(
- name = "reflection_data",
- srcs = REFLECTION_SRCS,
- cc_file_output = "reflection_data.cc",
- cpp_namespace = "iree::schemas",
- h_file_output = "reflection_data.h",
-)
diff --git a/iree/schemas/CMakeLists.txt b/iree/schemas/CMakeLists.txt
index e62b2e2..68d435c 100644
--- a/iree/schemas/CMakeLists.txt
+++ b/iree/schemas/CMakeLists.txt
@@ -23,95 +23,71 @@
"--reader"
"--builder"
"--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- bytecode_module_def_cc_fbs
- SRCS
- "bytecode_module_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
- PUBLIC
-)
-
-flatbuffer_cc_library(
- NAME
- dylib_executable_def_cc_fbs
+ dylib_executable_def_c_fbs
SRCS
"dylib_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- llvmir_executable_def_cc_fbs
+ llvmir_executable_def_c_fbs
SRCS
"llvmir_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- metal_executable_def_cc_fbs
+ metal_executable_def_c_fbs
SRCS
"metal_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- spirv_executable_def_cc_fbs
+ spirv_executable_def_c_fbs
SRCS
"spirv_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- vmla_executable_def_cc_fbs
+ vmla_executable_def_c_fbs
SRCS
"vmla_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
- PUBLIC
-)
-
-iree_cc_embed_data(
- NAME
- reflection_data
- CC_FILE_OUTPUT
- "reflection_data.cc"
- H_FILE_OUTPUT
- "reflection_data.h"
- CPP_NAMESPACE
- "iree::schemas"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
diff --git a/iree/schemas/dylib_executable_def.fbs b/iree/schemas/dylib_executable_def.fbs
index 5690096..8c30493 100644
--- a/iree/schemas/dylib_executable_def.fbs
+++ b/iree/schemas/dylib_executable_def.fbs
@@ -26,10 +26,10 @@
// An embedded (as opposed to external) dynamic library file.
// TODO(scotttodd): List of embedded files?
// TODO(scotttodd): Format of files, platform information (x86/arm/etc.)
- library_embedded:[byte];
+ library_embedded:[ubyte];
debug_database_filename:string;
- debug_database_embedded:[byte];
+ debug_database_embedded:[ubyte];
// TODO(scotttodd): Relative file path from this flatbuffer file
}
diff --git a/iree/schemas/llvmir_executable_def.fbs b/iree/schemas/llvmir_executable_def.fbs
index 77c0b61..5897900 100644
--- a/iree/schemas/llvmir_executable_def.fbs
+++ b/iree/schemas/llvmir_executable_def.fbs
@@ -25,7 +25,7 @@
// A map of entry points to string names with the same order as in the executable op.
entry_points:[string];
// A serialized llvm::Module object.
- llvmir_module:[byte];
+ bitcode_module:[ubyte];
}
root_type LLVMIRExecutableDef;
diff --git a/iree/schemas/metal_executable_def.fbs b/iree/schemas/metal_executable_def.fbs
index 7ad4b29..50f0eff 100644
--- a/iree/schemas/metal_executable_def.fbs
+++ b/iree/schemas/metal_executable_def.fbs
@@ -29,10 +29,6 @@
// This information is used to create MTLLibrary, MTLFunction and pipeline
// state objects.
table MetalExecutableDef {
- // Reserved implementation-specific value that can be passed from compiler
- // to runtime.
- tag:string;
-
// A map of entry point ordinals to string names as used in the shader
// library.
entry_points:[string];
diff --git a/iree/schemas/spirv_executable_def.fbs b/iree/schemas/spirv_executable_def.fbs
index 71cf330..250ea84 100644
--- a/iree/schemas/spirv_executable_def.fbs
+++ b/iree/schemas/spirv_executable_def.fbs
@@ -46,10 +46,6 @@
// This information is used to create the VkShaderModule, VkPipelineLayout, and
// any required VkDescriptorSetLayouts.
table SpirVExecutableDef {
- // Reserved implementation-specific value that can be passed from compiler to
- // runtime.
- tag:string;
-
// A map of entry point ordinals to string names as used in the shader module.
entry_points:[string];
diff --git a/iree/schemas/vmla_executable_def.fbs b/iree/schemas/vmla_executable_def.fbs
index 455b506..e85972c 100644
--- a/iree/schemas/vmla_executable_def.fbs
+++ b/iree/schemas/vmla_executable_def.fbs
@@ -27,7 +27,7 @@
// We embed the entire flatbuffer contents opaquely to allow for easier
// manipulation of the files (such that we can just slice out the bytes and
// dump them to a file, etc).
- bytecode_module:[byte];
+ bytecode_module:[ubyte];
}
root_type VMLAExecutableDef;
diff --git a/iree/test/e2e/hackability/BUILD b/iree/test/e2e/hackability/BUILD
new file mode 100644
index 0000000..5964b42
--- /dev/null
+++ b/iree/test/e2e/hackability/BUILD
@@ -0,0 +1,35 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Tests for end-to-end IREE support of specific features to prevent regression.
+# These should focus on support by IREE itself, not for issues with specific runner tools. Place
+# those tests in https://github.com/google/iree/tree/main/iree/tools/test/
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-run-mlir",
+ ],
+ tags = ["hostonly"],
+)
diff --git a/iree/test/e2e/hackability/CMakeLists.txt b/iree/test/e2e/hackability/CMakeLists.txt
new file mode 100644
index 0000000..e2b4157
--- /dev/null
+++ b/iree/test/e2e/hackability/CMakeLists.txt
@@ -0,0 +1,28 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-run-mlir
+ LABELS
+ "hostonly"
+)
diff --git a/iree/test/e2e/hackability/flow_partitioned.mlir b/iree/test/e2e/hackability/flow_partitioned.mlir
new file mode 100644
index 0000000..19ba210
--- /dev/null
+++ b/iree/test/e2e/hackability/flow_partitioned.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-run-mlir -export-all -iree-hal-target-backends=vmla %s | IreeFileCheck %s
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir %s | IreeFileCheck %s)
+
+flow.executable @ex0 {
+ flow.dispatch.entry @dispatch0 attributes {workload = 4 : index}
+ module {
+ func @dispatch0(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = mhlo.add %arg0, %arg0 : tensor<4xf32>
+ return %0 : tensor<4xf32>
+ }
+ }
+}
+
+// CHECK-LABEL: EXEC @staticShapedFn
+func @staticShapedFn() -> tensor<4xf32> {
+ %input = iree.unfoldable_constant dense<[-1.0, 2.0, -3.0, 4.0]> : tensor<4xf32>
+ %workload = constant 4 : index
+ %0 = flow.dispatch @ex0::@dispatch0[%workload : index](%input) : (tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+// CHECK: 4xf32=-2 4 -6 8
diff --git a/iree/test/e2e/llvmir_specific/BUILD b/iree/test/e2e/llvmir_specific/BUILD
index 8770dc3..f0bf5de 100644
--- a/iree/test/e2e/llvmir_specific/BUILD
+++ b/iree/test/e2e/llvmir_specific/BUILD
@@ -34,3 +34,15 @@
driver = "llvm",
target_backend = "llvm-ir",
)
+
+iree_check_single_backend_test_suite(
+ name = "check_llvm-ir-exponential_fast",
+ srcs = [
+ "exponential.mlir",
+ ],
+ compiler_flags = [
+ "-iree-codegen-linalg-to-llvm-fast-exp=true",
+ ],
+ driver = "llvm",
+ target_backend = "llvm-ir",
+)
diff --git a/iree/test/e2e/llvmir_specific/CMakeLists.txt b/iree/test/e2e/llvmir_specific/CMakeLists.txt
index 0d00ec1..496841a 100644
--- a/iree/test/e2e/llvmir_specific/CMakeLists.txt
+++ b/iree/test/e2e/llvmir_specific/CMakeLists.txt
@@ -26,3 +26,16 @@
COMPILER_FLAGS
"-iree-codegen-linalg-to-llvm-conv-img2col-conversion=true"
)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_llvm-ir-exponential_fast
+ SRCS
+ "exponential.mlir"
+ TARGET_BACKEND
+ llvm-ir
+ DRIVER
+ llvm
+ COMPILER_FLAGS
+ "-iree-codegen-linalg-to-llvm-fast-exp=true"
+)
diff --git a/iree/test/e2e/llvmir_specific/exponential.mlir b/iree/test/e2e/llvmir_specific/exponential.mlir
new file mode 100644
index 0000000..6e91326
--- /dev/null
+++ b/iree/test/e2e/llvmir_specific/exponential.mlir
@@ -0,0 +1,27 @@
+func @tensor() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
+ %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+ check.expect_almost_eq_const(%result, dense<[1.0, 2.7183, 7.3891, 54.5981]> : tensor<4xf32>) : tensor<4xf32>
+ return
+}
+
+func @scalar() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<1.0> : tensor<f32>
+ %result = "mhlo.exponential"(%input) : (tensor<f32>) -> tensor<f32>
+ check.expect_almost_eq_const(%result, dense<2.7183> : tensor<f32>) : tensor<f32>
+ return
+}
+
+func @double() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<1.0> : tensor<f64>
+ %result = "mhlo.exponential"(%input) : (tensor<f64>) -> tensor<f64>
+ check.expect_almost_eq_const(%result, dense<2.7183> : tensor<f64>) : tensor<f64>
+ return
+}
+
+func @negative() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<-1.0> : tensor<f32>
+ %result = "mhlo.exponential"(%input) : (tensor<f32>) -> tensor<f32>
+ check.expect_almost_eq_const(%result, dense<0.367879> : tensor<f32>) : tensor<f32>
+ return
+}
diff --git a/iree/test/e2e/structural/BUILD b/iree/test/e2e/structural/BUILD
index 5b41267..1489655 100644
--- a/iree/test/e2e/structural/BUILD
+++ b/iree/test/e2e/structural/BUILD
@@ -20,12 +20,39 @@
licenses = ["notice"], # Apache 2.0
)
-# TODO(#2395): This fails on conversion to buffers for both LLVM-ir and SPIR-V.
+# TODO(#2395): Enable all the tests for both LLVM-ir and SPIR-V.
iree_check_single_backend_test_suite(
- name = "check_vmla-reference",
- srcs = [
- "fused_dispatch_region.mlir",
- ],
+ name = "check_vmla_vmla",
+ srcs = glob(["*.mlir"]),
driver = "vmla",
target_backend = "vmla",
)
+
+iree_check_single_backend_test_suite(
+ name = "check_vulkan-spirv_vulkan",
+ srcs = [
+ "slice_add.mlir",
+ ],
+ driver = "vulkan",
+ target_backend = "vulkan-spirv",
+)
+
+# TODO(ataei): Enable dylib-llvm-aot tests.
+# See: https://github.com/google/iree/issues/2645
+iree_check_single_backend_test_suite(
+ name = "check_llvm-ir_llvm",
+ srcs = [
+ "slice_add.mlir",
+ ],
+ driver = "llvm",
+ target_backend = "llvm-ir",
+)
+
+test_suite(
+ name = "check",
+ tests = [
+ ":check_llvm-ir_llvm",
+ ":check_vmla_vmla",
+ ":check_vulkan-spirv_vulkan",
+ ],
+)
diff --git a/iree/test/e2e/structural/CMakeLists.txt b/iree/test/e2e/structural/CMakeLists.txt
index 4db28f5..6c87e88 100644
--- a/iree/test/e2e/structural/CMakeLists.txt
+++ b/iree/test/e2e/structural/CMakeLists.txt
@@ -14,13 +14,36 @@
iree_add_all_subdirs()
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
iree_check_single_backend_test_suite(
NAME
- check_vmla-reference
+ check_vmla_vmla
SRCS
- "fused_dispatch_region.mlir"
+ "${_GLOB_X_MLIR}"
TARGET_BACKEND
vmla
DRIVER
vmla
)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_vulkan-spirv_vulkan
+ SRCS
+ "slice_add.mlir"
+ TARGET_BACKEND
+ vulkan-spirv
+ DRIVER
+ vulkan
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_llvm-ir_llvm
+ SRCS
+ "slice_add.mlir"
+ TARGET_BACKEND
+ llvm-ir
+ DRIVER
+ llvm
+)
diff --git a/iree/test/e2e/structural/gather_concat.mlir b/iree/test/e2e/structural/gather_concat.mlir
new file mode 100644
index 0000000..a5a848e
--- /dev/null
+++ b/iree/test/e2e/structural/gather_concat.mlir
@@ -0,0 +1,22 @@
+func @gather_concat() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[
+ [[05, 06, 07, 08]],
+ [[09, 10, 11, 12]]]> : tensor<2x1x4xi32>
+ %start_indices = iree.unfoldable_constant dense<0> : tensor<i64>
+ %suffix = iree.unfoldable_constant dense<[[1, 2]]> : tensor<1x2xi32>
+ %workload = constant 12 : index
+ %result = flow.dispatch.region[%workload: index](%arg0 = %input : tensor<2x1x4xi32>, %arg1 = %start_indices : tensor<i64>, %arg2 = %suffix : tensor<1x2xi32>) -> tensor<1x6xi32> {
+ %0 = "mhlo.gather"(%arg0, %arg1) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<0> : tensor<1xi64>,
+ index_vector_dim = 0 : i64,
+ offset_dims = dense<[0, 1]> : tensor<2xi64>,
+ start_index_map = dense<0> : tensor<1xi64>},
+ slice_sizes = dense<[1, 1, 4]> : tensor<3xi64>
+ } : (tensor<2x1x4xi32>, tensor<i64>) -> tensor<1x4xi32>
+ %1 = "mhlo.concatenate"(%0, %arg2) {dimension = 1 : i64} : (tensor<1x4xi32>, tensor<1x2xi32>) -> tensor<1x6xi32>
+ flow.return %1 : tensor<1x6xi32>
+ }
+ check.expect_eq_const(%result, dense<[[5, 6, 7, 8, 1, 2]]> : tensor<1x6xi32>) : tensor<1x6xi32>
+ return
+}
diff --git a/iree/test/e2e/structural/slice_add.mlir b/iree/test/e2e/structural/slice_add.mlir
new file mode 100644
index 0000000..d6d301b
--- /dev/null
+++ b/iree/test/e2e/structural/slice_add.mlir
@@ -0,0 +1,84 @@
+func @slice_whole_buffer() attributes { iree.module.export } {
+ %input0 = iree.unfoldable_constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]]> : tensor<3x4xi32>
+ %input1 = iree.unfoldable_constant dense<10> : tensor<3x4xi32>
+ %workload = constant 12 : index
+ %result = flow.dispatch.region[%workload: index](%arg0 = %input0 : tensor<3x4xi32>, %arg1 = %input1 : tensor<3x4xi32>) -> tensor<3x4xi32> {
+ %0 = "mhlo.slice"(%arg0) {
+ start_indices = dense<[0, 0]> : tensor<2xi64>,
+ limit_indices = dense<[3, 4]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<3x4xi32>
+ %1 = mhlo.add %0, %arg1 : tensor<3x4xi32>
+ flow.return %1 : tensor<3x4xi32>
+ }
+ check.expect_eq_const(%result, dense<[
+ [11, 12, 13, 14],
+ [15, 16, 17, 18],
+ [19, 20, 21, 22]]> : tensor<3x4xi32>) : tensor<3x4xi32>
+ return
+}
+
+func @slice_whole_stride() attributes { iree.module.export } {
+ %input0 = iree.unfoldable_constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]]> : tensor<3x4xi32>
+ %input1 = iree.unfoldable_constant dense<10> : tensor<1x4xi32>
+ %workload = constant 4 : index
+ %result = flow.dispatch.region[%workload: index](%arg0 = %input0 : tensor<3x4xi32>, %arg1 = %input1 : tensor<1x4xi32>) -> tensor<1x4xi32> {
+ %0 = "mhlo.slice"(%arg0) {
+ start_indices = dense<[1, 0]> : tensor<2xi64>,
+ limit_indices = dense<[2, 4]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<1x4xi32>
+ %1 = mhlo.add %0, %arg1 : tensor<1x4xi32>
+ flow.return %1 : tensor<1x4xi32>
+ }
+ check.expect_eq_const(%result, dense<[[15, 16, 17, 18]]> : tensor<1x4xi32>) : tensor<1x4xi32>
+ return
+}
+
+func @slice_stride_part() attributes { iree.module.export } {
+ %input0 = iree.unfoldable_constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]]> : tensor<3x4xi32>
+ %input1 = iree.unfoldable_constant dense<10> : tensor<1x2xi32>
+ %workload = constant 2 : index
+ %result = flow.dispatch.region[%workload: index](%arg0 = %input0 : tensor<3x4xi32>, %arg1 = %input1 : tensor<1x2xi32>) -> tensor<1x2xi32> {
+ %0 = "mhlo.slice"(%arg0) {
+ start_indices = dense<[1, 1]> : tensor<2xi64>,
+ limit_indices = dense<[2, 3]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<1x2xi32>
+ %1 = mhlo.add %0, %arg1 : tensor<1x2xi32>
+ flow.return %1 : tensor<1x2xi32>
+ }
+ check.expect_eq_const(%result, dense<[[16, 17]]> : tensor<1x2xi32>) : tensor<1x2xi32>
+ return
+}
+
+func @slice_multi_stride() attributes { iree.module.export } {
+ %input0 = iree.unfoldable_constant dense<[
+ [01, 02, 03, 04],
+ [05, 06, 07, 08],
+ [09, 10, 11, 12]]> : tensor<3x4xi32>
+ %input1 = iree.unfoldable_constant dense<10> : tensor<2x4xi32>
+ %workload = constant 8 : index
+ %result = flow.dispatch.region[%workload: index](%arg0 = %input0 : tensor<3x4xi32>, %arg1 = %input1 : tensor<2x4xi32>) -> tensor<2x4xi32> {
+ %0 = "mhlo.slice"(%arg0) {
+ start_indices = dense<[1, 0]> : tensor<2xi64>,
+ limit_indices = dense<[3, 4]> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } : (tensor<3x4xi32>) -> tensor<2x4xi32>
+ %1 = mhlo.add %0, %arg1 : tensor<2x4xi32>
+ flow.return %1 : tensor<2x4xi32>
+ }
+ check.expect_eq_const(%result, dense<[
+ [15, 16, 17, 18],
+ [19, 20, 21, 22]]> : tensor<2x4xi32>) : tensor<2x4xi32>
+ return
+}
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index a3c9060..95e6802 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -17,7 +17,6 @@
# functionality of that op (though may make use of other ops where necessary). Tests should be
# written using the IREE Check framework and should always pass on the reference VMLA backend.
# See https://google.github.io/iree/TestingGuide#iree-core-end-to-end-tests.
-# TODO(hanchung): Reorganize/fix existing tests so the above is true.
load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite")
load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
@@ -103,7 +102,6 @@
"exponential_minus_one.mlir",
"floor.mlir",
"gather.mlir",
- "gather_concat.mlir",
"iota.mlir",
"log.mlir",
"log_plus_one.mlir",
@@ -156,7 +154,6 @@
"exponential_minus_one.mlir",
"floor.mlir",
"gather.mlir",
- "gather_concat.mlir",
"iota.mlir",
"log.mlir",
"log_plus_one.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index 1f89648..dd49ad1 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -91,7 +91,6 @@
"exponential_minus_one.mlir"
"floor.mlir"
"gather.mlir"
- "gather_concat.mlir"
"iota.mlir"
"log.mlir"
"log_plus_one.mlir"
@@ -144,7 +143,6 @@
"exponential_minus_one.mlir"
"floor.mlir"
"gather.mlir"
- "gather_concat.mlir"
"iota.mlir"
"log.mlir"
"log_plus_one.mlir"
diff --git a/iree/test/e2e/xla_ops/gather_concat.mlir b/iree/test/e2e/xla_ops/gather_concat.mlir
deleted file mode 100644
index c781a4d..0000000
--- a/iree/test/e2e/xla_ops/gather_concat.mlir
+++ /dev/null
@@ -1,18 +0,0 @@
-func @gather_concat() attributes { iree.module.export } {
- %input = iree.unfoldable_constant dense<[
- [[05, 06, 07, 08]],
- [[09, 10, 11, 12]]]> : tensor<2x1x4xi32>
- %start_indices = iree.unfoldable_constant dense<0> : tensor<i64>
- %gath = "mhlo.gather"(%input, %start_indices) {
- dimension_numbers = {
- collapsed_slice_dims = dense<0> : tensor<1xi64>,
- index_vector_dim = 0 : i64,
- offset_dims = dense<[0, 1]> : tensor<2xi64>,
- start_index_map = dense<0> : tensor<1xi64>},
- slice_sizes = dense<[1, 1, 4]> : tensor<3xi64>
- } : (tensor<2x1x4xi32>, tensor<i64>) -> tensor<1x4xi32>
- %suffix = iree.unfoldable_constant dense<[[1, 2]]> : tensor<1x2xi32>
- %res = "mhlo.concatenate"(%gath, %suffix) {dimension = 1 : i64} : (tensor<1x4xi32>, tensor<1x2xi32>) -> tensor<1x6xi32>
- check.expect_eq_const(%res, dense<[[5, 6, 7, 8, 1, 2]]> : tensor<1x6xi32>) : tensor<1x6xi32>
- return
-}
diff --git a/iree/testing/BUILD b/iree/testing/BUILD
index 0ad9324..f31a4d5 100644
--- a/iree/testing/BUILD
+++ b/iree/testing/BUILD
@@ -25,7 +25,7 @@
testonly = True,
srcs = ["benchmark_main.cc"],
deps = [
- "//iree/base:init",
+ "//iree/base:flags",
"@com_google_benchmark//:benchmark",
],
)
@@ -52,7 +52,7 @@
tags = ["keep_dep"],
deps = [
":gtest",
- "//iree/base:init",
+ "//iree/base:flags",
"@com_google_googletest//:gtest",
],
)
diff --git a/iree/testing/CMakeLists.txt b/iree/testing/CMakeLists.txt
index bdb2c6f..c790757 100644
--- a/iree/testing/CMakeLists.txt
+++ b/iree/testing/CMakeLists.txt
@@ -21,7 +21,7 @@
"benchmark_main.cc"
DEPS
benchmark
- iree::base::init
+ iree::base::flags
TESTONLY
PUBLIC
)
@@ -51,7 +51,7 @@
::gtest
gmock
gtest
- iree::base::init
+ iree::base::flags
TESTONLY
PUBLIC
)
diff --git a/iree/testing/benchmark_main.cc b/iree/testing/benchmark_main.cc
index 50df16d..690f456 100644
--- a/iree/testing/benchmark_main.cc
+++ b/iree/testing/benchmark_main.cc
@@ -13,13 +13,13 @@
// limitations under the License.
#include "benchmark/benchmark.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
namespace iree {
extern "C" int main(int argc, char** argv) {
::benchmark::Initialize(&argc, argv);
- iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
::benchmark::RunSpecifiedBenchmarks();
return 0;
}
diff --git a/iree/testing/gtest_main.cc b/iree/testing/gtest_main.cc
index 6fc7e78..d665aa8 100644
--- a/iree/testing/gtest_main.cc
+++ b/iree/testing/gtest_main.cc
@@ -13,10 +13,10 @@
// limitations under the License.
#include "gtest/gtest.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
extern "C" int main(int argc, char** argv) {
- ::iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
diff --git a/iree/testing/vulkan/BUILD b/iree/testing/vulkan/BUILD
index 4e235aa..3f40467 100644
--- a/iree/testing/vulkan/BUILD
+++ b/iree/testing/vulkan/BUILD
@@ -12,11 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load(
- "//iree:build_defs.oss.bzl",
- "PLATFORM_VULKAN_DEPS",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
@@ -37,7 +32,7 @@
"//iree/hal/vulkan:api",
"@dear_imgui",
"@dear_imgui//:imgui_sdl_vulkan",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
+ "@iree_vulkan_headers//:vulkan_headers",
"@sdl2//:SDL2",
"@vulkan_sdk//:sdk",
],
@@ -58,17 +53,17 @@
],
deps = [
":vulkan_gui_util",
- "//iree/base:init",
+ "//iree/base:flags",
"//iree/base:main",
"//iree/base:status",
"//iree/base:tracing",
+ "//iree/hal/vulkan/registration",
"//iree/modules/hal",
- "//iree/hal/vulkan:vulkan_driver_module",
"//iree/tools/utils:vm_util",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/flags:flag",
"@sdl2//:SDL2",
- ] + PLATFORM_VULKAN_DEPS,
+ ],
)
diff --git a/iree/testing/vulkan/CMakeLists.txt b/iree/testing/vulkan/CMakeLists.txt
index 6ceb5dc..a4e6a03 100644
--- a/iree/testing/vulkan/CMakeLists.txt
+++ b/iree/testing/vulkan/CMakeLists.txt
@@ -45,12 +45,6 @@
Vulkan::Vulkan
)
-if(${CMAKE_HOST_SYSTEM_NAME} STREQUAL "Windows")
- set(_GUI_LINKOPTS "-SUBSYSTEM:WINDOWS")
-else()
- set(_GUI_LINKOPTS "")
-endif()
-
iree_cc_binary(
NAME
iree-run-module-vulkan-gui
@@ -60,15 +54,15 @@
::vulkan_gui_util
absl::flags
iree::base::file_io
- iree::base::init
+ iree::base::flags
iree::base::main
iree::base::status
iree::base::tracing
- iree::hal::vulkan::vulkan_driver_module
+ iree::hal::vulkan::registration
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
iree::vm::bytecode_module
LINKOPTS
- "${_GUI_LINKOPTS}"
+ "${IREE_TARGET_GUI_LINKOPTS}"
)
diff --git a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
index 4ee342e..eec2dfe 100644
--- a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
+++ b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
@@ -20,9 +20,10 @@
// Other dependencies (helpers, etc.)
#include "absl/flags/flag.h"
#include "iree/base/file_io.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/main.h"
#include "iree/base/status.h"
+#include "iree/hal/vulkan/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
@@ -146,7 +147,8 @@
} // namespace iree
int iree::IreeMain(int argc, char** argv) {
- iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
+ IREE_CHECK_OK(iree_hal_vulkan_driver_module_register());
// --------------------------------------------------------------------------
// Create a window.
@@ -248,9 +250,6 @@
// --------------------------------------------------------------------------
// Setup IREE.
- // This call to |iree_api_init| is not technically required, but it is
- // included for completeness.
- IREE_CHECK_OK(iree_api_init(&argc, &argv));
// Check API version.
iree_api_version_t actual_version;
diff --git a/iree/testing/vulkan/vulkan_gui_util.cc b/iree/testing/vulkan/vulkan_gui_util.cc
index 5bb7c15..4c93bb3 100644
--- a/iree/testing/vulkan/vulkan_gui_util.cc
+++ b/iree/testing/vulkan/vulkan_gui_util.cc
@@ -194,12 +194,13 @@
VkQueueFamilyProperties* queues = (VkQueueFamilyProperties*)malloc(
sizeof(VkQueueFamilyProperties) * count);
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, queues);
- for (uint32_t i = 0; i < count; i++)
+ for (uint32_t i = 0; i < count; i++) {
if (queues[i].queueFlags &
(VK_QUEUE_GRAPHICS_BIT | VK_QUEUE_COMPUTE_BIT)) {
*queue_family_index = i;
break;
}
+ }
free(queues);
IM_ASSERT(*queue_family_index != (uint32_t)-1);
}
@@ -218,7 +219,8 @@
create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
create_info.queueCreateInfoCount = 1;
create_info.pQueueCreateInfos = &queue_info;
- create_info.enabledExtensionCount = device_extensions.size();
+ create_info.enabledExtensionCount =
+ static_cast<uint32_t>(device_extensions.size());
create_info.ppEnabledExtensionNames = device_extensions.data();
err = vkCreateDevice(*physical_device, &create_info, allocator, device);
check_vk_result(err);
diff --git a/iree/testing/vulkan/vulkan_gui_util.h b/iree/testing/vulkan/vulkan_gui_util.h
index ea097cd..c43cc5d 100644
--- a/iree/testing/vulkan/vulkan_gui_util.h
+++ b/iree/testing/vulkan/vulkan_gui_util.h
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// IREE's Vulkan HAL is built with VK_NO_PROTOTYPES so Vulkan can be loaded
-// dynamically. Using utilities defined in this header means to link against
-// the Vulkan SDK statically, so we want prototypes to be included.
-#undef VK_NO_PROTOTYPES
-
#include <SDL.h>
#include <SDL_vulkan.h>
#include <vulkan/vulkan.h>
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 27fc2fb..fb581b1 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -15,13 +15,6 @@
# Misc tools used to optimize, translate, and evaluate IREE.
# Compiler tooling, like the compiler, is not designed to run on device and is tagged as "hostonly".
-load(
- "//iree:build_defs.oss.bzl",
- "IREE_DRIVER_MODULES",
- "PLATFORM_VULKAN_DEPS",
- iree_cc_binary = "cc_binary",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
@@ -33,36 +26,34 @@
"sanitizer_suppressions.txt",
])
-iree_cc_binary(
+cc_binary(
name = "iree-benchmark-module",
testonly = True,
srcs = ["iree-benchmark-module-main.cc"],
deps = [
+ "//iree/base:file_io",
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/hal/drivers",
+ "//iree/modules/hal",
+ "//iree/tools/utils:vm_util",
+ "//iree/vm",
+ "//iree/vm:bytecode_module",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/flags:usage",
"@com_google_absl//absl/strings",
"@com_google_benchmark//:benchmark",
- "//iree/base:init",
- "//iree/base:file_io",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/modules/hal",
- "//iree/tools/utils:vm_util",
- "//iree/vm",
- "//iree/vm:bytecode_module",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
-iree_cc_binary(
+cc_binary(
name = "iree-dump-module",
srcs = ["iree-dump-module-main.cc"],
deps = [
- "//iree/base:file_io", # build-cleaner: keep
- "//iree/base:flatbuffer_util",
- "//iree/base:init",
- "//iree/schemas:bytecode_module_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/base:file_mapping",
+ "//iree/schemas:bytecode_module_def_c_fbs",
],
)
@@ -211,7 +202,7 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "iree-opt",
tags = ["hostonly"],
deps = [
@@ -219,17 +210,16 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "iree-run-mlir",
srcs = ["iree-run-mlir-main.cc"],
tags = ["hostonly"],
deps = [
":init_passes_and_dialects",
":init_targets",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
"//iree/base:api",
+ "//iree/base:flags",
+ "//iree/base:status",
"//iree/base:tracing",
"//iree/compiler/Dialect/Flow/Transforms",
"//iree/compiler/Dialect/HAL/Transforms",
@@ -237,45 +227,52 @@
"//iree/compiler/Dialect/VM/Target:init_targets",
"//iree/compiler/Dialect/VM/Target/Bytecode",
"//iree/compiler/Dialect/VM/Transforms",
+ "//iree/compiler/Translation:IREEVM",
"//iree/hal:api",
+ "//iree/hal/drivers",
"//iree/modules/hal",
"//iree/tools/utils:vm_util",
"//iree/vm",
"//iree/vm:bytecode_module",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/compiler/Translation:IREEVM",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
-iree_cc_binary(
+cc_binary(
name = "iree-run-module",
srcs = ["iree-run-module-main.cc"],
deps = [
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
"//iree/base:file_io",
- "//iree/base:init",
+ "//iree/base:flags",
"//iree/base:status",
"//iree/base:tracing",
+ "//iree/hal/drivers",
"//iree/modules/hal",
"//iree/tools/utils:vm_util",
"//iree/vm",
"//iree/vm:bytecode_module",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ ],
)
-iree_cc_binary(
+cc_binary(
name = "iree-tblgen",
+ srcs = [
+ "//iree/compiler/Dialect/IREE/Tools:GenSrcs",
+ "//iree/compiler/Dialect/VM/Tools:GenSrcs",
+ ],
tags = ["hostonly"],
deps = [
- "//iree/compiler/Dialect/IREE/Tools",
- "//iree/compiler/Dialect/VM/Tools",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:TableGen",
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TableGen",
@@ -305,7 +302,7 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "iree-translate",
tags = ["hostonly"],
deps = [
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 68f5d20..de1c9b9 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -68,15 +68,15 @@
absl::flags_usage
absl::strings
benchmark
- iree::base::init
+ iree::base::flags
iree::base::file_io
iree::base::status
iree::base::tracing
+ iree::hal::drivers
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
iree::vm::bytecode_module
- ${IREE_HAL_DRIVER_MODULES}
TESTONLY
)
@@ -88,11 +88,9 @@
SRCS
"iree-dump-module-main.cc"
DEPS
- flatbuffers
- iree::base::file_io
- iree::base::flatbuffer_util
- iree::base::init
- iree::schemas::bytecode_module_def_cc_fbs
+ flatcc::runtime
+ iree::base::file_mapping
+ iree::schemas::bytecode_module_def_c_fbs
)
iree_cc_binary(
@@ -106,14 +104,14 @@
absl::flags
absl::strings
iree::base::file_io
- iree::base::init
+ iree::base::flags
iree::base::status
iree::base::tracing
+ iree::hal::drivers
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
iree::vm::bytecode_module
- ${IREE_HAL_DRIVER_MODULES}
)
if(${IREE_ENABLE_MLIR})
@@ -122,13 +120,14 @@
iree-tblgen
SRCS
"${IREE_ROOT_DIR}/third_party/llvm-project/mlir/tools/mlir-tblgen/mlir-tblgen.cpp"
+ "${IREE_SOURCE_DIR}/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp"
+ "${IREE_SOURCE_DIR}/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp"
+ "${IREE_SOURCE_DIR}/iree/compiler/Dialect/VM/Tools/VMOpTableGen.cpp"
DEPS
+ LLVMSupport
+ LLVMTableGen
MLIRSupport
MLIRTableGen
- iree::compiler::Dialect::IREE::Tools
- iree::compiler::Dialect::VM::Tools
- LINKOPTS
- "-lpthread"
HOSTONLY
)
endif()
@@ -328,8 +327,7 @@
absl::span
absl::strings
iree::base::api
- iree::base::init
- iree::base::source_location
+ iree::base::flags
iree::base::status
iree::base::tracing
iree::compiler::Dialect::Flow::Transforms
@@ -340,11 +338,11 @@
iree::compiler::Dialect::VM::Transforms
iree::compiler::Translation::IREEVM
iree::hal::api
+ iree::hal::drivers
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
iree::vm::bytecode_module
- ${IREE_HAL_DRIVER_MODULES}
HOSTONLY
)
endif(${IREE_BUILD_COMPILER})
diff --git a/iree/tools/android/run_module_app/CMakeLists.txt b/iree/tools/android/run_module_app/CMakeLists.txt
index 42c2f23..2c90920 100644
--- a/iree/tools/android/run_module_app/CMakeLists.txt
+++ b/iree/tools/android/run_module_app/CMakeLists.txt
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-if (NOT ANDROID)
+if(NOT ANDROID)
return()
endif()
@@ -40,15 +40,14 @@
DEPS
::android_native_app_glue
absl::strings
- iree::base::initializer
iree::base::status
+ iree::hal::drivers
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
- ${IREE_HAL_DRIVER_MODULES}
LINKOPTS
- android
- log
+ "-landroid"
+ "-llog"
SHARED
WHOLEARCHIVE
)
diff --git a/iree/tools/android/run_module_app/src/main.cc b/iree/tools/android/run_module_app/src/main.cc
index 7250ab3..b123626 100644
--- a/iree/tools/android/run_module_app/src/main.cc
+++ b/iree/tools/android/run_module_app/src/main.cc
@@ -20,8 +20,8 @@
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
-#include "iree/base/initializer.h"
#include "iree/base/status.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
@@ -166,7 +166,7 @@
// trigger the workload.
std::this_thread::sleep_for(std::chrono::seconds(2));
- IREE_RUN_MODULE_INITIALIZERS();
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
ModuleLoader loader(app);
StatusOr<IreeModuleInvocation> invocation = loader.LoadModuleInvocation();
diff --git a/iree/tools/init_xla_dialects.h b/iree/tools/init_xla_dialects.h
index 04c5a97..cae6b31 100644
--- a/iree/tools/init_xla_dialects.h
+++ b/iree/tools/init_xla_dialects.h
@@ -18,10 +18,10 @@
#ifndef IREE_TOOLS_INIT_XLA_DIALECTS_H_
#define IREE_TOOLS_INIT_XLA_DIALECTS_H_
+#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir/IR/Dialect.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index 235895e..e6afebd 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -18,9 +18,10 @@
#include "absl/strings/string_view.h"
#include "benchmark/benchmark.h"
#include "iree/base/file_io.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
@@ -223,7 +224,7 @@
iree_vm_function_t function;
iree_vm_module_signature_t signature =
input_module_->signature(input_module_->self);
- for (int i = 0; i < signature.export_function_count; ++i) {
+ for (iree_host_size_t i = 0; i < signature.export_function_count; ++i) {
iree_string_view_t name;
IREE_CHECK_OK(input_module_->get_function(input_module_->self,
IREE_VM_FUNCTION_LINKAGE_EXPORT,
@@ -295,7 +296,8 @@
absl::flags_internal::UsageFlagsAction::kHandleUsage,
absl::flags_internal::OnUndefinedFlag::kIgnoreUndefined);
::benchmark::Initialize(&argc, argv);
- iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
iree::IREEBenchmark iree_benchmark;
auto status = iree_benchmark.Register();
diff --git a/iree/tools/iree-dump-module-main.cc b/iree/tools/iree-dump-module-main.cc
index 2fd0242..5ed60ad 100644
--- a/iree/tools/iree-dump-module-main.cc
+++ b/iree/tools/iree-dump-module-main.cc
@@ -16,163 +16,36 @@
#include <string>
#include <utility>
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/minireflect.h"
-#include "flatbuffers/reflection.h"
-#include "flatbuffers/util.h"
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/init.h"
-#include "iree/schemas/bytecode_module_def_generated.h"
+#include "iree/base/file_mapping.h"
+#include "iree/schemas/bytecode_module_def_json_printer.h"
-namespace {
-
-using ::flatbuffers::ElementaryType;
-using ::flatbuffers::NumToString;
-using ::flatbuffers::String;
-using ::flatbuffers::TypeTable;
-
-struct TruncatingToStringVisitor : public ::flatbuffers::IterationVisitor {
- std::string s;
- std::string d;
-
- bool is_truncating_vector = false;
- int vector_depth = 0;
-
- explicit TruncatingToStringVisitor(std::string delimiter)
- : d(std::move(delimiter)) {}
-
- void StartSequence() override {
- if (is_truncating_vector) return;
- s += "{";
- s += d;
- }
- void EndSequence() override {
- if (is_truncating_vector) return;
- s += d;
- s += "}";
- }
- void Field(size_t field_idx, size_t set_idx, ElementaryType type,
- bool is_vector, const TypeTable* type_table, const char* name,
- const uint8_t* val) override {
- if (is_truncating_vector) return;
- if (!val) return;
- if (set_idx) {
- s += ",";
- s += d;
- }
- if (name) {
- s += name;
- s += ": ";
- }
- }
- template <typename T>
- void Named(T x, const char* name) {
- if (name) {
- s += name;
- } else {
- s += NumToString(x);
- }
- }
- void UType(uint8_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void Bool(bool x) override {
- if (is_truncating_vector) return;
- s += x ? "true" : "false";
- }
- void Char(int8_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void UChar(uint8_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void Short(int16_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void UShort(uint16_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void Int(int32_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void UInt(uint32_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void Long(int64_t x) override {
- if (is_truncating_vector) return;
- s += NumToString(x);
- }
- void ULong(uint64_t x) override {
- if (is_truncating_vector) return;
- s += NumToString(x);
- }
- void Float(float x) override {
- if (is_truncating_vector) return;
- s += NumToString(x);
- }
- void Double(double x) override {
- if (is_truncating_vector) return;
- s += NumToString(x);
- }
- void String(const struct String* str) override {
- if (is_truncating_vector) return;
- ::flatbuffers::EscapeString(str->c_str(), str->size(), &s, true, false);
- }
- void Unknown(const uint8_t*) override {
- if (is_truncating_vector) return;
- s += "(?)";
- }
- void StartVector() override {
- ++vector_depth;
- if (is_truncating_vector) return;
- s += "[ ";
- }
- void EndVector() override {
- --vector_depth;
- if (vector_depth == 0) {
- is_truncating_vector = false;
- }
- if (is_truncating_vector) return;
- s += " ]";
- }
- void Element(size_t i, ElementaryType type, const TypeTable* type_table,
- const uint8_t* val) override {
- if (is_truncating_vector) return;
- if (i > 1024) {
- if (!is_truncating_vector) {
- s += ", ...";
- is_truncating_vector = true;
- }
- } else if (i) {
- s += ", ";
- }
- }
-};
-
-} // namespace
-
+// Today we just print to JSON. We could do something more useful (size
+// analysis, etc), but JSON should be enough.
+//
+// We could also move all of this into iree-translate (mlir -> vmfb -> json),
+// though having a tiny little tool not reliant on LLVM is nice (can run this
+// on a device).
extern "C" int main(int argc, char** argv) {
- iree::InitializeEnvironment(&argc, &argv);
-
if (argc < 2) {
- std::cerr << "Syntax: iree-dump-module filename\n";
+ std::cerr << "Syntax: iree-dump-module module.vmfb > module.json\n";
return 1;
}
- std::string module_path = argv[1];
- auto module_fb = iree::FlatBufferFile<iree::vm::BytecodeModuleDef>::LoadFile(
- iree::vm::BytecodeModuleDefIdentifier(), module_path)
- .value();
- TruncatingToStringVisitor tos_visitor("\n");
- auto object = reinterpret_cast<const uint8_t*>(module_fb->root());
- flatbuffers::IterateObject(object, module_fb->root()->MiniReflectTypeTable(),
- &tos_visitor);
- std::cout << tos_visitor.s << std::endl;
+ auto module_mapping_or = iree::FileMapping::OpenRead(argv[1]);
+ if (!module_mapping_or.ok()) {
+ std::cerr << module_mapping_or.status();
+ return 1;
+ }
+ auto module_mapping = std::move(module_mapping_or.value());
+ auto module_buffer = module_mapping->data();
+
+ // Print direct to stdout.
+ flatcc_json_printer_t printer;
+ flatcc_json_printer_init(&printer, /*fp=*/nullptr);
+ flatcc_json_printer_set_skip_default(&printer, true);
+ bytecode_module_def_print_json(
+ &printer, reinterpret_cast<const char*>(module_buffer.data()),
+ module_buffer.size());
+ flatcc_json_printer_clear(&printer);
+
return 0;
}
diff --git a/iree/tools/iree-run-mlir-main.cc b/iree/tools/iree-run-mlir-main.cc
index b0fb141..2e7cc17 100644
--- a/iree/tools/iree-run-mlir-main.cc
+++ b/iree/tools/iree-run-mlir-main.cc
@@ -44,7 +44,7 @@
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "iree/base/api.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
@@ -56,6 +56,7 @@
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
#include "iree/compiler/Translation/IREEVM.h"
#include "iree/hal/api.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/init_dialects.h"
#include "iree/tools/init_targets.h"
@@ -165,7 +166,7 @@
iree_host_size_t driver_count = 0;
IREE_RETURN_IF_ERROR(iree_hal_driver_registry_query_available_drivers(
iree_allocator_system(), &driver_names, &driver_count));
- for (int i = 0; i < driver_count; ++i) {
+ for (iree_host_size_t i = 0; i < driver_count; ++i) {
target_backends.push_back(
std::string(driver_names[i].data, driver_names[i].size));
}
@@ -378,7 +379,8 @@
Status evaluate_status = OkStatus();
auto module_signature = iree_vm_module_signature(bytecode_module);
- for (int i = 0; i < module_signature.export_function_count; ++i) {
+ for (iree_host_size_t i = 0; i < module_signature.export_function_count;
+ ++i) {
evaluate_status = run_function(i);
if (!evaluate_status.ok()) {
break;
@@ -521,7 +523,8 @@
}
argc_absl += run_args_flag.size();
char** argv_absl_ptr = argv_absl.data();
- iree::InitializeEnvironment(&argc_absl, &argv_absl_ptr);
+ iree_flags_parse_checked(&argc_absl, &argv_absl_ptr);
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
auto status = RunFile(input_file_flag, registry);
if (!status.ok()) {
diff --git a/iree/tools/iree-run-module-main.cc b/iree/tools/iree-run-module-main.cc
index 66c2fcc..dec0f6a 100644
--- a/iree/tools/iree-run-module-main.cc
+++ b/iree/tools/iree-run-module-main.cc
@@ -17,9 +17,10 @@
#include "absl/flags/flag.h"
#include "absl/strings/string_view.h"
#include "iree/base/file_io.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
@@ -156,7 +157,8 @@
} // namespace
extern "C" int main(int argc, char** argv) {
- iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
IREE_CHECK_OK(Run());
return 0;
}
diff --git a/iree/tools/utils/BUILD b/iree/tools/utils/BUILD
index 4aa8c6b..6d630f7 100644
--- a/iree/tools/utils/BUILD
+++ b/iree/tools/utils/BUILD
@@ -31,7 +31,7 @@
"//iree/modules/hal",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@@ -44,7 +44,7 @@
":vm_util",
"//iree/base:api",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
diff --git a/iree/tools/utils/CMakeLists.txt b/iree/tools/utils/CMakeLists.txt
index b0a9f7b..c684006 100644
--- a/iree/tools/utils/CMakeLists.txt
+++ b/iree/tools/utils/CMakeLists.txt
@@ -31,7 +31,7 @@
iree::modules::hal
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
PUBLIC
)
@@ -45,7 +45,7 @@
absl::strings
iree::base::api
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
diff --git a/iree/tools/utils/vm_util_test.cc b/iree/tools/utils/vm_util_test.cc
index 17f924b..5fd34d8 100644
--- a/iree/tools/utils/vm_util_test.cc
+++ b/iree/tools/utils/vm_util_test.cc
@@ -19,6 +19,7 @@
#include "absl/strings/str_cat.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
@@ -29,6 +30,10 @@
class VmUtilTest : public ::testing::Test {
protected:
+ static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
+ }
+
virtual void SetUp() {
IREE_ASSERT_OK(iree_hal_module_register_types());
IREE_ASSERT_OK(CreateDevice("vmla", &device_));
diff --git a/iree/vm/BUILD b/iree/vm/BUILD
index 3f3d93e..f69b517 100644
--- a/iree/vm/BUILD
+++ b/iree/vm/BUILD
@@ -1,4 +1,16 @@
-# Bytecode VM.
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
load("//iree/tools:compilation.bzl", "iree_bytecode_module")
load("//build_tools/bazel:tblgen.bzl", "gentbl")
@@ -9,36 +21,157 @@
licenses = ["notice"], # Apache 2.0
)
+#===------------------------------------------------------------------------===#
+# Public API
+#===------------------------------------------------------------------------===#
+
cc_library(
- name = "builtin_types",
- srcs = ["builtin_types.c"],
- hdrs = ["builtin_types.h"],
+ name = "vm",
+ hdrs = [
+ "api.h",
+ ],
deps = [
- ":list",
- ":ref",
+ ":impl",
+ "//iree/base:api",
+ ],
+)
+
+# TODO(benvanik): make these srcs and only expose an api_cc.h.
+cc_library(
+ name = "cc",
+ hdrs = [
+ "module_abi_packing.h",
+ "native_module_cc.h",
+ "ref_cc.h",
+ ],
+ deps = [
+ ":vm",
+ "//iree/base:api",
+ "//iree/base:ref_ptr",
+ "//iree/base:status",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+#===------------------------------------------------------------------------===#
+# Implementation
+#===------------------------------------------------------------------------===#
+
+cc_library(
+ name = "impl",
+ srcs = [
+ "builtin_types.c",
+ "context.c",
+ "instance.c",
+ "invocation.c",
+ "list.c",
+ "module.c",
+ "native_module.c",
+ "ref.c",
+ "stack.c",
+ ],
+ hdrs = [
+ "builtin_types.h",
+ "context.h",
+ "instance.h",
+ "invocation.h",
+ "list.h",
+ "module.h",
+ "native_module.h",
+ "ref.h",
+ "stack.h",
+ "type_def.h",
+ "value.h",
+ ],
+ deps = [
+ "//iree/base:api",
+ "//iree/base:core_headers",
+ "//iree/base:tracing",
+ ],
+)
+
+cc_test(
+ name = "list_test",
+ srcs = ["list_test.cc"],
+ deps = [
+ ":impl",
+ "//iree/base:api",
+ "//iree/base:ref_ptr",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "native_module_test",
+ srcs = ["native_module_test.cc"],
+ deps = [
+ ":cc",
+ ":impl",
+ ":native_module_test_hdrs",
+ "//iree/base:api",
+ "//iree/base:status",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "native_module_test_hdrs",
+ hdrs = [
+ "native_module_test.h",
+ ],
+ deps = [
+ ":impl",
"//iree/base:api",
],
)
cc_test(
- name = "bytecode_dispatch_test",
- srcs = ["bytecode_dispatch_test.cc"],
+ name = "native_module_benchmark",
+ srcs = ["native_module_benchmark.cc"],
deps = [
- ":builtin_types",
- ":bytecode_module",
- ":context",
- ":instance",
- ":invocation",
- ":module",
+ ":impl",
+ ":native_module_test_hdrs",
+ "//iree/base:api",
"//iree/base:logging",
- "//iree/base:status",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- "//iree/vm/test:all_bytecode_modules_cc",
- "@com_google_absl//absl/strings",
+ "//iree/testing:benchmark_main",
+ "@com_google_benchmark//:benchmark",
],
)
+cc_test(
+ name = "ref_test",
+ srcs = ["ref_test.cc"],
+ deps = [
+ ":impl",
+ "//iree/base:api",
+ "//iree/base:ref_ptr",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "stack_test",
+ srcs = ["stack_test.cc"],
+ deps = [
+ ":impl",
+ "//iree/base:api",
+ "//iree/base:ref_ptr",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+#===------------------------------------------------------------------------===#
+# Bytecode interpreter module
+#===------------------------------------------------------------------------===#
+
cc_library(
name = "bytecode_module",
srcs = [
@@ -52,19 +185,50 @@
"bytecode_module.h",
],
deps = [
- ":builtin_types",
":bytecode_op_table_gen",
- ":list",
- ":module",
- ":ref",
- ":stack",
- ":type_def",
- ":value",
- "//iree/base:alignment",
+ ":vm",
"//iree/base:api",
+ "//iree/base:core_headers",
+ "//iree/base:flatcc",
"//iree/base:tracing",
"//iree/schemas:bytecode_module_def_c_fbs",
- "@com_github_dvidelabs_flatcc//:runtime",
+ ],
+)
+
+# TODO(benvanik): see if we can remove this; not good to have this dep.
+gentbl(
+ name = "bytecode_op_table_gen",
+ tbl_outs = [
+ ("-gen-iree-vm-op-table-defs", "bytecode_op_table.h"),
+ ],
+ tblgen = "//iree/tools:iree-tblgen",
+ td_file = "//iree/compiler/Dialect/VM/IR:VMOps.td",
+ td_srcs = [
+ "//iree/compiler/Dialect/IREE/IR:td_files",
+ "//iree/compiler/Dialect/VM/IR:td_files",
+ "@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
+ "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
+ "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
+ ],
+)
+
+cc_test(
+ name = "bytecode_module_test",
+ srcs = [
+ "bytecode_dispatch_test.cc",
+ "bytecode_module_test.cc",
+ ],
+ deps = [
+ ":bytecode_module",
+ ":vm",
+ "//iree/base:logging",
+ "//iree/base:status",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ "//iree/vm/test:all_bytecode_modules_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -74,11 +238,7 @@
deps = [
":bytecode_module",
":bytecode_module_benchmark_module_cc",
- ":context",
- ":instance",
- ":module",
- ":native_module",
- ":stack",
+ ":vm",
"//iree/base:api",
"//iree/base:logging",
"//iree/testing:benchmark_main",
@@ -113,33 +273,9 @@
flags = ["-iree-vm-ir-to-bytecode-module"],
)
-cc_test(
- name = "bytecode_module_test",
- srcs = ["bytecode_module_test.cc"],
- deps = [
- ":bytecode_module",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-gentbl(
- name = "bytecode_op_table_gen",
- tbl_outs = [
- ("-gen-iree-vm-op-table-defs", "bytecode_op_table.h"),
- ],
- tblgen = "//iree/tools:iree-tblgen",
- td_file = "//iree/compiler/Dialect/VM/IR:VMOps.td",
- td_srcs = [
- "//iree/compiler/Dialect/IREE/IR:td_files",
- "//iree/compiler/Dialect/VM/IR:td_files",
- "@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Emit-C modules
+#===------------------------------------------------------------------------===#
cc_library(
name = "c_funcs",
@@ -147,254 +283,3 @@
"c_funcs.h",
],
)
-
-cc_library(
- name = "context",
- srcs = ["context.c"],
- hdrs = ["context.h"],
- deps = [
- ":instance",
- ":module",
- ":stack",
- "//iree/base:api",
- "//iree/base:atomics",
- "//iree/base:tracing",
- ],
-)
-
-cc_library(
- name = "instance",
- srcs = ["instance.c"],
- hdrs = ["instance.h"],
- deps = [
- ":builtin_types",
- "//iree/base:api",
- "//iree/base:atomics",
- "//iree/base:tracing",
- ],
-)
-
-cc_library(
- name = "invocation",
- srcs = ["invocation.c"],
- hdrs = ["invocation.h"],
- deps = [
- ":context",
- ":list",
- ":module",
- "//iree/base:api",
- "//iree/base:tracing",
- ],
-)
-
-cc_library(
- name = "list",
- srcs = ["list.c"],
- hdrs = ["list.h"],
- deps = [
- ":ref",
- ":type_def",
- ":value",
- "//iree/base:alignment",
- "//iree/base:api",
- ],
-)
-
-cc_test(
- name = "list_test",
- srcs = ["list_test.cc"],
- deps = [
- ":builtin_types",
- ":list",
- "//iree/base:api",
- "//iree/base:ref_ptr",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-cc_library(
- name = "module",
- srcs = ["module.c"],
- hdrs = ["module.h"],
- deps = [
- ":ref",
- "//iree/base:alignment",
- "//iree/base:api",
- "//iree/base:atomics",
- "//iree/base:tracing",
- ],
-)
-
-cc_library(
- name = "native_module",
- srcs = ["native_module.c"],
- hdrs = ["native_module.h"],
- deps = [
- ":module",
- ":stack",
- "//iree/base:api",
- ],
-)
-
-cc_test(
- name = "native_module_benchmark",
- srcs = ["native_module_benchmark.cc"],
- deps = [
- ":module",
- ":native_module",
- ":native_module_test_hdrs",
- ":stack",
- "//iree/base:api",
- "//iree/base:logging",
- "//iree/testing:benchmark_main",
- "@com_google_benchmark//:benchmark",
- ],
-)
-
-cc_test(
- name = "native_module_test",
- srcs = ["native_module_test.cc"],
- deps = [
- ":context",
- ":instance",
- ":invocation",
- ":list",
- ":native_module_test_hdrs",
- ":ref_cc",
- "//iree/base:api",
- "//iree/base:status",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-cc_library(
- name = "native_module_test_hdrs",
- hdrs = [
- "native_module_test.h",
- ],
- deps = [
- ":context",
- ":instance",
- ":native_module",
- ":ref",
- ":stack",
- "//iree/base:api",
- ],
-)
-
-cc_library(
- name = "native_module_cc",
- hdrs = [
- "module_abi_packing.h",
- "native_module_cc.h",
- ],
- deps = [
- ":builtin_types",
- ":module",
- ":ref",
- ":ref_cc",
- ":stack",
- "//iree/base:api",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "ref",
- srcs = ["ref.c"],
- hdrs = ["ref.h"],
- deps = [
- "//iree/base:api",
- "//iree/base:atomics",
- ],
-)
-
-cc_test(
- name = "ref_test",
- srcs = ["ref_test.cc"],
- deps = [
- ":ref",
- "//iree/base:api",
- "//iree/base:ref_ptr",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-cc_library(
- name = "ref_cc",
- hdrs = ["ref_cc.h"],
- deps = [
- ":ref",
- "//iree/base:api",
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "stack",
- srcs = ["stack.c"],
- hdrs = ["stack.h"],
- deps = [
- ":module",
- ":ref",
- "//iree/base:alignment",
- "//iree/base:api",
- "//iree/base:tracing",
- ],
-)
-
-cc_test(
- name = "stack_test",
- srcs = ["stack_test.cc"],
- deps = [
- ":ref",
- ":stack",
- "//iree/base:api",
- "//iree/base:ref_ptr",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-cc_library(
- name = "type_def",
- hdrs = ["type_def.h"],
- deps = [
- ":ref",
- ":value",
- ],
-)
-
-cc_library(
- name = "value",
- hdrs = ["value.h"],
-)
-
-cc_library(
- name = "vm",
- hdrs = [
- "api.h",
- ],
- deps = [
- ":builtin_types",
- ":context",
- ":instance",
- ":invocation",
- ":list",
- ":module",
- ":native_module",
- ":ref",
- ":stack",
- ":type_def",
- ":value",
- "//iree/base:api",
- ],
-)
diff --git a/iree/vm/CMakeLists.txt b/iree/vm/CMakeLists.txt
index 48d6544..2546732 100644
--- a/iree/vm/CMakeLists.txt
+++ b/iree/vm/CMakeLists.txt
@@ -16,36 +16,144 @@
iree_cc_library(
NAME
- builtin_types
+ vm
+ HDRS
+ "api.h"
+ DEPS
+ ::impl
+ iree::base::api
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ cc
+ HDRS
+ "module_abi_packing.h"
+ "native_module_cc.h"
+ "ref_cc.h"
+ DEPS
+ ::vm
+ absl::core_headers
+ absl::inlined_vector
+ absl::optional
+ absl::span
+ absl::strings
+ iree::base::api
+ iree::base::ref_ptr
+ iree::base::status
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ impl
HDRS
"builtin_types.h"
+ "context.h"
+ "instance.h"
+ "invocation.h"
+ "list.h"
+ "module.h"
+ "native_module.h"
+ "ref.h"
+ "stack.h"
+ "type_def.h"
+ "value.h"
SRCS
"builtin_types.c"
+ "context.c"
+ "instance.c"
+ "invocation.c"
+ "list.c"
+ "module.c"
+ "native_module.c"
+ "ref.c"
+ "stack.c"
DEPS
- ::list
- ::ref
+ iree::base::api
+ iree::base::core_headers
+ iree::base::tracing
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ list_test
+ SRCS
+ "list_test.cc"
+ DEPS
+ ::impl
+ iree::base::api
+ iree::base::ref_ptr
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
+
+iree_cc_test(
+ NAME
+ native_module_test
+ SRCS
+ "native_module_test.cc"
+ DEPS
+ ::cc
+ ::impl
+ ::native_module_test_hdrs
+ iree::base::api
+ iree::base::status
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
+
+iree_cc_library(
+ NAME
+ native_module_test_hdrs
+ HDRS
+ "native_module_test.h"
+ DEPS
+ ::impl
iree::base::api
PUBLIC
)
iree_cc_test(
NAME
- bytecode_dispatch_test
+ native_module_benchmark
SRCS
- "bytecode_dispatch_test.cc"
+ "native_module_benchmark.cc"
DEPS
- ::builtin_types
- ::bytecode_module
- ::context
- ::instance
- ::invocation
- ::module
- absl::strings
+ ::impl
+ ::native_module_test_hdrs
+ benchmark
+ iree::base::api
iree::base::logging
- iree::base::status
+ iree::testing::benchmark_main
+)
+
+iree_cc_test(
+ NAME
+ ref_test
+ SRCS
+ "ref_test.cc"
+ DEPS
+ ::impl
+ iree::base::api
+ iree::base::ref_ptr
iree::testing::gtest
iree::testing::gtest_main
- iree::vm::test::all_bytecode_modules_cc
+)
+
+iree_cc_test(
+ NAME
+ stack_test
+ SRCS
+ "stack_test.cc"
+ DEPS
+ ::impl
+ iree::base::api
+ iree::base::ref_ptr
+ iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_library(
@@ -60,21 +168,43 @@
"bytecode_module_impl.h"
"bytecode_op_table.h"
DEPS
- ::builtin_types
- ::list
- ::module
- ::ref
- ::stack
- ::type_def
- ::value
- flatcc::runtime
- iree::base::alignment
+ ::vm
iree::base::api
+ iree::base::core_headers
+ iree::base::flatcc
iree::base::tracing
iree::schemas::bytecode_module_def_c_fbs
PUBLIC
)
+iree_tablegen_library(
+ NAME
+ bytecode_op_table_gen
+ TD_FILE
+ "${IREE_ROOT_DIR}/iree/compiler/Dialect/VM/IR/VMOps.td"
+ OUTS
+ -gen-iree-vm-op-table-defs bytecode_op_table.h
+ TBLGEN
+ IREE
+)
+
+iree_cc_test(
+ NAME
+ bytecode_module_test
+ SRCS
+ "bytecode_dispatch_test.cc"
+ "bytecode_module_test.cc"
+ DEPS
+ ::bytecode_module
+ ::vm
+ absl::strings
+ iree::base::logging
+ iree::base::status
+ iree::testing::gtest
+ iree::testing::gtest_main
+ iree::vm::test::all_bytecode_modules_cc
+)
+
iree_cc_test(
NAME
bytecode_module_benchmark
@@ -83,11 +213,7 @@
DEPS
::bytecode_module
::bytecode_module_benchmark_module_cc
- ::context
- ::instance
- ::module
- ::native_module
- ::stack
+ ::vm
absl::inlined_vector
absl::strings
benchmark
@@ -132,28 +258,6 @@
PUBLIC
)
-iree_cc_test(
- NAME
- bytecode_module_test
- SRCS
- "bytecode_module_test.cc"
- DEPS
- ::bytecode_module
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_tablegen_library(
- NAME
- bytecode_op_table_gen
- TD_FILE
- "${IREE_ROOT_DIR}/iree/compiler/Dialect/VM/IR/VMOps.td"
- OUTS
- -gen-iree-vm-op-table-defs bytecode_op_table.h
- TBLGEN
- IREE
-)
-
iree_cc_library(
NAME
c_funcs
@@ -161,290 +265,3 @@
"c_funcs.h"
PUBLIC
)
-
-iree_cc_library(
- NAME
- context
- HDRS
- "context.h"
- SRCS
- "context.c"
- DEPS
- ::instance
- ::module
- ::stack
- iree::base::api
- iree::base::atomics
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- instance
- HDRS
- "instance.h"
- SRCS
- "instance.c"
- DEPS
- ::builtin_types
- iree::base::api
- iree::base::atomics
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- invocation
- HDRS
- "invocation.h"
- SRCS
- "invocation.c"
- DEPS
- ::context
- ::list
- ::module
- iree::base::api
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- list
- HDRS
- "list.h"
- SRCS
- "list.c"
- DEPS
- ::ref
- ::type_def
- ::value
- iree::base::alignment
- iree::base::api
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- list_test
- SRCS
- "list_test.cc"
- DEPS
- ::builtin_types
- ::list
- iree::base::api
- iree::base::ref_ptr
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- module
- HDRS
- "module.h"
- SRCS
- "module.c"
- DEPS
- ::ref
- iree::base::alignment
- iree::base::api
- iree::base::atomics
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_module
- HDRS
- "native_module.h"
- SRCS
- "native_module.c"
- DEPS
- ::module
- ::stack
- iree::base::api
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- native_module_benchmark
- SRCS
- "native_module_benchmark.cc"
- DEPS
- ::module
- ::native_module
- ::native_module_test_hdrs
- ::stack
- benchmark
- iree::base::api
- iree::base::logging
- iree::testing::benchmark_main
-)
-
-iree_cc_test(
- NAME
- native_module_test
- SRCS
- "native_module_test.cc"
- DEPS
- ::context
- ::instance
- ::invocation
- ::list
- ::native_module_test_hdrs
- ::ref_cc
- iree::base::api
- iree::base::status
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- native_module_test_hdrs
- HDRS
- "native_module_test.h"
- DEPS
- ::context
- ::instance
- ::native_module
- ::ref
- ::stack
- iree::base::api
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_module_cc
- HDRS
- "module_abi_packing.h"
- "native_module_cc.h"
- DEPS
- ::builtin_types
- ::module
- ::ref
- ::ref_cc
- ::stack
- absl::inlined_vector
- absl::optional
- absl::span
- absl::strings
- iree::base::api
- iree::base::ref_ptr
- iree::base::status
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- ref
- HDRS
- "ref.h"
- SRCS
- "ref.c"
- DEPS
- iree::base::api
- iree::base::atomics
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- ref_test
- SRCS
- "ref_test.cc"
- DEPS
- ::ref
- iree::base::api
- iree::base::ref_ptr
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- ref_cc
- HDRS
- "ref_cc.h"
- DEPS
- ::ref
- absl::core_headers
- iree::base::api
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- stack
- HDRS
- "stack.h"
- SRCS
- "stack.c"
- DEPS
- ::module
- ::ref
- iree::base::alignment
- iree::base::api
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- stack_test
- SRCS
- "stack_test.cc"
- DEPS
- ::ref
- ::stack
- iree::base::api
- iree::base::ref_ptr
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- type_def
- HDRS
- "type_def.h"
- DEPS
- ::ref
- ::value
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- value
- HDRS
- "value.h"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vm
- HDRS
- "api.h"
- DEPS
- ::builtin_types
- ::context
- ::instance
- ::invocation
- ::list
- ::module
- ::native_module
- ::ref
- ::stack
- ::type_def
- ::value
- iree::base::api
- PUBLIC
-)
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index bb5fbbf..bb5241f 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -15,8 +15,8 @@
#include <string.h>
#include "iree/base/tracing.h"
+#include "iree/vm/api.h"
#include "iree/vm/bytecode_dispatch_util.h"
-#include "iree/vm/list.h"
//===----------------------------------------------------------------------===//
// Math utilities, kept here to limit dependencies
diff --git a/iree/vm/bytecode_dispatch_test.cc b/iree/vm/bytecode_dispatch_test.cc
index caca978..b87dfc5 100644
--- a/iree/vm/bytecode_dispatch_test.cc
+++ b/iree/vm/bytecode_dispatch_test.cc
@@ -23,12 +23,8 @@
#include "iree/base/logging.h"
#include "iree/base/status.h"
#include "iree/testing/gtest.h"
-#include "iree/vm/builtin_types.h"
+#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/invocation.h"
-#include "iree/vm/module.h"
// Compiled module embedded here to avoid file IO:
#include "iree/vm/test/all_bytecode_modules.h"
diff --git a/iree/vm/bytecode_module.c b/iree/vm/bytecode_module.c
index caa89d3..52a5a36 100644
--- a/iree/vm/bytecode_module.c
+++ b/iree/vm/bytecode_module.c
@@ -17,9 +17,8 @@
#include "iree/base/alignment.h"
#include "iree/base/api.h"
#include "iree/base/tracing.h"
+#include "iree/vm/api.h"
#include "iree/vm/bytecode_module_impl.h"
-#include "iree/vm/ref.h"
-#include "iree/vm/stack.h"
// Perform an strcmp between a flatbuffers string and an IREE string view.
static bool iree_vm_flatbuffer_strcmp(flatbuffers_string_t lhs,
@@ -175,10 +174,6 @@
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"imports[%zu] missing full_name", i);
}
- if (!iree_vm_ImportFunctionDef_signature(import_def)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "imports[%zu] missing signature", i);
- }
}
for (size_t i = 0; i < iree_vm_ExportFunctionDef_vec_len(exported_functions);
@@ -195,10 +190,6 @@
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"exports[%zu] missing local_name", i);
}
- if (!iree_vm_ExportFunctionDef_signature(export_def)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "exports[%zu] missing signature", i);
- }
iree_host_size_t internal_ordinal =
iree_vm_ExportFunctionDef_internal_ordinal(export_def);
if (internal_ordinal >=
@@ -221,10 +212,6 @@
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"functions[%zu] missing body", i);
}
- if (!iree_vm_InternalFunctionDef_signature(function_def)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "functions[%zu] missing signature", i);
- }
iree_vm_FunctionDescriptor_struct_t function_descriptor =
iree_vm_FunctionDescriptor_vec_at(function_descriptors, i);
@@ -398,11 +385,17 @@
iree_vm_InternalFunctionDef_table_t function_def =
iree_vm_InternalFunctionDef_vec_at(internal_functions, ordinal);
- iree_vm_FunctionSignatureDef_table_t signature =
+ iree_vm_FunctionSignatureDef_table_t signature_def =
iree_vm_InternalFunctionDef_signature(function_def);
+ if (!signature_def) {
+ return iree_make_status(
+ IREE_STATUS_NOT_FOUND,
+ "reflection attribute at index %zu not found; no signature", index);
+ }
iree_vm_ReflectionAttrDef_vec_t reflection_attrs =
- iree_vm_FunctionSignatureDef_reflection_attrs(signature);
- if (index >= iree_vm_ReflectionAttrDef_vec_len(reflection_attrs)) {
+ iree_vm_FunctionSignatureDef_reflection_attrs(signature_def);
+ if (!reflection_attrs ||
+ index >= iree_vm_ReflectionAttrDef_vec_len(reflection_attrs)) {
return iree_make_status(IREE_STATUS_NOT_FOUND,
"reflection attribute at index %zu not found",
index);
@@ -502,14 +495,15 @@
static iree_host_size_t iree_vm_bytecode_module_layout_state(
iree_vm_BytecodeModuleDef_table_t module_def,
iree_vm_bytecode_module_state_t* state) {
- iree_vm_ModuleStateDef_table_t module_state =
+ iree_vm_ModuleStateDef_table_t module_state_def =
iree_vm_BytecodeModuleDef_module_state(module_def);
iree_host_size_t rwdata_storage_capacity = 0;
iree_host_size_t global_ref_count = 0;
- if (module_state) {
+ if (module_state_def) {
rwdata_storage_capacity =
- iree_vm_ModuleStateDef_global_bytes_capacity(module_state);
- global_ref_count = iree_vm_ModuleStateDef_global_ref_count(module_state);
+ iree_vm_ModuleStateDef_global_bytes_capacity(module_state_def);
+ global_ref_count =
+ iree_vm_ModuleStateDef_global_ref_count(module_state_def);
}
iree_host_size_t rodata_ref_count = iree_vm_RodataSegmentDef_vec_len(
iree_vm_BytecodeModuleDef_rodata_segments(module_def));
@@ -694,9 +688,12 @@
iree_vm_BytecodeModuleDef_internal_functions(module->def);
iree_vm_InternalFunctionDef_table_t function_def =
iree_vm_InternalFunctionDef_vec_at(internal_functions, function.ordinal);
+ iree_vm_FunctionSignatureDef_table_t signature_def =
+ iree_vm_InternalFunctionDef_signature(function_def);
flatbuffers_string_t calling_convention =
- iree_vm_FunctionSignatureDef_calling_convention(
- iree_vm_InternalFunctionDef_signature(function_def));
+ signature_def
+ ? iree_vm_FunctionSignatureDef_calling_convention(signature_def)
+ : 0;
iree_vm_function_signature_t signature;
memset(&signature, 0, sizeof(signature));
signature.calling_convention.data = calling_convention;
diff --git a/iree/vm/bytecode_module.h b/iree/vm/bytecode_module.h
index 2db39f5..8f694fc 100644
--- a/iree/vm/bytecode_module.h
+++ b/iree/vm/bytecode_module.h
@@ -18,7 +18,7 @@
#include <stdint.h>
#include "iree/base/api.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
#ifdef __cplusplus
extern "C" {
diff --git a/iree/vm/bytecode_module_benchmark.cc b/iree/vm/bytecode_module_benchmark.cc
index 2b88ba1..529ff3a 100644
--- a/iree/vm/bytecode_module_benchmark.cc
+++ b/iree/vm/bytecode_module_benchmark.cc
@@ -19,13 +19,9 @@
#include "benchmark/benchmark.h"
#include "iree/base/api.h"
#include "iree/base/logging.h"
+#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "iree/vm/bytecode_module_benchmark_module.h"
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
-#include "iree/vm/native_module.h"
-#include "iree/vm/stack.h"
namespace {
@@ -82,7 +78,7 @@
static iree_status_t RunFunction(benchmark::State& state,
absl::string_view function_name,
absl::Span<const int32_t> i32_args,
- int result_count, int batch_size = 1) {
+ int result_count, int64_t batch_size = 1) {
iree_vm_instance_t* instance = NULL;
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
@@ -294,7 +290,7 @@
return i;
};
while (state.KeepRunningBatch(state.range(0))) {
- int ret = loop(state.range(0));
+ int ret = loop(static_cast<int>(state.range(0)));
benchmark::DoNotOptimize(ret);
benchmark::ClobberMemory();
}
diff --git a/iree/vm/bytecode_module_impl.h b/iree/vm/bytecode_module_impl.h
index 0a74957..e4a282d 100644
--- a/iree/vm/bytecode_module_impl.h
+++ b/iree/vm/bytecode_module_impl.h
@@ -24,15 +24,10 @@
#endif // _MSC_VER
#include "iree/base/api.h"
-#include "iree/vm/builtin_types.h"
-#include "iree/vm/module.h"
-#include "iree/vm/ref.h"
-#include "iree/vm/stack.h"
-#include "iree/vm/type_def.h"
-#include "iree/vm/value.h"
+#include "iree/vm/api.h"
// NOTE: include order matters:
-#include "flatcc/reflection/flatbuffers_common_reader.h"
+#include "iree/base/flatcc.h"
#include "iree/schemas/bytecode_module_def_reader.h"
#include "iree/schemas/bytecode_module_def_verifier.h"
diff --git a/iree/vm/list_test.cc b/iree/vm/list_test.cc
index 779bd04..6baaa88 100644
--- a/iree/vm/list_test.cc
+++ b/iree/vm/list_test.cc
@@ -101,7 +101,7 @@
EXPECT_EQ(5, iree_vm_list_size(list));
for (iree_host_size_t i = 0; i < 5; ++i) {
- iree_vm_value_t value = iree_vm_value_make_i32(i);
+ iree_vm_value_t value = iree_vm_value_make_i32((int32_t)i);
IREE_ASSERT_OK(iree_vm_list_set_value(list, i, &value));
}
@@ -138,7 +138,7 @@
EXPECT_EQ(5, iree_vm_list_size(list));
for (iree_host_size_t i = 0; i < 5; ++i) {
- iree_vm_ref_t ref_a = MakeRef<A>(i);
+ iree_vm_ref_t ref_a = MakeRef<A>((float)i);
IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, i, &ref_a));
}
@@ -173,11 +173,11 @@
EXPECT_EQ(10, iree_vm_list_size(list));
for (iree_host_size_t i = 0; i < 5; ++i) {
- iree_vm_value_t value = iree_vm_value_make_i32(i);
+ iree_vm_value_t value = iree_vm_value_make_i32((int32_t)i);
IREE_ASSERT_OK(iree_vm_list_set_value(list, i, &value));
}
for (iree_host_size_t i = 5; i < 10; ++i) {
- iree_vm_ref_t ref_a = MakeRef<A>(i);
+ iree_vm_ref_t ref_a = MakeRef<A>(static_cast<float>(i));
IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, i, &ref_a));
}
diff --git a/iree/vm/native_module_cc.h b/iree/vm/native_module_cc.h
index fae6d64..32c0d56 100644
--- a/iree/vm/native_module_cc.h
+++ b/iree/vm/native_module_cc.h
@@ -148,7 +148,7 @@
if (out_function) {
out_function->module = module->interface();
out_function->linkage = IREE_VM_FUNCTION_LINKAGE_EXPORT;
- out_function->ordinal = ordinal;
+ out_function->ordinal = static_cast<uint16_t>(ordinal);
}
if (out_name) {
*out_name = dispatch_function.name;
diff --git a/iree/vm/stack.h b/iree/vm/stack.h
index 98fe693..bf59de6 100644
--- a/iree/vm/stack.h
+++ b/iree/vm/stack.h
@@ -87,9 +87,7 @@
// code), etc.
iree_vm_source_offset_t pc;
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- iree_zone_id_t trace_zone;
-#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
+ IREE_TRACE(iree_zone_id_t trace_zone;)
} iree_vm_stack_frame_t;
// Returns the implementation-defined frame storage associated with |frame|.
diff --git a/repo_utils.bzl b/repo_utils.bzl
deleted file mode 100644
index 933c00b..0000000
--- a/repo_utils.bzl
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# TODO(laurenzo): This is available upstream as of 0.28. Remove when ready.
-# See: https://docs.bazel.build/versions/master/repo/utils.html#maybe
-def maybe(repo_rule, name, **kwargs):
- """Utility function for only adding a repository if it's not already present.
- This is to implement safe repositories.bzl macro documented in
- https://docs.bazel.build/versions/master/skylark/deploying.html#dependencies.
- Args:
- repo_rule: repository rule function.
- name: name of the repository to create.
- **kwargs: remaining arguments that are passed to the repo_rule function.
- Returns:
- Nothing, defines the repository when needed as a side-effect.
- """
- if not native.existing_rule(name):
- repo_rule(name = name, **kwargs)
diff --git a/scripts/get_e2e_artifacts.py b/scripts/get_e2e_artifacts.py
index e62226d..2440ab7 100755
--- a/scripts/get_e2e_artifacts.py
+++ b/scripts/get_e2e_artifacts.py
@@ -162,10 +162,7 @@
if FLAGS.run_test_suites:
# Use bazel test to execute all of the test suites in parallel.
- command = [
- 'bazel', 'test', *test_suites, '--color=yes',
- '--test_arg=--get_saved_model'
- ]
+ command = ['bazel', 'test', *test_suites, '--color=yes']
print(f'Running: `{" ".join(command)}`')
if not FLAGS.dry_run:
subprocess.check_call(command)
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 325cb28..3b5509c 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -28,13 +28,12 @@
TENSORFLOW_COVERAGE_DIR = 'tensorflow_coverage'
REFERENCE_BACKEND = 'tf'
-# Assumes that tests are expanded for the tf, iree_vmla, iree_llvmjit and
+# Assumes that tests are expanded for the tf, iree_vmla, and
# iree_vulkan backends.
BACKENDS_TO_TITLES = collections.OrderedDict([
('tf', 'tensorflow'),
('tflite', 'tflite'),
('iree_vmla', 'vmla'),
- ('iree_llvmjit', 'llvm-ir'),
('iree_vulkan', 'vulkan-spirv'),
])
@@ -44,7 +43,12 @@
KWS_LINK = f'[Keyword Spotting Streaming]({KWS_LINK})'
COVERAGE_GROUP_TO_TEST_SUITES = {
- 'tf_base_coverage': ['//integrations/tensorflow/e2e:e2e_tests'],
+ 'tf_base_coverage': [
+ '//integrations/tensorflow/e2e:e2e_tests',
+ '//integrations/tensorflow/e2e/math:math_tests',
+ '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests',
+ '//integrations/tensorflow/e2e/math:math_complex_tests',
+ ],
'tf_keras_coverage': [
'//integrations/tensorflow/e2e/keras/layers:layers_tests',
'//integrations/tensorflow/e2e/keras/layers:layers_dynamic_batch_tests',
@@ -82,8 +86,16 @@
}
TEST_SUITES_TO_HEADERS = {
+ # tf_base_coverage
'//integrations/tensorflow/e2e:e2e_tests':
'End to end TensorFlow tests',
+ '//integrations/tensorflow/e2e/math:math_tests':
+ 'End to end tests of tf.math functions with static dimensions',
+ '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests':
+ 'End to end tests of tf.math functions with dynamic dimensions',
+ '//integrations/tensorflow/e2e/math:math_complex_tests':
+ 'End to end tests of tf.math functions with complex numbers',
+ # tf_keras_coverage
'//integrations/tensorflow/e2e/keras/layers:layers_tests':
'End to end tests of tf.keras layers (with default configuration and '
'static batch sizes in inference mode)',
@@ -96,12 +108,14 @@
'//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
'End to end tests of tf.keras layers in training mode (with default'
'configuration and static batch sizes)',
+ # language_and_speech_coverage
'//integrations/tensorflow/e2e:mobile_bert_squad_tests':
'End to end test of MobileBert on SQuAD',
'//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
f'End to end tests of {KWS_LINK} models',
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
f'End to end tests of {KWS_LINK} models in internal streaming mode',
+ # vision_coverage
'//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
'End to end tests of tf.keras.applications vision models on Imagenet',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -109,18 +123,31 @@
}
TEST_SUITES_TO_NOTES = {
+ '//integrations/tensorflow/e2e/math:math_tests':
+ ('**Note:** To be thorough, these tests use high rank tensors and\n'
+ 'test int dtypes where TensorFlow allows them to be used. Both of\n'
+ 'these choices disproportionately affect TFLite coverage, and\n'
+ 'don\'t represent coverage for simple use cases.\n'),
'//integrations/tensorflow/e2e/keras/layers:layers_tests': (
'**Note:** Layers like `Dropout` are listed as passing in this table,\n'
'but they function similar to identity layers in these tests. **See \n'
'the third table for the coverage of these layers during training.**\n'
'\n'
'These tests also only modify required `tf.keras.layers` arguments.\n'
- 'See the full API tests below for coverage on of non-default '
+ 'See the full API tests below for coverage on of non-default\n'
'layer configurations.'),
}
# Key to use as the name of the rows in the left column for each test in the
# suite.
TEST_SUITE_TO_ROW_ID_KEY = {
+ # tf_base_coverage
+ '//integrations/tensorflow/e2e/math:math_tests':
+ 'functions',
+ '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests':
+ 'functions',
+ '//integrations/tensorflow/e2e/math:math_complex_tests':
+ 'functions',
+ # tf_keras_coverage
'//integrations/tensorflow/e2e/keras/layers:layers_tests':
'layer',
'//integrations/tensorflow/e2e/keras/layers:layers_full_api_tests':
@@ -129,10 +156,12 @@
'layer',
'//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
'layer',
+ # language_and_speech_coverage
'//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
'model',
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
'model',
+ # vision_coverage
'//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
'model',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
@@ -142,6 +171,14 @@
# Some test suites are generated from a single source. This allows us to point
# to the right test file when generating test URLs.
SINGLE_SOURCE_SUITES = {
+ # tf_base_coverage
+ '//integrations/tensorflow/e2e/math:math_tests':
+ 'math_test',
+ '//integrations/tensorflow/e2e/math:math_dynamic_dims_tests':
+ 'math_test',
+ '//integrations/tensorflow/e2e/math:math_complex_tests':
+ 'math_test',
+ # tf_keras_coverage
'//integrations/tensorflow/e2e/keras/layers:layers_tests':
'layers_test',
'//integrations/tensorflow/e2e/keras/layers:layers_full_api_tests':
@@ -150,10 +187,12 @@
'layers_test',
'//integrations/tensorflow/e2e/keras/layers:layers_training_tests':
'layers_test',
+ # language_and_speech_coverage
'//integrations/tensorflow/e2e/keras:keyword_spotting_tests':
'keyword_spotting_streaming_test',
'//integrations/tensorflow/e2e/keras:keyword_spotting_internal_streaming_tests':
'keyword_spotting_streaming_test',
+ # vision_coverage
'//integrations/tensorflow/e2e/keras:imagenet_non_hermetic_tests':
'vision_model_test',
'//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
diff --git a/third_party/flatbuffers b/third_party/flatbuffers
deleted file mode 160000
index a5d9d0f..0000000
--- a/third_party/flatbuffers
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit a5d9d0f7d368054fd1691aedf1db4116efcc233e
diff --git a/third_party/half/LICENSE b/third_party/half/LICENSE
new file mode 100644
index 0000000..9e4618b
--- /dev/null
+++ b/third_party/half/LICENSE
@@ -0,0 +1,21 @@
+The MIT License
+
+Copyright (c) 2012-2017 Christian Rau
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/third_party/half/README.txt b/third_party/half/README.txt
new file mode 100644
index 0000000..3a0960c
--- /dev/null
+++ b/third_party/half/README.txt
@@ -0,0 +1,288 @@
+HALF-PRECISION FLOATING POINT LIBRARY (Version 1.12.0)
+------------------------------------------------------
+
+This is a C++ header-only library to provide an IEEE 754 conformant 16-bit
+half-precision floating point type along with corresponding arithmetic
+operators, type conversions and common mathematical functions. It aims for both
+efficiency and ease of use, trying to accurately mimic the behaviour of the
+builtin floating point types at the best performance possible.
+
+
+INSTALLATION AND REQUIREMENTS
+-----------------------------
+
+Comfortably enough, the library consists of just a single header file
+containing all the functionality, which can be directly included by your
+projects, without the neccessity to build anything or link to anything.
+
+Whereas this library is fully C++98-compatible, it can profit from certain
+C++11 features. Support for those features is checked automatically at compile
+(or rather preprocessing) time, but can be explicitly enabled or disabled by
+defining the corresponding preprocessor symbols to either 1 or 0 yourself. This
+is useful when the automatic detection fails (for more exotic implementations)
+or when a feature should be explicitly disabled:
+
+ - 'long long' integer type for mathematical functions returning 'long long'
+ results (enabled for VC++ 2003 and newer, gcc and clang, overridable with
+ 'HALF_ENABLE_CPP11_LONG_LONG').
+
+ - Static assertions for extended compile-time checks (enabled for VC++ 2010,
+ gcc 4.3, clang 2.9 and newer, overridable with 'HALF_ENABLE_CPP11_STATIC_ASSERT').
+
+ - Generalized constant expressions (enabled for VC++ 2015, gcc 4.6, clang 3.1
+ and newer, overridable with 'HALF_ENABLE_CPP11_CONSTEXPR').
+
+ - noexcept exception specifications (enabled for VC++ 2015, gcc 4.6, clang 3.0
+ and newer, overridable with 'HALF_ENABLE_CPP11_NOEXCEPT').
+
+ - User-defined literals for half-precision literals to work (enabled for
+ VC++ 2015, gcc 4.7, clang 3.1 and newer, overridable with
+ 'HALF_ENABLE_CPP11_USER_LITERALS').
+
+ - Type traits and template meta-programming features from <type_traits>
+ (enabled for VC++ 2010, libstdc++ 4.3, libc++ and newer, overridable with
+ 'HALF_ENABLE_CPP11_TYPE_TRAITS').
+
+ - Special integer types from <cstdint> (enabled for VC++ 2010, libstdc++ 4.3,
+ libc++ and newer, overridable with 'HALF_ENABLE_CPP11_CSTDINT').
+
+ - Certain C++11 single-precision mathematical functions from <cmath> for
+ an improved implementation of their half-precision counterparts to work
+ (enabled for VC++ 2013, libstdc++ 4.3, libc++ and newer, overridable with
+ 'HALF_ENABLE_CPP11_CMATH').
+
+ - Hash functor 'std::hash' from <functional> (enabled for VC++ 2010,
+ libstdc++ 4.3, libc++ and newer, overridable with 'HALF_ENABLE_CPP11_HASH').
+
+The library has been tested successfully with Visual C++ 2005-2015, gcc 4.4-4.8
+and clang 3.1. Please contact me if you have any problems, suggestions or even
+just success testing it on other platforms.
+
+
+DOCUMENTATION
+-------------
+
+Here follow some general words about the usage of the library and its
+implementation. For a complete documentation of its iterface look at the
+corresponding website http://half.sourceforge.net. You may also generate the
+complete developer documentation from the library's only include file's doxygen
+comments, but this is more relevant to developers rather than mere users (for
+reasons described below).
+
+BASIC USAGE
+
+To make use of the library just include its only header file half.hpp, which
+defines all half-precision functionality inside the 'half_float' namespace. The
+actual 16-bit half-precision data type is represented by the 'half' type. This
+type behaves like the builtin floating point types as much as possible,
+supporting the usual arithmetic, comparison and streaming operators, which
+makes its use pretty straight-forward:
+
+ using half_float::half;
+ half a(3.4), b(5);
+ half c = a * b;
+ c += 3;
+ if(c > a)
+ std::cout << c << std::endl;
+
+Additionally the 'half_float' namespace also defines half-precision versions
+for all mathematical functions of the C++ standard library, which can be used
+directly through ADL:
+
+ half a(-3.14159);
+ half s = sin(abs(a));
+ long l = lround(s);
+
+You may also specify explicit half-precision literals, since the library
+provides a user-defined literal inside the 'half_float::literal' namespace,
+which you just need to import (assuming support for C++11 user-defined literals):
+
+ using namespace half_float::literal;
+ half x = 1.0_h;
+
+Furthermore the library provides proper specializations for
+'std::numeric_limits', defining various implementation properties, and
+'std::hash' for hashing half-precision numbers (assuming support for C++11
+'std::hash'). Similar to the corresponding preprocessor symbols from <cmath>
+the library also defines the 'HUGE_VALH' constant and maybe the 'FP_FAST_FMAH'
+symbol.
+
+CONVERSIONS AND ROUNDING
+
+The half is explicitly constructible/convertible from a single-precision float
+argument. Thus it is also explicitly constructible/convertible from any type
+implicitly convertible to float, but constructing it from types like double or
+int will involve the usual warnings arising when implicitly converting those to
+float because of the lost precision. On the one hand those warnings are
+intentional, because converting those types to half neccessarily also reduces
+precision. But on the other hand they are raised for explicit conversions from
+those types, when the user knows what he is doing. So if those warnings keep
+bugging you, then you won't get around first explicitly converting to float
+before converting to half, or use the 'half_cast' described below. In addition
+you can also directly assign float values to halfs.
+
+In contrast to the float-to-half conversion, which reduces precision, the
+conversion from half to float (and thus to any other type implicitly
+convertible from float) is implicit, because all values represetable with
+half-precision are also representable with single-precision. This way the
+half-to-float conversion behaves similar to the builtin float-to-double
+conversion and all arithmetic expressions involving both half-precision and
+single-precision arguments will be of single-precision type. This way you can
+also directly use the mathematical functions of the C++ standard library,
+though in this case you will invoke the single-precision versions which will
+also return single-precision values, which is (even if maybe performing the
+exact same computation, see below) not as conceptually clean when working in a
+half-precision environment.
+
+The default rounding mode for conversions from float to half uses truncation
+(round toward zero, but mapping overflows to infinity) for rounding values not
+representable exactly in half-precision. This is the fastest rounding possible
+and is usually sufficient. But by redefining the 'HALF_ROUND_STYLE'
+preprocessor symbol (before including half.hpp) this default can be overridden
+with one of the other standard rounding modes using their respective constants
+or the equivalent values of 'std::float_round_style' (it can even be
+synchronized with the underlying single-precision implementation by defining it
+to 'std::numeric_limits<float>::round_style'):
+
+ - 'std::round_indeterminate' or -1 for the fastest rounding (default).
+
+ - 'std::round_toward_zero' or 0 for rounding toward zero.
+
+ - std::round_to_nearest' or 1 for rounding to the nearest value.
+
+ - std::round_toward_infinity' or 2 for rounding toward positive infinity.
+
+ - std::round_toward_neg_infinity' or 3 for rounding toward negative infinity.
+
+In addition to changing the overall default rounding mode one can also use the
+'half_cast'. This converts between half and any built-in arithmetic type using
+a configurable rounding mode (or the default rounding mode if none is
+specified). In addition to a configurable rounding mode, 'half_cast' has
+another big difference to a mere 'static_cast': Any conversions are performed
+directly using the given rounding mode, without any intermediate conversion
+to/from 'float'. This is especially relevant for conversions to integer types,
+which don't necessarily truncate anymore. But also for conversions from
+'double' or 'long double' this may produce more precise results than a
+pre-conversion to 'float' using the single-precision implementation's current
+rounding mode would.
+
+ half a = half_cast<half>(4.2);
+ half b = half_cast<half,std::numeric_limits<float>::round_style>(4.2f);
+ assert( half_cast<int, std::round_to_nearest>( 0.7_h ) == 1 );
+ assert( half_cast<half,std::round_toward_zero>( 4097 ) == 4096.0_h );
+ assert( half_cast<half,std::round_toward_infinity>( 4097 ) == 4100.0_h );
+ assert( half_cast<half,std::round_toward_infinity>( std::numeric_limits<double>::min() ) > 0.0_h );
+
+When using round to nearest (either as default or through 'half_cast') ties are
+by default resolved by rounding them away from zero (and thus equal to the
+behaviour of the 'round' function). But by redefining the
+'HALF_ROUND_TIES_TO_EVEN' preprocessor symbol to 1 (before including half.hpp)
+this default can be changed to the slightly slower but less biased and more
+IEEE-conformant behaviour of rounding half-way cases to the nearest even value.
+
+ #define HALF_ROUND_TIES_TO_EVEN 1
+ #include <half.hpp>
+ ...
+ assert( half_cast<int,std::round_to_nearest>(3.5_h)
+ == half_cast<int,std::round_to_nearest>(4.5_h) );
+
+IMPLEMENTATION
+
+For performance reasons (and ease of implementation) many of the mathematical
+functions provided by the library as well as all arithmetic operations are
+actually carried out in single-precision under the hood, calling to the C++
+standard library implementations of those functions whenever appropriate,
+meaning the arguments are converted to floats and the result back to half. But
+to reduce the conversion overhead as much as possible any temporary values
+inside of lengthy expressions are kept in single-precision as long as possible,
+while still maintaining a strong half-precision type to the outside world. Only
+when finally assigning the value to a half or calling a function that works
+directly on halfs is the actual conversion done (or never, when further
+converting the result to float.
+
+This approach has two implications. First of all you have to treat the
+library's documentation at http://half.sourceforge.net as a simplified version,
+describing the behaviour of the library as if implemented this way. The actual
+argument and return types of functions and operators may involve other internal
+types (feel free to generate the exact developer documentation from the Doxygen
+comments in the library's header file if you really need to). But nevertheless
+the behaviour is exactly like specified in the documentation. The other
+implication is, that in the presence of rounding errors or over-/underflows
+arithmetic expressions may produce different results when compared to
+converting to half-precision after each individual operation:
+
+ half a = std::numeric_limits<half>::max() * 2.0_h / 2.0_h; // a = MAX
+ half b = half(std::numeric_limits<half>::max() * 2.0_h) / 2.0_h; // b = INF
+ assert( a != b );
+
+But this should only be a problem in very few cases. One last word has to be
+said when talking about performance. Even with its efforts in reducing
+conversion overhead as much as possible, the software half-precision
+implementation can most probably not beat the direct use of single-precision
+computations. Usually using actual float values for all computations and
+temproraries and using halfs only for storage is the recommended way. On the
+one hand this somehow makes the provided mathematical functions obsolete
+(especially in light of the implicit conversion from half to float), but
+nevertheless the goal of this library was to provide a complete and
+conceptually clean half-precision implementation, to which the standard
+mathematical functions belong, even if usually not needed.
+
+IEEE CONFORMANCE
+
+The half type uses the standard IEEE representation with 1 sign bit, 5 exponent
+bits and 10 mantissa bits (11 when counting the hidden bit). It supports all
+types of special values, like subnormal values, infinity and NaNs. But there
+are some limitations to the complete conformance to the IEEE 754 standard:
+
+ - The implementation does not differentiate between signalling and quiet
+ NaNs, this means operations on halfs are not specified to trap on
+ signalling NaNs (though they may, see last point).
+
+ - Though arithmetic operations are internally rounded to single-precision
+ using the underlying single-precision implementation's current rounding
+ mode, those values are then converted to half-precision using the default
+ half-precision rounding mode (changed by defining 'HALF_ROUND_STYLE'
+ accordingly). This mixture of rounding modes is also the reason why
+ 'std::numeric_limits<half>::round_style' may actually return
+ 'std::round_indeterminate' when half- and single-precision rounding modes
+ don't match.
+
+ - Because of internal truncation it may also be that certain single-precision
+ NaNs will be wrongly converted to half-precision infinity, though this is
+ very unlikely to happen, since most single-precision implementations don't
+ tend to only set the lowest bits of a NaN mantissa.
+
+ - The implementation does not provide any floating point exceptions, thus
+ arithmetic operations or mathematical functions are not specified to invoke
+ proper floating point exceptions. But due to many functions implemented in
+ single-precision, those may still invoke floating point exceptions of the
+ underlying single-precision implementation.
+
+Some of those points could have been circumvented by controlling the floating
+point environment using <cfenv> or implementing a similar exception mechanism.
+But this would have required excessive runtime checks giving two high an impact
+on performance for something that is rarely ever needed. If you really need to
+rely on proper floating point exceptions, it is recommended to explicitly
+perform computations using the built-in floating point types to be on the safe
+side. In the same way, if you really need to rely on a particular rounding
+behaviour, it is recommended to either use single-precision computations and
+explicitly convert the result to half-precision using 'half_cast' and
+specifying the desired rounding mode, or synchronize the default half-precision
+rounding mode to the rounding mode of the single-precision implementation (most
+likely 'HALF_ROUND_STYLE=1', 'HALF_ROUND_TIES_TO_EVEN=1'). But this is really
+considered an expert-scenario that should be used only when necessary, since
+actually working with half-precision usually comes with a certain
+tolerance/ignorance of exactness considerations and proper rounding comes with
+a certain performance cost.
+
+
+CREDITS AND CONTACT
+-------------------
+
+This library is developed by CHRISTIAN RAU and released under the MIT License
+(see LICENSE.txt). If you have any questions or problems with it, feel free to
+contact me at rauy@users.sourceforge.net.
+
+Additional credit goes to JEROEN VAN DER ZIJP for his paper on "Fast Half Float
+Conversions", whose algorithms have been used in the library for converting
+between half-precision and single-precision values.
diff --git a/third_party/half/UPDATING.md b/third_party/half/UPDATING.md
new file mode 100644
index 0000000..3a8d2fe
--- /dev/null
+++ b/third_party/half/UPDATING.md
@@ -0,0 +1,8 @@
+This project is from: https://sourceforge.net/p/half/code/HEAD/tree/
+
+This can be updated by running:
+```shell
+curl http://svn.code.sf.net/p/half/code/trunk/include/half.hpp > half.hpp
+curl http://svn.code.sf.net/p/half/code/trunk/README.txt > README.txt
+curl http://svn.code.sf.net/p/half/code/trunk/LICENSE.txt > LICENSE.txt
+```
\ No newline at end of file
diff --git a/third_party/half/half.hpp b/third_party/half/half.hpp
new file mode 100644
index 0000000..e9d1aca
--- /dev/null
+++ b/third_party/half/half.hpp
@@ -0,0 +1,4023 @@
+// half - IEEE 754-based half-precision floating point library.
+//
+// Copyright (c) 2012-2017 Christian Rau <rauy@users.sourceforge.net>
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+// Version 1.12.0
+
+/// \file
+/// Main header file for half precision functionality.
+
+#ifndef HALF_HALF_HPP
+#define HALF_HALF_HPP
+
+/// Combined gcc version number.
+#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__)
+
+//check C++11 language features
+#if defined(__clang__) //clang
+ #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
+ #define HALF_ENABLE_CPP11_STATIC_ASSERT 1
+ #endif
+ #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
+ #define HALF_ENABLE_CPP11_CONSTEXPR 1
+ #endif
+ // Do not use NOEXCEPT. Not approved for google3 yet.
+ // #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
+ #if 0
+ #define HALF_ENABLE_CPP11_NOEXCEPT 1
+ #endif
+ // Do not use user literals. Not approved for google3 yet.
+ // #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
+ #if 0
+ #define HALF_ENABLE_CPP11_USER_LITERALS 1
+ #endif
+ #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG)
+ #define HALF_ENABLE_CPP11_LONG_LONG 1
+ #endif
+/*#elif defined(__INTEL_COMPILER)
+ //Intel C++ #if __INTEL_COMPILER >= 1100 &&
+ !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? #define
+ HALF_ENABLE_CPP11_STATIC_ASSERT 1 #endif #if __INTEL_COMPILER >= 1300 &&
+ !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? #define
+ HALF_ENABLE_CPP11_CONSTEXPR 1 #endif #if __INTEL_COMPILER >= 1300 &&
+ !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? #define
+ HALF_ENABLE_CPP11_NOEXCEPT 1 #endif #if __INTEL_COMPILER >= 1100 &&
+ !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #define
+ HALF_ENABLE_CPP11_LONG_LONG 1 #endif*/
+#elif defined(__GNUC__) //gcc
+ #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
+ #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
+ #define HALF_ENABLE_CPP11_STATIC_ASSERT 1
+ #endif
+ #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
+ #define HALF_ENABLE_CPP11_CONSTEXPR 1
+ #endif
+ #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
+ #define HALF_ENABLE_CPP11_NOEXCEPT 1
+ #endif
+ #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
+ #define HALF_ENABLE_CPP11_USER_LITERALS 1
+ #endif
+ #if !defined(HALF_ENABLE_CPP11_LONG_LONG)
+ #define HALF_ENABLE_CPP11_LONG_LONG 1
+ #endif
+ #endif
+#elif defined(_MSC_VER) //Visual C++
+#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR)
+#define HALF_ENABLE_CPP11_CONSTEXPR 1
+#endif
+#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT)
+#define HALF_ENABLE_CPP11_NOEXCEPT 1
+#endif
+#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS)
+#define HALF_ENABLE_CPP11_USER_LITERALS 1
+#endif
+#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT)
+#define HALF_ENABLE_CPP11_STATIC_ASSERT 1
+#endif
+#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG)
+#define HALF_ENABLE_CPP11_LONG_LONG 1
+#endif
+#define HALF_POP_WARNINGS 1
+#pragma warning(push)
+#pragma warning(disable : 4099 4127 4146) // struct vs class, constant in if,
+ // negative unsigned
+#endif
+
+//check C++11 library features
+#include <utility>
+#if defined(_LIBCPP_VERSION) //libc++
+ #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
+#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
+#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
+#endif
+#ifndef HALF_ENABLE_CPP11_CSTDINT
+#define HALF_ENABLE_CPP11_CSTDINT 1
+#endif
+#ifndef HALF_ENABLE_CPP11_CMATH
+#define HALF_ENABLE_CPP11_CMATH 1
+#endif
+#ifndef HALF_ENABLE_CPP11_HASH
+#define HALF_ENABLE_CPP11_HASH 1
+#endif
+#endif
+#elif defined(__GLIBCXX__) //libstdc++
+ #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103
+ #ifdef __clang__
+#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS)
+#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
+#endif
+#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT)
+#define HALF_ENABLE_CPP11_CSTDINT 1
+#endif
+#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH)
+#define HALF_ENABLE_CPP11_CMATH 1
+#endif
+#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH)
+#define HALF_ENABLE_CPP11_HASH 1
+#endif
+#else
+#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT)
+#define HALF_ENABLE_CPP11_CSTDINT 1
+#endif
+#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH)
+#define HALF_ENABLE_CPP11_CMATH 1
+#endif
+#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH)
+#define HALF_ENABLE_CPP11_HASH 1
+#endif
+#endif
+#endif
+#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++
+ #if _CPPLIB_VER >= 520
+#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS
+#define HALF_ENABLE_CPP11_TYPE_TRAITS 1
+#endif
+#ifndef HALF_ENABLE_CPP11_CSTDINT
+#define HALF_ENABLE_CPP11_CSTDINT 1
+#endif
+#ifndef HALF_ENABLE_CPP11_HASH
+#define HALF_ENABLE_CPP11_HASH 1
+#endif
+#endif
+#if _CPPLIB_VER >= 610
+#ifndef HALF_ENABLE_CPP11_CMATH
+#define HALF_ENABLE_CPP11_CMATH 1
+#endif
+#endif
+#endif
+#undef HALF_GNUC_VERSION
+
+//support constexpr
+#if HALF_ENABLE_CPP11_CONSTEXPR
+ #define HALF_CONSTEXPR constexpr
+ #define HALF_CONSTEXPR_CONST constexpr
+#else
+ #define HALF_CONSTEXPR
+ #define HALF_CONSTEXPR_CONST const
+#endif
+
+//support noexcept
+#if HALF_ENABLE_CPP11_NOEXCEPT
+ #define HALF_NOEXCEPT noexcept
+ #define HALF_NOTHROW noexcept
+#else
+ #define HALF_NOEXCEPT
+ #define HALF_NOTHROW throw()
+#endif
+
+#include <algorithm>
+#include <iostream>
+#include <limits>
+#include <climits>
+#include <cmath>
+#include <cstring>
+#if HALF_ENABLE_CPP11_TYPE_TRAITS
+#include <type_traits>
+#endif
+#if HALF_ENABLE_CPP11_CSTDINT
+ #include <cstdint>
+#endif
+#if HALF_ENABLE_CPP11_HASH
+#include <functional>
+#endif
+
+/// Default rounding mode.
+/// This specifies the rounding mode used for all conversions between
+/// [half](\ref half_float::half)s and `float`s as well as for the half_cast()
+/// if not specifying a rounding mode explicitly. It can be redefined (before
+/// including half.hpp) to one of the standard rounding modes using their
+/// respective constants or the equivalent values of `std::float_round_style`:
+///
+/// `std::float_round_style` | value | rounding
+/// ---------------------------------|-------|-------------------------
+/// `std::round_indeterminate` | -1 | fastest (default)
+/// `std::round_toward_zero` | 0 | toward zero
+/// `std::round_to_nearest` | 1 | to nearest
+/// `std::round_toward_infinity` | 2 | toward positive infinity
+/// `std::round_toward_neg_infinity` | 3 | toward negative infinity
+///
+/// By default this is set to `-1` (`std::round_indeterminate`), which uses
+/// truncation (round toward zero, but with overflows set to infinity) and is
+/// the fastest rounding mode possible. It can even be set to
+/// `std::numeric_limits<float>::round_style` to synchronize the rounding mode
+/// with that of the underlying single-precision implementation.
+#ifndef HALF_ROUND_STYLE
+#define HALF_ROUND_STYLE -1 // = std::round_indeterminate
+#endif
+
+/// Tie-breaking behaviour for round to nearest.
+/// This specifies if ties in round to nearest should be resolved by rounding to
+/// the nearest even value. By default this is defined to `0` resulting in the
+/// faster but slightly more biased behaviour of rounding away from zero in
+/// half-way cases (and thus equal to the round() function), but can be
+/// redefined to `1` (before including half.hpp) if more IEEE-conformant
+/// behaviour is needed.
+#ifndef HALF_ROUND_TIES_TO_EVEN
+#define HALF_ROUND_TIES_TO_EVEN 0 // ties away from zero
+#endif
+
+/// Value signaling overflow.
+/// In correspondence with `HUGE_VAL[F|L]` from `<cmath>` this symbol expands to
+/// a positive value signaling the overflow of an operation, in particular it
+/// just evaluates to positive infinity.
+#define HUGE_VALH std::numeric_limits<half_float::half>::infinity()
+
+/// Fast half-precision fma function.
+/// This symbol is only defined if the fma() function generally executes as fast
+/// as, or faster than, a separate half-precision multiplication followed by an
+/// addition. Due to the internal single-precision implementation of all
+/// arithmetic operations, this is in fact always the case.
+#define FP_FAST_FMAH 1
+
+#ifndef FP_ILOGB0
+ #define FP_ILOGB0 INT_MIN
+#endif
+#ifndef FP_ILOGBNAN
+ #define FP_ILOGBNAN INT_MAX
+#endif
+#ifndef FP_SUBNORMAL
+ #define FP_SUBNORMAL 0
+#endif
+#ifndef FP_ZERO
+ #define FP_ZERO 1
+#endif
+#ifndef FP_NAN
+ #define FP_NAN 2
+#endif
+#ifndef FP_INFINITE
+ #define FP_INFINITE 3
+#endif
+#ifndef FP_NORMAL
+ #define FP_NORMAL 4
+#endif
+
+
+/// Main namespace for half precision functionality.
+/// This namespace contains all the functionality provided by the library.
+namespace half_float
+{
+ class half;
+
+#if HALF_ENABLE_CPP11_USER_LITERALS
+ /// Library-defined half-precision literals.
+ /// Import this namespace to enable half-precision floating point
+ /// literals:
+ /// ~~~~{.cpp}
+ /// using namespace half_float::literal;
+ /// half_float::half = 4.2_h;
+ /// ~~~~
+ namespace literal {
+ half operator""_h(long double);
+ }
+#endif
+
+ /// \internal
+ /// \brief Implementation details.
+ namespace detail
+ {
+#if HALF_ENABLE_CPP11_TYPE_TRAITS
+ /// Conditional type.
+ template <bool B, typename T, typename F>
+ struct conditional : std::conditional<B, T, F> {};
+
+ /// Helper for tag dispatching.
+ template <bool B>
+ struct bool_type : std::integral_constant<bool, B> {};
+ using std::false_type;
+ using std::true_type;
+
+ /// Type traits for floating point types.
+ template <typename T>
+ struct is_float : std::is_floating_point<T> {};
+#else
+ /// Conditional type.
+ template <bool, typename T, typename>
+ struct conditional {
+ typedef T type;
+ };
+ template <typename T, typename F>
+ struct conditional<false, T, F> {
+ typedef F type;
+ };
+
+ /// Helper for tag dispatching.
+ template <bool>
+ struct bool_type {};
+ typedef bool_type<true> true_type;
+ typedef bool_type<false> false_type;
+
+ /// Type traits for floating point types.
+ template <typename>
+ struct is_float : false_type {};
+ template <typename T>
+ struct is_float<const T> : is_float<T> {};
+ template <typename T>
+ struct is_float<volatile T> : is_float<T> {};
+ template <typename T>
+ struct is_float<const volatile T> : is_float<T> {};
+ template <>
+ struct is_float<float> : true_type {};
+ template <>
+ struct is_float<double> : true_type {};
+ template <>
+ struct is_float<long double> : true_type {};
+#endif
+
+ /// Type traits for floating point bits.
+ template <typename T>
+ struct bits {
+ typedef unsigned char type;
+ };
+ template <typename T>
+ struct bits<const T> : bits<T> {};
+ template <typename T>
+ struct bits<volatile T> : bits<T> {};
+ template <typename T>
+ struct bits<const volatile T> : bits<T> {};
+
+#if HALF_ENABLE_CPP11_CSTDINT
+ /// Unsigned integer of (at least) 16 bits width.
+ typedef std::uint_least16_t uint16;
+
+ /// Unsigned integer of (at least) 32 bits width.
+ template <>
+ struct bits<float> {
+ typedef std::uint_least32_t type;
+ };
+
+ /// Unsigned integer of (at least) 64 bits width.
+ template <>
+ struct bits<double> {
+ typedef std::uint_least64_t type;
+ };
+#else
+ /// Unsigned integer of (at least) 16 bits width.
+ typedef unsigned short uint16;
+
+ /// Unsigned integer of (at least) 32 bits width.
+ template <>
+ struct bits<float>
+ : conditional<std::numeric_limits<unsigned int>::digits >= 32,
+ unsigned int, unsigned long> {};
+
+#if HALF_ENABLE_CPP11_LONG_LONG
+ /// Unsigned integer of (at least) 64 bits width.
+ template <>
+ struct bits<double>
+ : conditional<std::numeric_limits<unsigned long>::digits >= 64,
+ unsigned long, unsigned long long> {};
+#else
+ /// Unsigned integer of (at least) 64 bits width.
+ template <>
+ struct bits<double> {
+ typedef unsigned long type;
+ };
+#endif
+#endif
+
+ /// Tag type for binary construction.
+ struct binary_t {};
+
+ /// Tag for binary construction.
+ HALF_CONSTEXPR_CONST binary_t binary = binary_t();
+
+ /// Temporary half-precision expression.
+ /// This class represents a half-precision expression which just stores
+ /// a single-precision value internally.
+ struct expr {
+ /// Conversion constructor.
+ /// \param f single-precision value to convert
+ explicit HALF_CONSTEXPR expr(float f) HALF_NOEXCEPT : value_(f) {}
+
+ /// Conversion to single-precision.
+ /// \return single precision value representing expression value
+ HALF_CONSTEXPR operator float() const HALF_NOEXCEPT { return value_; }
+
+ private:
+ /// Internal expression value stored in single-precision.
+ float value_;
+ };
+
+ /// SFINAE helper for generic half-precision functions.
+ /// This class template has to be specialized for each valid
+ /// combination of argument types to provide a corresponding
+ /// `type` member equivalent to \a T.
+ /// \tparam T type to return
+ template <typename T, typename, typename = void, typename = void>
+ struct enable {};
+ template <typename T>
+ struct enable<T, half, void, void> {
+ typedef T type; };
+ template<typename T> struct enable<T,expr,void,void> { typedef T type; };
+ template<typename T> struct enable<T,half,half,void> { typedef T type; };
+ template<typename T> struct enable<T,half,expr,void> { typedef T type; };
+ template<typename T> struct enable<T,expr,half,void> { typedef T type; };
+ template<typename T> struct enable<T,expr,expr,void> { typedef T type; };
+ template<typename T> struct enable<T,half,half,half> { typedef T type; };
+ template<typename T> struct enable<T,half,half,expr> { typedef T type; };
+ template<typename T> struct enable<T,half,expr,half> { typedef T type; };
+ template<typename T> struct enable<T,half,expr,expr> { typedef T type; };
+ template<typename T> struct enable<T,expr,half,half> { typedef T type; };
+ template<typename T> struct enable<T,expr,half,expr> { typedef T type; };
+ template<typename T> struct enable<T,expr,expr,half> { typedef T type; };
+ template<typename T> struct enable<T,expr,expr,expr> { typedef T type; };
+
+ /// Return type for specialized generic 2-argument
+ /// half-precision functions. This class template has to be
+ /// specialized for each valid combination of argument types to
+ /// provide a corresponding `type` member denoting the
+ /// appropriate return type. \tparam T first argument type
+ /// \tparam U first argument type
+ template <typename T, typename U>
+ struct result : enable<expr, T, U> {};
+ template<> struct result<half,half> { typedef half type; };
+
+ /// \name Classification helpers
+ /// \{
+
+ /// Check for infinity.
+ /// \tparam T argument type (builtin floating point type)
+ /// \param arg value to query
+ /// \retval true if infinity
+ /// \retval false else
+ template<typename T> bool builtin_isinf(T arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return std::isinf(arg);
+ #elif defined(_MSC_VER)
+ return !::_finite(static_cast<double>(arg)) &&
+ !::_isnan(static_cast<double>(arg));
+#else
+ return arg == std::numeric_limits<T>::infinity() ||
+ arg == -std::numeric_limits<T>::infinity();
+#endif
+ }
+
+ /// Check for NaN.
+ /// \tparam T argument type (builtin floating point type)
+ /// \param arg value to query
+ /// \retval true if not a number
+ /// \retval false else
+ template<typename T> bool builtin_isnan(T arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return std::isnan(arg);
+ #elif defined(_MSC_VER)
+ return ::_isnan(static_cast<double>(arg)) != 0;
+#else
+ return arg != arg;
+#endif
+ }
+
+ /// Check sign.
+ /// \tparam T argument type (builtin floating point type)
+ /// \param arg value to query
+ /// \retval true if signbit set
+ /// \retval false else
+ template<typename T> bool builtin_signbit(T arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return std::signbit(arg);
+ #else
+ return arg < T() || (arg == T() && T(1) / arg < T());
+#endif
+ }
+
+ /// \}
+ /// \name Conversion
+ /// \{
+
+ /// Convert IEEE single-precision to half-precision.
+ /// Credit for this goes to [Jeroen van der
+ /// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf).
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \param value single-precision value
+ /// \return binary representation of half-precision value
+ template <std::float_round_style R>
+ uint16 float2half_impl(float value, true_type) {
+ typedef bits<float>::type uint32;
+ uint32 bits; // = *reinterpret_cast<uint32*>(&value);
+ // //violating strict aliasing!
+ std::memcpy(&bits, &value, sizeof(float));
+ /* uint16 hbits = (bits>>16) & 0x8000;
+ bits &= 0x7FFFFFFF;
+ int exp = bits >> 23;
+ if(exp == 255)
+ return hbits | 0x7C00 |
+ (0x3FF&-static_cast<unsigned>((bits&0x7FFFFF)!=0)); if(exp
+ > 142)
+ {
+ if(R ==
+ std::round_toward_infinity) return hbits | 0x7C00 -
+ (hbits>>15); if(R == std::round_toward_neg_infinity) return
+ hbits | 0x7BFF + (hbits>>15); return hbits | 0x7BFF +
+ (R!=std::round_toward_zero);
+ }
+ int g, s;
+ if(exp > 112)
+ {
+ g = (bits>>12) & 1;
+ s = (bits&0xFFF) != 0;
+ hbits |= ((exp-112)<<10) |
+ ((bits>>13)&0x3FF);
+ }
+ else if(exp > 101)
+ {
+ int i = 125 - exp;
+ bits = (bits&0x7FFFFF) |
+ 0x800000; g = (bits>>i) & 1; s = (bits&((1L<<i)-1)) != 0;
+ hbits |= bits >> (i+1);
+ }
+ else
+ {
+ g = 0;
+ s = bits != 0;
+ }
+ if(R == std::round_to_nearest)
+ #if HALF_ROUND_TIES_TO_EVEN
+ hbits += g &
+ (s|hbits); #else hbits += g; #endif else if(R ==
+ std::round_toward_infinity) hbits += ~(hbits>>15) & (s|g);
+ else if(R ==
+ std::round_toward_neg_infinity) hbits += (hbits>>15) &
+ (g|s);
+ */
+ static const uint16 base_table[512] = {
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002,
+ 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, 0x0100,
+ 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800,
+ 0x1C00, 0x2000, 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400,
+ 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, 0x5000,
+ 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00,
+ 0x7000, 0x7400, 0x7800, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00,
+ 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000,
+ 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, 0x8010,
+ 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, 0x8400, 0x8800,
+ 0x8C00, 0x9000, 0x9400, 0x9800, 0x9C00, 0xA000, 0xA400,
+ 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000,
+ 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00,
+ 0xE000, 0xE400, 0xE800, 0xEC00, 0xF000, 0xF400, 0xF800,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00,
+ 0xFC00};
+ static const unsigned char shift_table[512] = {
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15,
+ 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+ 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+ 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 23, 22, 21, 20, 19,
+ 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+ 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
+ 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 13};
+ uint16 hbits = base_table[bits >> 23] +
+ static_cast<uint16>((bits & 0x7FFFFF) >>
+ shift_table[bits >> 23]);
+ if (R == std::round_to_nearest)
+ hbits +=
+ (((bits & 0x7FFFFF) >> (shift_table[bits >> 23] - 1)) |
+ (((bits >> 23) & 0xFF) == 102)) &
+ ((hbits & 0x7C00) != 0x7C00)
+#if HALF_ROUND_TIES_TO_EVEN
+ & (((((static_cast<uint32>(1)
+ << (shift_table[bits >> 23] - 1)) -
+ 1) &
+ bits) != 0) |
+ hbits)
+#endif
+ ;
+ else if (R == std::round_toward_zero)
+ hbits -=
+ ((hbits & 0x7FFF) == 0x7C00) & ~shift_table[bits >> 23];
+ else if (R == std::round_toward_infinity)
+ hbits += ((((bits & 0x7FFFFF &
+ ((static_cast<uint32>(1)
+ << (shift_table[bits >> 23])) -
+ 1)) != 0) |
+ (((bits >> 23) <= 102) & ((bits >> 23) != 0))) &
+ (hbits < 0x7C00)) -
+ ((hbits == 0xFC00) & ((bits >> 23) != 511));
+ else if (R == std::round_toward_neg_infinity)
+ hbits +=
+ ((((bits & 0x7FFFFF &
+ ((static_cast<uint32>(1)
+ << (shift_table[bits >> 23])) -
+ 1)) != 0) |
+ (((bits >> 23) <= 358) & ((bits >> 23) != 256))) &
+ (hbits < 0xFC00) & (hbits >> 15)) -
+ ((hbits == 0x7C00) & ((bits >> 23) != 255));
+ return hbits;
+ }
+
+ /// Convert IEEE double-precision to half-precision.
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \param value double-precision value
+ /// \return binary representation of half-precision value
+ template <std::float_round_style R>
+ uint16 float2half_impl(double value, true_type) {
+ typedef bits<float>::type uint32;
+ typedef bits<double>::type uint64;
+ uint64 bits; // = *reinterpret_cast<uint64*>(&value);
+ // //violating strict aliasing!
+ std::memcpy(&bits, &value, sizeof(double));
+ uint32 hi = bits >> 32, lo = bits & 0xFFFFFFFF;
+ uint16 hbits = (hi >> 16) & 0x8000;
+ hi &= 0x7FFFFFFF;
+ int exp = hi >> 20;
+ if (exp == 2047)
+ return hbits | 0x7C00 |
+ (0x3FF & -static_cast<unsigned>(
+ (bits & 0xFFFFFFFFFFFFF) != 0));
+ if (exp > 1038) {
+ if (R == std::round_toward_infinity)
+ return hbits | 0x7C00 - (hbits >> 15);
+ if (R == std::round_toward_neg_infinity)
+ return hbits | 0x7BFF + (hbits >> 15);
+ return hbits | 0x7BFF + (R != std::round_toward_zero);
+ }
+ int g, s = lo != 0;
+ if (exp > 1008) {
+ g = (hi >> 9) & 1;
+ s |= (hi & 0x1FF) != 0;
+ hbits |= ((exp - 1008) << 10) | ((hi >> 10) & 0x3FF);
+ } else if (exp > 997) {
+ int i = 1018 - exp;
+ hi = (hi & 0xFFFFF) | 0x100000;
+ g = (hi >> i) & 1;
+ s |= (hi & ((1L << i) - 1)) != 0;
+ hbits |= hi >> (i + 1);
+ } else {
+ g = 0;
+ s |= hi != 0;
+ }
+ if (R == std::round_to_nearest)
+#if HALF_ROUND_TIES_TO_EVEN
+ hbits += g & (s | hbits);
+#else
+ hbits += g;
+#endif
+ else if (R == std::round_toward_infinity)
+ hbits += ~(hbits >> 15) & (s | g);
+ else if (R == std::round_toward_neg_infinity)
+ hbits += (hbits >> 15) & (g | s);
+ return hbits;
+ }
+
+ /// Convert non-IEEE floating point to half-precision.
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \tparam T source type (builtin floating
+ /// point type) \param value floating point value \return binary
+ /// representation of half-precision value
+ template <std::float_round_style R, typename T>
+ uint16 float2half_impl(T value, ...) {
+ uint16 hbits = static_cast<unsigned>(builtin_signbit(value))
+ << 15;
+ if (value == T()) return hbits;
+ if (builtin_isnan(value)) return hbits | 0x7FFF;
+ if (builtin_isinf(value)) return hbits | 0x7C00;
+ int exp;
+ std::frexp(value, &exp);
+ if (exp > 16) {
+ if (R == std::round_toward_infinity)
+ return hbits | 0x7C00 - (hbits >> 15);
+ else if (R == std::round_toward_neg_infinity)
+ return hbits | 0x7BFF + (hbits >> 15);
+ return hbits | 0x7BFF + (R != std::round_toward_zero);
+ }
+ if (exp < -13)
+ value = std::ldexp(value, 24);
+ else {
+ value = std::ldexp(value, 11 - exp);
+ hbits |= ((exp + 13) << 10);
+ }
+ T ival, frac = std::modf(value, &ival);
+ hbits +=
+ static_cast<uint16>(std::abs(static_cast<int>(ival)));
+ if (R == std::round_to_nearest) {
+ frac = std::abs(frac);
+#if HALF_ROUND_TIES_TO_EVEN
+ hbits += (frac > T(0.5)) | ((frac == T(0.5)) & hbits);
+#else
+ hbits += frac >= T(0.5);
+#endif
+ } else if (R == std::round_toward_infinity)
+ hbits += frac > T();
+ else if (R == std::round_toward_neg_infinity)
+ hbits += frac < T();
+ return hbits;
+ }
+
+ /// Convert floating point to half-precision.
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \tparam T source type (builtin floating
+ /// point type) \param value floating point value \return binary
+ /// representation of half-precision value
+ template <std::float_round_style R, typename T>
+ uint16 float2half(T value) {
+ return float2half_impl<R>(
+ value,
+ bool_type < std::numeric_limits<T>::is_iec559 &&
+ sizeof(typename bits<T>::type) == sizeof(T) > ());
+ }
+
+ /// Convert integer to half-precision floating point.
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \tparam S `true` if value negative,
+ /// `false` else \tparam T type to convert (builtin integer
+ /// type) \param value non-negative integral value \return
+ /// binary representation of half-precision value
+ template <std::float_round_style R, bool S, typename T>
+ uint16 int2half_impl(T value) {
+#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS
+ static_assert(std::is_integral<T>::value,
+ "int to half conversion only supports builtin "
+ "integer types");
+#endif
+ if (S) value = -value;
+ uint16 bits = S << 15;
+ if (value > 0xFFFF) {
+ if (R == std::round_toward_infinity)
+ bits |= 0x7C00 - S;
+ else if (R == std::round_toward_neg_infinity)
+ bits |= 0x7BFF + S;
+ else
+ bits |= 0x7BFF + (R != std::round_toward_zero);
+ } else if (value) {
+ unsigned int m = value, exp = 24;
+ for (; m < 0x400; m <<= 1, --exp)
+ ;
+ for (; m > 0x7FF; m >>= 1, ++exp)
+ ;
+ bits |= (exp << 10) + m;
+ if (exp > 24) {
+ if (R == std::round_to_nearest)
+ bits +=
+ (value >> (exp - 25)) & 1
+#if HALF_ROUND_TIES_TO_EVEN
+ & (((((1 << (exp - 25)) - 1) & value) != 0) | bits)
+#endif
+ ;
+ else if (R == std::round_toward_infinity)
+ bits += ((value & ((1 << (exp - 24)) - 1)) != 0) & !S;
+ else if (R == std::round_toward_neg_infinity)
+ bits += ((value & ((1 << (exp - 24)) - 1)) != 0) & S;
+ }
+ }
+ return bits;
+ }
+
+ /// Convert integer to half-precision floating point.
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \tparam T type to convert (builtin
+ /// integer type) \param value integral value \return binary
+ /// representation of half-precision value
+ template<std::float_round_style R,typename T> uint16 int2half(T value)
+ {
+ return (value < 0) ? int2half_impl<R, true>(value)
+ : int2half_impl<R, false>(value);
+ }
+
+ /// Convert half-precision to IEEE single-precision.
+ /// Credit for this goes to [Jeroen van der
+ /// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf).
+ /// \param value binary representation of half-precision value
+ /// \return single-precision value
+ inline float half2float_impl(uint16 value, float, true_type) {
+ typedef bits<float>::type uint32;
+ /* uint32 bits =
+ static_cast<uint32>(value&0x8000)
+ << 16; int abs = value & 0x7FFF; if(abs)
+ {
+ bits |= 0x38000000 <<
+ static_cast<unsigned>(abs>=0x7C00); for(; abs<0x400;
+ abs<<=1,bits-=0x800000) ; bits += static_cast<uint32>(abs)
+ << 13;
+ }
+ */
+ static const uint32 mantissa_table[2048] = {
+ 0x00000000, 0x33800000, 0x34000000, 0x34400000,
+ 0x34800000, 0x34A00000, 0x34C00000, 0x34E00000,
+ 0x35000000, 0x35100000, 0x35200000, 0x35300000,
+ 0x35400000, 0x35500000, 0x35600000, 0x35700000,
+ 0x35800000, 0x35880000, 0x35900000, 0x35980000,
+ 0x35A00000, 0x35A80000, 0x35B00000, 0x35B80000,
+ 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000,
+ 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000,
+ 0x36000000, 0x36040000, 0x36080000, 0x360C0000,
+ 0x36100000, 0x36140000, 0x36180000, 0x361C0000,
+ 0x36200000, 0x36240000, 0x36280000, 0x362C0000,
+ 0x36300000, 0x36340000, 0x36380000, 0x363C0000,
+ 0x36400000, 0x36440000, 0x36480000, 0x364C0000,
+ 0x36500000, 0x36540000, 0x36580000, 0x365C0000,
+ 0x36600000, 0x36640000, 0x36680000, 0x366C0000,
+ 0x36700000, 0x36740000, 0x36780000, 0x367C0000,
+ 0x36800000, 0x36820000, 0x36840000, 0x36860000,
+ 0x36880000, 0x368A0000, 0x368C0000, 0x368E0000,
+ 0x36900000, 0x36920000, 0x36940000, 0x36960000,
+ 0x36980000, 0x369A0000, 0x369C0000, 0x369E0000,
+ 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000,
+ 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000,
+ 0x36B00000, 0x36B20000, 0x36B40000, 0x36B60000,
+ 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000,
+ 0x36C00000, 0x36C20000, 0x36C40000, 0x36C60000,
+ 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000,
+ 0x36D00000, 0x36D20000, 0x36D40000, 0x36D60000,
+ 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000,
+ 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000,
+ 0x36E80000, 0x36EA0000, 0x36EC0000, 0x36EE0000,
+ 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000,
+ 0x36F80000, 0x36FA0000, 0x36FC0000, 0x36FE0000,
+ 0x37000000, 0x37010000, 0x37020000, 0x37030000,
+ 0x37040000, 0x37050000, 0x37060000, 0x37070000,
+ 0x37080000, 0x37090000, 0x370A0000, 0x370B0000,
+ 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000,
+ 0x37100000, 0x37110000, 0x37120000, 0x37130000,
+ 0x37140000, 0x37150000, 0x37160000, 0x37170000,
+ 0x37180000, 0x37190000, 0x371A0000, 0x371B0000,
+ 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000,
+ 0x37200000, 0x37210000, 0x37220000, 0x37230000,
+ 0x37240000, 0x37250000, 0x37260000, 0x37270000,
+ 0x37280000, 0x37290000, 0x372A0000, 0x372B0000,
+ 0x372C0000, 0x372D0000, 0x372E0000, 0x372F0000,
+ 0x37300000, 0x37310000, 0x37320000, 0x37330000,
+ 0x37340000, 0x37350000, 0x37360000, 0x37370000,
+ 0x37380000, 0x37390000, 0x373A0000, 0x373B0000,
+ 0x373C0000, 0x373D0000, 0x373E0000, 0x373F0000,
+ 0x37400000, 0x37410000, 0x37420000, 0x37430000,
+ 0x37440000, 0x37450000, 0x37460000, 0x37470000,
+ 0x37480000, 0x37490000, 0x374A0000, 0x374B0000,
+ 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000,
+ 0x37500000, 0x37510000, 0x37520000, 0x37530000,
+ 0x37540000, 0x37550000, 0x37560000, 0x37570000,
+ 0x37580000, 0x37590000, 0x375A0000, 0x375B0000,
+ 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000,
+ 0x37600000, 0x37610000, 0x37620000, 0x37630000,
+ 0x37640000, 0x37650000, 0x37660000, 0x37670000,
+ 0x37680000, 0x37690000, 0x376A0000, 0x376B0000,
+ 0x376C0000, 0x376D0000, 0x376E0000, 0x376F0000,
+ 0x37700000, 0x37710000, 0x37720000, 0x37730000,
+ 0x37740000, 0x37750000, 0x37760000, 0x37770000,
+ 0x37780000, 0x37790000, 0x377A0000, 0x377B0000,
+ 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000,
+ 0x37800000, 0x37808000, 0x37810000, 0x37818000,
+ 0x37820000, 0x37828000, 0x37830000, 0x37838000,
+ 0x37840000, 0x37848000, 0x37850000, 0x37858000,
+ 0x37860000, 0x37868000, 0x37870000, 0x37878000,
+ 0x37880000, 0x37888000, 0x37890000, 0x37898000,
+ 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000,
+ 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000,
+ 0x378E0000, 0x378E8000, 0x378F0000, 0x378F8000,
+ 0x37900000, 0x37908000, 0x37910000, 0x37918000,
+ 0x37920000, 0x37928000, 0x37930000, 0x37938000,
+ 0x37940000, 0x37948000, 0x37950000, 0x37958000,
+ 0x37960000, 0x37968000, 0x37970000, 0x37978000,
+ 0x37980000, 0x37988000, 0x37990000, 0x37998000,
+ 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000,
+ 0x379C0000, 0x379C8000, 0x379D0000, 0x379D8000,
+ 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000,
+ 0x37A00000, 0x37A08000, 0x37A10000, 0x37A18000,
+ 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000,
+ 0x37A40000, 0x37A48000, 0x37A50000, 0x37A58000,
+ 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000,
+ 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000,
+ 0x37AA0000, 0x37AA8000, 0x37AB0000, 0x37AB8000,
+ 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000,
+ 0x37AE0000, 0x37AE8000, 0x37AF0000, 0x37AF8000,
+ 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000,
+ 0x37B20000, 0x37B28000, 0x37B30000, 0x37B38000,
+ 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000,
+ 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000,
+ 0x37B80000, 0x37B88000, 0x37B90000, 0x37B98000,
+ 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000,
+ 0x37BC0000, 0x37BC8000, 0x37BD0000, 0x37BD8000,
+ 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000,
+ 0x37C00000, 0x37C08000, 0x37C10000, 0x37C18000,
+ 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000,
+ 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000,
+ 0x37C60000, 0x37C68000, 0x37C70000, 0x37C78000,
+ 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000,
+ 0x37CA0000, 0x37CA8000, 0x37CB0000, 0x37CB8000,
+ 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000,
+ 0x37CE0000, 0x37CE8000, 0x37CF0000, 0x37CF8000,
+ 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000,
+ 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000,
+ 0x37D40000, 0x37D48000, 0x37D50000, 0x37D58000,
+ 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000,
+ 0x37D80000, 0x37D88000, 0x37D90000, 0x37D98000,
+ 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000,
+ 0x37DC0000, 0x37DC8000, 0x37DD0000, 0x37DD8000,
+ 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000,
+ 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000,
+ 0x37E20000, 0x37E28000, 0x37E30000, 0x37E38000,
+ 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000,
+ 0x37E60000, 0x37E68000, 0x37E70000, 0x37E78000,
+ 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000,
+ 0x37EA0000, 0x37EA8000, 0x37EB0000, 0x37EB8000,
+ 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000,
+ 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000,
+ 0x37F00000, 0x37F08000, 0x37F10000, 0x37F18000,
+ 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000,
+ 0x37F40000, 0x37F48000, 0x37F50000, 0x37F58000,
+ 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000,
+ 0x37F80000, 0x37F88000, 0x37F90000, 0x37F98000,
+ 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000,
+ 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000,
+ 0x37FE0000, 0x37FE8000, 0x37FF0000, 0x37FF8000,
+ 0x38000000, 0x38004000, 0x38008000, 0x3800C000,
+ 0x38010000, 0x38014000, 0x38018000, 0x3801C000,
+ 0x38020000, 0x38024000, 0x38028000, 0x3802C000,
+ 0x38030000, 0x38034000, 0x38038000, 0x3803C000,
+ 0x38040000, 0x38044000, 0x38048000, 0x3804C000,
+ 0x38050000, 0x38054000, 0x38058000, 0x3805C000,
+ 0x38060000, 0x38064000, 0x38068000, 0x3806C000,
+ 0x38070000, 0x38074000, 0x38078000, 0x3807C000,
+ 0x38080000, 0x38084000, 0x38088000, 0x3808C000,
+ 0x38090000, 0x38094000, 0x38098000, 0x3809C000,
+ 0x380A0000, 0x380A4000, 0x380A8000, 0x380AC000,
+ 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000,
+ 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000,
+ 0x380D0000, 0x380D4000, 0x380D8000, 0x380DC000,
+ 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000,
+ 0x380F0000, 0x380F4000, 0x380F8000, 0x380FC000,
+ 0x38100000, 0x38104000, 0x38108000, 0x3810C000,
+ 0x38110000, 0x38114000, 0x38118000, 0x3811C000,
+ 0x38120000, 0x38124000, 0x38128000, 0x3812C000,
+ 0x38130000, 0x38134000, 0x38138000, 0x3813C000,
+ 0x38140000, 0x38144000, 0x38148000, 0x3814C000,
+ 0x38150000, 0x38154000, 0x38158000, 0x3815C000,
+ 0x38160000, 0x38164000, 0x38168000, 0x3816C000,
+ 0x38170000, 0x38174000, 0x38178000, 0x3817C000,
+ 0x38180000, 0x38184000, 0x38188000, 0x3818C000,
+ 0x38190000, 0x38194000, 0x38198000, 0x3819C000,
+ 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000,
+ 0x381B0000, 0x381B4000, 0x381B8000, 0x381BC000,
+ 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000,
+ 0x381D0000, 0x381D4000, 0x381D8000, 0x381DC000,
+ 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000,
+ 0x381F0000, 0x381F4000, 0x381F8000, 0x381FC000,
+ 0x38200000, 0x38204000, 0x38208000, 0x3820C000,
+ 0x38210000, 0x38214000, 0x38218000, 0x3821C000,
+ 0x38220000, 0x38224000, 0x38228000, 0x3822C000,
+ 0x38230000, 0x38234000, 0x38238000, 0x3823C000,
+ 0x38240000, 0x38244000, 0x38248000, 0x3824C000,
+ 0x38250000, 0x38254000, 0x38258000, 0x3825C000,
+ 0x38260000, 0x38264000, 0x38268000, 0x3826C000,
+ 0x38270000, 0x38274000, 0x38278000, 0x3827C000,
+ 0x38280000, 0x38284000, 0x38288000, 0x3828C000,
+ 0x38290000, 0x38294000, 0x38298000, 0x3829C000,
+ 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000,
+ 0x382B0000, 0x382B4000, 0x382B8000, 0x382BC000,
+ 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000,
+ 0x382D0000, 0x382D4000, 0x382D8000, 0x382DC000,
+ 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000,
+ 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000,
+ 0x38300000, 0x38304000, 0x38308000, 0x3830C000,
+ 0x38310000, 0x38314000, 0x38318000, 0x3831C000,
+ 0x38320000, 0x38324000, 0x38328000, 0x3832C000,
+ 0x38330000, 0x38334000, 0x38338000, 0x3833C000,
+ 0x38340000, 0x38344000, 0x38348000, 0x3834C000,
+ 0x38350000, 0x38354000, 0x38358000, 0x3835C000,
+ 0x38360000, 0x38364000, 0x38368000, 0x3836C000,
+ 0x38370000, 0x38374000, 0x38378000, 0x3837C000,
+ 0x38380000, 0x38384000, 0x38388000, 0x3838C000,
+ 0x38390000, 0x38394000, 0x38398000, 0x3839C000,
+ 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000,
+ 0x383B0000, 0x383B4000, 0x383B8000, 0x383BC000,
+ 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000,
+ 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000,
+ 0x383E0000, 0x383E4000, 0x383E8000, 0x383EC000,
+ 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000,
+ 0x38400000, 0x38404000, 0x38408000, 0x3840C000,
+ 0x38410000, 0x38414000, 0x38418000, 0x3841C000,
+ 0x38420000, 0x38424000, 0x38428000, 0x3842C000,
+ 0x38430000, 0x38434000, 0x38438000, 0x3843C000,
+ 0x38440000, 0x38444000, 0x38448000, 0x3844C000,
+ 0x38450000, 0x38454000, 0x38458000, 0x3845C000,
+ 0x38460000, 0x38464000, 0x38468000, 0x3846C000,
+ 0x38470000, 0x38474000, 0x38478000, 0x3847C000,
+ 0x38480000, 0x38484000, 0x38488000, 0x3848C000,
+ 0x38490000, 0x38494000, 0x38498000, 0x3849C000,
+ 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000,
+ 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000,
+ 0x384C0000, 0x384C4000, 0x384C8000, 0x384CC000,
+ 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000,
+ 0x384E0000, 0x384E4000, 0x384E8000, 0x384EC000,
+ 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000,
+ 0x38500000, 0x38504000, 0x38508000, 0x3850C000,
+ 0x38510000, 0x38514000, 0x38518000, 0x3851C000,
+ 0x38520000, 0x38524000, 0x38528000, 0x3852C000,
+ 0x38530000, 0x38534000, 0x38538000, 0x3853C000,
+ 0x38540000, 0x38544000, 0x38548000, 0x3854C000,
+ 0x38550000, 0x38554000, 0x38558000, 0x3855C000,
+ 0x38560000, 0x38564000, 0x38568000, 0x3856C000,
+ 0x38570000, 0x38574000, 0x38578000, 0x3857C000,
+ 0x38580000, 0x38584000, 0x38588000, 0x3858C000,
+ 0x38590000, 0x38594000, 0x38598000, 0x3859C000,
+ 0x385A0000, 0x385A4000, 0x385A8000, 0x385AC000,
+ 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000,
+ 0x385C0000, 0x385C4000, 0x385C8000, 0x385CC000,
+ 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000,
+ 0x385E0000, 0x385E4000, 0x385E8000, 0x385EC000,
+ 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000,
+ 0x38600000, 0x38604000, 0x38608000, 0x3860C000,
+ 0x38610000, 0x38614000, 0x38618000, 0x3861C000,
+ 0x38620000, 0x38624000, 0x38628000, 0x3862C000,
+ 0x38630000, 0x38634000, 0x38638000, 0x3863C000,
+ 0x38640000, 0x38644000, 0x38648000, 0x3864C000,
+ 0x38650000, 0x38654000, 0x38658000, 0x3865C000,
+ 0x38660000, 0x38664000, 0x38668000, 0x3866C000,
+ 0x38670000, 0x38674000, 0x38678000, 0x3867C000,
+ 0x38680000, 0x38684000, 0x38688000, 0x3868C000,
+ 0x38690000, 0x38694000, 0x38698000, 0x3869C000,
+ 0x386A0000, 0x386A4000, 0x386A8000, 0x386AC000,
+ 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000,
+ 0x386C0000, 0x386C4000, 0x386C8000, 0x386CC000,
+ 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000,
+ 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000,
+ 0x386F0000, 0x386F4000, 0x386F8000, 0x386FC000,
+ 0x38700000, 0x38704000, 0x38708000, 0x3870C000,
+ 0x38710000, 0x38714000, 0x38718000, 0x3871C000,
+ 0x38720000, 0x38724000, 0x38728000, 0x3872C000,
+ 0x38730000, 0x38734000, 0x38738000, 0x3873C000,
+ 0x38740000, 0x38744000, 0x38748000, 0x3874C000,
+ 0x38750000, 0x38754000, 0x38758000, 0x3875C000,
+ 0x38760000, 0x38764000, 0x38768000, 0x3876C000,
+ 0x38770000, 0x38774000, 0x38778000, 0x3877C000,
+ 0x38780000, 0x38784000, 0x38788000, 0x3878C000,
+ 0x38790000, 0x38794000, 0x38798000, 0x3879C000,
+ 0x387A0000, 0x387A4000, 0x387A8000, 0x387AC000,
+ 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000,
+ 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000,
+ 0x387D0000, 0x387D4000, 0x387D8000, 0x387DC000,
+ 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000,
+ 0x387F0000, 0x387F4000, 0x387F8000, 0x387FC000,
+ 0x38000000, 0x38002000, 0x38004000, 0x38006000,
+ 0x38008000, 0x3800A000, 0x3800C000, 0x3800E000,
+ 0x38010000, 0x38012000, 0x38014000, 0x38016000,
+ 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000,
+ 0x38020000, 0x38022000, 0x38024000, 0x38026000,
+ 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000,
+ 0x38030000, 0x38032000, 0x38034000, 0x38036000,
+ 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000,
+ 0x38040000, 0x38042000, 0x38044000, 0x38046000,
+ 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000,
+ 0x38050000, 0x38052000, 0x38054000, 0x38056000,
+ 0x38058000, 0x3805A000, 0x3805C000, 0x3805E000,
+ 0x38060000, 0x38062000, 0x38064000, 0x38066000,
+ 0x38068000, 0x3806A000, 0x3806C000, 0x3806E000,
+ 0x38070000, 0x38072000, 0x38074000, 0x38076000,
+ 0x38078000, 0x3807A000, 0x3807C000, 0x3807E000,
+ 0x38080000, 0x38082000, 0x38084000, 0x38086000,
+ 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000,
+ 0x38090000, 0x38092000, 0x38094000, 0x38096000,
+ 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000,
+ 0x380A0000, 0x380A2000, 0x380A4000, 0x380A6000,
+ 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000,
+ 0x380B0000, 0x380B2000, 0x380B4000, 0x380B6000,
+ 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000,
+ 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000,
+ 0x380C8000, 0x380CA000, 0x380CC000, 0x380CE000,
+ 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000,
+ 0x380D8000, 0x380DA000, 0x380DC000, 0x380DE000,
+ 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000,
+ 0x380E8000, 0x380EA000, 0x380EC000, 0x380EE000,
+ 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000,
+ 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000,
+ 0x38100000, 0x38102000, 0x38104000, 0x38106000,
+ 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000,
+ 0x38110000, 0x38112000, 0x38114000, 0x38116000,
+ 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000,
+ 0x38120000, 0x38122000, 0x38124000, 0x38126000,
+ 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000,
+ 0x38130000, 0x38132000, 0x38134000, 0x38136000,
+ 0x38138000, 0x3813A000, 0x3813C000, 0x3813E000,
+ 0x38140000, 0x38142000, 0x38144000, 0x38146000,
+ 0x38148000, 0x3814A000, 0x3814C000, 0x3814E000,
+ 0x38150000, 0x38152000, 0x38154000, 0x38156000,
+ 0x38158000, 0x3815A000, 0x3815C000, 0x3815E000,
+ 0x38160000, 0x38162000, 0x38164000, 0x38166000,
+ 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000,
+ 0x38170000, 0x38172000, 0x38174000, 0x38176000,
+ 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000,
+ 0x38180000, 0x38182000, 0x38184000, 0x38186000,
+ 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000,
+ 0x38190000, 0x38192000, 0x38194000, 0x38196000,
+ 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000,
+ 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000,
+ 0x381A8000, 0x381AA000, 0x381AC000, 0x381AE000,
+ 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000,
+ 0x381B8000, 0x381BA000, 0x381BC000, 0x381BE000,
+ 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000,
+ 0x381C8000, 0x381CA000, 0x381CC000, 0x381CE000,
+ 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000,
+ 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000,
+ 0x381E0000, 0x381E2000, 0x381E4000, 0x381E6000,
+ 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000,
+ 0x381F0000, 0x381F2000, 0x381F4000, 0x381F6000,
+ 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000,
+ 0x38200000, 0x38202000, 0x38204000, 0x38206000,
+ 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000,
+ 0x38210000, 0x38212000, 0x38214000, 0x38216000,
+ 0x38218000, 0x3821A000, 0x3821C000, 0x3821E000,
+ 0x38220000, 0x38222000, 0x38224000, 0x38226000,
+ 0x38228000, 0x3822A000, 0x3822C000, 0x3822E000,
+ 0x38230000, 0x38232000, 0x38234000, 0x38236000,
+ 0x38238000, 0x3823A000, 0x3823C000, 0x3823E000,
+ 0x38240000, 0x38242000, 0x38244000, 0x38246000,
+ 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000,
+ 0x38250000, 0x38252000, 0x38254000, 0x38256000,
+ 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000,
+ 0x38260000, 0x38262000, 0x38264000, 0x38266000,
+ 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000,
+ 0x38270000, 0x38272000, 0x38274000, 0x38276000,
+ 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000,
+ 0x38280000, 0x38282000, 0x38284000, 0x38286000,
+ 0x38288000, 0x3828A000, 0x3828C000, 0x3828E000,
+ 0x38290000, 0x38292000, 0x38294000, 0x38296000,
+ 0x38298000, 0x3829A000, 0x3829C000, 0x3829E000,
+ 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000,
+ 0x382A8000, 0x382AA000, 0x382AC000, 0x382AE000,
+ 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000,
+ 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000,
+ 0x382C0000, 0x382C2000, 0x382C4000, 0x382C6000,
+ 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000,
+ 0x382D0000, 0x382D2000, 0x382D4000, 0x382D6000,
+ 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000,
+ 0x382E0000, 0x382E2000, 0x382E4000, 0x382E6000,
+ 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000,
+ 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000,
+ 0x382F8000, 0x382FA000, 0x382FC000, 0x382FE000,
+ 0x38300000, 0x38302000, 0x38304000, 0x38306000,
+ 0x38308000, 0x3830A000, 0x3830C000, 0x3830E000,
+ 0x38310000, 0x38312000, 0x38314000, 0x38316000,
+ 0x38318000, 0x3831A000, 0x3831C000, 0x3831E000,
+ 0x38320000, 0x38322000, 0x38324000, 0x38326000,
+ 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000,
+ 0x38330000, 0x38332000, 0x38334000, 0x38336000,
+ 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000,
+ 0x38340000, 0x38342000, 0x38344000, 0x38346000,
+ 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000,
+ 0x38350000, 0x38352000, 0x38354000, 0x38356000,
+ 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000,
+ 0x38360000, 0x38362000, 0x38364000, 0x38366000,
+ 0x38368000, 0x3836A000, 0x3836C000, 0x3836E000,
+ 0x38370000, 0x38372000, 0x38374000, 0x38376000,
+ 0x38378000, 0x3837A000, 0x3837C000, 0x3837E000,
+ 0x38380000, 0x38382000, 0x38384000, 0x38386000,
+ 0x38388000, 0x3838A000, 0x3838C000, 0x3838E000,
+ 0x38390000, 0x38392000, 0x38394000, 0x38396000,
+ 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000,
+ 0x383A0000, 0x383A2000, 0x383A4000, 0x383A6000,
+ 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000,
+ 0x383B0000, 0x383B2000, 0x383B4000, 0x383B6000,
+ 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000,
+ 0x383C0000, 0x383C2000, 0x383C4000, 0x383C6000,
+ 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000,
+ 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000,
+ 0x383D8000, 0x383DA000, 0x383DC000, 0x383DE000,
+ 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000,
+ 0x383E8000, 0x383EA000, 0x383EC000, 0x383EE000,
+ 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000,
+ 0x383F8000, 0x383FA000, 0x383FC000, 0x383FE000,
+ 0x38400000, 0x38402000, 0x38404000, 0x38406000,
+ 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000,
+ 0x38410000, 0x38412000, 0x38414000, 0x38416000,
+ 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000,
+ 0x38420000, 0x38422000, 0x38424000, 0x38426000,
+ 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000,
+ 0x38430000, 0x38432000, 0x38434000, 0x38436000,
+ 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000,
+ 0x38440000, 0x38442000, 0x38444000, 0x38446000,
+ 0x38448000, 0x3844A000, 0x3844C000, 0x3844E000,
+ 0x38450000, 0x38452000, 0x38454000, 0x38456000,
+ 0x38458000, 0x3845A000, 0x3845C000, 0x3845E000,
+ 0x38460000, 0x38462000, 0x38464000, 0x38466000,
+ 0x38468000, 0x3846A000, 0x3846C000, 0x3846E000,
+ 0x38470000, 0x38472000, 0x38474000, 0x38476000,
+ 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000,
+ 0x38480000, 0x38482000, 0x38484000, 0x38486000,
+ 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000,
+ 0x38490000, 0x38492000, 0x38494000, 0x38496000,
+ 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000,
+ 0x384A0000, 0x384A2000, 0x384A4000, 0x384A6000,
+ 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000,
+ 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000,
+ 0x384B8000, 0x384BA000, 0x384BC000, 0x384BE000,
+ 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000,
+ 0x384C8000, 0x384CA000, 0x384CC000, 0x384CE000,
+ 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000,
+ 0x384D8000, 0x384DA000, 0x384DC000, 0x384DE000,
+ 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000,
+ 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000,
+ 0x384F0000, 0x384F2000, 0x384F4000, 0x384F6000,
+ 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000,
+ 0x38500000, 0x38502000, 0x38504000, 0x38506000,
+ 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000,
+ 0x38510000, 0x38512000, 0x38514000, 0x38516000,
+ 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000,
+ 0x38520000, 0x38522000, 0x38524000, 0x38526000,
+ 0x38528000, 0x3852A000, 0x3852C000, 0x3852E000,
+ 0x38530000, 0x38532000, 0x38534000, 0x38536000,
+ 0x38538000, 0x3853A000, 0x3853C000, 0x3853E000,
+ 0x38540000, 0x38542000, 0x38544000, 0x38546000,
+ 0x38548000, 0x3854A000, 0x3854C000, 0x3854E000,
+ 0x38550000, 0x38552000, 0x38554000, 0x38556000,
+ 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000,
+ 0x38560000, 0x38562000, 0x38564000, 0x38566000,
+ 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000,
+ 0x38570000, 0x38572000, 0x38574000, 0x38576000,
+ 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000,
+ 0x38580000, 0x38582000, 0x38584000, 0x38586000,
+ 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000,
+ 0x38590000, 0x38592000, 0x38594000, 0x38596000,
+ 0x38598000, 0x3859A000, 0x3859C000, 0x3859E000,
+ 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000,
+ 0x385A8000, 0x385AA000, 0x385AC000, 0x385AE000,
+ 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000,
+ 0x385B8000, 0x385BA000, 0x385BC000, 0x385BE000,
+ 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000,
+ 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000,
+ 0x385D0000, 0x385D2000, 0x385D4000, 0x385D6000,
+ 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000,
+ 0x385E0000, 0x385E2000, 0x385E4000, 0x385E6000,
+ 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000,
+ 0x385F0000, 0x385F2000, 0x385F4000, 0x385F6000,
+ 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000,
+ 0x38600000, 0x38602000, 0x38604000, 0x38606000,
+ 0x38608000, 0x3860A000, 0x3860C000, 0x3860E000,
+ 0x38610000, 0x38612000, 0x38614000, 0x38616000,
+ 0x38618000, 0x3861A000, 0x3861C000, 0x3861E000,
+ 0x38620000, 0x38622000, 0x38624000, 0x38626000,
+ 0x38628000, 0x3862A000, 0x3862C000, 0x3862E000,
+ 0x38630000, 0x38632000, 0x38634000, 0x38636000,
+ 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000,
+ 0x38640000, 0x38642000, 0x38644000, 0x38646000,
+ 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000,
+ 0x38650000, 0x38652000, 0x38654000, 0x38656000,
+ 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000,
+ 0x38660000, 0x38662000, 0x38664000, 0x38666000,
+ 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000,
+ 0x38670000, 0x38672000, 0x38674000, 0x38676000,
+ 0x38678000, 0x3867A000, 0x3867C000, 0x3867E000,
+ 0x38680000, 0x38682000, 0x38684000, 0x38686000,
+ 0x38688000, 0x3868A000, 0x3868C000, 0x3868E000,
+ 0x38690000, 0x38692000, 0x38694000, 0x38696000,
+ 0x38698000, 0x3869A000, 0x3869C000, 0x3869E000,
+ 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000,
+ 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000,
+ 0x386B0000, 0x386B2000, 0x386B4000, 0x386B6000,
+ 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000,
+ 0x386C0000, 0x386C2000, 0x386C4000, 0x386C6000,
+ 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000,
+ 0x386D0000, 0x386D2000, 0x386D4000, 0x386D6000,
+ 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000,
+ 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000,
+ 0x386E8000, 0x386EA000, 0x386EC000, 0x386EE000,
+ 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000,
+ 0x386F8000, 0x386FA000, 0x386FC000, 0x386FE000,
+ 0x38700000, 0x38702000, 0x38704000, 0x38706000,
+ 0x38708000, 0x3870A000, 0x3870C000, 0x3870E000,
+ 0x38710000, 0x38712000, 0x38714000, 0x38716000,
+ 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000,
+ 0x38720000, 0x38722000, 0x38724000, 0x38726000,
+ 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000,
+ 0x38730000, 0x38732000, 0x38734000, 0x38736000,
+ 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000,
+ 0x38740000, 0x38742000, 0x38744000, 0x38746000,
+ 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000,
+ 0x38750000, 0x38752000, 0x38754000, 0x38756000,
+ 0x38758000, 0x3875A000, 0x3875C000, 0x3875E000,
+ 0x38760000, 0x38762000, 0x38764000, 0x38766000,
+ 0x38768000, 0x3876A000, 0x3876C000, 0x3876E000,
+ 0x38770000, 0x38772000, 0x38774000, 0x38776000,
+ 0x38778000, 0x3877A000, 0x3877C000, 0x3877E000,
+ 0x38780000, 0x38782000, 0x38784000, 0x38786000,
+ 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000,
+ 0x38790000, 0x38792000, 0x38794000, 0x38796000,
+ 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000,
+ 0x387A0000, 0x387A2000, 0x387A4000, 0x387A6000,
+ 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000,
+ 0x387B0000, 0x387B2000, 0x387B4000, 0x387B6000,
+ 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000,
+ 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000,
+ 0x387C8000, 0x387CA000, 0x387CC000, 0x387CE000,
+ 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000,
+ 0x387D8000, 0x387DA000, 0x387DC000, 0x387DE000,
+ 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000,
+ 0x387E8000, 0x387EA000, 0x387EC000, 0x387EE000,
+ 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000,
+ 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000};
+ static const uint32 exponent_table[64] = {
+ 0x00000000, 0x00800000, 0x01000000, 0x01800000,
+ 0x02000000, 0x02800000, 0x03000000, 0x03800000,
+ 0x04000000, 0x04800000, 0x05000000, 0x05800000,
+ 0x06000000, 0x06800000, 0x07000000, 0x07800000,
+ 0x08000000, 0x08800000, 0x09000000, 0x09800000,
+ 0x0A000000, 0x0A800000, 0x0B000000, 0x0B800000,
+ 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000,
+ 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000,
+ 0x80000000, 0x80800000, 0x81000000, 0x81800000,
+ 0x82000000, 0x82800000, 0x83000000, 0x83800000,
+ 0x84000000, 0x84800000, 0x85000000, 0x85800000,
+ 0x86000000, 0x86800000, 0x87000000, 0x87800000,
+ 0x88000000, 0x88800000, 0x89000000, 0x89800000,
+ 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000,
+ 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000,
+ 0x8E000000, 0x8E800000, 0x8F000000, 0xC7800000};
+ static const unsigned short offset_table[64] = {
+ 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024,
+ 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024,
+ 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024,
+ 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024,
+ 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024,
+ 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024,
+ 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024,
+ 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024};
+ uint32 bits = mantissa_table[offset_table[value >> 10] +
+ (value & 0x3FF)] +
+ exponent_table[value >> 10];
+ // return *reinterpret_cast<float*>(&bits);
+ ////violating strict aliasing!
+ float out;
+ std::memcpy(&out, &bits, sizeof(float));
+ return out;
+ }
+
+ /// Convert half-precision to IEEE double-precision.
+ /// \param value binary representation of half-precision value
+ /// \return double-precision value
+ inline double half2float_impl(uint16 value, double, true_type) {
+ typedef bits<float>::type uint32;
+ typedef bits<double>::type uint64;
+ uint32 hi = static_cast<uint32>(value & 0x8000) << 16;
+ int abs = value & 0x7FFF;
+ if (abs) {
+ hi |= 0x3F000000 << static_cast<unsigned>(abs >= 0x7C00);
+ for (; abs < 0x400; abs <<= 1, hi -= 0x100000)
+ ;
+ hi += static_cast<uint32>(abs) << 10;
+ }
+ uint64 bits = static_cast<uint64>(hi) << 32;
+ // return
+ //*reinterpret_cast<double*>(&bits);
+ ////violating strict aliasing!
+ double out;
+ std::memcpy(&out, &bits, sizeof(double));
+ return out;
+ }
+
+ /// Convert half-precision to non-IEEE floating point.
+ /// \tparam T type to convert to (builtin integer type)
+ /// \param value binary representation of half-precision value
+ /// \return floating point value
+ template <typename T>
+ T half2float_impl(uint16 value, T, ...) {
+ T out;
+ int abs = value & 0x7FFF;
+ if (abs > 0x7C00)
+ out = std::numeric_limits<T>::has_quiet_NaN
+ ? std::numeric_limits<T>::quiet_NaN()
+ : T();
+ else if (abs == 0x7C00)
+ out = std::numeric_limits<T>::has_infinity
+ ? std::numeric_limits<T>::infinity()
+ : std::numeric_limits<T>::max();
+ else if (abs > 0x3FF)
+ out = std::ldexp(static_cast<T>((abs & 0x3FF) | 0x400),
+ (abs >> 10) - 25);
+ else
+ out = std::ldexp(static_cast<T>(abs), -24);
+ return (value & 0x8000) ? -out : out;
+ }
+
+ /// Convert half-precision to floating point.
+ /// \tparam T type to convert to (builtin integer type)
+ /// \param value binary representation of half-precision value
+ /// \return floating point value
+ template <typename T>
+ T half2float(uint16 value) {
+ return half2float_impl(
+ value, T(),
+ bool_type < std::numeric_limits<T>::is_iec559 &&
+ sizeof(typename bits<T>::type) == sizeof(T) > ());
+ }
+
+ /// Convert half-precision floating point to integer.
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \tparam E `true` for round to even,
+ /// `false` for round away from zero \tparam T type to convert
+ /// to (buitlin integer type with at least 16 bits precision,
+ /// excluding any implicit sign bits) \param value binary
+ /// representation of half-precision value \return integral
+ /// value
+ template <std::float_round_style R, bool E, typename T>
+ T half2int_impl(uint16 value) {
+#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS
+ static_assert(std::is_integral<T>::value,
+ "half to int conversion only supports builtin "
+ "integer types");
+#endif
+ unsigned int e = value & 0x7FFF;
+ if (e >= 0x7C00)
+ return (value & 0x8000) ? std::numeric_limits<T>::min()
+ : std::numeric_limits<T>::max();
+ if (e < 0x3800) {
+ if (R == std::round_toward_infinity)
+ return T(~(value >> 15) & (e != 0));
+ else if (R == std::round_toward_neg_infinity)
+ return -T(value > 0x8000);
+ return T();
+ }
+ unsigned int m = (value & 0x3FF) | 0x400;
+ e >>= 10;
+ if (e < 25) {
+ if (R == std::round_to_nearest)
+ m += (1 << (24 - e)) - (~(m >> (25 - e)) & E);
+ else if (R == std::round_toward_infinity)
+ m += ((value >> 15) - 1) & ((1 << (25 - e)) - 1U);
+ else if (R == std::round_toward_neg_infinity)
+ m += -(value >> 15) & ((1 << (25 - e)) - 1U);
+ m >>= 25 - e;
+ } else
+ m <<= e - 25;
+ return (value & 0x8000) ? -static_cast<T>(m)
+ : static_cast<T>(m);
+ }
+
+ /// Convert half-precision floating point to integer.
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \tparam T type to convert to (buitlin
+ /// integer type with at least 16 bits precision, excluding any
+ /// implicit sign bits) \param value binary representation of
+ /// half-precision value \return integral value
+ template <std::float_round_style R, typename T>
+ T half2int(uint16 value) {
+ return half2int_impl<R, HALF_ROUND_TIES_TO_EVEN, T>(value);
+ }
+
+ /// Convert half-precision floating point to integer using
+ /// round-to-nearest-away-from-zero. \tparam T type to convert
+ /// to (buitlin integer type with at least 16 bits precision,
+ /// excluding any implicit sign bits) \param value binary
+ /// representation of half-precision value \return integral
+ /// value
+ template <typename T>
+ T half2int_up(uint16 value) {
+ return half2int_impl<std::round_to_nearest, 0, T>(value);
+ }
+
+ /// Round half-precision number to nearest integer value.
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \tparam E `true` for round to even,
+ /// `false` for round away from zero \param value binary
+ /// representation of half-precision value \return
+ /// half-precision bits for nearest integral value
+ template <std::float_round_style R, bool E>
+ uint16 round_half_impl(uint16 value) {
+ unsigned int e = value & 0x7FFF;
+ uint16 result = value;
+ if (e < 0x3C00) {
+ result &= 0x8000;
+ if (R == std::round_to_nearest)
+ result |= 0x3C00U & -(e >= (0x3800 + E));
+ else if (R == std::round_toward_infinity)
+ result |= 0x3C00U & -(~(value >> 15) & (e != 0));
+ else if (R == std::round_toward_neg_infinity)
+ result |= 0x3C00U & -(value > 0x8000);
+ } else if (e < 0x6400) {
+ e = 25 - (e >> 10);
+ unsigned int mask = (1 << e) - 1;
+ if (R == std::round_to_nearest)
+ result += (1 << (e - 1)) - (~(result >> e) & E);
+ else if (R == std::round_toward_infinity)
+ result += mask & ((value >> 15) - 1);
+ else if (R == std::round_toward_neg_infinity)
+ result += mask & -(value >> 15);
+ result &= ~mask;
+ }
+ return result;
+ }
+
+ /// Round half-precision number to nearest integer value.
+ /// \tparam R rounding mode to use, `std::round_indeterminate`
+ /// for fastest rounding \param value binary representation of
+ /// half-precision value \return half-precision bits for nearest
+ /// integral value
+ template <std::float_round_style R>
+ uint16 round_half(uint16 value) {
+ return round_half_impl<R, HALF_ROUND_TIES_TO_EVEN>(value);
+ }
+
+ /// Round half-precision number to nearest integer value using
+ /// round-to-nearest-away-from-zero. \param value binary
+ /// representation of half-precision value \return
+ /// half-precision bits for nearest integral value
+ inline uint16 round_half_up(uint16 value) {
+ return round_half_impl<std::round_to_nearest, 0>(value);
+ }
+ /// \}
+
+ struct functions;
+ template<typename> struct unary_specialized;
+ template<typename,typename> struct binary_specialized;
+ template<typename,typename,std::float_round_style> struct half_caster;
+ }
+
+ /// Half-precision floating point type.
+ /// This class implements an IEEE-conformant half-precision floating
+ /// point type with the usual arithmetic operators and conversions. It
+ /// is implicitly convertible to single-precision floating point, which
+ /// makes artihmetic expressions and functions with mixed-type operands
+ /// to be of the most precise operand type. Additionally all arithmetic
+ /// operations (and many mathematical functions) are carried out in
+ /// single-precision internally. All conversions from single- to
+ /// half-precision are done using the library's default rounding mode,
+ /// but temporary results inside chained arithmetic expressions are kept
+ /// in single-precision as long as possible (while of course still
+ /// maintaining a strong half-precision type).
+ ///
+ /// According to the C++98/03 definition, the half type is not a POD
+ /// type. But according to C++11's less strict and extended definitions
+ /// it is both a standard layout type and a trivially copyable type
+ /// (even if not a POD type), which means it can be
+ /// standard-conformantly copied using raw binary copies. But in this
+ /// context some more words about the actual size of the type. Although
+ /// the half is representing an IEEE 16-bit type, it does not
+ /// neccessarily have to be of exactly 16-bits size. But on any
+ /// reasonable implementation the actual binary representation of this
+ /// type will most probably not ivolve any additional "magic" or padding
+ /// beyond the simple binary representation of the underlying 16-bit
+ /// IEEE number, even if not strictly guaranteed by the standard. But
+ /// even then it only has an actual size of 16 bits if your C++
+ /// implementation supports an unsigned integer type of exactly 16 bits
+ /// width. But this should be the case on nearly any reasonable
+ /// platform.
+ ///
+ /// So if your C++ implementation is not totally exotic or imposes
+ /// special alignment requirements, it is a reasonable assumption that
+ /// the data of a half is just comprised of the 2 bytes of the
+ /// underlying IEEE representation.
+ class half {
+ friend struct detail::functions;
+ friend struct detail::unary_specialized<half>;
+ friend struct detail::binary_specialized<half, half>;
+ template <typename, typename, std::float_round_style>
+ friend struct detail::half_caster;
+ friend class std::numeric_limits<half>;
+#if HALF_ENABLE_CPP11_HASH
+ friend struct std::hash<half>;
+#endif
+#if HALF_ENABLE_CPP11_USER_LITERALS
+ friend half literal::operator""_h(long double);
+#endif
+
+ public:
+ /// Default constructor.
+ /// This initializes the half to 0. Although this does not match
+ /// the builtin types' default-initialization semantics and may
+ /// be less efficient than no initialization, it is needed to
+ /// provide proper value-initialization semantics.
+ HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {}
+
+ /// Copy constructor.
+ /// \tparam T type of concrete half expression
+ /// \param rhs half expression to copy from
+ half(detail::expr rhs)
+ : data_(
+ detail::float2half<round_style>(static_cast<float>(rhs))) {}
+
+ /// Conversion constructor.
+ /// \param rhs float to convert
+ explicit half(float rhs)
+ : data_(detail::float2half<round_style>(rhs)) {}
+
+ /// Conversion to single-precision.
+ /// \return single precision value representing expression value
+ operator float() const { return detail::half2float<float>(data_); }
+
+ /// Assignment operator.
+ /// \tparam T type of concrete half expression
+ /// \param rhs half expression to copy from
+ /// \return reference to this half
+ half &operator=(detail::expr rhs) {
+ return *this = static_cast<float>(rhs);
+ }
+
+ /// Arithmetic assignment.
+ /// \tparam T type of concrete half expression
+ /// \param rhs half expression to add
+ /// \return reference to this half
+ template <typename T>
+ typename detail::enable<half &, T>::type operator+=(T rhs) {
+ return *this += static_cast<float>(rhs);
+ }
+
+ /// Arithmetic assignment.
+ /// \tparam T type of concrete half expression
+ /// \param rhs half expression to subtract
+ /// \return reference to this half
+ template <typename T>
+ typename detail::enable<half &, T>::type operator-=(T rhs) {
+ return *this -= static_cast<float>(rhs);
+ }
+
+ /// Arithmetic assignment.
+ /// \tparam T type of concrete half expression
+ /// \param rhs half expression to multiply with
+ /// \return reference to this half
+ template <typename T>
+ typename detail::enable<half &, T>::type operator*=(T rhs) {
+ return *this *= static_cast<float>(rhs);
+ }
+
+ /// Arithmetic assignment.
+ /// \tparam T type of concrete half expression
+ /// \param rhs half expression to divide by
+ /// \return reference to this half
+ template <typename T>
+ typename detail::enable<half &, T>::type operator/=(T rhs) {
+ return *this /= static_cast<float>(rhs);
+ }
+
+ /// Assignment operator.
+ /// \param rhs single-precision value to copy from
+ /// \return reference to this half
+ half &operator=(float rhs) {
+ data_ = detail::float2half<round_style>(rhs);
+ return *this;
+ }
+
+ /// Arithmetic assignment.
+ /// \param rhs single-precision value to add
+ /// \return reference to this half
+ half &operator+=(float rhs) {
+ data_ = detail::float2half<round_style>(
+ detail::half2float<float>(data_) + rhs);
+ return *this;
+ }
+
+ /// Arithmetic assignment.
+ /// \param rhs single-precision value to subtract
+ /// \return reference to this half
+ half &operator-=(float rhs) {
+ data_ = detail::float2half<round_style>(
+ detail::half2float<float>(data_) - rhs);
+ return *this;
+ }
+
+ /// Arithmetic assignment.
+ /// \param rhs single-precision value to multiply with
+ /// \return reference to this half
+ half &operator*=(float rhs) {
+ data_ = detail::float2half<round_style>(
+ detail::half2float<float>(data_) * rhs);
+ return *this;
+ }
+
+ /// Arithmetic assignment.
+ /// \param rhs single-precision value to divide by
+ /// \return reference to this half
+ half &operator/=(float rhs) {
+ data_ = detail::float2half<round_style>(
+ detail::half2float<float>(data_) / rhs);
+ return *this;
+ }
+
+ /// Prefix increment.
+ /// \return incremented half value
+ half &operator++() { return *this += 1.0f; }
+
+ /// Prefix decrement.
+ /// \return decremented half value
+ half &operator--() { return *this -= 1.0f; }
+
+ /// Postfix increment.
+ /// \return non-incremented half value
+ half operator++(int) {
+ half out(*this);
+ ++*this;
+ return out;
+ }
+
+ /// Postfix decrement.
+ /// \return non-decremented half value
+ half operator--(int) {
+ half out(*this);
+ --*this;
+ return out;
+ }
+
+ private:
+ /// Rounding mode to use
+ static const std::float_round_style round_style =
+ (std::float_round_style)(HALF_ROUND_STYLE);
+
+ /// Constructor.
+ /// \param bits binary representation to set half to
+ HALF_CONSTEXPR half(detail::binary_t,
+ detail::uint16 bits) HALF_NOEXCEPT : data_(bits) {
+ }
+
+ /// Internal binary representation
+ detail::uint16 data_;
+ };
+
+#if HALF_ENABLE_CPP11_USER_LITERALS
+ namespace literal
+ {
+ /// Half literal.
+ /// While this returns an actual half-precision value, half literals can
+ /// unfortunately not be constant expressions due to rather involved
+ /// conversions. \param value literal value \return half with given
+ /// value (if representable)
+ inline half operator""_h(long double value) {
+ return half(detail::binary,
+ detail::float2half<half::round_style>(value));
+ }
+ } // namespace literal
+#endif
+
+ namespace detail
+ {
+ /// Wrapper implementing unspecialized half-precision functions.
+ struct functions
+ {
+ /// Addition implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return Half-precision sum stored in single-precision
+ static expr plus(float x, float y) { return expr(x+y); }
+
+ /// Subtraction implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return Half-precision difference stored in single-precision
+ static expr minus(float x, float y) { return expr(x-y); }
+
+ /// Multiplication implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return Half-precision product stored in single-precision
+ static expr multiplies(float x, float y) { return expr(x*y); }
+
+ /// Division implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return Half-precision quotient stored in single-precision
+ static expr divides(float x, float y) { return expr(x/y); }
+
+ /// Output implementation.
+ /// \param out stream to write to
+ /// \param arg value to write
+ /// \return reference to stream
+ template<typename charT,typename traits> static std::basic_ostream<charT,traits>& write(std::basic_ostream<charT,traits> &out, float arg) { return out << arg; }
+
+ /// Input implementation.
+ /// \param in stream to read from
+ /// \param arg half to read into
+ /// \return reference to stream
+ template<typename charT,typename traits> static std::basic_istream<charT,traits>& read(std::basic_istream<charT,traits> &in, half &arg)
+ {
+ float f;
+ if(in >> f)
+ arg = f;
+ return in;
+ }
+
+ /// Modulo implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return Half-precision division remainder stored in
+ /// single-precision
+ static expr fmod(float x, float y) {
+ return expr(std::fmod(x, y));
+ }
+
+ /// Remainder implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return Half-precision division remainder stored in
+ /// single-precision
+ static expr remainder(float x, float y) {
+#if HALF_ENABLE_CPP11_CMATH
+ return expr(std::remainder(x, y));
+#else
+ if (builtin_isnan(x) || builtin_isnan(y))
+ return expr(
+ std::numeric_limits<float>::quiet_NaN());
+ float ax = std::fabs(x), ay = std::fabs(y);
+ if (ax >= 65536.0f || ay < std::ldexp(1.0f, -24))
+ return expr(
+ std::numeric_limits<float>::quiet_NaN());
+ if (ay >= 65536.0f) return expr(x);
+ if (ax == ay)
+ return expr(builtin_signbit(x) ? -0.0f : 0.0f);
+ ax = std::fmod(ax, ay + ay);
+ float y2 = 0.5f * ay;
+ if (ax > y2) {
+ ax -= ay;
+ if (ax >= y2) ax -= ay;
+ }
+ return expr(builtin_signbit(x) ? -ax : ax);
+#endif
+ }
+
+ /// Remainder implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \param quo address to store quotient bits at
+ /// \return Half-precision division remainder stored in
+ /// single-precision
+ static expr remquo(float x, float y, int *quo) {
+#if HALF_ENABLE_CPP11_CMATH
+ return expr(std::remquo(x, y, quo));
+#else
+ if (builtin_isnan(x) || builtin_isnan(y))
+ return expr(
+ std::numeric_limits<float>::quiet_NaN());
+ bool sign = builtin_signbit(x),
+ qsign =
+ static_cast<bool>(sign ^ builtin_signbit(y));
+ float ax = std::fabs(x), ay = std::fabs(y);
+ if (ax >= 65536.0f || ay < std::ldexp(1.0f, -24))
+ return expr(
+ std::numeric_limits<float>::quiet_NaN());
+ if (ay >= 65536.0f) return expr(x);
+ if (ax == ay)
+ return *quo = qsign ? -1 : 1,
+ expr(sign ? -0.0f : 0.0f);
+ ax = std::fmod(ax, 8.0f * ay);
+ int cquo = 0;
+ if (ax >= 4.0f * ay) {
+ ax -= 4.0f * ay;
+ cquo += 4;
+ }
+ if (ax >= 2.0f * ay) {
+ ax -= 2.0f * ay;
+ cquo += 2;
+ }
+ float y2 = 0.5f * ay;
+ if (ax > y2) {
+ ax -= ay;
+ ++cquo;
+ if (ax >= y2) {
+ ax -= ay;
+ ++cquo;
+ }
+ }
+ return *quo = qsign ? -cquo : cquo,
+ expr(sign ? -ax : ax);
+#endif
+ }
+
+ /// Positive difference implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return Positive difference stored in
+ /// single-precision
+ static expr fdim(float x, float y)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::fdim(x, y));
+ #else
+ return expr((x <= y) ? 0.0f : (x - y));
+#endif
+ }
+
+ /// Fused multiply-add implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \param z third operand
+ /// \return \a x * \a y + \a z stored in single-precision
+ static expr fma(float x, float y, float z)
+ {
+ #if HALF_ENABLE_CPP11_CMATH && defined(FP_FAST_FMAF)
+ return expr(std::fma(x, y, z));
+ #else
+ return expr(x*y+z);
+ #endif
+ }
+
+ /// Get NaN.
+ /// \return Half-precision quiet NaN
+ static half nanh() { return half(binary, 0x7FFF); }
+
+ /// Exponential implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr exp(float arg) { return expr(std::exp(arg)); }
+
+ /// Exponential implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr expm1(float arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::expm1(arg));
+ #else
+ return expr(static_cast<float>(
+ std::exp(static_cast<double>(arg)) - 1.0));
+#endif
+ }
+
+ /// Binary exponential implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr exp2(float arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::exp2(arg));
+ #else
+ return expr(static_cast<float>(std::exp(
+ arg * 0.69314718055994530941723212145818)));
+#endif
+ }
+
+ /// Logarithm implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr log(float arg) { return expr(std::log(arg)); }
+
+ /// Common logarithm implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr log10(float arg) { return expr(std::log10(arg)); }
+
+ /// Logarithm implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr log1p(float arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::log1p(arg));
+ #else
+ return expr(static_cast<float>(std::log(1.0 + arg)));
+#endif
+ }
+
+ /// Binary logarithm implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr log2(float arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::log2(arg));
+ #else
+ return expr(static_cast<float>(
+ std::log(static_cast<double>(arg)) *
+ 1.4426950408889634073599246810019));
+#endif
+ }
+
+ /// Square root implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr sqrt(float arg) { return expr(std::sqrt(arg)); }
+
+ /// Cubic root implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr cbrt(float arg) {
+#if HALF_ENABLE_CPP11_CMATH
+ return expr(std::cbrt(arg));
+#else
+ if (builtin_isnan(arg) || builtin_isinf(arg))
+ return expr(arg);
+ return expr(
+ builtin_signbit(arg)
+ ? -static_cast<float>(std::pow(
+ -static_cast<double>(arg), 1.0 / 3.0))
+ : static_cast<float>(std::pow(
+ static_cast<double>(arg), 1.0 / 3.0)));
+#endif
+ }
+
+ /// Hypotenuse implementation.
+ /// \param x first argument
+ /// \param y second argument
+ /// \return function value stored in single-preicision
+ static expr hypot(float x, float y)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::hypot(x, y));
+ #else
+ return expr(
+ (builtin_isinf(x) || builtin_isinf(y))
+ ? std::numeric_limits<float>::infinity()
+ : static_cast<float>(
+ std::sqrt(static_cast<double>(x) * x +
+ static_cast<double>(y) * y)));
+#endif
+ }
+
+ /// Power implementation.
+ /// \param base value to exponentiate
+ /// \param exp power to expontiate to
+ /// \return function value stored in single-preicision
+ static expr pow(float base, float exp) { return expr(std::pow(base, exp)); }
+
+ /// Sine implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr sin(float arg) { return expr(std::sin(arg)); }
+
+ /// Cosine implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr cos(float arg) { return expr(std::cos(arg)); }
+
+ /// Tan implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr tan(float arg) { return expr(std::tan(arg)); }
+
+ /// Arc sine implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr asin(float arg) { return expr(std::asin(arg)); }
+
+ /// Arc cosine implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr acos(float arg) { return expr(std::acos(arg)); }
+
+ /// Arc tangent implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr atan(float arg) { return expr(std::atan(arg)); }
+
+ /// Arc tangent implementation.
+ /// \param x first argument
+ /// \param y second argument
+ /// \return function value stored in single-preicision
+ static expr atan2(float x, float y) { return expr(std::atan2(x, y)); }
+
+ /// Hyperbolic sine implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr sinh(float arg) { return expr(std::sinh(arg)); }
+
+ /// Hyperbolic cosine implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr cosh(float arg) { return expr(std::cosh(arg)); }
+
+ /// Hyperbolic tangent implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr tanh(float arg) { return expr(std::tanh(arg)); }
+
+ /// Hyperbolic area sine implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr asinh(float arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::asinh(arg));
+ #else
+ return expr(
+ (arg == -std::numeric_limits<float>::infinity())
+ ? arg
+ : static_cast<float>(std::log(
+ arg + std::sqrt(arg * arg + 1.0))));
+#endif
+ }
+
+ /// Hyperbolic area cosine implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr acosh(float arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::acosh(arg));
+ #else
+ return expr(
+ (arg < -1.0f)
+ ? std::numeric_limits<float>::quiet_NaN()
+ : static_cast<float>(std::log(
+ arg + std::sqrt(arg * arg - 1.0))));
+#endif
+ }
+
+ /// Hyperbolic area tangent implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr atanh(float arg)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::atanh(arg));
+ #else
+ return expr(static_cast<float>(
+ 0.5 * std::log((1.0 + arg) / (1.0 - arg))));
+#endif
+ }
+
+ /// Error function implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr erf(float arg) {
+#if HALF_ENABLE_CPP11_CMATH
+ return expr(std::erf(arg));
+#else
+ return expr(static_cast<float>(
+ erf(static_cast<double>(arg))));
+#endif
+ }
+
+ /// Complementary implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr erfc(float arg) {
+#if HALF_ENABLE_CPP11_CMATH
+ return expr(std::erfc(arg));
+#else
+ return expr(static_cast<float>(
+ 1.0 - erf(static_cast<double>(arg))));
+#endif
+ }
+
+ /// Gamma logarithm implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr lgamma(float arg) {
+#if HALF_ENABLE_CPP11_CMATH
+ return expr(std::lgamma(arg));
+#else
+ if (builtin_isinf(arg))
+ return expr(std::numeric_limits<float>::infinity());
+ if (arg < 0.0f) {
+ float i, f = std::modf(-arg, &i);
+ if (f == 0.0f)
+ return expr(
+ std::numeric_limits<float>::infinity());
+ return expr(static_cast<float>(
+ 1.1447298858494001741434273513531 -
+ std::log(std::abs(std::sin(
+ 3.1415926535897932384626433832795 * f))) -
+ lgamma(1.0 - arg)));
+ }
+ return expr(static_cast<float>(
+ lgamma(static_cast<double>(arg))));
+#endif
+ }
+
+ /// Gamma implementation.
+ /// \param arg function argument
+ /// \return function value stored in single-preicision
+ static expr tgamma(float arg) {
+#if HALF_ENABLE_CPP11_CMATH
+ return expr(std::tgamma(arg));
+#else
+ if (arg == 0.0f)
+ return builtin_signbit(arg)
+ ? expr(-std::numeric_limits<
+ float>::infinity())
+ : expr(std::numeric_limits<
+ float>::infinity());
+ if (arg < 0.0f) {
+ float i, f = std::modf(-arg, &i);
+ if (f == 0.0f)
+ return expr(
+ std::numeric_limits<float>::quiet_NaN());
+ double value =
+ 3.1415926535897932384626433832795 /
+ (std::sin(3.1415926535897932384626433832795 *
+ f) *
+ std::exp(lgamma(1.0 - arg)));
+ return expr(static_cast<float>(
+ (std::fmod(i, 2.0f) == 0.0f) ? -value : value));
+ }
+ if (builtin_isinf(arg)) return expr(arg);
+ return expr(static_cast<float>(
+ std::exp(lgamma(static_cast<double>(arg)))));
+#endif
+ }
+
+ /// Floor implementation.
+ /// \param arg value to round
+ /// \return rounded value
+ static half floor(half arg) {
+ return half(
+ binary,
+ round_half<std::round_toward_neg_infinity>(
+ arg.data_));
+ }
+
+ /// Ceiling implementation.
+ /// \param arg value to round
+ /// \return rounded value
+ static half ceil(half arg) {
+ return half(binary,
+ round_half<std::round_toward_infinity>(
+ arg.data_));
+ }
+
+ /// Truncation implementation.
+ /// \param arg value to round
+ /// \return rounded value
+ static half trunc(half arg) {
+ return half(
+ binary,
+ round_half<std::round_toward_zero>(arg.data_));
+ }
+
+ /// Nearest integer implementation.
+ /// \param arg value to round
+ /// \return rounded value
+ static half round(half arg) {
+ return half(binary, round_half_up(arg.data_));
+ }
+
+ /// Nearest integer implementation.
+ /// \param arg value to round
+ /// \return rounded value
+ static long lround(half arg) {
+ return detail::half2int_up<long>(arg.data_);
+ }
+
+ /// Nearest integer implementation.
+ /// \param arg value to round
+ /// \return rounded value
+ static half rint(half arg) {
+ return half(binary,
+ round_half<half::round_style>(arg.data_));
+ }
+
+ /// Nearest integer implementation.
+ /// \param arg value to round
+ /// \return rounded value
+ static long lrint(half arg) {
+ return detail::half2int<half::round_style, long>(
+ arg.data_);
+ }
+
+#if HALF_ENABLE_CPP11_LONG_LONG
+ /// Nearest integer implementation.
+ /// \param arg value to round
+ /// \return rounded value
+ static long long llround(half arg) {
+ return detail::half2int_up<long long>(arg.data_);
+ }
+
+ /// Nearest integer implementation.
+ /// \param arg value to round
+ /// \return rounded value
+ static long long llrint(half arg) {
+ return detail::half2int<half::round_style, long long>(
+ arg.data_);
+ }
+#endif
+
+ /// Decompression implementation.
+ /// \param arg number to decompress
+ /// \param exp address to store exponent at
+ /// \return normalized significant
+ static half frexp(half arg, int *exp)
+ {
+ int m = arg.data_ & 0x7FFF, e = -14;
+ if (m >= 0x7C00 || !m) return *exp = 0, arg;
+ for (; m < 0x400; m <<= 1, --e)
+ ;
+ return *exp = e + (m >> 10),
+ half(binary, (arg.data_ & 0x8000) | 0x3800 |
+ (m & 0x3FF));
+ }
+
+ /// Decompression implementation.
+ /// \param arg number to decompress
+ /// \param iptr address to store integer part at
+ /// \return fractional part
+ static half modf(half arg, half *iptr)
+ {
+ unsigned int e = arg.data_ & 0x7FFF;
+ if (e >= 0x6400)
+ return *iptr = arg,
+ half(binary,
+ arg.data_ & (0x8000U | -(e > 0x7C00)));
+ if (e < 0x3C00)
+ return iptr->data_ = arg.data_ & 0x8000, arg;
+ e >>= 10;
+ unsigned int mask = (1 << (25 - e)) - 1,
+ m = arg.data_ & mask;
+ iptr->data_ = arg.data_ & ~mask;
+ if (!m) return half(binary, arg.data_ & 0x8000);
+ for (; m < 0x400; m <<= 1, --e)
+ ;
+ return half(binary, static_cast<uint16>(
+ (arg.data_ & 0x8000) |
+ (e << 10) | (m & 0x3FF)));
+ }
+
+ /// Scaling implementation.
+ /// \param arg number to scale
+ /// \param exp power of two to scale by
+ /// \return scaled number
+ static half scalbln(half arg, long exp)
+ {
+ unsigned int m = arg.data_ & 0x7FFF;
+ if (m >= 0x7C00 || !m) return arg;
+ for (; m < 0x400; m <<= 1, --exp)
+ ;
+ exp += m >> 10;
+ uint16 value = arg.data_ & 0x8000;
+ if (exp > 30) {
+ if (half::round_style == std::round_toward_zero)
+ value |= 0x7BFF;
+ else if (half::round_style ==
+ std::round_toward_infinity)
+ value |= 0x7C00 - (value >> 15);
+ else if (half::round_style ==
+ std::round_toward_neg_infinity)
+ value |= 0x7BFF + (value >> 15);
+ else
+ value |= 0x7C00;
+ } else if (exp > 0)
+ value |= (exp << 10) | (m & 0x3FF);
+ else if (exp > -11) {
+ m = (m & 0x3FF) | 0x400;
+ if (half::round_style == std::round_to_nearest) {
+ m += 1 << -exp;
+#if HALF_ROUND_TIES_TO_EVEN
+ m -= (m >> (1 - exp)) & 1;
+#endif
+ } else if (half::round_style ==
+ std::round_toward_infinity)
+ m +=
+ ((value >> 15) - 1) & ((1 << (1 - exp)) - 1U);
+ else if (half::round_style ==
+ std::round_toward_neg_infinity)
+ m += -(value >> 15) & ((1 << (1 - exp)) - 1U);
+ value |= m >> (1 - exp);
+ } else if (half::round_style ==
+ std::round_toward_infinity)
+ value -= (value >> 15) - 1;
+ else if (half::round_style ==
+ std::round_toward_neg_infinity)
+ value += value >> 15;
+ return half(binary, value);
+ }
+
+ /// Exponent implementation.
+ /// \param arg number to query
+ /// \return floating point exponent
+ static int ilogb(half arg)
+ {
+ int abs = arg.data_ & 0x7FFF;
+ if (!abs) return FP_ILOGB0;
+ if (abs < 0x7C00) {
+ int exp = (abs >> 10) - 15;
+ if (abs < 0x400)
+ for (; abs < 0x200; abs <<= 1, --exp)
+ ;
+ return exp;
+ }
+ if (abs > 0x7C00) return FP_ILOGBNAN;
+ return INT_MAX;
+ }
+
+ /// Exponent implementation.
+ /// \param arg number to query
+ /// \return floating point exponent
+ static half logb(half arg)
+ {
+ int abs = arg.data_ & 0x7FFF;
+ if (!abs) return half(binary, 0xFC00);
+ if (abs < 0x7C00) {
+ int exp = (abs >> 10) - 15;
+ if (abs < 0x400)
+ for (; abs < 0x200; abs <<= 1, --exp)
+ ;
+ uint16 bits = (exp < 0) << 15;
+ if (exp) {
+ unsigned int m = std::abs(exp) << 6, e = 18;
+ for (; m < 0x400; m <<= 1, --e)
+ ;
+ bits |= (e << 10) + m;
+ }
+ return half(binary, bits);
+ }
+ if (abs > 0x7C00) return arg;
+ return half(binary, 0x7C00);
+ }
+
+ /// Enumeration implementation.
+ /// \param from number to increase/decrease
+ /// \param to direction to enumerate into
+ /// \return next representable number
+ static half nextafter(half from, half to)
+ {
+ uint16 fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF;
+ if(fabs > 0x7C00)
+ return from;
+ if (tabs > 0x7C00 || from.data_ == to.data_ ||
+ !(fabs | tabs))
+ return to;
+ if(!fabs)
+ return half(binary, (to.data_ & 0x8000) + 1);
+ bool lt = ((fabs == from.data_)
+ ? static_cast<int>(fabs)
+ : -static_cast<int>(fabs)) <
+ ((tabs == to.data_)
+ ? static_cast<int>(tabs)
+ : -static_cast<int>(tabs));
+ return half(binary,
+ from.data_ +
+ (((from.data_ >> 15) ^
+ static_cast<unsigned>(lt))
+ << 1) -
+ 1);
+ }
+
+ /// Enumeration implementation.
+ /// \param from number to increase/decrease
+ /// \param to direction to enumerate into
+ /// \return next representable number
+ static half nexttoward(half from, long double to)
+ {
+ if(isnan(from))
+ return from;
+ long double lfrom = static_cast<long double>(from);
+ if(builtin_isnan(to) || lfrom == to)
+ return half(static_cast<float>(to));
+ if(!(from.data_&0x7FFF))
+ return half(binary,
+ (static_cast<detail::uint16>(
+ builtin_signbit(to))
+ << 15) +
+ 1);
+ return half(
+ binary,
+ from.data_ +
+ (((from.data_ >> 15) ^
+ static_cast<unsigned>(lfrom < to))
+ << 1) -
+ 1);
+ }
+
+ /// Sign implementation
+ /// \param x first operand
+ /// \param y second operand
+ /// \return composed value
+ static half copysign(half x, half y) {
+ return half(binary,
+ x.data_ ^ ((x.data_ ^ y.data_) & 0x8000));
+ }
+
+ /// Classification implementation.
+ /// \param arg value to classify
+ /// \retval true if infinite number
+ /// \retval false else
+ static int fpclassify(half arg)
+ {
+ unsigned int abs = arg.data_ & 0x7FFF;
+ return abs ? ((abs > 0x3FF)
+ ? ((abs >= 0x7C00)
+ ? ((abs > 0x7C00)
+ ? FP_NAN
+ : FP_INFINITE)
+ : FP_NORMAL)
+ : FP_SUBNORMAL)
+ : FP_ZERO;
+ }
+
+ /// Classification implementation.
+ /// \param arg value to classify
+ /// \retval true if finite number
+ /// \retval false else
+ static bool isfinite(half arg) { return (arg.data_&0x7C00) != 0x7C00; }
+
+ /// Classification implementation.
+ /// \param arg value to classify
+ /// \retval true if infinite number
+ /// \retval false else
+ static bool isinf(half arg) { return (arg.data_&0x7FFF) == 0x7C00; }
+
+ /// Classification implementation.
+ /// \param arg value to classify
+ /// \retval true if not a number
+ /// \retval false else
+ static bool isnan(half arg) { return (arg.data_&0x7FFF) > 0x7C00; }
+
+ /// Classification implementation.
+ /// \param arg value to classify
+ /// \retval true if normal number
+ /// \retval false else
+ static bool isnormal(half arg) { return ((arg.data_&0x7C00)!=0) & ((arg.data_&0x7C00)!=0x7C00); }
+
+ /// Sign bit implementation.
+ /// \param arg value to check
+ /// \retval true if signed
+ /// \retval false if unsigned
+ static bool signbit(half arg) { return (arg.data_&0x8000) != 0; }
+
+ /// Comparison implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if operands equal
+ /// \retval false else
+ static bool isequal(half x, half y) { return (x.data_==y.data_ || !((x.data_|y.data_)&0x7FFF)) && !isnan(x); }
+
+ /// Comparison implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if operands not equal
+ /// \retval false else
+ static bool isnotequal(half x, half y) { return (x.data_!=y.data_ && ((x.data_|y.data_)&0x7FFF)) || isnan(x); }
+
+ /// Comparison implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x > \a y
+ /// \retval false else
+ static bool isgreater(half x, half y) {
+ int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
+ return xabs <= 0x7C00 && yabs <= 0x7C00 &&
+ (((xabs == x.data_) ? xabs : -xabs) >
+ ((yabs == y.data_) ? yabs : -yabs));
+ }
+
+ /// Comparison implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x >= \a y
+ /// \retval false else
+ static bool isgreaterequal(half x, half y) {
+ int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
+ return xabs <= 0x7C00 && yabs <= 0x7C00 &&
+ (((xabs == x.data_) ? xabs : -xabs) >=
+ ((yabs == y.data_) ? yabs : -yabs));
+ }
+
+ /// Comparison implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x < \a y
+ /// \retval false else
+ static bool isless(half x, half y) {
+ int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
+ return xabs <= 0x7C00 && yabs <= 0x7C00 &&
+ (((xabs == x.data_) ? xabs : -xabs) <
+ ((yabs == y.data_) ? yabs : -yabs));
+ }
+
+ /// Comparison implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x <= \a y
+ /// \retval false else
+ static bool islessequal(half x, half y) {
+ int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
+ return xabs <= 0x7C00 && yabs <= 0x7C00 &&
+ (((xabs == x.data_) ? xabs : -xabs) <=
+ ((yabs == y.data_) ? yabs : -yabs));
+ }
+
+ /// Comparison implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if either \a x > \a y nor \a x < \a y
+ /// \retval false else
+ static bool islessgreater(half x, half y) {
+ int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
+ if (xabs > 0x7C00 || yabs > 0x7C00) return false;
+ int a = (xabs == x.data_) ? xabs : -xabs,
+ b = (yabs == y.data_) ? yabs : -yabs;
+ return a < b || a > b;
+ }
+
+ /// Comparison implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if operand unordered
+ /// \retval false else
+ static bool isunordered(half x, half y) { return isnan(x) || isnan(y); }
+
+ private:
+ static double erf(double arg) {
+ if (builtin_isinf(arg))
+ return (arg < 0.0) ? -1.0 : 1.0;
+ double x2 = arg * arg, ax2 = 0.147 * x2,
+ value = std::sqrt(
+ 1.0 -
+ std::exp(
+ -x2 *
+ (1.2732395447351626861510701069801 +
+ ax2) /
+ (1.0 + ax2)));
+ return builtin_signbit(arg) ? -value : value;
+ }
+
+ static double lgamma(double arg) {
+ double v = 1.0;
+ for (; arg < 8.0; ++arg) v *= arg;
+ double w = 1.0 / (arg * arg);
+ return (((((((-0.02955065359477124183006535947712 *
+ w +
+ 0.00641025641025641025641025641026) *
+ w +
+ -0.00191752691752691752691752691753) *
+ w +
+ 8.4175084175084175084175084175084e-4) *
+ w +
+ -5.952380952380952380952380952381e-4) *
+ w +
+ 7.9365079365079365079365079365079e-4) *
+ w +
+ -0.00277777777777777777777777777778) *
+ w +
+ 0.08333333333333333333333333333333) /
+ arg +
+ 0.91893853320467274178032973640562 -
+ std::log(v) - arg +
+ (arg - 0.5) * std::log(arg);
+ }
+ };
+
+ /// Wrapper for unary half-precision functions needing specialization for individual argument types.
+ /// \tparam T argument type
+ template<typename T> struct unary_specialized
+ {
+ /// Negation implementation.
+ /// \param arg value to negate
+ /// \return negated value
+ static HALF_CONSTEXPR half negate(half arg) {
+ return half(binary, arg.data_ ^ 0x8000);
+ }
+
+ /// Absolute value implementation.
+ /// \param arg function argument
+ /// \return absolute value
+ static half fabs(half arg) {
+ return half(binary, arg.data_ & 0x7FFF);
+ }
+ };
+ template<> struct unary_specialized<expr>
+ {
+ static HALF_CONSTEXPR expr negate(float arg) { return expr(-arg); }
+ static expr fabs(float arg) { return expr(std::fabs(arg)); }
+ };
+
+ /// Wrapper for binary half-precision functions needing specialization for individual argument types.
+ /// \tparam T first argument type
+ /// \tparam U first argument type
+ template<typename T,typename U> struct binary_specialized
+ {
+ /// Minimum implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return minimum value
+ static expr fmin(float x, float y)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::fmin(x, y));
+ #else
+ if (builtin_isnan(x)) return expr(y);
+ if (builtin_isnan(y)) return expr(x);
+ return expr(std::min(x, y));
+#endif
+ }
+
+ /// Maximum implementation.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return maximum value
+ static expr fmax(float x, float y)
+ {
+ #if HALF_ENABLE_CPP11_CMATH
+ return expr(std::fmax(x, y));
+ #else
+ if (builtin_isnan(x)) return expr(y);
+ if (builtin_isnan(y)) return expr(x);
+ return expr(std::max(x, y));
+#endif
+ }
+ };
+ template<> struct binary_specialized<half,half>
+ {
+ static half fmin(half x, half y) {
+ int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
+ if (xabs > 0x7C00) return y;
+ if (yabs > 0x7C00) return x;
+ return (((xabs == x.data_) ? xabs : -xabs) >
+ ((yabs == y.data_) ? yabs : -yabs))
+ ? y
+ : x;
+ }
+ static half fmax(half x, half y) {
+ int xabs = x.data_ & 0x7FFF, yabs = y.data_ & 0x7FFF;
+ if (xabs > 0x7C00) return y;
+ if (yabs > 0x7C00) return x;
+ return (((xabs == x.data_) ? xabs : -xabs) <
+ ((yabs == y.data_) ? yabs : -yabs))
+ ? y
+ : x;
+ }
+ };
+
+ /// Helper class for half casts.
+ /// This class template has to be specialized for all valid cast
+ /// argument to define an appropriate static `cast` member
+ /// function and a corresponding `type` member denoting its
+ /// return type. \tparam T destination type \tparam U source
+ /// type \tparam R rounding mode to use
+ template <typename T, typename U,
+ std::float_round_style R =
+ (std::float_round_style)(HALF_ROUND_STYLE)>
+ struct half_caster {};
+ template <typename U, std::float_round_style R>
+ struct half_caster<half, U, R> {
+#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS
+ static_assert(
+ std::is_arithmetic<U>::value,
+ "half_cast from non-arithmetic type unsupported");
+#endif
+
+ static half cast(U arg) {
+ return cast_impl(arg, is_float<U>());
+ };
+
+ private:
+ static half cast_impl(U arg, true_type) {
+ return half(binary, float2half<R>(arg));
+ }
+ static half cast_impl(U arg, false_type) {
+ return half(binary, int2half<R>(arg));
+ }
+ };
+ template<typename T,std::float_round_style R> struct half_caster<T,half,R>
+ {
+#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS
+ static_assert(std::is_arithmetic<T>::value,
+ "half_cast to non-arithmetic type unsupported");
+#endif
+
+ static T cast(half arg) {
+ return cast_impl(arg, is_float<T>());
+ }
+
+ private:
+ static T cast_impl(half arg, true_type) {
+ return half2float<T>(arg.data_);
+ }
+ static T cast_impl(half arg, false_type) {
+ return half2int<R, T>(arg.data_);
+ }
+ };
+ template <typename T, std::float_round_style R>
+ struct half_caster<T, expr, R> {
+#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS
+ static_assert(std::is_arithmetic<T>::value,
+ "half_cast to non-arithmetic type unsupported");
+#endif
+
+ static T cast(expr arg) {
+ return cast_impl(arg, is_float<T>());
+ }
+
+ private:
+ static T cast_impl(float arg, true_type) {
+ return static_cast<T>(arg);
+ }
+ static T cast_impl(half arg, false_type) {
+ return half2int<R, T>(arg.data_);
+ }
+ };
+ template <std::float_round_style R>
+ struct half_caster<half, half, R> {
+ static half cast(half arg) { return arg; }
+ };
+ template <std::float_round_style R>
+ struct half_caster<half, expr, R> : half_caster<half, half, R> {
+ };
+
+ /// \name Comparison operators
+ /// \{
+
+ /// Comparison for equality.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if operands equal
+ /// \retval false else
+ template<typename T,typename U> typename enable<bool,T,U>::type operator==(T x, U y) { return functions::isequal(x, y); }
+
+ /// Comparison for inequality.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if operands not equal
+ /// \retval false else
+ template<typename T,typename U> typename enable<bool,T,U>::type operator!=(T x, U y) { return functions::isnotequal(x, y); }
+
+ /// Comparison for less than.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x less than \a y
+ /// \retval false else
+ template<typename T,typename U> typename enable<bool,T,U>::type operator<(T x, U y) { return functions::isless(x, y); }
+
+ /// Comparison for greater than.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x greater than \a y
+ /// \retval false else
+ template<typename T,typename U> typename enable<bool,T,U>::type operator>(T x, U y) { return functions::isgreater(x, y); }
+
+ /// Comparison for less equal.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x less equal \a y
+ /// \retval false else
+ template<typename T,typename U> typename enable<bool,T,U>::type operator<=(T x, U y) { return functions::islessequal(x, y); }
+
+ /// Comparison for greater equal.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x greater equal \a y
+ /// \retval false else
+ template<typename T,typename U> typename enable<bool,T,U>::type operator>=(T x, U y) { return functions::isgreaterequal(x, y); }
+
+ /// \}
+ /// \name Arithmetic operators
+ /// \{
+
+ /// Add halfs.
+ /// \param x left operand
+ /// \param y right operand
+ /// \return sum of half expressions
+ template<typename T,typename U> typename enable<expr,T,U>::type operator+(T x, U y) { return functions::plus(x, y); }
+
+ /// Subtract halfs.
+ /// \param x left operand
+ /// \param y right operand
+ /// \return difference of half expressions
+ template<typename T,typename U> typename enable<expr,T,U>::type operator-(T x, U y) { return functions::minus(x, y); }
+
+ /// Multiply halfs.
+ /// \param x left operand
+ /// \param y right operand
+ /// \return product of half expressions
+ template<typename T,typename U> typename enable<expr,T,U>::type operator*(T x, U y) { return functions::multiplies(x, y); }
+
+ /// Divide halfs.
+ /// \param x left operand
+ /// \param y right operand
+ /// \return quotient of half expressions
+ template<typename T,typename U> typename enable<expr,T,U>::type operator/(T x, U y) { return functions::divides(x, y); }
+
+ /// Identity.
+ /// \param arg operand
+ /// \return uncahnged operand
+ template<typename T> HALF_CONSTEXPR typename enable<T,T>::type operator+(T arg) { return arg; }
+
+ /// Negation.
+ /// \param arg operand
+ /// \return negated operand
+ template<typename T> HALF_CONSTEXPR typename enable<T,T>::type operator-(T arg) { return unary_specialized<T>::negate(arg); }
+
+ /// \}
+ /// \name Input and output
+ /// \{
+
+ /// Output operator.
+ /// \param out output stream to write into
+ /// \param arg half expression to write
+ /// \return reference to output stream
+ template<typename T,typename charT,typename traits> typename enable<std::basic_ostream<charT,traits>&,T>::type
+ operator<<(std::basic_ostream<charT,traits> &out, T arg) { return functions::write(out, arg); }
+
+ /// Input operator.
+ /// \param in input stream to read from
+ /// \param arg half to read into
+ /// \return reference to input stream
+ template<typename charT,typename traits> std::basic_istream<charT,traits>&
+ operator>>(std::basic_istream<charT,traits> &in, half &arg) { return functions::read(in, arg); }
+
+ /// \}
+ /// \name Basic mathematical operations
+ /// \{
+
+ /// Absolute value.
+ /// \param arg operand
+ /// \return absolute value of \a arg
+// template<typename T> typename enable<T,T>::type abs(T arg) { return unary_specialized<T>::fabs(arg); }
+ inline half abs(half arg) { return unary_specialized<half>::fabs(arg); }
+ inline expr abs(expr arg) { return unary_specialized<expr>::fabs(arg); }
+
+ /// Absolute value.
+ /// \param arg operand
+ /// \return absolute value of \a arg
+// template<typename T> typename enable<T,T>::type fabs(T arg) { return unary_specialized<T>::fabs(arg); }
+ inline half fabs(half arg) { return unary_specialized<half>::fabs(arg); }
+ inline expr fabs(expr arg) { return unary_specialized<expr>::fabs(arg); }
+
+ /// Remainder of division.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return remainder of floating point division.
+// template<typename T,typename U> typename enable<expr,T,U>::type fmod(T x, U y) { return functions::fmod(x, y); }
+ inline expr fmod(half x, half y) { return functions::fmod(x, y); }
+ inline expr fmod(half x, expr y) { return functions::fmod(x, y); }
+ inline expr fmod(expr x, half y) { return functions::fmod(x, y); }
+ inline expr fmod(expr x, expr y) { return functions::fmod(x, y); }
+
+ /// Remainder of division.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return remainder of floating point division.
+ // template<typename T,typename U> typename
+ // enable<expr,T,U>::type remainder(T x, U y) { return
+ // functions::remainder(x, y); }
+ inline expr remainder(half x, half y) {
+ return functions::remainder(x, y);
+ }
+ inline expr remainder(half x, expr y) {
+ return functions::remainder(x, y);
+ }
+ inline expr remainder(expr x, half y) {
+ return functions::remainder(x, y);
+ }
+ inline expr remainder(expr x, expr y) {
+ return functions::remainder(x, y);
+ }
+
+ /// Remainder of division.
+ /// \param x first operand
+ /// \param y second operand
+ /// \param quo address to store some bits of quotient at
+ /// \return remainder of floating point division.
+ // template<typename T,typename U> typename
+ // enable<expr,T,U>::type remquo(T x, U y, int *quo) { return
+ // functions::remquo(x, y, quo); }
+ inline expr remquo(half x, half y, int *quo) {
+ return functions::remquo(x, y, quo);
+ }
+ inline expr remquo(half x, expr y, int *quo) {
+ return functions::remquo(x, y, quo);
+ }
+ inline expr remquo(expr x, half y, int *quo) {
+ return functions::remquo(x, y, quo);
+ }
+ inline expr remquo(expr x, expr y, int *quo) {
+ return functions::remquo(x, y, quo);
+ }
+
+ /// Fused multiply add.
+ /// \param x first operand
+ /// \param y second operand
+ /// \param z third operand
+ /// \return ( \a x * \a y ) + \a z rounded as one operation.
+ // template<typename T,typename U,typename V>
+ //typename enable<expr,T,U,V>::type fma(T x, U y, V z) { return
+ //functions::fma(x, y, z); }
+ inline expr fma(half x, half y, half z) { return functions::fma(x, y, z); }
+ inline expr fma(half x, half y, expr z) { return functions::fma(x, y, z); }
+ inline expr fma(half x, expr y, half z) { return functions::fma(x, y, z); }
+ inline expr fma(half x, expr y, expr z) { return functions::fma(x, y, z); }
+ inline expr fma(expr x, half y, half z) { return functions::fma(x, y, z); }
+ inline expr fma(expr x, half y, expr z) { return functions::fma(x, y, z); }
+ inline expr fma(expr x, expr y, half z) { return functions::fma(x, y, z); }
+ inline expr fma(expr x, expr y, expr z) { return functions::fma(x, y, z); }
+
+ /// Maximum of half expressions.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return maximum of operands
+// template<typename T,typename U> typename result<T,U>::type fmax(T x, U y) { return binary_specialized<T,U>::fmax(x, y); }
+ inline half fmax(half x, half y) { return binary_specialized<half,half>::fmax(x, y); }
+ inline expr fmax(half x, expr y) { return binary_specialized<half,expr>::fmax(x, y); }
+ inline expr fmax(expr x, half y) { return binary_specialized<expr,half>::fmax(x, y); }
+ inline expr fmax(expr x, expr y) { return binary_specialized<expr,expr>::fmax(x, y); }
+
+ /// Minimum of half expressions.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return minimum of operands
+// template<typename T,typename U> typename result<T,U>::type fmin(T x, U y) { return binary_specialized<T,U>::fmin(x, y); }
+ inline half fmin(half x, half y) { return binary_specialized<half,half>::fmin(x, y); }
+ inline expr fmin(half x, expr y) { return binary_specialized<half,expr>::fmin(x, y); }
+ inline expr fmin(expr x, half y) { return binary_specialized<expr,half>::fmin(x, y); }
+ inline expr fmin(expr x, expr y) { return binary_specialized<expr,expr>::fmin(x, y); }
+
+ /// Positive difference.
+ /// \param x first operand
+ /// \param y second operand
+ /// \return \a x - \a y or 0 if difference negative
+// template<typename T,typename U> typename enable<expr,T,U>::type fdim(T x, U y) { return functions::fdim(x, y); }
+ inline expr fdim(half x, half y) { return functions::fdim(x, y); }
+ inline expr fdim(half x, expr y) { return functions::fdim(x, y); }
+ inline expr fdim(expr x, half y) { return functions::fdim(x, y); }
+ inline expr fdim(expr x, expr y) { return functions::fdim(x, y); }
+
+ /// Get NaN value.
+ /// \return quiet NaN
+ inline half nanh(const char *) { return functions::nanh(); }
+
+ /// \}
+ /// \name Exponential functions
+ /// \{
+
+ /// Exponential function.
+ /// \param arg function argument
+ /// \return e raised to \a arg
+// template<typename T> typename enable<expr,T>::type exp(T arg) { return functions::exp(arg); }
+ inline expr exp(half arg) { return functions::exp(arg); }
+ inline expr exp(expr arg) { return functions::exp(arg); }
+
+ /// Exponential minus one.
+ /// \param arg function argument
+ /// \return e raised to \a arg subtracted by 1
+// template<typename T> typename enable<expr,T>::type expm1(T arg) { return functions::expm1(arg); }
+ inline expr expm1(half arg) { return functions::expm1(arg); }
+ inline expr expm1(expr arg) { return functions::expm1(arg); }
+
+ /// Binary exponential.
+ /// \param arg function argument
+ /// \return 2 raised to \a arg
+// template<typename T> typename enable<expr,T>::type exp2(T arg) { return functions::exp2(arg); }
+ inline expr exp2(half arg) { return functions::exp2(arg); }
+ inline expr exp2(expr arg) { return functions::exp2(arg); }
+
+ /// Natural logorithm.
+ /// \param arg function argument
+ /// \return logarithm of \a arg to base e
+// template<typename T> typename enable<expr,T>::type log(T arg) { return functions::log(arg); }
+ inline expr log(half arg) { return functions::log(arg); }
+ inline expr log(expr arg) { return functions::log(arg); }
+
+ /// Common logorithm.
+ /// \param arg function argument
+ /// \return logarithm of \a arg to base 10
+// template<typename T> typename enable<expr,T>::type log10(T arg) { return functions::log10(arg); }
+ inline expr log10(half arg) { return functions::log10(arg); }
+ inline expr log10(expr arg) { return functions::log10(arg); }
+
+ /// Natural logorithm.
+ /// \param arg function argument
+ /// \return logarithm of \a arg plus 1 to base e
+// template<typename T> typename enable<expr,T>::type log1p(T arg) { return functions::log1p(arg); }
+ inline expr log1p(half arg) { return functions::log1p(arg); }
+ inline expr log1p(expr arg) { return functions::log1p(arg); }
+
+ /// Binary logorithm.
+ /// \param arg function argument
+ /// \return logarithm of \a arg to base 2
+// template<typename T> typename enable<expr,T>::type log2(T arg) { return functions::log2(arg); }
+ inline expr log2(half arg) { return functions::log2(arg); }
+ inline expr log2(expr arg) { return functions::log2(arg); }
+
+ /// \}
+ /// \name Power functions
+ /// \{
+
+ /// Square root.
+ /// \param arg function argument
+ /// \return square root of \a arg
+// template<typename T> typename enable<expr,T>::type sqrt(T arg) { return functions::sqrt(arg); }
+ inline expr sqrt(half arg) { return functions::sqrt(arg); }
+ inline expr sqrt(expr arg) { return functions::sqrt(arg); }
+
+ /// Cubic root.
+ /// \param arg function argument
+ /// \return cubic root of \a arg
+ // template<typename T> typename
+ //enable<expr,T>::type cbrt(T arg) { return
+ // functions::cbrt(arg); }
+ inline expr cbrt(half arg) { return functions::cbrt(arg); }
+ inline expr cbrt(expr arg) { return functions::cbrt(arg); }
+
+ /// Hypotenuse function.
+ /// \param x first argument
+ /// \param y second argument
+ /// \return square root of sum of squares without internal over-
+ /// or underflows
+ // template<typename T,typename U> typename
+ //enable<expr,T,U>::type hypot(T x, U y) { return
+ //functions::hypot(x, y); }
+ inline expr hypot(half x, half y) { return functions::hypot(x, y); }
+ inline expr hypot(half x, expr y) { return functions::hypot(x, y); }
+ inline expr hypot(expr x, half y) { return functions::hypot(x, y); }
+ inline expr hypot(expr x, expr y) { return functions::hypot(x, y); }
+
+ /// Power function.
+ /// \param base first argument
+ /// \param exp second argument
+ /// \return \a base raised to \a exp
+// template<typename T,typename U> typename enable<expr,T,U>::type pow(T base, U exp) { return functions::pow(base, exp); }
+ inline expr pow(half base, half exp) { return functions::pow(base, exp); }
+ inline expr pow(half base, expr exp) { return functions::pow(base, exp); }
+ inline expr pow(expr base, half exp) { return functions::pow(base, exp); }
+ inline expr pow(expr base, expr exp) { return functions::pow(base, exp); }
+
+ /// \}
+ /// \name Trigonometric functions
+ /// \{
+
+ /// Sine function.
+ /// \param arg function argument
+ /// \return sine value of \a arg
+// template<typename T> typename enable<expr,T>::type sin(T arg) { return functions::sin(arg); }
+ inline expr sin(half arg) { return functions::sin(arg); }
+ inline expr sin(expr arg) { return functions::sin(arg); }
+
+ /// Cosine function.
+ /// \param arg function argument
+ /// \return cosine value of \a arg
+// template<typename T> typename enable<expr,T>::type cos(T arg) { return functions::cos(arg); }
+ inline expr cos(half arg) { return functions::cos(arg); }
+ inline expr cos(expr arg) { return functions::cos(arg); }
+
+ /// Tangent function.
+ /// \param arg function argument
+ /// \return tangent value of \a arg
+// template<typename T> typename enable<expr,T>::type tan(T arg) { return functions::tan(arg); }
+ inline expr tan(half arg) { return functions::tan(arg); }
+ inline expr tan(expr arg) { return functions::tan(arg); }
+
+ /// Arc sine.
+ /// \param arg function argument
+ /// \return arc sine value of \a arg
+// template<typename T> typename enable<expr,T>::type asin(T arg) { return functions::asin(arg); }
+ inline expr asin(half arg) { return functions::asin(arg); }
+ inline expr asin(expr arg) { return functions::asin(arg); }
+
+ /// Arc cosine function.
+ /// \param arg function argument
+ /// \return arc cosine value of \a arg
+// template<typename T> typename enable<expr,T>::type acos(T arg) { return functions::acos(arg); }
+ inline expr acos(half arg) { return functions::acos(arg); }
+ inline expr acos(expr arg) { return functions::acos(arg); }
+
+ /// Arc tangent function.
+ /// \param arg function argument
+ /// \return arc tangent value of \a arg
+// template<typename T> typename enable<expr,T>::type atan(T arg) { return functions::atan(arg); }
+ inline expr atan(half arg) { return functions::atan(arg); }
+ inline expr atan(expr arg) { return functions::atan(arg); }
+
+ /// Arc tangent function.
+ /// \param x first argument
+ /// \param y second argument
+ /// \return arc tangent value
+// template<typename T,typename U> typename enable<expr,T,U>::type atan2(T x, U y) { return functions::atan2(x, y); }
+ inline expr atan2(half x, half y) { return functions::atan2(x, y); }
+ inline expr atan2(half x, expr y) { return functions::atan2(x, y); }
+ inline expr atan2(expr x, half y) { return functions::atan2(x, y); }
+ inline expr atan2(expr x, expr y) { return functions::atan2(x, y); }
+
+ /// \}
+ /// \name Hyperbolic functions
+ /// \{
+
+ /// Hyperbolic sine.
+ /// \param arg function argument
+ /// \return hyperbolic sine value of \a arg
+// template<typename T> typename enable<expr,T>::type sinh(T arg) { return functions::sinh(arg); }
+ inline expr sinh(half arg) { return functions::sinh(arg); }
+ inline expr sinh(expr arg) { return functions::sinh(arg); }
+
+ /// Hyperbolic cosine.
+ /// \param arg function argument
+ /// \return hyperbolic cosine value of \a arg
+// template<typename T> typename enable<expr,T>::type cosh(T arg) { return functions::cosh(arg); }
+ inline expr cosh(half arg) { return functions::cosh(arg); }
+ inline expr cosh(expr arg) { return functions::cosh(arg); }
+
+ /// Hyperbolic tangent.
+ /// \param arg function argument
+ /// \return hyperbolic tangent value of \a arg
+// template<typename T> typename enable<expr,T>::type tanh(T arg) { return functions::tanh(arg); }
+ inline expr tanh(half arg) { return functions::tanh(arg); }
+ inline expr tanh(expr arg) { return functions::tanh(arg); }
+
+ /// Hyperbolic area sine.
+ /// \param arg function argument
+ /// \return area sine value of \a arg
+// template<typename T> typename enable<expr,T>::type asinh(T arg) { return functions::asinh(arg); }
+ inline expr asinh(half arg) { return functions::asinh(arg); }
+ inline expr asinh(expr arg) { return functions::asinh(arg); }
+
+ /// Hyperbolic area cosine.
+ /// \param arg function argument
+ /// \return area cosine value of \a arg
+// template<typename T> typename enable<expr,T>::type acosh(T arg) { return functions::acosh(arg); }
+ inline expr acosh(half arg) { return functions::acosh(arg); }
+ inline expr acosh(expr arg) { return functions::acosh(arg); }
+
+ /// Hyperbolic area tangent.
+ /// \param arg function argument
+ /// \return area tangent value of \a arg
+// template<typename T> typename enable<expr,T>::type atanh(T arg) { return functions::atanh(arg); }
+ inline expr atanh(half arg) { return functions::atanh(arg); }
+ inline expr atanh(expr arg) { return functions::atanh(arg); }
+
+ /// \}
+ /// \name Error and gamma functions
+ /// \{
+
+ /// Error function.
+ /// \param arg function argument
+ /// \return error function value of \a arg
+// template<typename T> typename enable<expr,T>::type erf(T arg) { return functions::erf(arg); }
+ inline expr erf(half arg) { return functions::erf(arg); }
+ inline expr erf(expr arg) { return functions::erf(arg); }
+
+ /// Complementary error function.
+ /// \param arg function argument
+ /// \return 1 minus error function value of \a arg
+// template<typename T> typename enable<expr,T>::type erfc(T arg) { return functions::erfc(arg); }
+ inline expr erfc(half arg) { return functions::erfc(arg); }
+ inline expr erfc(expr arg) { return functions::erfc(arg); }
+
+ /// Natural logarithm of gamma function.
+ /// \param arg function argument
+ /// \return natural logarith of gamma function for \a arg
+// template<typename T> typename enable<expr,T>::type lgamma(T arg) { return functions::lgamma(arg); }
+ inline expr lgamma(half arg) { return functions::lgamma(arg); }
+ inline expr lgamma(expr arg) { return functions::lgamma(arg); }
+
+ /// Gamma function.
+ /// \param arg function argument
+ /// \return gamma function value of \a arg
+// template<typename T> typename enable<expr,T>::type tgamma(T arg) { return functions::tgamma(arg); }
+ inline expr tgamma(half arg) { return functions::tgamma(arg); }
+ inline expr tgamma(expr arg) { return functions::tgamma(arg); }
+
+ /// \}
+ /// \name Rounding
+ /// \{
+
+ /// Nearest integer not less than half value.
+ /// \param arg half to round
+ /// \return nearest integer not less than \a arg
+// template<typename T> typename enable<half,T>::type ceil(T arg) { return functions::ceil(arg); }
+ inline half ceil(half arg) { return functions::ceil(arg); }
+ inline half ceil(expr arg) { return functions::ceil(arg); }
+
+ /// Nearest integer not greater than half value.
+ /// \param arg half to round
+ /// \return nearest integer not greater than \a arg
+// template<typename T> typename enable<half,T>::type floor(T arg) { return functions::floor(arg); }
+ inline half floor(half arg) { return functions::floor(arg); }
+ inline half floor(expr arg) { return functions::floor(arg); }
+
+ /// Nearest integer not greater in magnitude than half value.
+ /// \param arg half to round
+ /// \return nearest integer not greater in magnitude than \a arg
+// template<typename T> typename enable<half,T>::type trunc(T arg) { return functions::trunc(arg); }
+ inline half trunc(half arg) { return functions::trunc(arg); }
+ inline half trunc(expr arg) { return functions::trunc(arg); }
+
+ /// Nearest integer.
+ /// \param arg half to round
+ /// \return nearest integer, rounded away from zero in half-way cases
+// template<typename T> typename enable<half,T>::type round(T arg) { return functions::round(arg); }
+ inline half round(half arg) { return functions::round(arg); }
+ inline half round(expr arg) { return functions::round(arg); }
+
+ /// Nearest integer.
+ /// \param arg half to round
+ /// \return nearest integer, rounded away from zero in half-way cases
+// template<typename T> typename enable<long,T>::type lround(T arg) { return functions::lround(arg); }
+ inline long lround(half arg) { return functions::lround(arg); }
+ inline long lround(expr arg) { return functions::lround(arg); }
+
+ /// Nearest integer using half's internal rounding mode.
+ /// \param arg half expression to round
+ /// \return nearest integer using default rounding mode
+ // template<typename T> typename
+ //enable<half,T>::type nearbyint(T arg) { return
+ // functions::nearbyint(arg); }
+ inline half nearbyint(half arg) { return functions::rint(arg); }
+ inline half nearbyint(expr arg) { return functions::rint(arg); }
+
+ /// Nearest integer using half's internal rounding mode.
+ /// \param arg half expression to round
+ /// \return nearest integer using default rounding mode
+ // template<typename T> typename
+ //enable<half,T>::type rint(T arg) { return
+ // functions::rint(arg); }
+ inline half rint(half arg) { return functions::rint(arg); }
+ inline half rint(expr arg) { return functions::rint(arg); }
+
+ /// Nearest integer using half's internal rounding mode.
+ /// \param arg half expression to round
+ /// \return nearest integer using default rounding mode
+ // template<typename T> typename
+ //enable<long,T>::type lrint(T arg) { return
+ // functions::lrint(arg); }
+ inline long lrint(half arg) { return functions::lrint(arg); }
+ inline long lrint(expr arg) { return functions::lrint(arg); }
+ #if HALF_ENABLE_CPP11_LONG_LONG
+ /// Nearest integer.
+ /// \param arg half to round
+ /// \return nearest integer, rounded away from zero in half-way cases
+// template<typename T> typename enable<long long,T>::type llround(T arg) { return functions::llround(arg); }
+ inline long long llround(half arg) { return functions::llround(arg); }
+ inline long long llround(expr arg) { return functions::llround(arg); }
+
+ /// Nearest integer using half's internal rounding mode.
+ /// \param arg half expression to round
+ /// \return nearest integer using default rounding mode
+ // template<typename T> typename enable<long
+ // long,T>::type llrint(T arg) { return functions::llrint(arg);
+ // }
+ inline long long llrint(half arg) {
+ return functions::llrint(arg);
+ }
+ inline long long llrint(expr arg) {
+ return functions::llrint(arg);
+ }
+#endif
+
+ /// \}
+ /// \name Floating point manipulation
+ /// \{
+
+ /// Decompress floating point number.
+ /// \param arg number to decompress
+ /// \param exp address to store exponent at
+ /// \return significant in range [0.5, 1)
+// template<typename T> typename enable<half,T>::type frexp(T arg, int *exp) { return functions::frexp(arg, exp); }
+ inline half frexp(half arg, int *exp) { return functions::frexp(arg, exp); }
+ inline half frexp(expr arg, int *exp) { return functions::frexp(arg, exp); }
+
+ /// Multiply by power of two.
+ /// \param arg number to modify
+ /// \param exp power of two to multiply with
+ /// \return \a arg multplied by 2 raised to \a exp
+// template<typename T> typename enable<half,T>::type ldexp(T arg, int exp) { return functions::scalbln(arg, exp); }
+ inline half ldexp(half arg, int exp) { return functions::scalbln(arg, exp); }
+ inline half ldexp(expr arg, int exp) { return functions::scalbln(arg, exp); }
+
+ /// Extract integer and fractional parts.
+ /// \param arg number to decompress
+ /// \param iptr address to store integer part at
+ /// \return fractional part
+// template<typename T> typename enable<half,T>::type modf(T arg, half *iptr) { return functions::modf(arg, iptr); }
+ inline half modf(half arg, half *iptr) { return functions::modf(arg, iptr); }
+ inline half modf(expr arg, half *iptr) { return functions::modf(arg, iptr); }
+
+ /// Multiply by power of two.
+ /// \param arg number to modify
+ /// \param exp power of two to multiply with
+ /// \return \a arg multplied by 2 raised to \a exp
+// template<typename T> typename enable<half,T>::type scalbn(T arg, int exp) { return functions::scalbln(arg, exp); }
+ inline half scalbn(half arg, int exp) { return functions::scalbln(arg, exp); }
+ inline half scalbn(expr arg, int exp) { return functions::scalbln(arg, exp); }
+
+ /// Multiply by power of two.
+ /// \param arg number to modify
+ /// \param exp power of two to multiply with
+ /// \return \a arg multplied by 2 raised to \a exp
+ // template<typename T> typename
+ //enable<half,T>::type scalbln(T arg, long exp) { return
+ // functions::scalbln(arg, exp); }
+ inline half scalbln(half arg, long exp) {
+ return functions::scalbln(arg, exp);
+ }
+ inline half scalbln(expr arg, long exp) { return functions::scalbln(arg, exp); }
+
+ /// Extract exponent.
+ /// \param arg number to query
+ /// \return floating point exponent
+ /// \retval FP_ILOGB0 for zero
+ /// \retval FP_ILOGBNAN for NaN
+ /// \retval MAX_INT for infinity
+// template<typename T> typename enable<int,T>::type ilogb(T arg) { return functions::ilogb(arg); }
+ inline int ilogb(half arg) { return functions::ilogb(arg); }
+ inline int ilogb(expr arg) { return functions::ilogb(arg); }
+
+ /// Extract exponent.
+ /// \param arg number to query
+ /// \return floating point exponent
+// template<typename T> typename enable<half,T>::type logb(T arg) { return functions::logb(arg); }
+ inline half logb(half arg) { return functions::logb(arg); }
+ inline half logb(expr arg) { return functions::logb(arg); }
+
+ /// Next representable value.
+ /// \param from value to compute next representable value for
+ /// \param to direction towards which to compute next value
+ /// \return next representable value after \a from in direction towards \a to
+// template<typename T,typename U> typename enable<half,T,U>::type nextafter(T from, U to) { return functions::nextafter(from, to); }
+ inline half nextafter(half from, half to) { return functions::nextafter(from, to); }
+ inline half nextafter(half from, expr to) { return functions::nextafter(from, to); }
+ inline half nextafter(expr from, half to) { return functions::nextafter(from, to); }
+ inline half nextafter(expr from, expr to) { return functions::nextafter(from, to); }
+
+ /// Next representable value.
+ /// \param from value to compute next representable value for
+ /// \param to direction towards which to compute next value
+ /// \return next representable value after \a from in direction towards \a to
+// template<typename T> typename enable<half,T>::type nexttoward(T from, long double to) { return functions::nexttoward(from, to); }
+ inline half nexttoward(half from, long double to) { return functions::nexttoward(from, to); }
+ inline half nexttoward(expr from, long double to) { return functions::nexttoward(from, to); }
+
+ /// Take sign.
+ /// \param x value to change sign for
+ /// \param y value to take sign from
+ /// \return value equal to \a x in magnitude and to \a y in sign
+// template<typename T,typename U> typename enable<half,T,U>::type copysign(T x, U y) { return functions::copysign(x, y); }
+ inline half copysign(half x, half y) { return functions::copysign(x, y); }
+ inline half copysign(half x, expr y) { return functions::copysign(x, y); }
+ inline half copysign(expr x, half y) { return functions::copysign(x, y); }
+ inline half copysign(expr x, expr y) { return functions::copysign(x, y); }
+
+ /// \}
+ /// \name Floating point classification
+ /// \{
+
+ /// Classify floating point value.
+ /// \param arg number to classify
+ /// \retval FP_ZERO for positive and negative zero
+ /// \retval FP_SUBNORMAL for subnormal numbers
+ /// \retval FP_INFINITY for positive and negative infinity
+ /// \retval FP_NAN for NaNs
+ /// \retval FP_NORMAL for all other (normal) values
+ // template<typename T> typename
+ //enable<int,T>::type fpclassify(T arg) { return
+ // functions::fpclassify(arg); }
+ inline int fpclassify(half arg) {
+ return functions::fpclassify(arg);
+ }
+ inline int fpclassify(expr arg) {
+ return functions::fpclassify(arg);
+ }
+
+ /// Check if finite number.
+ /// \param arg number to check
+ /// \retval true if neither infinity nor NaN
+ /// \retval false else
+ // template<typename T> typename
+ //enable<bool,T>::type isfinite(T arg) { return
+ // functions::isfinite(arg); }
+ inline bool isfinite(half arg) {
+ return functions::isfinite(arg);
+ }
+ inline bool isfinite(expr arg) {
+ return functions::isfinite(arg);
+ }
+
+ /// Check for infinity.
+ /// \param arg number to check
+ /// \retval true for positive or negative infinity
+ /// \retval false else
+ // template<typename T> typename
+ //enable<bool,T>::type isinf(T arg) { return
+ // functions::isinf(arg); }
+ inline bool isinf(half arg) { return functions::isinf(arg); }
+ inline bool isinf(expr arg) { return functions::isinf(arg); }
+
+ /// Check for NaN.
+ /// \param arg number to check
+ /// \retval true for NaNs
+ /// \retval false else
+ // template<typename T> typename
+ //enable<bool,T>::type isnan(T arg) { return
+ // functions::isnan(arg); }
+ inline bool isnan(half arg) { return functions::isnan(arg); }
+ inline bool isnan(expr arg) { return functions::isnan(arg); }
+
+ /// Check if normal number.
+ /// \param arg number to check
+ /// \retval true if normal number
+ /// \retval false if either subnormal, zero, infinity or NaN
+ // template<typename T> typename
+ //enable<bool,T>::type isnormal(T arg) { return
+ // functions::isnormal(arg); }
+ inline bool isnormal(half arg) {
+ return functions::isnormal(arg);
+ }
+ inline bool isnormal(expr arg) {
+ return functions::isnormal(arg);
+ }
+
+ /// Check sign.
+ /// \param arg number to check
+ /// \retval true for negative number
+ /// \retval false for positive number
+ // template<typename T> typename
+ //enable<bool,T>::type signbit(T arg) { return
+ // functions::signbit(arg); }
+ inline bool signbit(half arg) {
+ return functions::signbit(arg);
+ }
+ inline bool signbit(expr arg) {
+ return functions::signbit(arg);
+ }
+
+ /// \}
+ /// \name Comparison
+ /// \{
+
+ /// Comparison for greater than.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x greater than \a y
+ /// \retval false else
+ // template<typename T,typename U> typename
+ // enable<bool,T,U>::type isgreater(T x, U y) { return
+ // functions::isgreater(x, y); }
+ inline bool isgreater(half x, half y) {
+ return functions::isgreater(x, y);
+ }
+ inline bool isgreater(half x, expr y) {
+ return functions::isgreater(x, y);
+ }
+ inline bool isgreater(expr x, half y) {
+ return functions::isgreater(x, y);
+ }
+ inline bool isgreater(expr x, expr y) {
+ return functions::isgreater(x, y);
+ }
+
+ /// Comparison for greater equal.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x greater equal \a y
+ /// \retval false else
+ // template<typename T,typename U> typename
+ // enable<bool,T,U>::type isgreaterequal(T x, U y) { return
+ // functions::isgreaterequal(x, y); }
+ inline bool isgreaterequal(half x, half y) {
+ return functions::isgreaterequal(x, y);
+ }
+ inline bool isgreaterequal(half x, expr y) {
+ return functions::isgreaterequal(x, y);
+ }
+ inline bool isgreaterequal(expr x, half y) {
+ return functions::isgreaterequal(x, y);
+ }
+ inline bool isgreaterequal(expr x, expr y) {
+ return functions::isgreaterequal(x, y);
+ }
+
+ /// Comparison for less than.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x less than \a y
+ /// \retval false else
+ // template<typename T,typename U> typename
+ // enable<bool,T,U>::type isless(T x, U y) { return
+ // functions::isless(x, y); }
+ inline bool isless(half x, half y) {
+ return functions::isless(x, y);
+ }
+ inline bool isless(half x, expr y) {
+ return functions::isless(x, y);
+ }
+ inline bool isless(expr x, half y) {
+ return functions::isless(x, y);
+ }
+ inline bool isless(expr x, expr y) {
+ return functions::isless(x, y);
+ }
+
+ /// Comparison for less equal.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if \a x less equal \a y
+ /// \retval false else
+ // template<typename T,typename U> typename
+ // enable<bool,T,U>::type islessequal(T x, U y) { return
+ // functions::islessequal(x, y); }
+ inline bool islessequal(half x, half y) {
+ return functions::islessequal(x, y);
+ }
+ inline bool islessequal(half x, expr y) {
+ return functions::islessequal(x, y);
+ }
+ inline bool islessequal(expr x, half y) {
+ return functions::islessequal(x, y);
+ }
+ inline bool islessequal(expr x, expr y) {
+ return functions::islessequal(x, y);
+ }
+
+ /// Comarison for less or greater.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if either less or greater
+ /// \retval false else
+ // template<typename T,typename U> typename
+ // enable<bool,T,U>::type islessgreater(T x, U y) { return
+ // functions::islessgreater(x, y); }
+ inline bool islessgreater(half x, half y) {
+ return functions::islessgreater(x, y);
+ }
+ inline bool islessgreater(half x, expr y) {
+ return functions::islessgreater(x, y);
+ }
+ inline bool islessgreater(expr x, half y) {
+ return functions::islessgreater(x, y);
+ }
+ inline bool islessgreater(expr x, expr y) {
+ return functions::islessgreater(x, y);
+ }
+
+ /// Check if unordered.
+ /// \param x first operand
+ /// \param y second operand
+ /// \retval true if unordered (one or two NaN operands)
+ /// \retval false else
+ // template<typename T,typename U> typename
+ // enable<bool,T,U>::type isunordered(T x, U y) { return
+ // functions::isunordered(x, y); }
+ inline bool isunordered(half x, half y) {
+ return functions::isunordered(x, y);
+ }
+ inline bool isunordered(half x, expr y) {
+ return functions::isunordered(x, y);
+ }
+ inline bool isunordered(expr x, half y) {
+ return functions::isunordered(x, y);
+ }
+ inline bool isunordered(expr x, expr y) {
+ return functions::isunordered(x, y);
+ }
+
+ /// \name Casting
+ /// \{
+
+ /// Cast to or from half-precision floating point number.
+ /// This casts between [half](\ref half_float::half) and any
+ /// built-in arithmetic type. The values are converted directly
+ /// using the given rounding mode, without any roundtrip over
+ /// `float` that a `static_cast` would otherwise do. It uses the
+ /// default rounding mode.
+ ///
+ /// Using this cast with neither of the two types being a
+ /// [half](\ref half_float::half) or with any of the two types
+ /// not being a built-in arithmetic type (apart from [half](\ref
+ /// half_float::half), of course) results in a compiler error
+ /// and casting between [half](\ref half_float::half)s is just a
+ /// no-op. \tparam T destination type (half or built-in
+ /// arithmetic type) \tparam U source type (half or built-in
+ /// arithmetic type) \param arg value to cast \return \a arg
+ /// converted to destination type
+ template <typename T, typename U>
+ T half_cast(U arg) {
+ return half_caster<T, U>::cast(arg);
+ }
+
+ /// Cast to or from half-precision floating point number.
+ /// This casts between [half](\ref half_float::half) and any
+ /// built-in arithmetic type. The values are converted directly
+ /// using the given rounding mode, without any roundtrip over
+ /// `float` that a `static_cast` would otherwise do.
+ ///
+ /// Using this cast with neither of the two types being a
+ /// [half](\ref half_float::half) or with any of the two types
+ /// not being a built-in arithmetic type (apart from [half](\ref
+ /// half_float::half), of course) results in a compiler error
+ /// and casting between [half](\ref half_float::half)s is just a
+ /// no-op. \tparam T destination type (half or built-in
+ /// arithmetic type) \tparam R rounding mode to use. \tparam U
+ /// source type (half or built-in arithmetic type) \param arg
+ /// value to cast \return \a arg converted to destination type
+ template <typename T, std::float_round_style R, typename U>
+ T half_cast(U arg) {
+ return half_caster<T, U, R>::cast(arg);
+ }
+ /// \}
+ }
+
+ using detail::operator==;
+ using detail::operator!=;
+ using detail::operator<;
+ using detail::operator>;
+ using detail::operator<=;
+ using detail::operator>=;
+ using detail::operator+;
+ using detail::operator-;
+ using detail::operator*;
+ using detail::operator/;
+ using detail::operator<<;
+ using detail::operator>>;
+
+ using detail::abs;
+ using detail::acos;
+ using detail::acosh;
+ using detail::asin;
+ using detail::asinh;
+ using detail::atan;
+ using detail::atan2;
+ using detail::atanh;
+ using detail::cbrt;
+ using detail::ceil;
+ using detail::cos;
+ using detail::cosh;
+ using detail::erf;
+ using detail::erfc;
+ using detail::exp;
+ using detail::exp2;
+ using detail::expm1;
+ using detail::fabs;
+ using detail::fdim;
+ using detail::floor;
+ using detail::fma;
+ using detail::fmax;
+ using detail::fmin;
+ using detail::fmod;
+ using detail::hypot;
+ using detail::lgamma;
+ using detail::log;
+ using detail::log10;
+ using detail::log1p;
+ using detail::log2;
+ using detail::lrint;
+ using detail::lround;
+ using detail::nanh;
+ using detail::nearbyint;
+ using detail::pow;
+ using detail::remainder;
+ using detail::remquo;
+ using detail::rint;
+ using detail::round;
+ using detail::sin;
+ using detail::sinh;
+ using detail::sqrt;
+ using detail::tan;
+ using detail::tanh;
+ using detail::tgamma;
+ using detail::trunc;
+#if HALF_ENABLE_CPP11_LONG_LONG
+ using detail::llrint;
+ using detail::llround;
+#endif
+ using detail::frexp;
+ using detail::ldexp;
+ using detail::modf;
+ using detail::scalbn;
+ using detail::scalbln;
+ using detail::ilogb;
+ using detail::logb;
+ using detail::nextafter;
+ using detail::nexttoward;
+ using detail::copysign;
+ using detail::fpclassify;
+ using detail::isfinite;
+ using detail::isinf;
+ using detail::isnan;
+ using detail::isnormal;
+ using detail::signbit;
+ using detail::isgreater;
+ using detail::isgreaterequal;
+ using detail::isless;
+ using detail::islessequal;
+ using detail::islessgreater;
+ using detail::isunordered;
+
+ using detail::half_cast;
+}
+
+
+/// Extensions to the C++ standard library.
+namespace std
+{
+/// Numeric limits for half-precision floats.
+/// Because of the underlying single-precision implementation of many
+/// operations, it inherits some properties from `std::numeric_limits<float>`.
+template <>
+class numeric_limits<half_float::half> : public numeric_limits<float> {
+ public:
+ /// Supports signed values.
+ static HALF_CONSTEXPR_CONST bool is_signed = true;
+
+ /// Is not exact.
+ static HALF_CONSTEXPR_CONST bool is_exact = false;
+
+ /// Doesn't provide modulo arithmetic.
+ static HALF_CONSTEXPR_CONST bool is_modulo = false;
+
+ /// IEEE conformant.
+ static HALF_CONSTEXPR_CONST bool is_iec559 = true;
+
+ /// Supports infinity.
+ static HALF_CONSTEXPR_CONST bool has_infinity = true;
+
+ /// Supports quiet NaNs.
+ static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true;
+
+ /// Supports subnormal values.
+ static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present;
+
+ /// Rounding mode.
+ /// Due to the mix of internal single-precision computations (using the
+ /// rounding mode of the underlying single-precision implementation) with the
+ /// rounding mode of the single-to-half conversions, the actual rounding mode
+ /// might be `std::round_indeterminate` if the default half-precision rounding
+ /// mode doesn't match the single-precision rounding mode.
+ static HALF_CONSTEXPR_CONST float_round_style round_style =
+ (std::numeric_limits<float>::round_style == half_float::half::round_style)
+ ? half_float::half::round_style
+ : round_indeterminate;
+
+ /// Significant digits.
+ static HALF_CONSTEXPR_CONST int digits = 11;
+
+ /// Significant decimal digits.
+ static HALF_CONSTEXPR_CONST int digits10 = 3;
+
+ /// Required decimal digits to represent all possible values.
+ static HALF_CONSTEXPR_CONST int max_digits10 = 5;
+
+ /// Number base.
+ static HALF_CONSTEXPR_CONST int radix = 2;
+
+ /// One more than smallest exponent.
+ static HALF_CONSTEXPR_CONST int min_exponent = -13;
+
+ /// Smallest normalized representable power of 10.
+ static HALF_CONSTEXPR_CONST int min_exponent10 = -4;
+
+ /// One more than largest exponent
+ static HALF_CONSTEXPR_CONST int max_exponent = 16;
+
+ /// Largest finitely representable power of 10.
+ static HALF_CONSTEXPR_CONST int max_exponent10 = 4;
+
+ /// Smallest positive normal value.
+ static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW {
+ return half_float::half(half_float::detail::binary, 0x0400);
+ }
+
+ /// Smallest finite value.
+ static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW {
+ return half_float::half(half_float::detail::binary, 0xFBFF);
+ }
+
+ /// Largest finite value.
+ static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW {
+ return half_float::half(half_float::detail::binary, 0x7BFF);
+ }
+
+ /// Difference between one and next representable value.
+ static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW {
+ return half_float::half(half_float::detail::binary, 0x1400);
+ }
+
+ /// Maximum rounding error.
+ static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW {
+ return half_float::half(
+ half_float::detail::binary,
+ (round_style == std::round_to_nearest) ? 0x3800 : 0x3C00);
+ }
+
+ /// Positive infinity.
+ static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW {
+ return half_float::half(half_float::detail::binary, 0x7C00);
+ }
+
+ /// Quiet NaN.
+ static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW {
+ return half_float::half(half_float::detail::binary, 0x7FFF);
+ }
+
+ /// Signalling NaN.
+ static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW {
+ return half_float::half(half_float::detail::binary, 0x7DFF);
+ }
+
+ /// Smallest positive subnormal value.
+ static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW {
+ return half_float::half(half_float::detail::binary, 0x0001);
+ }
+};
+
+#if HALF_ENABLE_CPP11_HASH
+ /// Hash function for half-precision floats.
+ /// This is only defined if C++11 `std::hash` is supported and enabled.
+ template<> struct hash<half_float::half> //: unary_function<half_float::half,size_t>
+ {
+ /// Type of function argument.
+ typedef half_float::half argument_type;
+
+ /// Function return type.
+ typedef size_t result_type;
+
+ /// Compute hash function.
+ /// \param arg half to hash
+ /// \return hash value
+ result_type operator()(argument_type arg) const {
+ return hash<half_float::detail::uint16>()(
+ static_cast<unsigned>(arg.data_) &
+ -(arg.data_ != 0x8000));
+ }
+ };
+#endif
+}
+
+
+#undef HALF_CONSTEXPR
+#undef HALF_CONSTEXPR_CONST
+#undef HALF_NOEXCEPT
+#undef HALF_NOTHROW
+#ifdef HALF_POP_WARNINGS
+#pragma warning(pop)
+#undef HALF_POP_WARNINGS
+#endif
+
+#endif
diff --git a/third_party/llvm-bazel b/third_party/llvm-bazel
index 2ad01e4..fe17a7e 160000
--- a/third_party/llvm-bazel
+++ b/third_party/llvm-bazel
@@ -1 +1 @@
-Subproject commit 2ad01e4b485d8753600766d967d1a7358b98ddd3
+Subproject commit fe17a7eff316d5846742cec8ced48bb6c49831db
diff --git a/third_party/llvm-project b/third_party/llvm-project
index c8d73d9..5ce85e6 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit c8d73d939fa4fda9c87b3979225d02d63062bd68
+Subproject commit 5ce85e66358a69e786093756c77fae2e140947c1
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 3ea78d0..218c3a2 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 3ea78d0fdd5402809e12d067b81dcd2a43cc8a45
+Subproject commit 218c3a2712bc72d2239299a5eef4ff3c156004ed
diff --git a/third_party/tracy b/third_party/tracy
index d7059ec..d8cb536 160000
--- a/third_party/tracy
+++ b/third_party/tracy
@@ -1 +1 @@
-Subproject commit d7059eca6351546d1f51e248fc75e49dfeee709e
+Subproject commit d8cb536712e876ba956f27f23dbede1c2eccad28