Merge main -> google
* f92602b7 Always emitting framepointers in generated ELFs (#3987)
* 89ef6328 bump Tracy to get Android fixes (#3988)
* 7f989eb2 Disable MLIR crash reproducer on CI in python tests. (#3943)
* 4db3d08c Adding a demonstration of using the flow dialect post-partitioning. (#3701)
* 589dfa7b Remove no-longer-functional flag (#3961)
* 64543172 Fix MacOS builds after hack-and-slash PR (#3962)
* 8fb887e4 Update links to coverage tables (#3956)
* c93facb4 Adding iree_atomic_slist_t type. (#3917)
* bd082ca6 Merge pull request #3874 from google/benvanik-hack-and-slash
* f4f4ea26 Use UnitTestSpec in tf.keras.layers tests (#3935)
* 7b8e9f75 Reverting flatcc to use our own cmake file for cross-compilation.
* 323e1fde Simplify dylib driver switch.
* 55f3de0d Only register VMLA driver in bindings/java/.
* f429916f Fix warning flag on Windows and HAL CTS driver registration. (#3911)
* 7951e228 Drop IREE_DRIVER_MODULES and iree/base:initializer from ModelBuilder.
* 4e111e2f Disable layering_check in iree/hal/drivers/BUILD.
* 4773736d Add package to iree/base/testing/BUILD.
* 7ca321c1 Skipping dylib driver in simple_embedding_test as a hack.
* 692deb59 Overriding the default MLIR -> LLVM module name.
* 513f40e8 Speculative removing nowindows tags (#3615). If there's something that still d..
* cc47813a Removing the broken forward declarations entirely from some codegen code. http..
* fbcad44d Removing _GNU_SOURCE copt for wait_handle.
* c9b10a01 Fixing bad type in hal/api.h (been there for ages!).
* e0c532ec Changing iree::InitializeEnvironment to iree_flags_parse. Preparation for #3814.
* 8886ac07 Removing iree_api_init from the API.
* 132d747c Removing ALWAYSLINK support from cmake.
* 0ed81f6b Removing iree/base/initializer.h.
* 0135343c Changing to an explicit driver registration mechanism. This is required for ha..
* bf091d3c Removing ALWAYSLINK support from external_cc_library.
* d4bb871d Changing iree-tblgen to not require alwayslink.
* 6bc6f90c Removing IREE_COMMON_INCLUDE_DIRS and uses by LLVM/MLIR.
* 3c082aab Removing IREE_COMMON_INCLUDE_DIRS mhlo pollution.
* 2395cb99 Removing emitc usage of IREE_COMMON_INCLUDE_DIRS for now.
* 036bd966 TODOs on future library layout changes.
* c3a13e62 Rearranging iree/vm/ to reduce a public + cc target.
* 493b0e2b Rearranging iree/base build rules. By moving the dynamic_library_test out of t..
* 8fd38bf9 Replacing uses of some absl utilities now in C++14.
* 67863190 Removing unused absl/types/variant.h include.
* 99bd1af5 Replace absl::exchange with iree::exchange to reduce absl/utility dep.
* c1d0ee10 Removing unused PLATFORM_VULKAN_DEPS. It may be needed internally but it shoul..
* 15437f4b Simplifying iree/hal/dylib/ build config.
* e6984a5a Simplifying iree/hal/ build config.
* 10062814 Simplifying iree/hal/vulkan/ build config.
* 827e51b0 Simplifying iree/hal/llvmjit/ build config.
* 9a72f5d1 Simplifying iree/hal/metal/ build config.
* c7a7d726 Simplifying iree/hal/vmla/ build config.
* 90faf21f Adding IREE_TARGET_GUI_LINKOPTS to remove custom linkopts use.
* e5774c30 Remove unused args from flatbuffers_c_library macro.
* 22d16b4d Adding iree/base/flatcc.h to make flatcc easier to include.
* e44dee56 Switching from -DVK_NO_PROTOTYPES to iree/hal/vulkan/vulkan_headers.h.
* 48ca2fe6 Removing build-config setting of _GNU_SOURCE.
* eeb7dde0 Goodbye flatbuffers (well, the C++ ones anyway).
* 9c676a86 Removing all build config/utils related to flatbuffers.
* 49c61213 byte->ubyte in flatbuffer defs.
* c99000a8 Replacing compiler use of VM bytecode module def flatbuffers->flatcc.
* 1bf1e8d7 Replacing runtime use of metal executable flatbuffers->flatcc. Maybe it works?..
* 48aafb89 Replacing runtime use of spirv executable flatbuffers->flatcc.
* 011e9a2d Replacing runtime use of llvmjit executable flatbuffers->flatcc.
* a021062f Replacing runtime use of dylib executable flatbuffers->flatcc.
* 53a05d73 Replacing runtime use of VMLA executable flatbuffers->flatc.
* 99d30a99 Replacing compiler use of HAL executable flatbuffers->flatc.
* 6ebd1b0c Removing unused tag field in metal/spirv.
* bc685ed7 Adding flatcc json support and making iree-dump-module use it.
* 94b11c35 Adding include for flatcc to flat_c_libraries.
* 1172cf1f Removing unused iree::schemas::reflection_data.
* c86281af Removing unneeded flatbuffers copts.
* 7f3a7e3a Fixing various type warnings. We don't today have these warnings enabled in ou..
* c17659fc Refining MSVC warning set to the minimum required and documenting.
* b7c92bf4 Cleaning up MSVC warnings and syncing with bazel warnings.
* 94356d3b Removing legacy repo_utils.bzl.
* 0f0d9c82 Prevent bazel-to-cmake from running on iree/base/CMakeLists.txt for now.
* 36225a4d Centralizing -ldl/-lpthread linkopts (as they were in bazel already).
* 31c4dbb9 Documenting iree_copts with a nice big warning.
* e4740a57 Pass android libraries as actual linkopts.
* 85cdd868 Fixing cmake style issues - prefer `if(` not `if (` please.
* bf4069e3 Sorting copts/linkopts so we can override things.
* 28040cd8 Simplifying VMA build integration.
* 479ef30f Replacing use of PROJECT_SOURCE_DIR/PROJECT_BINARY_DIR. Those use the previous..
PiperOrigin-RevId: 344157809
diff --git a/.clang-format b/.clang-format
index 0547576..2899a77 100644
--- a/.clang-format
+++ b/.clang-format
@@ -16,3 +16,13 @@
# mostly LLVM style (for naming/etc) but the Google formatting (because internal
# tooling has... issues).
BasedOnStyle: Google
+
+# Some includes have specific inclusion order requirements. To prevent
+# clang-format from breaking things when bucketing headers into categories we
+# override the behavior here for those well-known headers.
+IncludeCategories:
+ # We override VK_NO_PROTOTYPES so we don't pull in auto-imported symbols and
+ # this must always sort above any header (ours or otherwise) that may
+ # transitively #include <vulkan/vulkan.h>.
+ - Regex: 'iree/hal/vulkan/vulkan_headers.h'
+ Priority: -1
diff --git a/.gitmodules b/.gitmodules
index 45076ab..dcedcea 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -10,9 +10,6 @@
[submodule "third_party/tensorflow"]
path = third_party/tensorflow
url = https://github.com/tensorflow/tensorflow.git
-[submodule "third_party/flatbuffers"]
- path = third_party/flatbuffers
- url = https://github.com/google/flatbuffers.git
[submodule "third_party/vulkan_headers"]
path = third_party/vulkan_headers
url = https://github.com/KhronosGroup/Vulkan-Headers.git
diff --git a/BUILD.bazel b/BUILD.bazel
index 6b3ee5c..f3de1d2 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -22,7 +22,6 @@
# "@bazel_skylib//"
# "@com_google_benchmark//"
# "@com_github_pytorch_cpuinfo//"
-# "@com_github_google_flatbuffers//"
# "@com_github_dvidelabs_flatcc//"
# "@half//"
# "@com_google_googletest//"
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 474109c..3903d75 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -64,7 +64,7 @@
set(IREE_ENABLE_MLIR ON CACHE BOOL "Enable LLVM dependencies if the IREE compiler is build." FORCE)
endif()
-if (${IREE_ENABLE_MLIR})
+if(${IREE_ENABLE_MLIR})
set(IREE_MLIR_DEP_MODE "BUNDLED" CACHE STRING "One of BUNDLED (default), DISABLED, INSTALLED")
endif()
@@ -92,9 +92,9 @@
Vulkan
)
-if( IREE_HAL_DRIVERS_TO_BUILD STREQUAL "all" )
- set( IREE_HAL_DRIVERS_TO_BUILD ${IREE_ALL_HAL_DRIVERS} )
- # For cross compilation towords Android, we don't want LLVM JIT HAL driver.
+if(IREE_HAL_DRIVERS_TO_BUILD STREQUAL "all")
+ set(IREE_HAL_DRIVERS_TO_BUILD ${IREE_ALL_HAL_DRIVERS})
+ # For cross compilation towards Android, we don't want LLVM JIT HAL driver.
if(ANDROID)
list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM)
endif()
@@ -121,24 +121,6 @@
set(IREE_HAL_DRIVER_${uppercase_backend} ON CACHE BOOL "" FORCE)
endforeach()
-# Enable HAL driver modules based on options.
-set (IREE_HAL_DRIVER_MODULES "")
-if(${IREE_HAL_DRIVER_DYLIB})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::dylib_driver_module)
-endif()
-if(${IREE_HAL_DRIVER_LLVM})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::llvmjit::llvmjit_driver_module)
-endif()
-if(${IREE_HAL_DRIVER_METAL})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::metal::metal_driver_module)
-endif()
-if(${IREE_HAL_DRIVER_VMLA})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vmla::vmla_driver_module)
-endif()
-if(${IREE_HAL_DRIVER_VULKAN})
- list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vulkan::vulkan_driver_module)
-endif()
-
# List of all target backends to be built by default:
set(IREE_ALL_TARGET_BACKENDS
# TODO(#2645): Add DYLIB-LLVM-AOT when it doesn't require an env var
@@ -231,7 +213,7 @@
option(IREE_ENABLE_TSAN "Enable thread sanitizer" OFF)
option(IREE_ENABLE_CCACHE "Use ccache if installed to speed up rebuilds." OFF)
-if (${IREE_ENABLE_CCACHE})
+if(${IREE_ENABLE_CCACHE})
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}")
@@ -245,7 +227,6 @@
include(iree_macros)
include(iree_copts)
include(sanitizers)
-include(iree_whole_archive_link)
include(iree_cc_binary)
include(iree_cc_library)
include(iree_cc_test)
@@ -259,7 +240,7 @@
include(iree_check_test)
set(DEFAULT_CMAKE_BUILD_TYPE "Release")
-if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
+if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
message(STATUS "No build type selected, default to ${DEFAULT_CMAKE_BUILD_TYPE}")
set(CMAKE_BUILD_TYPE "${DEFAULT_CMAKE_BUILD_TYPE}" CACHE STRING "Build type (default ${DEFAULT_CMAKE_BUILD_TYPE})" FORCE)
endif()
@@ -336,18 +317,29 @@
message(STATUS "Adding bundled LLVM source dependency")
add_iree_mlir_src_dep("third_party/llvm-project")
- # Extend module path to allow submodules to use LLVM and MLIR CMake modules
- list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
- list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/lib/cmake/llvm/")
+ # Extend module path to allow submodules to use LLVM and MLIR CMake modules.
+ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_BINARY_DIR}/lib/cmake/mlir")
+ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/lib/cmake/llvm/")
- # Set include directories
+ # Add the bundled include directories for cmake files looking for them.
list(APPEND LLVM_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/llvm-project/llvm/include
- ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/include
+ ${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/llvm/include
+ ${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/include
)
list(APPEND MLIR_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/llvm-project/mlir/include
- ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include
+ ${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/mlir/include
+ ${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include
+ )
+
+ # Avoid globally modifying paths by instead adding the include paths to the
+ # rules that really should have them in the first place.
+ target_include_directories(LLVMSupport PUBLIC
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/llvm/include>
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/include>
+ )
+ target_include_directories(MLIRSupport PUBLIC
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/mlir/include>
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include>
)
# Set build option to use MHLO alongside with bundled MLIR
@@ -396,24 +388,19 @@
find_package(PythonLibs 3)
endif()
-list(APPEND CMAKE_MODULE_PATH
- ${CMAKE_CURRENT_LIST_DIR}/third_party/flatbuffers/CMake/
-)
-
include(external_cc_library)
include(flatbuffer_c_library)
-include(flatbuffer_cc_library)
add_subdirectory(build_tools/third_party/flatcc EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/half EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/pffft EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/renderdoc_api EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/ruy EXCLUDE_FROM_ALL)
+add_subdirectory(build_tools/third_party/vulkan_memory_allocator EXCLUDE_FROM_ALL)
add_subdirectory(third_party/cpuinfo EXCLUDE_FROM_ALL)
add_subdirectory(third_party/googletest EXCLUDE_FROM_ALL)
add_subdirectory(third_party/abseil-cpp EXCLUDE_FROM_ALL)
-add_subdirectory(third_party/flatbuffers EXCLUDE_FROM_ALL)
add_subdirectory(third_party/flatcc EXCLUDE_FROM_ALL)
add_subdirectory(third_party/vulkan_headers EXCLUDE_FROM_ALL)
@@ -428,18 +415,10 @@
iree_get_executable_path(FLATBUFFERS_FLATC_EXECUTABLE flatc)
# Add a custom target to copy the flatc to the binary directory.
- add_custom_target(iree_host_flatc
- COMMAND
- "${CMAKE_COMMAND}" -E copy_if_different
- "${IREE_HOST_BINARY_ROOT}/third_party/flatbuffers/flatc${IREE_HOST_EXECUTABLE_SUFFIX}"
- "${IREE_HOST_BINARY_ROOT}/bin"
- DEPENDS iree_host_build_flatc
- COMMENT "Installing host flatc..."
- )
add_custom_target(iree_host_flatcc_cli
COMMAND
"${CMAKE_COMMAND}" -E copy_if_different
- "${PROJECT_SOURCE_DIR}/third_party/flatcc/bin/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
+ "${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatcc/bin/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
"${IREE_HOST_BINARY_ROOT}/bin/flatcc_cli${IREE_HOST_EXECUTABLE_SUFFIX}"
DEPENDS iree_host_build_flatcc_cli
COMMENT "Installing host flatcc..."
@@ -538,8 +517,6 @@
add_subdirectory(experimental)
endif()
-# Note: this must be called after all libraries have been declared.
-iree_complete_binary_link_options()
if(${IREE_BUILD_PYTHON_BINDINGS})
iree_complete_py_extension_link_options()
endif()
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 72f09c0..05f7426 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -2,7 +2,6 @@
daff5fead3fbe22c6fc58310ca3f49caf117f185 third_party/benchmark
63b254577ed77a8004a9be6ac707f3dccc4e1fd9 third_party/cpuinfo
4c13807b7d43ff0946b7ffea0ae3aee9e611d778 third_party/dear_imgui
-a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
fe17a7eff316d5846742cec8ced48bb6c49831db third_party/llvm-bazel
@@ -14,7 +13,7 @@
a1390ed39ec77ecfb574bc6fcd5bfc5e3adbdea9 third_party/sdl2
685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
-37cfaab27f11635a240e93550c1c0faf36e07856 third_party/tensorflow
-d7059eca6351546d1f51e248fc75e49dfeee709e third_party/tracy
+218c3a2712bc72d2239299a5eef4ff3c156004ed third_party/tensorflow
+d8cb536712e876ba956f27f23dbede1c2eccad28 third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator
diff --git a/WORKSPACE b/WORKSPACE
index 8530017..75505a5 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -18,7 +18,7 @@
workspace(name = "iree_core")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
-load(":repo_utils.bzl", "maybe")
+load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
###############################################################################
# Bazel rules.
@@ -196,14 +196,6 @@
path = "third_party/googletest",
)
-# Note that TensorFlow provides this as "flatbuffers" which is wrong.
-# It is only used for TFLite and may cause ODR issues if not fixed.
-maybe(
- local_repository,
- name = "com_github_google_flatbuffers",
- path = "third_party/flatbuffers",
-)
-
maybe(
new_local_repository,
name = "com_github_dvidelabs_flatcc",
diff --git a/bindings/java/com/google/iree/native/CMakeLists.txt b/bindings/java/com/google/iree/native/CMakeLists.txt
index e749258..11d1607 100644
--- a/bindings/java/com/google/iree/native/CMakeLists.txt
+++ b/bindings/java/com/google/iree/native/CMakeLists.txt
@@ -16,10 +16,10 @@
NAME
cc_wrappers
SRCS
- "context_wrapper.cc"
- "function_wrapper.cc"
- "instance_wrapper.cc"
- "module_wrapper.cc"
+ "context_wrapper.cc"
+ "function_wrapper.cc"
+ "instance_wrapper.cc"
+ "module_wrapper.cc"
HDRS
"context_wrapper.h"
"function_wrapper.h"
@@ -27,16 +27,15 @@
"module_wrapper.h"
DEPS
iree::base::api
- iree::base::init
+ iree::base::flags
iree::base::logging
iree::base::status
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::drivers
iree::modules::hal
iree::modules::strings::strings_module
iree::modules::tensorlist::native_module
+ iree::vm
+ iree::vm::cc
iree::vm::bytecode_module
- iree::vm::context
- iree::vm::instance
- iree::vm::ref_cc
)
diff --git a/bindings/java/com/google/iree/native/context_wrapper.h b/bindings/java/com/google/iree/native/context_wrapper.h
index 180a259..40efe28 100644
--- a/bindings/java/com/google/iree/native/context_wrapper.h
+++ b/bindings/java/com/google/iree/native/context_wrapper.h
@@ -23,7 +23,7 @@
#include "iree/base/status.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/hal_module.h"
-#include "iree/vm/context.h"
+#include "iree/vm/api.h"
namespace iree {
namespace java {
diff --git a/bindings/java/com/google/iree/native/function_wrapper.h b/bindings/java/com/google/iree/native/function_wrapper.h
index a8dd7dd..dbce860 100644
--- a/bindings/java/com/google/iree/native/function_wrapper.h
+++ b/bindings/java/com/google/iree/native/function_wrapper.h
@@ -17,7 +17,7 @@
#include <memory>
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace java {
diff --git a/bindings/java/com/google/iree/native/instance_wrapper.cc b/bindings/java/com/google/iree/native/instance_wrapper.cc
index cdd9fb0..e9679e6 100644
--- a/bindings/java/com/google/iree/native/instance_wrapper.cc
+++ b/bindings/java/com/google/iree/native/instance_wrapper.cc
@@ -16,7 +16,8 @@
#include <mutex>
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/modules/strings/strings_module.h"
#include "iree/modules/tensorlist/native_module.h"
@@ -27,13 +28,16 @@
namespace {
void SetupVm() {
- // TODO(jennik): Pass flags through from java.
+ // TODO(jennik): Pass flags through from java and us iree_flags_parse.
+ // This checked version will abort()/exit() and that's... not great.
char binname[] = "libiree.so";
char* argv[] = {binname};
char** aargv = argv;
int argc = 1;
- InitializeEnvironment(&argc, &aargv);
+ iree_flags_parse_checked(&argc, &aargv);
+ // TODO(jennik): register all available drivers
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
IREE_CHECK_OK(iree_vm_register_builtin_types());
IREE_CHECK_OK(iree_hal_module_register_types());
IREE_CHECK_OK(iree_tensorlist_module_register_types());
diff --git a/bindings/java/com/google/iree/native/instance_wrapper.h b/bindings/java/com/google/iree/native/instance_wrapper.h
index e97bfc5..66acc0f 100644
--- a/bindings/java/com/google/iree/native/instance_wrapper.h
+++ b/bindings/java/com/google/iree/native/instance_wrapper.h
@@ -16,7 +16,7 @@
#define IREE_BINDINGS_JAVA_COM_GOOGLE_IREE_NATIVE_INSTANCE_WRAPPER_H_
#include "iree/base/status.h"
-#include "iree/vm/instance.h"
+#include "iree/vm/api.h"
namespace iree {
namespace java {
diff --git a/bindings/javatests/com/google/iree/integration_test.cc b/bindings/javatests/com/google/iree/integration_test.cc
index 5c31c34..661878d 100644
--- a/bindings/javatests/com/google/iree/integration_test.cc
+++ b/bindings/javatests/com/google/iree/integration_test.cc
@@ -19,7 +19,7 @@
#include "bindings/java/com/google/iree/native/instance_wrapper.h"
#include "bindings/java/com/google/iree/native/module_wrapper.h"
#include "bindings/javatests/com/google/iree/simple_mul_bytecode_module.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
namespace iree {
namespace java {
diff --git a/bindings/python/build_defs.oss.bzl b/bindings/python/build_defs.oss.bzl
index a01a244..34d87b9 100644
--- a/bindings/python/build_defs.oss.bzl
+++ b/bindings/python/build_defs.oss.bzl
@@ -17,10 +17,8 @@
load("@iree_native_python//:build_defs.bzl", "py_extension")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test")
-load("//iree:build_defs.oss.bzl", _PLATFORM_VULKAN_DEPS = "PLATFORM_VULKAN_DEPS")
NUMPY_DEPS = []
-PLATFORM_VULKAN_DEPS = _PLATFORM_VULKAN_DEPS
PYTHON_HEADERS_DEPS = ["@iree_native_python//:python_headers"]
PYTHON_CPP_EXTRA_DEPS = []
diff --git a/bindings/python/pyiree/rt/BUILD b/bindings/python/pyiree/rt/BUILD
index 4c9d89c..6e64af5 100644
--- a/bindings/python/pyiree/rt/BUILD
+++ b/bindings/python/pyiree/rt/BUILD
@@ -15,7 +15,6 @@
load(
"//bindings/python:build_defs.oss.bzl",
"NUMPY_DEPS",
- "PLATFORM_VULKAN_DEPS",
"PYBIND_COPTS",
"PYBIND_EXTENSION_COPTS",
"PYBIND_FEATURES",
@@ -31,12 +30,6 @@
licenses = ["notice"], # Apache 2.0
)
-DRIVER_DEPS = PLATFORM_VULKAN_DEPS + [
- "//iree/hal/vulkan:vulkan_driver_module",
- "//iree/hal/llvmjit:llvmjit_driver_module",
- "//iree/hal/vmla:vmla_driver_module",
-]
-
iree_py_library(
name = "rt",
srcs = [
@@ -59,10 +52,10 @@
features = PYBIND_FEATURES,
linkstatic = 1,
win_def_file = "export.def",
- deps = DRIVER_DEPS + [
+ deps = [
":rt_library",
"//bindings/python/pyiree/common",
- "//iree/base:initializer",
+ "//iree/hal/drivers",
],
)
@@ -90,10 +83,6 @@
"//iree/modules/tensorlist:native_module",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:invocation",
- "//iree/vm:list",
- "//iree/vm:module",
- "//iree/vm:ref",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
diff --git a/bindings/python/pyiree/rt/CMakeLists.txt b/bindings/python/pyiree/rt/CMakeLists.txt
index 225dccd..7828073 100644
--- a/bindings/python/pyiree/rt/CMakeLists.txt
+++ b/bindings/python/pyiree/rt/CMakeLists.txt
@@ -35,8 +35,6 @@
iree::modules::tensorlist::native_module
iree::vm
iree::vm::bytecode_module
- iree::vm::invocation
- iree::vm::ref
absl::inlined_vector
absl::memory
absl::strings
@@ -56,10 +54,7 @@
::PyExtRtLib
bindings::python::pyiree::common::PyextCommonLib
DEPS
- iree::hal::vulkan::vulkan_driver_module
- iree::hal::llvmjit::llvmjit_driver_module
- iree::hal::vmla::vmla_driver_module
- iree::base::initializer
+ iree::hal::drivers
)
iree_py_library(
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index e92adf1..3cc046c 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -24,8 +24,7 @@
#include "iree/base/signature_mangle.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/hal_module.h"
-#include "iree/vm/list.h"
-#include "iree/vm/ref.h"
+#include "iree/vm/api.h"
namespace iree {
namespace python {
diff --git a/bindings/python/pyiree/rt/initialize_module.cc b/bindings/python/pyiree/rt/initialize_module.cc
index e4e3824..840278a 100644
--- a/bindings/python/pyiree/rt/initialize_module.cc
+++ b/bindings/python/pyiree/rt/initialize_module.cc
@@ -18,13 +18,13 @@
#include "bindings/python/pyiree/rt/hal.h"
#include "bindings/python/pyiree/rt/host_types.h"
#include "bindings/python/pyiree/rt/vm.h"
-#include "iree/base/initializer.h"
+#include "iree/hal/drivers/init.h"
namespace iree {
namespace python {
PYBIND11_MODULE(binding, m) {
- IREE_RUN_MODULE_INITIALIZERS();
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
m.doc() = "IREE Binding Backend Helpers";
SetupFunctionAbiBindings(m);
diff --git a/bindings/python/pyiree/rt/vm.cc b/bindings/python/pyiree/rt/vm.cc
index f7e0031..ce1a271 100644
--- a/bindings/python/pyiree/rt/vm.cc
+++ b/bindings/python/pyiree/rt/vm.cc
@@ -24,8 +24,7 @@
#include "iree/modules/hal/hal_module.h"
#include "iree/modules/strings/strings_module.h"
#include "iree/modules/tensorlist/native_module.h"
-#include "iree/vm/invocation.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace python {
diff --git a/bindings/python/pyiree/rt/vm.h b/bindings/python/pyiree/rt/vm.h
index e65560c..34302bc 100644
--- a/bindings/python/pyiree/rt/vm.h
+++ b/bindings/python/pyiree/rt/vm.h
@@ -21,7 +21,6 @@
#include "iree/base/api.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
-#include "iree/vm/list.h"
namespace iree {
namespace python {
diff --git a/build_tools/BUILD b/build_tools/BUILD
new file mode 100644
index 0000000..cdd4ae3
--- /dev/null
+++ b/build_tools/BUILD
@@ -0,0 +1,33 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# Helper for getting linkopts into libraries in a way we can easily spot and
+# rewrite from bazel-to-cmake and copybara.
+cc_library(
+ name = "default_linkopts",
+ linkopts = select({
+ "//iree:iree_is_msvc": [],
+ "//conditions:default": [
+ # Just include libraries that should be presumed in 2020.
+ "-ldl",
+ "-lpthread",
+ ],
+ }),
+)
diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc
index 7fc9945..58d2b91 100644
--- a/build_tools/bazel/iree.bazelrc
+++ b/build_tools/bazel/iree.bazelrc
@@ -145,7 +145,7 @@
build:generic_clang --copt=-Wthread-safety
build:generic_clang --copt=-Wthread-safety-beta
build:generic_clang --copt=-Wunused-comparison
-build:generic_clang --copt=-Wunused-variable
+build:generic_clang --copt=-Wno-unused-variable
build:generic_clang --copt=-Wvla
# LINT.ThenChange(https://github.com/google/iree/tree/main/build_tools/cmake/iree_copts.cmake:clang_diagnostics)
@@ -274,15 +274,29 @@
# absl forces /W3 in their copts, so we exclude them to avoid D9025
build:_msvc_base --per_file_copt=+external,-com_google_absl@/w
+# Find the source of truth for these in iree_copts.cmake.
+build:_msvc_base --copt=/DWIN32_LEAN_AND_MEAN
+build:_msvc_base --copt=/DNOMINMAX
+build:_msvc_base --copt=/D_USE_MATH_DEFINES
+build:_msvc_base --copt=/D_CRT_SECURE_NO_WARNINGS
+build:_msvc_base --copt=/D_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES
+build:_msvc_base --copt=/EHsc
+build:_msvc_base --copt=/bigobj
+
+build:_msvc_base --copt=/wd4200
+build:_msvc_base --copt=/wd4018
+build:_msvc_base --copt=/wd4146 # operator applied to unsigned type, result still unsigned
+build:_msvc_base --copt=/wd4244 # possible loss of data
+build:_msvc_base --copt=/wd4267 # initializing: possible loss of data
+build:_msvc_base --copt=/wd4005 # allow: macro redefinition
+build:_msvc_base --copt=/wd4065 # allow: switch statement contains 'default' but no 'case' labels
+build:_msvc_base --copt=/wd4141 # allow: inline used more than once
+build:_msvc_base --copt=/wd4624 # allow: destructor was implicitly defined as deleted
+build:_msvc_base --copt=/wd5105 # macro expansion producing 'defined' has undefined behavior
+
# And some more explicit disables. For some reason the `/w` on external doesn't
# work for these, maybe they come from headers?
-build:_msvc_base --copt=/wd4244 # possible loss of data
-build:_msvc_base --copt=/wd4624 # destructor was implicitly defined as deleted
-build:_msvc_base --copt=/wd4005 # macro redefinition
-build:_msvc_base --copt=/wd4267 # initializing: possible loss of data
-build:_msvc_base --copt=/wd4141 # inline used more than once
# new warning with the standards-compliant preprocessor. winbase itself is not standards-compliant
-build:_msvc_base --copt=/wd5105
build:_msvc_base --per_file_copt=mkl_dnn@/wd4551 # missing argument list
build:_msvc_base --per_file_copt=mkl_dnn@/wd4068 # unknown pragma
build:_msvc_base --per_file_copt=farmhash@/wd4319 # zero extending to T of greater size
@@ -295,16 +309,8 @@
build:_msvc_base --copt=/arch:AVX
# Host and target are the same in windows so don't waste time building both.
build:_msvc_base --distinct_host_configuration=false
-# Avoids incompatible versions of winsock and other badness.
-build:_msvc_base --copt=/DWIN32_LEAN_AND_MEAN
-# Why are min/max macros? No one knows.
-build:_msvc_base --copt=/DNOMINMAX
-# Yay for security warnings. Boo for non-standard.
-build:_msvc_base --copt=/D_CRT_SECURE_NO_WARNINGS
# TensorFlow requires the "monolithic" build mode for now on Windows.
build:_msvc_base --define framework_shared_object=false
-# Necessary for M_* math constants.
-build:_msvc_base --copt=/D_USE_MATH_DEFINES
# Workaround WinGDI.h defining `ERROR`, which conflicts with logging macros.
# Note that IREE and TensorFlow both `#undef ERROR` and define their own
diff --git a/build_tools/bazel/iree_flatcc.bzl b/build_tools/bazel/iree_flatcc.bzl
index e091af6..6d35299 100644
--- a/build_tools/bazel/iree_flatcc.bzl
+++ b/build_tools/bazel/iree_flatcc.bzl
@@ -37,6 +37,9 @@
outs += ["%s_builder.h" % (out_stem)]
if arg == "--verifier":
outs += ["%s_verifier.h" % (out_stem)]
+ if arg == "--json":
+ outs += ["%s_json_parser.h" % (out_stem)]
+ outs += ["%s_json_printer.h" % (out_stem)]
native.genrule(
name = name + "_gen",
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
index 6f477d1..f6dca0b 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
@@ -37,16 +37,6 @@
def __init__(self, converter):
self.converter = converter
- # TODO(gcmn): Do this in a less hard-coded way
- self.PLATFORM_VULKAN_DEPS = []
- self.PLATFORM_VULKAN_TEST_DEPS = ["//iree/testing:gtest_main"]
- self.FLATBUFFER_SUPPORTS_REFLECTIONS = False
- self.PLATFORM_VULKAN_LOADER_COPTS = []
- self.IREE_DRIVER_MODULES = [
- "//iree/hal/vmla:vmla_driver_module",
- "//iree/hal/vulkan:vulkan_driver_module",
- "//iree/hal/llvmjit:llvmjit_driver_module",
- ]
# ------------------------------------------------------------------------- #
# Conversion utilities, written to reduce boilerplate and allow for reuse #
@@ -109,9 +99,6 @@
else:
return ""
- def _convert_alwayslink_block(self, alwayslink):
- return self._convert_option_block("ALWAYSLINK", alwayslink)
-
def _convert_testonly_block(self, testonly):
return self._convert_option_block("TESTONLY", testonly)
@@ -348,7 +335,7 @@
srcs=None,
data=None,
deps=None,
- alwayslink=False,
+ defines=None,
testonly=False,
linkopts=None,
**kwargs):
@@ -360,7 +347,7 @@
srcs_block = self._convert_srcs_block(srcs)
data_block = self._convert_data_block(data)
deps_block = self._convert_deps_block(deps)
- alwayslink_block = self._convert_alwayslink_block(alwayslink)
+ defines_block = self._convert_string_list_block("DEFINES", defines)
testonly_block = self._convert_testonly_block(testonly)
self.converter.body += (f"iree_cc_library(\n"
@@ -370,7 +357,7 @@
f"{srcs_block}"
f"{data_block}"
f"{deps_block}"
- f"{alwayslink_block}"
+ f"{defines_block}"
f"{testonly_block}"
f" PUBLIC\n)\n\n")
@@ -499,17 +486,6 @@
f"{flatcc_args_block}"
f" PUBLIC\n)\n\n")
- def iree_flatbuffer_cc_library(self, name, srcs, flatc_args=None):
- name_block = self._convert_name_block(name)
- srcs_block = self._convert_srcs_block(srcs)
- flatc_args_block = self._convert_flatc_args_block(flatc_args)
-
- self.converter.body += (f"flatbuffer_cc_library(\n"
- f"{name_block}"
- f"{srcs_block}"
- f"{flatc_args_block}"
- f" PUBLIC\n)\n\n")
-
def gentbl(self,
name,
tblgen,
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index c8f1dbf..f20ae3d 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -16,6 +16,9 @@
# Bazel to CMake target name conversions used by bazel_to_cmake.py.
EXPLICIT_TARGET_MAPPING = {
+ # Internal utilities to emulate various binary/library options.
+ "//build_tools:default_linkopts": [],
+
# absl
"@com_google_absl//absl/flags:flag": ["absl::flags"],
"@com_google_absl//absl/flags:parse": ["absl::flags_parse"],
@@ -57,15 +60,11 @@
"@llvm-project//mlir:MlirOptLib": ["MLIROptLib"],
"@llvm-project//mlir:VectorOps": ["MLIRVector"],
# Vulkan
- # TODO(scotttodd): Set -DVK_NO_PROTOTYPES to COPTS for _no_prototypes.
- # Maybe add a wrapper CMake lib within build_tools/third_party/?
"@iree_vulkan_headers//:vulkan_headers": ["Vulkan::Headers"],
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes": ["Vulkan::Headers"],
# The Bazel target maps to the IMPORTED target defined by FindVulkan().
"@vulkan_sdk//:sdk": ["Vulkan::Vulkan"],
# Misc single targets
"@com_google_benchmark//:benchmark": ["benchmark"],
- "@com_github_google_flatbuffers//:flatbuffers": ["flatbuffers"],
"@com_github_dvidelabs_flatcc//:flatcc": ["flatcc"],
"@com_github_dvidelabs_flatcc//:runtime": ["flatcc::runtime"],
"@com_google_googletest//:gtest": ["gmock", "gtest"],
@@ -73,7 +72,8 @@
"@pffft": ["pffft"],
"@sdl2//:SDL2": ["SDL2-static"],
"@com_github_pytorch_cpuinfo//:cpuinfo": ["cpuinfo"],
- "@half": ["half"],
+ "@half//:half": ["half"],
+ "@vulkan_memory_allocator//:impl_header_only": ["vulkan_memory_allocator"],
}
diff --git a/build_tools/cmake/external_cc_library.cmake b/build_tools/cmake/external_cc_library.cmake
index dd9799b..d9cd38e 100644
--- a/build_tools/cmake/external_cc_library.cmake
+++ b/build_tools/cmake/external_cc_library.cmake
@@ -32,7 +32,6 @@
# DEFINES: List of public defines
# INCLUDES: Include directories to add to dependencies
# LINKOPTS: List of link options
-# ALWAYSLINK: Always link the library into any binary with a direct dep.
# PUBLIC: Add this so that this library will be exported under ${PACKAGE}::
# Also in IDE, target will appear in ${PACKAGE} folder while non PUBLIC will be
# in ${PACKAGE}/internal.
@@ -79,7 +78,7 @@
# )
function(external_cc_library)
cmake_parse_arguments(_RULE
- "PUBLIC;ALWAYSLINK;TESTONLY"
+ "PUBLIC;TESTONLY"
"PACKAGE;NAME;ROOT"
"HDRS;SRCS;COPTS;DEFINES;LINKOPTS;DATA;DEPS;INCLUDES"
${ARGN}
@@ -122,7 +121,8 @@
)
target_include_directories(${_NAME} SYSTEM
PUBLIC
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
"$<BUILD_INTERFACE:${_RULE_INCLUDES}>"
)
target_compile_options(${_NAME}
@@ -134,8 +134,8 @@
PUBLIC
${_RULE_DEPS}
PRIVATE
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
)
target_compile_definitions(${_NAME}
PUBLIC
@@ -143,10 +143,6 @@
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
- if(DEFINED _RULE_ALWAYSLINK)
- set_property(TARGET ${_NAME} PROPERTY ALWAYSLINK 1)
- endif()
-
# Add all external targets to a a folder in the IDE for organization.
if(_RULE_PUBLIC)
set_property(TARGET ${_NAME} PROPERTY FOLDER third_party)
@@ -164,19 +160,20 @@
add_library(${_NAME} INTERFACE)
target_include_directories(${_NAME} SYSTEM
INTERFACE
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
"$<BUILD_INTERFACE:${_RULE_INCLUDES}>"
)
target_compile_options(${_NAME}
INTERFACE
- ${_RULE_COPTS}
${IREE_DEFAULT_COPTS}
+ ${_RULE_COPTS}
)
target_link_libraries(${_NAME}
INTERFACE
- ${_RULE_DEPS}
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
+ ${_RULE_DEPS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
target_compile_definitions(${_NAME}
diff --git a/build_tools/cmake/flatbuffer_c_library.cmake b/build_tools/cmake/flatbuffer_c_library.cmake
index 24ee470..23ade0c 100644
--- a/build_tools/cmake/flatbuffer_c_library.cmake
+++ b/build_tools/cmake/flatbuffer_c_library.cmake
@@ -21,10 +21,6 @@
# Parameters:
# NAME: name of target (see Note)
# SRCS: List of source files for the library
-# DEPS: List of other libraries to be linked in to the binary targets
-# COPTS: List of private compile options
-# DEFINES: List of public defines
-# LINKOPTS: List of link options
# FLATCC_ARGS: List of flattbuffers arguments. Default:
# "--common"
# "--reader"
@@ -39,32 +35,28 @@
#
# flatbuffer_c_library(
# NAME
-# base_schema
+# some_def
# SRCS
-# "a.fbs"
-# )
-# flatbuffer_c_library(
-# NAME
-# other_schemas
-# SRCS
-# "b.fbs"
-# DEPS
-# iree::schemas::base_schema
+# "some_def.fbs"
+# FLATCC_ARGS
+# "--reader"
+# "--builder"
+# "--verifier"
+# "--json"
# PUBLIC
# )
-#
# iree_cc_binary(
# NAME
# main_lib
# ...
# DEPS
-# iree::schemas::other_schemas
+# iree::schemas::some_def
# )
function(flatbuffer_c_library)
cmake_parse_arguments(_RULE
"PUBLIC;TESTONLY"
"NAME"
- "SRCS;COPTS;DEFINES;LINKOPTS;DEPS;FLATCC_ARGS"
+ "SRCS;FLATCC_ARGS"
${ARGN}
)
@@ -95,6 +87,9 @@
list(APPEND _OUTS "${_SRC_FILENAME}_builder.h")
elseif(_ARG STREQUAL "--verifier")
list(APPEND _OUTS "${_SRC_FILENAME}_verifier.h")
+ elseif(_ARG STREQUAL "--json")
+ list(APPEND _OUTS "${_SRC_FILENAME}_json_printer.h")
+ list(APPEND _OUTS "${_SRC_FILENAME}_json_parser.h")
endif()
endforeach()
endforeach()
@@ -125,28 +120,24 @@
${_GEN_TARGET}
DEPENDS
${_OUTS}
- ${_RULE_DEPS}
)
add_library(${_NAME} INTERFACE)
add_dependencies(${_NAME} ${_GEN_TARGET})
target_include_directories(${_NAME} SYSTEM
INTERFACE
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
${CMAKE_CURRENT_BINARY_DIR}
)
target_link_libraries(${_NAME}
INTERFACE
flatcc::runtime
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
)
- target_compile_definitions(${_NAME}
- INTERFACE
- ${_RULE_DEFINES}
- )
target_compile_options(${_NAME}
INTERFACE
+ "-I${IREE_ROOT_DIR}/third_party/flatcc/include/"
"-I${IREE_ROOT_DIR}/third_party/flatcc/include/flatcc/reflection/"
)
diff --git a/build_tools/cmake/flatbuffer_cc_library.cmake b/build_tools/cmake/flatbuffer_cc_library.cmake
deleted file mode 100644
index 52618de..0000000
--- a/build_tools/cmake/flatbuffer_cc_library.cmake
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-include(BuildFlatBuffers)
-include(CMakeParseArguments)
-
-# flatbuffer_cc_library()
-#
-# CMake function to imitate Bazel's flatbuffer_cc_library rule.
-#
-# Parameters:
-# NAME: name of target (see Note)
-# SRCS: List of source files for the library
-# DEPS: List of other libraries to be linked in to the binary targets
-# COPTS: List of private compile options
-# DEFINES: List of public defines
-# LINKOPTS: List of link options
-# FLATC_ARGS: List of flattbuffers arguments. Default:
-# "--keep-prefix"
-# "--scoped-enums"
-# "--reflect-names"
-# "--gen-object-api"
-# PUBLIC: Add this so that this library will be exported under iree::
-# Also in IDE, target will appear in IREE folder while non PUBLIC will be in IREE/internal.
-# TESTONLY: When added, this target will only be built if user passes -DIREE_BUILD_TESTS=ON to CMake.
-#
-# Note:
-# By default, flatbuffer_cc_library will always create a library named ${NAME},
-# and alias target iree::${NAME}. The iree:: form should always be used.
-# This is to reduce namespace pollution.
-#
-# flatbuffer_cc_library(
-# NAME
-# base_schema
-# SRCS
-# "a.fbs"
-# )
-# flatbuffer_cc_library(
-# NAME
-# other_schemas
-# SRCS
-# "b.fbs"
-# DEPS
-# iree::schemas::base_schema
-# PUBLIC
-# )
-#
-# iree_cc_binary(
-# NAME
-# main_lib
-# ...
-# DEPS
-# iree::schemas::other_schemas
-# )
-function(flatbuffer_cc_library)
- cmake_parse_arguments(_RULE
- "PUBLIC;TESTONLY"
- "NAME"
- "SRCS;COPTS;DEFINES;LINKOPTS;DEPS;FLATC_ARGS"
- ${ARGN}
- )
-
- if(_RULE_TESTONLY AND NOT IREE_BUILD_TESTS)
- return()
- endif()
-
- # Prefix the library with the package name, so we get: iree_package_name
- iree_package_name(_PACKAGE_NAME)
- set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
-
- if(NOT DEFINED _RULE_FLATC_ARGS)
- set(FLATBUFFERS_FLATC_SCHEMA_EXTRA_ARGS
- # Preserve root-relative include paths in generated code.
- "--keep-prefix"
- # Use C++11 'enum class' for enums.
- "--scoped-enums"
- # Include reflection tables used for dumping debug representations.
- "--reflect-names"
- # Generate FooT types for unpack/pack support. Note that this should only
- # be used in tooling as the code size/runtime overhead is non-trivial.
- "--gen-object-api"
- )
- else()
- set(FLATBUFFERS_FLATC_SCHEMA_EXTRA_ARGS ${_RULE_FLATC_ARGS})
- endif()
-
- set(_GEN_TARGET "${_NAME}_gen")
-
- build_flatbuffers(
- "${_RULE_SRCS}"
- "${IREE_ROOT_DIR}"
- "${_GEN_TARGET}" # custom_target_name
- "${_RULE_DEPS}" # additional_dependencies
- "${CMAKE_CURRENT_BINARY_DIR}" # generated_include_dir
- "${CMAKE_CURRENT_BINARY_DIR}" # binary_schemas_dir
- "" # copy_text_schemas_dir
- )
-
- # Add dependency on flatc explicitly. This is needed for cross-compiling
- # where flatc comes from another CMake invocation for host.
- iree_add_executable_dependencies(${_GEN_TARGET} flatc)
-
- add_library(${_NAME} INTERFACE)
- add_dependencies(${_NAME} ${_GEN_TARGET})
- target_include_directories(${_NAME} SYSTEM
- INTERFACE
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
- ${CMAKE_CURRENT_BINARY_DIR}
- )
- target_link_libraries(${_NAME}
- INTERFACE
- flatbuffers
- ${_RULE_LINKOPTS}
- ${IREE_DEFAULT_LINKOPTS}
- )
- target_compile_definitions(${_NAME}
- INTERFACE
- ${_RULE_DEFINES}
- )
-
- # Alias the iree_package_name library to iree::package::name.
- # This lets us more clearly map to Bazel and makes it possible to
- # disambiguate the underscores in paths vs. the separators.
- iree_package_ns(_PACKAGE_NS)
- add_library(${_PACKAGE_NS}::${_RULE_NAME} ALIAS ${_NAME})
-endfunction()
diff --git a/build_tools/cmake/iree_cc_binary.cmake b/build_tools/cmake/iree_cc_binary.cmake
index 5e1bd0a..168baa8 100644
--- a/build_tools/cmake/iree_cc_binary.cmake
+++ b/build_tools/cmake/iree_cc_binary.cmake
@@ -14,10 +14,6 @@
include(CMakeParseArguments)
-if (NOT DEFINED _IREE_CC_BINARY_NAMES)
- set(_IREE_CC_BINARY_NAMES "")
-endif()
-
# iree_cc_binary()
#
# CMake function to imitate Bazel's cc_binary rule.
@@ -112,7 +108,8 @@
endif()
target_include_directories(${_NAME} SYSTEM
PUBLIC
- ${IREE_COMMON_INCLUDE_DIRS}
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
)
target_compile_definitions(${_NAME}
PUBLIC
@@ -120,45 +117,33 @@
)
target_compile_options(${_NAME}
PRIVATE
+ ${IREE_DEFAULT_COPTS}
${_RULE_COPTS}
)
target_link_options(${_NAME}
PRIVATE
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
+ )
+
+ # Replace dependencies passed by ::name with iree::package::name
+ iree_package_ns(_PACKAGE_NS)
+ list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
+
+ target_link_libraries(${_NAME}
+ PUBLIC
+ ${_RULE_DEPS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
- iree_package_ns(_PACKAGE_NS)
- # Replace dependencies passed by ::name with ::iree::package::name
- list(TRANSFORM _RULE_DATA REPLACE "^::" "${_PACKAGE_NS}::")
- list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
-
# Add all IREE targets to a folder in the IDE for organization.
set_property(TARGET ${_NAME} PROPERTY FOLDER ${IREE_IDE_FOLDER}/binaries)
set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD ${IREE_CXX_STANDARD})
set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD_REQUIRED ON)
- # Defer computing transitive dependencies and calling target_link_libraries()
- # until all libraries have been declared.
- # Track target and deps, use in iree_complete_binary_link_options() later.
- set_property(GLOBAL APPEND PROPERTY _IREE_CC_BINARY_NAMES "${_NAME}")
- set_property(TARGET ${_NAME} PROPERTY DIRECT_DEPS ${_RULE_DEPS})
-
install(TARGETS ${_NAME}
RENAME ${_RULE_NAME}
COMPONENT ${_RULE_NAME}
RUNTIME DESTINATION bin)
endfunction()
-
-# Sets target_link_libraries() on all registered binaries.
-# This must be called after all libraries have been declared.
-function(iree_complete_binary_link_options)
- get_property(_NAMES GLOBAL PROPERTY _IREE_CC_BINARY_NAMES)
-
- foreach(_NAME ${_NAMES})
- get_target_property(_DIRECT_DEPS ${_NAME} DIRECT_DEPS)
- iree_whole_archive_link(${_NAME} ${_DIRECT_DEPS})
- endforeach(_NAME)
-endfunction()
diff --git a/build_tools/cmake/iree_cc_library.cmake b/build_tools/cmake/iree_cc_library.cmake
index e4d5b67..fc7cb5d 100644
--- a/build_tools/cmake/iree_cc_library.cmake
+++ b/build_tools/cmake/iree_cc_library.cmake
@@ -29,12 +29,10 @@
# DEFINES: List of public defines
# INCLUDES: Include directories to add to dependencies
# LINKOPTS: List of link options
-# ALWAYSLINK: Always link the library into any binary with a direct dep.
# PUBLIC: Add this so that this library will be exported under iree::
# Also in IDE, target will appear in IREE folder while non PUBLIC will be in IREE/internal.
# TESTONLY: When added, this target will only be built if user passes -DIREE_BUILD_TESTS=ON to CMake.
# SHARED: If set, will compile to a shared object.
-# WHOLEARCHIVE: If set, links all symbols from "ALWAYSLINK" libraries.
#
# Note:
# By default, iree_cc_library will always create a library named iree_${NAME},
@@ -69,7 +67,7 @@
function(iree_cc_library)
cmake_parse_arguments(
_RULE
- "PUBLIC;ALWAYSLINK;TESTONLY;SHARED;WHOLEARCHIVE"
+ "PUBLIC;TESTONLY;SHARED;WHOLEARCHIVE"
"NAME"
"HDRS;TEXTUAL_HDRS;SRCS;COPTS;DEFINES;LINKOPTS;DATA;DEPS;INCLUDES"
${ARGN}
@@ -79,10 +77,9 @@
return()
endif()
+ # Replace dependencies passed by ::name with iree::package::name
iree_package_ns(_PACKAGE_NS)
- # Replace dependencies passed by ::name with ::iree::package::name
list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
- list(TRANSFORM _RULE_DATA REPLACE "^::" "${_PACKAGE_NS}::")
# Prefix the library with the package name, so we get: iree_package_name.
iree_package_name(_PACKAGE_NAME)
@@ -105,13 +102,10 @@
endif()
if(NOT _RULE_IS_INTERFACE)
- if (_RULE_SHARED)
+ if(_RULE_SHARED)
add_library(${_NAME} SHARED "")
else()
add_library(${_NAME} STATIC "")
- if (_RULE_WHOLEARCHIVE)
- message(FATAL_ERROR "WHOLEARCHIVE must be set together with SHARED")
- endif()
endif()
target_sources(${_NAME}
@@ -122,7 +116,8 @@
)
target_include_directories(${_NAME} SYSTEM
PUBLIC
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
)
target_include_directories(${_NAME}
PUBLIC
@@ -130,19 +125,16 @@
)
target_compile_options(${_NAME}
PRIVATE
- ${_RULE_COPTS}
${IREE_DEFAULT_COPTS}
+ ${_RULE_COPTS}
)
- if(_RULE_WHOLEARCHIVE)
- iree_whole_archive_link(${_NAME} ${_RULE_DEPS})
- else()
- target_link_libraries(${_NAME} PUBLIC ${_RULE_DEPS})
- endif()
target_link_libraries(${_NAME}
+ PUBLIC
+ ${_RULE_DEPS}
PRIVATE
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
@@ -151,10 +143,6 @@
${_RULE_DEFINES}
)
- if(DEFINED _RULE_ALWAYSLINK)
- set_property(TARGET ${_NAME} PROPERTY ALWAYSLINK 1)
- endif()
-
# Add all IREE targets to a folder in the IDE for organization.
if(_RULE_PUBLIC)
set_property(TARGET ${_NAME} PROPERTY FOLDER ${IREE_IDE_FOLDER})
@@ -172,18 +160,19 @@
add_library(${_NAME} INTERFACE)
target_include_directories(${_NAME} SYSTEM
INTERFACE
- "$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
)
target_compile_options(${_NAME}
INTERFACE
- ${_RULE_COPTS}
${IREE_DEFAULT_COPTS}
+ ${_RULE_COPTS}
)
target_link_libraries(${_NAME}
INTERFACE
+ ${IREE_DEFAULT_LINKOPTS}
${_RULE_DEPS}
${_RULE_LINKOPTS}
- ${IREE_DEFAULT_LINKOPTS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
target_compile_definitions(${_NAME}
diff --git a/build_tools/cmake/iree_cc_test.cmake b/build_tools/cmake/iree_cc_test.cmake
index 4ae27fb..0eac677 100644
--- a/build_tools/cmake/iree_cc_test.cmake
+++ b/build_tools/cmake/iree_cc_test.cmake
@@ -66,11 +66,6 @@
${ARGN}
)
- iree_package_ns(_PACKAGE_NS)
- # Replace dependencies passed by ::name with ::iree::package::name
- list(TRANSFORM _RULE_DATA REPLACE "^::" "${_PACKAGE_NS}::")
- list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
-
# Prefix the library with the package name, so we get: iree_package_name
iree_package_name(_PACKAGE_NAME)
set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
@@ -82,7 +77,8 @@
)
target_include_directories(${_NAME} SYSTEM
PUBLIC
- ${IREE_COMMON_INCLUDE_DIRS}
+ "$<BUILD_INTERFACE:${IREE_SOURCE_DIR}>"
+ "$<BUILD_INTERFACE:${IREE_BINARY_DIR}>"
)
target_compile_definitions(${_NAME}
PUBLIC
@@ -90,26 +86,32 @@
)
target_compile_options(${_NAME}
PRIVATE
+ ${IREE_DEFAULT_COPTS}
${_RULE_COPTS}
)
target_link_options(${_NAME}
PRIVATE
- ${_RULE_LINKOPTS}
${IREE_DEFAULT_LINKOPTS}
+ ${_RULE_LINKOPTS}
+ )
+
+ # Replace dependencies passed by ::name with iree::package::name
+ iree_package_ns(_PACKAGE_NS)
+ list(TRANSFORM _RULE_DEPS REPLACE "^::" "${_PACKAGE_NS}::")
+
+ target_link_libraries(${_NAME}
+ PUBLIC
+ ${_RULE_DEPS}
)
iree_add_data_dependencies(NAME ${_NAME} DATA ${_RULE_DATA})
+
# Add all IREE targets to a folder in the IDE for organization.
set_property(TARGET ${_NAME} PROPERTY FOLDER ${IREE_IDE_FOLDER}/test)
set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD ${IREE_CXX_STANDARD})
set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD_REQUIRED ON)
- # Defer computing transitive dependencies and calling target_link_libraries()
- # until all libraries have been declared.
- # Track target and deps, use in iree_complete_binary_link_options() later.
list(APPEND _RULE_DEPS "gmock")
- set_property(GLOBAL APPEND PROPERTY _IREE_CC_BINARY_NAMES "${_NAME}")
- set_property(TARGET ${_NAME} PROPERTY DIRECT_DEPS ${_RULE_DEPS})
string(REPLACE "::" "/" _PACKAGE_PATH ${_PACKAGE_NS})
set(_TEST_NAME "${_PACKAGE_PATH}/${_RULE_NAME}")
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 968380c..903b576 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -21,22 +21,74 @@
# By default Abseil strips string literals on mobile platforms, which means
# we cannot run IREE binaries via command-line with proper options. Turn off
# the stripping.
-# TODO: we might still want to strip when compiling IREE into Android Java apps.
+# TODO(#3814): remove ABSL flags.
if(ANDROID)
add_definitions(-DABSL_FLAGS_STRIP_NAMES=0)
endif()
#-------------------------------------------------------------------------------
-# C++ used within IREE
+# C/C++ options as used within IREE
#-------------------------------------------------------------------------------
+#
+# ██ ██ █████ ██████ ███ ██ ██ ███ ██ ██████
+# ██ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██ ██
+# ██ █ ██ ███████ ██████ ██ ██ ██ ██ ██ ██ ██ ██ ███
+# ██ ███ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██ ██
+# ███ ███ ██ ██ ██ ██ ██ ████ ██ ██ ████ ██████
+#
+# Everything here is added to *every* iree_cc_library/iree_cc_binary/etc.
+# That includes both runtime and compiler components, and these may propagate
+# out to user code interacting with either (such as custom modules).
+#
+# Be extremely judicious in the use of these flags.
+#
+# - Need to disable a warning?
+# Usually these are encountered in compiler-specific code and can be disabled
+# in a compiler-specific way. Only add global warning disables when it's clear
+# that we never want them or that they'll show up in a lot of places.
+#
+# See: https://stackoverflow.com/questions/3378560/how-to-disable-gcc-warnings-for-a-few-lines-of-code
+#
+# - Need to add a linker dependency?
+# First figure out if you *really* need it. If it's only required on specific
+# platforms and in very specific files clang or msvc are used prefer
+# autolinking. GCC is stubborn and doesn't have autolinking so additional
+# flags may be required there.
+#
+# See: https://en.wikipedia.org/wiki/Auto-linking
+#
+# - Need to tweak a compilation mode setting (debug/asserts/etc)?
+# Don't do that here, and in general *don't do that at all* unless it's behind
+# a very specific IREE-prefixed cmake flag (like IREE_SIZE_OPTIMIZED).
+# There's no one-size solution when we are dealing with cross-project and
+# cross-compiled binaries - there's no safe way to set global options that
+# won't cause someone to break, and you probably don't really need to do
+# change that setting anyway. Follow the rule of least surprise: if the user
+# has CMake's Debug configuration active then don't force things into release
+# mode, etc.
+#
+# - Need to add an include directory?
+# Don't do that here. Always prefer to fully-specify the path from the IREE
+# workspace root when it's known that the compilation will be occuring using
+# the files within the IREE checkout; for example, instead of adding a global
+# include path to third_party/foo/ and #include <foo.h>'ing, just
+# #include "third_party/foo/foo.h". This reduces build configuration, makes it
+# easier for readers to find the files, etc.
+#
+# - Still think you need to add an include directory? (system includes, etc)
+# Don't do that here, either. It's highly doubtful that every single target in
+# all of IREE (both compiler and runtime) on all platforms (both host and
+# cross-compilation targets) needs your special include directory. Add it on
+# the COPTS of the target you are using it in and, ideally, private to that
+# target (used in .c/cc files, not in a .h that leaks the include path
+# requirements to all consumers of the API).
set(IREE_CXX_STANDARD ${CMAKE_CXX_STANDARD})
-set(IREE_ROOT_DIR ${PROJECT_SOURCE_DIR})
-list(APPEND IREE_COMMON_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}
- ${PROJECT_BINARY_DIR}
-)
+# TODO(benvanik): fix these names (or remove entirely).
+set(IREE_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR})
+set(IREE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
+set(IREE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
iree_select_compiler_opts(IREE_DEFAULT_COPTS
CLANG
@@ -97,7 +149,6 @@
"-Wthread-safety"
"-Wthread-safety-beta"
"-Wunused-comparison"
- "-Wunused-variable"
"-Wvla"
# LINT.ThenChange(https://github.com/google/iree/tree/main/build_tools/bazel/iree.bazelrc:clang_diagnostics)
@@ -111,27 +162,145 @@
"-Wno-gnu-label-as-value"
CLANG_OR_GCC
"-Wno-unused-parameter"
+ "-Wno-unused-variable"
"-Wno-undef"
"-fvisibility=hidden"
MSVC_OR_CLANG_CL
+ # Exclude a bunch of rarely-used APIs, such as crypto/DDE/shell.
+ # https://docs.microsoft.com/en-us/windows/win32/winprog/using-the-windows-headers
+ # NOTE: this is not really required anymore for build performance but does
+ # work around some issues that crop up with header version compatibility
+ # (abseil has issues with winsock versions).
"/DWIN32_LEAN_AND_MEAN"
+
+ # Don't allow windows.h to define MIN and MAX and conflict with the STL.
+ # There's no legit use for these macros as any code we are writing ourselves
+ # that we want a MIN/MAX in should be using an IREE-prefixed version
+ # instead: iree_min iree_max
+ # https://stackoverflow.com/a/4914108
+ "/DNOMINMAX"
+
+ # Adds M_PI and other constants to <math.h>/<cmath> (to match non-windows).
+ # https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
"/D_USE_MATH_DEFINES"
- "/wd4624"
- # 'inline': used more than once
- "/wd4141"
- # 'WIN32_LEAN_AND_MEAN': macro redefinition
- "/wd4005"
- "/wd4267"
- "/wd4141"
- "/wd4244"
- "/wd4146"
- "/wd4018"
- "/wd4065"
- # TODO(benvanik): figure out if really required or accidentally enabled.
+
+ # Disable the "deprecation" warnings about CRT functions like strcpy.
+ # Though the secure versions *are* better, they aren't portable and as such
+ # just make cross-platform code annoying. One solution is to reimplement
+ # them in a portable fashion and use those - and that's what we try to do
+ # in certain places where we can get away with it. Other uses, like getenv,
+ # are fine as these are not intended for use in core runtime code that needs
+ # to be secure (friends don't let friends ship entire compiler stacks
+ # embedded inside security sensitive applications anyway :).
+ # https://docs.microsoft.com/en-us/cpp/c-runtime-library/security-features-in-the-crt
+ "/D_CRT_SECURE_NO_WARNINGS"
+
+ # With the above said about the "deprecated" functions; this useful flag
+ # will at least try to use them when possible without any change to user
+ # code. Note however because the new versions use templates they won't be
+ # activated in C code; that's fine.
+ # https://docs.microsoft.com/en-us/cpp/c-runtime-library/secure-template-overloads
+ "/D_CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES"
+
+ # Configure exception handling for standard C++ behavior.
+ # - /EHs enables C++ catch-style exceptions
+ # - /EHc breaks unwinding across extern C boundaries, dramatically reducing
+ # unwind table size and associated exception handling overhead as the
+ # compiler can assume no exception will ever be thrown within any function
+ # annotated with extern "C".
+ # https://docs.microsoft.com/en-us/cpp/build/reference/eh-exception-handling-model
+ #
+ # TODO(benvanik): figure out if we need /EHs - we don't use exceptions in
+ # the runtime and I'm pretty sure LLVM doesn't use them either.
"/EHsc"
+
+ # Default max section count is 64k, which is woefully inadequate for some of
+ # the insanely bloated tablegen outputs LLVM/MLIR produces. This cranks it
+ # up to 2^32. It's not great that we have to generate/link files like that
+ # but it's better to not get spurious failures during LTCG.
+ # https://docs.microsoft.com/en-us/cpp/build/reference/bigobj-increase-number-of-sections-in-dot-obj-file
"/bigobj"
+
+ # "nonstandard extension used : zero-sized array in struct/union"
+ # This happens with unsized or zero-length arrays at the end of structs,
+ # which is completely valid in C where we do it and get this warning. Shut
+ # it up and rely on the better warnings from clang to catch if we try to
+ # use it where it really matters (on a class that has copy/move ctors, etc).
+ # https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-levels-2-and-4-c4200
+ "/wd4200"
+
+ # "signed/unsigned mismatch in comparison"
+ # This is along the lines of a generic implicit conversion warning but tends
+ # to crop up in code that implicitly treats unsigned size_t values as if
+ # they were signed values instead of properly using ssize_t. In certain
+ # cases where the comparison being performed may be guarding access to
+ # memory this can cause unexpected behavior ("-1ull < 512ull, great let's
+ # dereference buffer[-1ull]!").
+ # https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-3-c4018
+ #
+ # TODO(#3844): remove this (or make it per-file to iree/compiler, as LLVM
+ # tends to not care about these kind of things and it crops up there a lot).
+ "/wd4018"
+
+ # Also common in LLVM is mismatching signed/unsigned math. That's even more
+ # dangerous than C4018: almost always these crop up in doing something with
+ # a size_t and a non-size_t value (usually int or something like it) and do
+ # you want out-of-bounds access exploits? Because that's how you get
+ # out-of-bounds access exploits. Before fuzzers took over finding code and
+ # trying to compile it with this warning forced to be an error was a way to
+ # narrow down the places to look for attack vectors. I lived through the
+ # Microsoft SAL/safe-int code red, and once you get used to using the safe
+ # buffer offset/size manipulation functions it eliminates all kinds of
+ # annoying bugs - as well as potential security issues.
+ #
+ # TODO(#3844): work to remove this class of errors from our code. It's
+ # almost entirely in LLVM related stuff so per-file iree/compiler/... would
+ # be fine.
+ "/wd4146" # operator applied to unsigned type, result still unsigned
+ "/wd4244" # possible loss of data
+ "/wd4267" # initializing: possible loss of data
+
+ # Misc tweaks to better match reasonable clang/gcc behavior:
+ "/wd4005" # allow: macro redefinition
+ "/wd4065" # allow: switch statement contains 'default' but no 'case' labels
+ "/wd4141" # allow: inline used more than once
+ "/wd4624" # allow: destructor was implicitly defined as deleted
+
+ # TODO(benvanik): confirm these are all still required and document:
+ "/wd4146" # operator applied to unsigned type, result still unsigned
+ "/wd4244" # possible loss of data
+ "/wd4267" # initializing: possible loss of data
+ "/wd5105" # allow: macro expansion producing 'defined' has undefined behavior
)
-set(IREE_DEFAULT_LINKOPTS "${ABSL_DEFAULT_LINKOPTS}")
+
+if(NOT ANDROID)
+ iree_select_compiler_opts(_IREE_PTHREADS_LINKOPTS
+ CLANG_OR_GCC
+ "-lpthread"
+ )
+else()
+ # Android provides its own pthreads support with no linking required.
+endif()
+
+iree_select_compiler_opts(IREE_DEFAULT_LINKOPTS
+ ALL
+ # TODO(benvanik): remove the ABSL usage here; we aren't abseil.
+ "${ABSL_DEFAULT_LINKOPTS}"
+ CLANG_OR_GCC
+ # Required by all modern software, effectively:
+ "-ldl"
+ ${_IREE_PTHREADS_LINKOPTS}
+)
+
+# Add to LINKOPTS on a binary to configure it for X/Wayland/Windows/etc
+# depending on the target cross-compilation platform.
+if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
+ set(IREE_TARGET_GUI_LINKOPTS "-SUBSYSTEM:WINDOWS")
+else()
+ set(IREE_TARGET_GUI_LINKOPTS "")
+endif()
+
+# TODO(benvanik): remove the ABSL usage here; we aren't abseil.
set(IREE_TEST_COPTS "${ABSL_TEST_COPTS}")
#-------------------------------------------------------------------------------
@@ -207,7 +376,7 @@
set(FLATBUFFERS_BUILD_GRPCTEST OFF CACHE BOOL "" FORCE)
set(FLATBUFFERS_INSTALL OFF CACHE BOOL "" FORCE)
set(FLATBUFFERS_INCLUDE_DIRS
- "${PROJECT_SOURCE_DIR}/third_party/flatbuffers/include/"
+ "${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include/"
)
if(CMAKE_CROSSCOMPILING)
@@ -216,14 +385,6 @@
set(FLATBUFFERS_BUILD_FLATC ON CACHE BOOL "" FORCE)
endif()
-iree_select_compiler_opts(FLATBUFFERS_COPTS
- CLANG
- # Flatbuffers has a bunch of incorrect documentation annotations.
- "-Wno-documentation"
- "-Wno-documentation-unknown-command"
-)
-list(APPEND IREE_DEFAULT_COPTS ${FLATBUFFERS_COPTS})
-
#-------------------------------------------------------------------------------
# Third party: flatcc
#-------------------------------------------------------------------------------
@@ -267,35 +428,12 @@
set(LLVM_USE_LINKER ${IREE_USE_LINKER} CACHE STRING "" FORCE)
endif()
-# TODO: This should go in add_iree_mlir_src_dep at the top level.
-if(IREE_MLIR_DEP_MODE STREQUAL "BUNDLED")
- list(APPEND IREE_COMMON_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/llvm-project/llvm/include
- ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/include
- ${PROJECT_SOURCE_DIR}/third_party/llvm-project/mlir/include
- ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include
- )
-endif()
-
set(MLIR_TABLEGEN_EXE mlir-tblgen)
# iree-tblgen is not defined using the add_tablegen mechanism as other TableGen
# tools in LLVM.
iree_get_executable_path(IREE_TABLEGEN_EXE iree-tblgen)
#-------------------------------------------------------------------------------
-# Third party: tensorflow
-#-------------------------------------------------------------------------------
-
-list(APPEND IREE_COMMON_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/tensorflow
- ${PROJECT_SOURCE_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/include/
- ${PROJECT_BINARY_DIR}/third_party/tensorflow
- ${PROJECT_BINARY_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/include/
- ${PROJECT_BINARY_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/
- ${PROJECT_BINARY_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms
-)
-
-#-------------------------------------------------------------------------------
# Third party: mlir-emitc
#-------------------------------------------------------------------------------
@@ -304,10 +442,6 @@
set(EMITC_ENABLE_HLO OFF)
set(EMITC_INCLUDE_TESTS OFF)
- list(APPEND IREE_COMMON_INCLUDE_DIRS
- ${PROJECT_SOURCE_DIR}/third_party/mlir-emitc/include
- ${PROJECT_BINARY_DIR}/third_party/mlir-emitc/include
- )
add_definitions(-DIREE_HAVE_EMITC_DIALECT)
endif()
diff --git a/build_tools/cmake/iree_cross_compile.cmake b/build_tools/cmake/iree_cross_compile.cmake
index 3cfa1e9..f749f90 100644
--- a/build_tools/cmake/iree_cross_compile.cmake
+++ b/build_tools/cmake/iree_cross_compile.cmake
@@ -101,7 +101,8 @@
message(STATUS "C++ compiler for ${CONFIG_NAME} build: ${_CONFIG_CXX_COMPILER}")
add_custom_command(OUTPUT ${IREE_${CONFIG_NAME}_BINARY_ROOT}/CMakeCache.txt
- COMMAND "${CMAKE_COMMAND}" "${PROJECT_SOURCE_DIR}" -G "${CMAKE_GENERATOR}"
+ COMMAND "${CMAKE_COMMAND}" "${CMAKE_CURRENT_SOURCE_DIR}"
+ -G "${CMAKE_GENERATOR}"
-DCMAKE_MAKE_PROGRAM="${CMAKE_MAKE_PROGRAM}"
-DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}"
-DCMAKE_C_COMPILER="${_CONFIG_C_COMPILER}"
@@ -150,7 +151,7 @@
if(NOT _RULE_CONFIG)
set(_RULE_CONFIG "$<CONFIG>")
endif()
- if (CMAKE_GENERATOR MATCHES "Make")
+ if(CMAKE_GENERATOR MATCHES "Make")
# Use special command for Makefiles to support parallelism.
set(${_RULE_CMDVAR}
"$(MAKE)" "-C" "${_RULE_BINDIR}" "${EXECUTABLE_TARGET}" PARENT_SCOPE)
diff --git a/build_tools/cmake/iree_multipy.cmake b/build_tools/cmake/iree_multipy.cmake
index a974e95..f773925 100644
--- a/build_tools/cmake/iree_multipy.cmake
+++ b/build_tools/cmake/iree_multipy.cmake
@@ -344,7 +344,7 @@
# Performs a depth-first search through the dependency graph, appending all
# dependencies of TARGET to the TRANSITIVE_DEPS list.
function(_iree_transitive_dependencies_helper TARGET TRANSITIVE_DEPS)
- if (NOT TARGET "${TARGET}")
+ if(NOT TARGET "${TARGET}")
# Excluded from the project, or invalid name? Just ignore.
return()
endif()
@@ -358,7 +358,7 @@
endif()
set(_RESULT "${${TRANSITIVE_DEPS}}")
- if (${_TARGET_NAME} IN_LIST _RESULT)
+ if(${_TARGET_NAME} IN_LIST _RESULT)
# Already visited, ignore.
return()
endif()
@@ -368,7 +368,7 @@
list(APPEND _RESULT ${_TARGET_NAME})
# Check for non-target identifiers again after resolving the alias.
- if (NOT TARGET ${_TARGET_NAME})
+ if(NOT TARGET ${_TARGET_NAME})
return()
endif()
@@ -408,7 +408,7 @@
foreach(_DEP ${_TRANSITIVE_DEPS})
# Check if _DEP is a library with the ALWAYSLINK property set.
set(_DEP_IS_ALWAYSLINK OFF)
- if (TARGET ${_DEP})
+ if(TARGET ${_DEP})
get_target_property(_DEP_TYPE ${_DEP} TYPE)
if(${_DEP_TYPE} STREQUAL "INTERFACE_LIBRARY")
# Can't be ALWAYSLINK since it's an INTERFACE library.
@@ -428,14 +428,14 @@
# For macOS, also add a `-Wl,-force_load` version of the dep.
if(MSVC)
get_target_property(_ALIASED_TARGET ${_DEP} ALIASED_TARGET)
- if (_ALIASED_TARGET)
+ if(_ALIASED_TARGET)
list(APPEND _ALWAYS_LINK_DEPS "-WHOLEARCHIVE:${_ALIASED_TARGET}")
else()
list(APPEND _ALWAYS_LINK_DEPS "-WHOLEARCHIVE:${_DEP}")
endif()
elseif(APPLE)
get_target_property(_ALIASED_TARGET ${_DEP} ALIASED_TARGET)
- if (_ALIASED_TARGET)
+ if(_ALIASED_TARGET)
list(APPEND _ALWAYS_LINK_DEPS "-Wl,-force_load $<TARGET_FILE:${_ALIASED_TARGET}>")
else()
list(APPEND _ALWAYS_LINK_DEPS "-Wl,-force_load $<TARGET_FILE:${_DEP}>")
diff --git a/build_tools/cmake/iree_setup_toolchain.cmake b/build_tools/cmake/iree_setup_toolchain.cmake
index 48e2463..c3789df 100644
--- a/build_tools/cmake/iree_setup_toolchain.cmake
+++ b/build_tools/cmake/iree_setup_toolchain.cmake
@@ -19,7 +19,7 @@
endfunction()
if(IREE_ENABLE_LLD)
- if (IREE_USE_LINKER)
+ if(IREE_USE_LINKER)
message(FATAL_ERROR "IREE_ENABLE_LLD and IREE_USE_LINKER can't be set at the same time")
endif()
set(IREE_USE_LINKER "lld")
diff --git a/build_tools/cmake/iree_tablegen_doc.cmake b/build_tools/cmake/iree_tablegen_doc.cmake
index cc62fe4..5e93390 100644
--- a/build_tools/cmake/iree_tablegen_doc.cmake
+++ b/build_tools/cmake/iree_tablegen_doc.cmake
@@ -53,7 +53,10 @@
endif()
- set(_INCLUDE_DIRS ${IREE_COMMON_INCLUDE_DIRS})
+ set(_INCLUDE_DIRS
+ "${MLIR_INCLUDE_DIRS}"
+ "${IREE_SOURCE_DIR}"
+ )
list(APPEND _INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR})
list(TRANSFORM _INCLUDE_DIRS PREPEND "-I")
@@ -73,7 +76,7 @@
endwhile()
# Put all dialect docs at one place.
- set(_DOC_DIR ${PROJECT_BINARY_DIR}/doc/Dialects/)
+ set(_DOC_DIR ${CMAKE_CURRENT_BINARY_DIR}/doc/Dialects/)
# Set a target to drive copy.
add_custom_target(${_NAME}_target
${CMAKE_COMMAND} -E make_directory ${_DOC_DIR}
diff --git a/build_tools/cmake/iree_tablegen_library.cmake b/build_tools/cmake/iree_tablegen_library.cmake
index e185e59..6226b16 100644
--- a/build_tools/cmake/iree_tablegen_library.cmake
+++ b/build_tools/cmake/iree_tablegen_library.cmake
@@ -41,7 +41,10 @@
endif()
set(LLVM_TARGET_DEFINITIONS ${_RULE_TD_FILE})
- set(_INCLUDE_DIRS ${IREE_COMMON_INCLUDE_DIRS})
+ set(_INCLUDE_DIRS
+ "${MLIR_INCLUDE_DIRS}"
+ "${IREE_SOURCE_DIR}"
+ )
list(APPEND _INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR})
list(TRANSFORM _INCLUDE_DIRS PREPEND "-I")
set(_OUTPUTS)
diff --git a/build_tools/cmake/iree_whole_archive_link.cmake b/build_tools/cmake/iree_whole_archive_link.cmake
deleted file mode 100644
index e34f900..0000000
--- a/build_tools/cmake/iree_whole_archive_link.cmake
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Lists all transitive dependencies of DIRECT_DEPS in TRANSITIVE_DEPS.
-function(_iree_transitive_dependencies DIRECT_DEPS TRANSITIVE_DEPS)
- set(_TRANSITIVE "")
-
- foreach(_DEP ${DIRECT_DEPS})
- _iree_transitive_dependencies_helper(${_DEP} _TRANSITIVE)
- endforeach(_DEP)
-
- set(${TRANSITIVE_DEPS} "${_TRANSITIVE}" PARENT_SCOPE)
-endfunction()
-
-# Recursive helper function for _iree_transitive_dependencies.
-# Performs a depth-first search through the dependency graph, appending all
-# dependencies of TARGET to the TRANSITIVE_DEPS list.
-function(_iree_transitive_dependencies_helper TARGET TRANSITIVE_DEPS)
- if (NOT TARGET "${TARGET}")
- # Excluded from the project, or invalid name? Just ignore.
- return()
- endif()
-
- # Resolve aliases, canonicalize name formatting.
- get_target_property(_ALIASED_TARGET ${TARGET} ALIASED_TARGET)
- if(_ALIASED_TARGET)
- set(_TARGET_NAME ${_ALIASED_TARGET})
- else()
- string(REPLACE "::" "_" _TARGET_NAME ${TARGET})
- endif()
-
- set(_RESULT "${${TRANSITIVE_DEPS}}")
- if (${_TARGET_NAME} IN_LIST _RESULT)
- # Already visited, ignore.
- return()
- endif()
-
- # Append this target to the list. Dependencies of this target will be added
- # (if valid and not already visited) in recursive function calls.
- list(APPEND _RESULT ${_TARGET_NAME})
-
- # Check for non-target identifiers again after resolving the alias.
- if (NOT TARGET ${_TARGET_NAME})
- return()
- endif()
-
- # Get the list of direct dependencies for this target.
- get_target_property(_TARGET_TYPE ${_TARGET_NAME} TYPE)
- if(NOT ${_TARGET_TYPE} STREQUAL "INTERFACE_LIBRARY")
- get_target_property(_TARGET_DEPS ${_TARGET_NAME} LINK_LIBRARIES)
- else()
- get_target_property(_TARGET_DEPS ${_TARGET_NAME} INTERFACE_LINK_LIBRARIES)
- endif()
-
- if(_TARGET_DEPS)
- # Recurse on each dependency.
- foreach(_TARGET_DEP ${_TARGET_DEPS})
- _iree_transitive_dependencies_helper(${_TARGET_DEP} _RESULT)
- endforeach(_TARGET_DEP)
- endif()
-
- # Propagate the augmented list up to the parent scope.
- set(${TRANSITIVE_DEPS} "${_RESULT}" PARENT_SCOPE)
-endfunction()
-
-# Given the ${TARGET} and the libaries it directly depends on in ${ARGN},
-# properly establish the linking relationship by considering ALWAYSLINK
-# in a recursive manner.
-#
-# All symbols from ALWAYSLINK libraries will be included in ${TARGET},
-# regardless of whether they are directly referenced or not.
-function(iree_whole_archive_link TARGET)
- # List all dependencies, including transitive dependencies, then split the
- # dependency list into one for whole archive (ALWAYSLINK) and one for
- # standard linking (which only links in symbols that are directly used).
- _iree_transitive_dependencies("${ARGN}" _TRANSITIVE_DEPS)
- set(_ALWAYS_LINK_DEPS "")
- set(_STANDARD_DEPS "")
- foreach(_DEP ${_TRANSITIVE_DEPS})
- # Check if _DEP is a library with the ALWAYSLINK property set.
- set(_DEP_IS_ALWAYSLINK OFF)
- if (TARGET ${_DEP})
- get_target_property(_DEP_TYPE ${_DEP} TYPE)
- if(${_DEP_TYPE} STREQUAL "INTERFACE_LIBRARY")
- # Can't be ALWAYSLINK since it's an INTERFACE library.
- # We also can't even query for the property, since it isn't allowlisted.
- else()
- get_target_property(_DEP_IS_ALWAYSLINK ${_DEP} ALWAYSLINK)
- endif()
- endif()
-
- # Append to the corresponding list of deps.
- if(_DEP_IS_ALWAYSLINK)
- list(APPEND _ALWAYS_LINK_DEPS ${_DEP})
-
- # For MSVC, also add a `-WHOLEARCHIVE:` version of the dep.
- # CMake treats -WHOLEARCHIVE[:lib] as a link flag and will not actually
- # try to link the library in, so we need the flag *and* the dependency.
- # For macOS, also add a `-Wl,-force_load` version of the dep.
- if(MSVC)
- get_target_property(_ALIASED_TARGET ${_DEP} ALIASED_TARGET)
- if (_ALIASED_TARGET)
- list(APPEND _ALWAYS_LINK_DEPS "-WHOLEARCHIVE:${_ALIASED_TARGET}")
- else()
- list(APPEND _ALWAYS_LINK_DEPS "-WHOLEARCHIVE:${_DEP}")
- endif()
- elseif(APPLE)
- get_target_property(_ALIASED_TARGET ${_DEP} ALIASED_TARGET)
- if (_ALIASED_TARGET)
- list(APPEND _ALWAYS_LINK_DEPS "-Wl,-force_load $<TARGET_FILE:${_ALIASED_TARGET}>")
- else()
- list(APPEND _ALWAYS_LINK_DEPS "-Wl,-force_load $<TARGET_FILE:${_DEP}>")
- endif()
- endif()
- else()
- list(APPEND _STANDARD_DEPS ${_DEP})
- endif()
- endforeach(_DEP)
-
- # Call into target_link_libraries with the lists of deps.
- if(MSVC OR APPLE)
- target_link_libraries(${TARGET}
- PUBLIC
- ${_ALWAYS_LINK_DEPS}
- ${_STANDARD_DEPS}
- )
- else()
- target_link_libraries(${TARGET}
- PUBLIC
- "-Wl,--whole-archive"
- ${_ALWAYS_LINK_DEPS}
- "-Wl,--no-whole-archive"
- ${_STANDARD_DEPS}
- )
- endif()
-endfunction()
-
diff --git a/build_tools/third_party/flatcc/BUILD.overlay b/build_tools/third_party/flatcc/BUILD.overlay
index 3481995..729d57c 100644
--- a/build_tools/third_party/flatcc/BUILD.overlay
+++ b/build_tools/third_party/flatcc/BUILD.overlay
@@ -14,13 +14,14 @@
package(default_visibility = ["//visibility:public"])
-# NOTE: we exclude JSON parsing/printing to avoid additional dependencies.
cc_library(
name = "runtime",
srcs = [
"config/config.h",
"src/runtime/builder.c",
"src/runtime/emitter.c",
+ "src/runtime/json_parser.c",
+ "src/runtime/json_printer.c",
"src/runtime/refmap.c",
"src/runtime/verifier.c",
],
@@ -45,10 +46,19 @@
"include/flatcc/reflection/flatbuffers_common_builder.h",
"include/flatcc/reflection/flatbuffers_common_reader.h",
] + glob(["include/flatcc/portable/**/*.h"]),
- copts = [
- "-Iexternal/com_github_dvidelabs_flatcc/config/",
- "-Iexternal/com_github_dvidelabs_flatcc/include/",
+ textual_hdrs = [
+ "include/flatcc/flatcc_json_parser.h",
+ "include/flatcc/flatcc_json_printer.h",
],
+ copts = [
+ "-Iexternal/com_github_dvidelabs_flatcc/config/",
+ "-Iexternal/com_github_dvidelabs_flatcc/include/",
+ ] + select({
+ "@bazel_tools//src/conditions:windows": [],
+ "//conditions:default": [
+ "-Wno-implicit-fallthrough",
+ ],
+ }),
includes = [
"include/",
],
diff --git a/build_tools/third_party/flatcc/CMakeLists.txt b/build_tools/third_party/flatcc/CMakeLists.txt
index 82b1cb2..f1ad43e 100644
--- a/build_tools/third_party/flatcc/CMakeLists.txt
+++ b/build_tools/third_party/flatcc/CMakeLists.txt
@@ -27,6 +27,8 @@
SRCS
"src/runtime/builder.c"
"src/runtime/emitter.c"
+ "src/runtime/json_parser.c"
+ "src/runtime/json_printer.c"
"src/runtime/refmap.c"
"src/runtime/verifier.c"
HDRS
@@ -40,6 +42,8 @@
"include/flatcc/flatcc_flatbuffers.h"
"include/flatcc/flatcc_identifier.h"
"include/flatcc/flatcc_iov.h"
+ "include/flatcc/flatcc_json_parser.h"
+ "include/flatcc/flatcc_json_printer.h"
"include/flatcc/flatcc_portable.h"
"include/flatcc/flatcc_prologue.h"
"include/flatcc/flatcc_refmap.h"
diff --git a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
index 4b9df5e..4d270b9 100644
--- a/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
+++ b/build_tools/third_party/tensorflow/tensorflow/compiler/mlir/hlo/CMakeLists.txt
@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set(TF_MLIR_HLO_SRC_ROOT
- "${IREE_ROOT_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/"
+set(TF_MLIR_HLO_SOURCE_DIR
+ "${IREE_SOURCE_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/"
+)
+set(TF_MLIR_HLO_BINARY_DIR
+ "${IREE_BINARY_DIR}/third_party/tensorflow/tensorflow/compiler/mlir/hlo/"
)
external_cc_library(
@@ -22,7 +25,7 @@
NAME
mlir_hlo
ROOT
- ${TF_MLIR_HLO_SRC_ROOT}
+ ${TF_MLIR_HLO_SOURCE_DIR}
DEPS
MhloDialect
MhloInferFusibilityOpInterface
@@ -32,5 +35,10 @@
MhloPasses
MhloLhloToLinalg
MLIRMhloUtils
+ INCLUDES
+ "${TF_MLIR_HLO_SOURCE_DIR}/"
+ "${TF_MLIR_HLO_SOURCE_DIR}/include/"
+ "${TF_MLIR_HLO_BINARY_DIR}/"
+ "${TF_MLIR_HLO_BINARY_DIR}/include/"
PUBLIC
)
diff --git a/build_tools/third_party/vulkan_headers/BUILD.overlay b/build_tools/third_party/vulkan_headers/BUILD.overlay
index 3667e68..06aa1ea 100644
--- a/build_tools/third_party/vulkan_headers/BUILD.overlay
+++ b/build_tools/third_party/vulkan_headers/BUILD.overlay
@@ -14,18 +14,6 @@
package(default_visibility = ["//visibility:public"])
-# Exports all headers but defining VK_NO_PROTOTYPES to disable the
-# inclusion of C function prototypes. Useful if dynamically loading
-# all symbols via dlopen/etc.
-# Not all headers are hermetic, so they are just included as textual
-# headers to disable additional validation.
-cc_library(
- name = "vulkan_headers_no_prototypes",
- defines = ["VK_NO_PROTOTYPES"],
- includes = ["include"],
- textual_hdrs = glob(["include/vulkan/*.h"]),
-)
-
# Exports all headers, including C function prototypes. Useful if statically
# linking against the Vulkan SDK.
# Not all headers are hermetic, so they are just included as textual
diff --git a/build_tools/third_party/vulkan_memory_allocator/CMakeLists.txt b/build_tools/third_party/vulkan_memory_allocator/CMakeLists.txt
new file mode 100644
index 0000000..5e8a120
--- /dev/null
+++ b/build_tools/third_party/vulkan_memory_allocator/CMakeLists.txt
@@ -0,0 +1,28 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(VMA_ROOT "${IREE_ROOT_DIR}/third_party/vulkan_memory_allocator/")
+
+external_cc_library(
+ PACKAGE
+ vulkan_memory_allocator
+ NAME
+ vulkan_memory_allocator
+ ROOT
+ ${VMA_ROOT}
+ HDRS
+ "src/vk_mem_alloc.h"
+ INCLUDES
+ "${VMA_ROOT}/src/"
+)
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index f378bea..00cccb2 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Tests for end-to-end IREE support starting from the XLA HLO dialect.
+# Tests for ModelBuilder.
load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//iree:build_defs.oss.bzl", "IREE_DRIVER_MODULES", "PLATFORM_VULKAN_DEPS")
package(
default_visibility = ["//visibility:public"],
@@ -109,13 +108,12 @@
deps = [
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
- "//iree/base:initializer",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:SPIRVDialect",
"@llvm-project//mlir:mlir_runner_utils",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
cc_binary(
@@ -125,7 +123,6 @@
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
"//experimental/ModelBuilder:VulkanLaunchWrapper",
- "//iree/base:initializer",
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
@@ -149,7 +146,7 @@
"@llvm-project//mlir:VectorToLLVM",
# mlir_runner_utils with iostream needed for printMemRef atm
"@llvm-project//mlir:mlir_runner_utils",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
cc_binary(
@@ -162,7 +159,6 @@
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
"//experimental/ModelBuilder:VulkanLaunchWrapper",
- "//iree/base:initializer",
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
@@ -182,7 +178,7 @@
"@llvm-project//mlir:mlir_c_runner_utils",
# mlir_runner_utils with iostream needed for printMemRef atm
"@llvm-project//mlir:mlir_runner_utils",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
cc_binary(
@@ -195,7 +191,6 @@
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
"//experimental/ModelBuilder:VulkanLaunchWrapper",
- "//iree/base:initializer",
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
@@ -215,7 +210,7 @@
"@llvm-project//mlir:mlir_c_runner_utils",
# mlir_runner_utils with iostream needed for printMemRef atm
"@llvm-project//mlir:mlir_runner_utils",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
cc_binary(
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
index 68e7910..18493fa 100644
--- a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
@@ -16,7 +16,6 @@
#include "experimental/ModelBuilder/ModelBuilder.h"
#include "experimental/ModelBuilder/ModelRunner.h"
#include "experimental/ModelBuilder/VulkanWrapperPass.h"
-#include "iree/base/initializer.h"
#include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
#include "iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
@@ -541,7 +540,6 @@
}
int main(int argc, char **argv) {
- iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
diff --git a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
index a8590ac..c05c296 100644
--- a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
@@ -34,7 +34,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser.h"
-#include "iree/base/initializer.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Pass/PassManager.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
@@ -164,7 +163,6 @@
}
int main(int argc, char **argv) {
- iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
diff --git a/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp b/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
index 0e3d388..be71ae5 100644
--- a/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
@@ -33,7 +33,6 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser.h"
#include "mlir/ExecutionEngine/RunnerUtils.h"
-#include "iree/base/initializer.h"
static llvm::cl::opt<std::string> vulkanWrapper(
"vulkan-wrapper", llvm::cl::desc("Vulkan wrapper library"),
@@ -118,7 +117,6 @@
}
int main(int argc, char **argv) {
- iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
diff --git a/experimental/ModelBuilder/test/TestVectorToGPU.cpp b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
index a269ca9..e8d778e 100644
--- a/experimental/ModelBuilder/test/TestVectorToGPU.cpp
+++ b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
@@ -35,7 +35,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser.h"
-#include "iree/base/initializer.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Pass/PassManager.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
@@ -149,7 +148,6 @@
}
int main(int argc, char **argv) {
- iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
index 4ec9d66..fb55501 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
@@ -19,6 +19,7 @@
import tempfile
from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+from absl import flags
from absl import logging
import numpy as np
from pyiree import rt
@@ -26,6 +27,18 @@
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
+flags.DEFINE_bool(
+ "capture_crash_reproducer", True,
+ "Captures MLIR crash reproducers in the artifacts directory for crashes "
+ "and suppresses C++ stack traces.")
+
+FLAGS = flags.FLAGS
+
+
+def _running_bazel_test() -> bool:
+ # Bazel guarantees that TEST_TMPDIR is set when `bazel test` is running.
+ return "TEST_TMPDIR" in os.environ
+
def _setup_mlir_crash_reproducer(
function: Any, # pytype doesn't support arbitrary Callable[*args, **kwargs]
@@ -149,8 +162,13 @@
return _incrementally_lower_compiler_module(compiler_module, backend_info,
artifacts_dir)
- _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
- backend_info.backend_id)
+ # Avoid the crash reproducer under tests or if the flag is false.
+ # Developers can run tests outside of the test runner (e.g. `bazel run`) to
+ # use the crash reproducer.
+ if (FLAGS.capture_crash_reproducer and not _running_bazel_test()):
+ _compile_module = _setup_mlir_crash_reproducer(_compile_module,
+ artifacts_dir,
+ backend_info.backend_id)
return _compile_module(module, backend_info, exported_names, artifacts_dir)
@@ -188,8 +206,13 @@
return _incrementally_lower_compiler_module(compiler_module, backend_info,
artifacts_dir)
- _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
- backend_info.backend_id)
+ # Avoid the crash reproducer under tests or if the flag is false.
+ # Developers can run tests outside of the test runner (e.g. `bazel run`) to
+ # use the crash reproducer.
+ if (FLAGS.capture_crash_reproducer and not _running_bazel_test()):
+ _compile_module = _setup_mlir_crash_reproducer(_compile_module,
+ artifacts_dir,
+ backend_info.backend_id)
return _compile_module(saved_model_dir, saved_model_tags, backend_info,
exported_name, artifacts_dir)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index e2a67b9..313be88 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -50,9 +50,6 @@
"Summarize the inputs and outputs of each module trace logged to disk.")
flags.DEFINE_bool("log_all_traces", False,
"Log all traces to logging.info, even if comparison passes.")
-flags.DEFINE_bool(
- "get_saved_model", False,
- "Creates and stores a SavedModel for the tf.Module class to be tested.")
FLAGS = flags.FLAGS
DEFAULT_INPUT_GENERATOR = tf_utils.uniform
@@ -432,6 +429,7 @@
atol: float = None,
rtol: float = None,
name: str = None,
+ static_signature: Sequence[tf.TensorSpec] = None,
**tf_function_kwargs):
"""Creates a tf.function that can be used to generate unit_tests.
@@ -453,6 +451,10 @@
name:
optional, the name to reference this function with. Must be used if
decorating a lambda.
+ static_signature:
+ optional, a signature with the same structure as 'input_signature'. Used
+ to specify the correct shape for data generation when dynamic dims are
+ provided.
Raises:
ValueError: if 'input_generator' and 'input_args' are both specified.
@@ -481,17 +483,22 @@
global _global_unit_test_configs
if function.__name__ not in _global_unit_test_configs:
+ if static_signature is not None:
+ signature = static_signature
+ else:
+ signature = function.input_signature
+
if input_generator is not None:
# Use the user-specificed input_generator.
get_trace_args = lambda: tf_utils.generate_inputs(
- function.input_signature, input_generator)
+ signature, input_generator)
elif input_args is not None:
# Use the user-specified input_args.
get_trace_args = lambda: copy.deepcopy(input_args)
else:
# No user data specification – default to using random uniform data.
get_trace_args = lambda: tf_utils.generate_inputs(
- function.input_signature, DEFAULT_INPUT_GENERATOR)
+ signature, DEFAULT_INPUT_GENERATOR)
_global_unit_test_configs[function.__name__] = dict(
get_trace_args=get_trace_args,
diff --git a/integrations/tensorflow/compiler/CheckNoTF.cpp b/integrations/tensorflow/compiler/CheckNoTF.cpp
index 246f28a..b69a1ac 100644
--- a/integrations/tensorflow/compiler/CheckNoTF.cpp
+++ b/integrations/tensorflow/compiler/CheckNoTF.cpp
@@ -13,13 +13,13 @@
// limitations under the License.
#include "llvm/Support/FormatVariadic.h"
+#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
diff --git a/integrations/tensorflow/compiler/LegalizeTF.cpp b/integrations/tensorflow/compiler/LegalizeTF.cpp
index e453436..edff0d2 100644
--- a/integrations/tensorflow/compiler/LegalizeTF.cpp
+++ b/integrations/tensorflow/compiler/LegalizeTF.cpp
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index d74b597..7fa4ff7 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -13,8 +13,8 @@
# limitations under the License.
# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. A coverage table generated from this file can be viewed here:
-# https://google.github.io/iree/tf-e2e-coverage
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage
# Updates made to test suite names should also be reflected here:
# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 468eb79..6ecd6b5 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -13,8 +13,8 @@
# limitations under the License.
# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. A coverage table generated from this file can be viewed here:
-# https://google.github.io/iree/tf-e2e-coverage
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage
# Updates made to test suite names should also be reflected here:
# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
diff --git a/integrations/tensorflow/e2e/keras/layers/BUILD b/integrations/tensorflow/e2e/keras/layers/BUILD
index 75dbeda..759741e 100644
--- a/integrations/tensorflow/e2e/keras/layers/BUILD
+++ b/integrations/tensorflow/e2e/keras/layers/BUILD
@@ -13,8 +13,8 @@
# limitations under the License.
# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. A coverage table generated from this file can be viewed here:
-# https://google.github.io/iree/tf-e2e-coverage
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage/tf-keras-coverage
# Updates made to test suite names should also be reflected here:
# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
@@ -87,7 +87,6 @@
"Embedding",
"Flatten",
"GRU",
- "GRUCell",
"GaussianDropout",
"GaussianNoise",
"GlobalAveragePooling1D",
@@ -98,7 +97,6 @@
"GlobalMaxPool3D",
"InputLayer",
"LSTM",
- "LSTMCell",
"Lambda",
"LayerNormalization",
"LeakyReLU",
@@ -120,7 +118,6 @@
"SeparableConv1D",
"SeparableConv2D",
# "SimpleRNN", # TODO(meadowlark): Debug flakiness.
- "SimpleRNNCell",
"Softmax",
"SpatialDropout1D",
"SpatialDropout2D",
@@ -137,14 +134,6 @@
FAILING_STATIC = [
{
- # Wrapping these in a tf.function appears to cause a keras bug.
- "layer": [
- "GRUCell",
- "LSTMCell",
- "SimpleRNNCell",
- ],
- },
- {
# Failing on TFLite
"layer": [
"AveragePooling3D",
@@ -152,6 +141,7 @@
"Conv3D",
"ConvLSTM2D",
"LayerNormalization",
+ "Softmax",
"MaxPool3D",
"ZeroPadding3D",
],
@@ -161,10 +151,11 @@
# Failing on IREE
"layer": [
"ConvLSTM2D",
+ "GRU",
+ "LSTM", # Failing unless 'return_sequences = True'
"LayerNormalization",
"LeakyReLU",
"LocallyConnected2D",
- "Masking",
"MultiHeadAttention",
"UpSampling2D",
],
@@ -186,6 +177,7 @@
# Failing on LLVM and Vulkan
"layer": [
"Lambda",
+ "Masking",
"MaxPool1D",
"MaxPool2D",
"MaxPool3D",
@@ -198,11 +190,11 @@
{
# Failing on Vulkan
"layer": [
+ "Attention",
+ "AdditiveAttention",
"AveragePooling1D",
"AveragePooling2D",
"AveragePooling3D",
- "GRU",
- "LSTM", # TODO(silvasean): Get this test working on Vulkan.
"ThresholdedReLU",
],
"target_backends": "iree_vulkan",
@@ -216,9 +208,9 @@
flags_to_values = {
"reference_backend": "tf",
"layer": LAYERS,
- "dynamic_batch": False,
+ "dynamic_dims": False,
"training": False,
- "test_full_api": False,
+ "test_default_kwargs_only": True,
"target_backends": [
"tf",
"tflite",
@@ -237,12 +229,14 @@
# bazel run integrations/tensorflow/e2e/keras/layers:layers_test_manual -- \
# --list_layers_with_full_api_tests
LAYERS_WITH_FULL_API_TESTS = [
+ "ActivityRegularization",
"AdditiveAttention",
"Attention",
"AveragePooling1D",
"AveragePooling2D",
"AveragePooling3D",
"BatchNormalization",
+ "Concatenate",
"Conv1D",
"Conv1DTranspose",
"Conv2D",
@@ -250,22 +244,20 @@
"Conv3D",
"Conv3DTranspose",
# "ConvLSTM2D", # TODO(meadowlark): Debug flakiness.
+ "Cropping1D",
+ "Cropping2D",
+ "Cropping3D",
"DepthwiseConv2D",
- "GlobalAveragePooling1D",
- "GlobalAveragePooling2D",
- "GlobalAveragePooling3D",
- "GlobalMaxPool1D",
- "GlobalMaxPool2D",
- "GlobalMaxPool3D",
"GRU",
+ "LSTM",
"LocallyConnected1D",
"LocallyConnected2D",
- "LSTM",
"MaxPool1D",
"MaxPool2D",
"MaxPool3D",
"SeparableConv1D",
"SeparableConv2D",
+ "SimpleRNN",
# "SimpleRNN", # TODO(meadowlark): Debug flakiness.
]
@@ -274,8 +266,6 @@
# Failing on TFLite
"layer": [
"AveragePooling3D",
- "Conv1D",
- "Conv2D",
"Conv2DTranspose",
"Conv3D",
"Conv3DTranspose",
@@ -288,6 +278,7 @@
"MaxPool1D",
"MaxPool3D",
"SeparableConv1D", # Failing on Kokoro.
+ "SeparableConv2D",
"SimpleRNN",
],
"target_backends": "tflite",
@@ -295,19 +286,13 @@
{
# Failing on IREE
"layer": [
- "Conv1D",
- "Conv2D",
"Conv2DTranspose",
"Conv3DTranspose",
- "Conv3D",
"ConvLSTM2D",
- "DepthwiseConv2D",
"GRU",
"LocallyConnected1D",
"LocallyConnected2D",
"LSTM",
- "SeparableConv1D",
- "SeparableConv2D",
"SimpleRNN",
],
"target_backends": [
@@ -317,6 +302,10 @@
],
},
{
+ "layer": "Conv3D",
+ "target_backends": "iree_vmla",
+ },
+ {
# Failing on LLVM and Vulakn
"layer": [
"AdditiveAttention",
@@ -343,9 +332,9 @@
flags_to_values = {
"reference_backend": "tf",
"layer": LAYERS_WITH_FULL_API_TESTS,
- "dynamic_batch": False,
+ "dynamic_dims": False,
"training": False,
- "test_full_api": True,
+ "test_default_kwargs_only": False,
"target_backends": [
"tf",
"tflite",
@@ -362,14 +351,6 @@
FAILING_DYNAMIC = [
{
- # Wrapping these in a tf.function appears to cause a keras bug.
- "layer": [
- "GRUCell",
- "LSTMCell",
- "SimpleRNNCell",
- ],
- },
- {
# TFLite does not support dynamic shapes.
"target_backends": "tflite",
},
@@ -380,6 +361,7 @@
"AveragePooling1D",
"AveragePooling2D",
"AveragePooling3D",
+ "BatchNormalization",
"Concatenate",
"Conv1D",
"Conv1DTranspose",
@@ -391,7 +373,9 @@
"Cropping1D",
"Cropping2D",
"Cropping3D",
+ "Dense",
"DepthwiseConv2D",
+ "Dot",
"ELU",
"Flatten",
"GRU",
@@ -431,9 +415,6 @@
"Add",
"Attention",
"Average",
- "BatchNormalization",
- "Dense",
- "Dot",
"GlobalAveragePooling1D",
"GlobalAveragePooling2D",
"GlobalAveragePooling3D",
@@ -459,15 +440,15 @@
]
iree_e2e_cartesian_product_test_suite(
- name = "layers_dynamic_batch_tests",
+ name = "layers_dynamic_dims_tests",
srcs = ["layers_test.py"],
failing_configurations = FAILING_DYNAMIC,
flags_to_values = {
"reference_backend": "tf",
"layer": LAYERS,
- "dynamic_batch": True,
+ "dynamic_dims": True,
"training": False,
- "test_full_api": False,
+ "test_default_kwargs_only": True,
"target_backends": [
"tf",
"tflite",
@@ -491,14 +472,11 @@
# "ConvLSTM2D", # TODO(meadowlark): Debug flakiness.
"Dropout",
"GRU",
- "GRUCell",
"GaussianDropout",
"GaussianNoise",
"LSTM",
- "LSTMCell",
"MultiHeadAttention",
# "SimpleRNN", # TODO(meadowlark): Debug flakiness.
- "SimpleRNNCell",
"SpatialDropout1D",
"SpatialDropout2D",
"SpatialDropout3D",
@@ -506,14 +484,6 @@
FAILING_TRAINING = [
{
- # Wrapping these in a tf.function appears to cause a keras bug.
- "layer": [
- "GRUCell",
- "LSTMCell",
- "SimpleRNNCell",
- ],
- },
- {
# Failing on TFLite:
"layer": [
"AlphaDropout",
@@ -533,6 +503,7 @@
"AdditiveAttention",
"AlphaDropout",
"Attention",
+ "BatchNormalization",
"ConvLSTM2D",
"Dropout",
"GaussianDropout",
@@ -560,9 +531,9 @@
flags_to_values = {
"reference_backend": "tf",
"layer": LAYERS_WITH_TRAINING_BEHAVIOR,
- "dynamic_batch": False,
+ "dynamic_dims": False,
"training": True,
- "test_full_api": False,
+ "test_default_kwargs_only": True,
"target_backends": [
"tf",
"tflite",
diff --git a/integrations/tensorflow/e2e/keras/layers/layers_test.py b/integrations/tensorflow/e2e/keras/layers/layers_test.py
index 4755450..45a7d95 100644
--- a/integrations/tensorflow/e2e/keras/layers/layers_test.py
+++ b/integrations/tensorflow/e2e/keras/layers/layers_test.py
@@ -16,11 +16,13 @@
import collections
import copy
+import inspect
import os
-from typing import Any, Dict, Sequence, Union
+from typing import Any, Dict, List, Sequence, Tuple, Union
from absl import app
from absl import flags
+from absl import logging
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -28,516 +30,448 @@
FLAGS = flags.FLAGS
DROPOUT = 0.5
-DIM = 4
-RANK_2_INPUT = [DIM] * 2
-RANK_3_INPUT = [DIM] * 3
-RANK_4_INPUT = [DIM] * 4
+CONV_FILTERS = 2
+CONV_KERNEL_SIZE = 2
+DIM = 3
-CONV_1D_INPUT = [2, 8, 3]
-CONV_2D_INPUT = [2, 8, 8, 3]
-CONV_3D_INPUT = [2, 8, 8, 8, 3]
+# Used for attention layers and recurrent layers.
+RANK_3_SHAPE = [DIM] * 3
+# Highest rank that tf.keras will allow for all layers.
+RANK_5_SHAPE = [DIM] * 5
-# Configs are namedtuples storing keyword arguments and shapes to test a
-# tf.keras.layers.Layer with. They are used in two ways:
-# 1. To directly specify the kwargs and shapes for a layers test.
-# 2. In 'generate_configs', to specify how to change a default config to
-# specify a non-default test. In this case, the overriding Config will
-# exclusively specify the shape of the test if its shape is not None, and
-# the overriding Config will extend/update the kwargs of the default
-# Config.
-Config = collections.namedtuple('Config', ['kwargs', 'shapes'])
-# Use old default API for compatibility with Python 3.6.
-Config.__new__.__defaults__ = (dict(), None)
+UNARY_SIGNATURE_SHAPES = [[RANK_5_SHAPE]]
+BINARY_SIGNATURE_SHAPES = [[RANK_5_SHAPE] * 2]
+TERNARY_SIGNATURE_SHAPES = [[RANK_5_SHAPE] * 3]
+CONV_1D_SIGNATURE_SHAPES = [[[2, 8, 3]]]
+CONV_2D_SIGNATURE_SHAPES = [[[2, 8, 8, 3]]]
+CONV_3D_SIGNATURE_SHAPES = [[[2, 8, 8, 8, 3]]]
-def generate_configs(default_config: Config,
- override_configs: Dict[str, Config]) -> Dict[str, Config]:
- """Generates a dict of 'Config's based off changes to a default Config."""
- configs = {'default': default_config}
- for exported_name, config in override_configs.items():
- shapes = default_config.shapes if config.shapes is None else config.shapes
+RNN_SIGNATURE_SHAPES = [[RANK_3_SHAPE]]
+RNN_KWARGS_TO_VALUES = dict(units=[4],
+ return_sequences=[False, True],
+ stateful=[False, True])
- # Deep copy to avoid inplace mutation of the default.
- kwargs = copy.deepcopy(default_config.kwargs)
- kwargs.update(config.kwargs) # Adds new and overwrites old kwargs.
+POOLING_KWARGS_TO_VALUES = dict(strides=[None, 2],
+ padding=["valid", "same"],
+ data_format=[None, "channels_first"])
+CONV_KWARGS_TO_VALUES = dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 2],
+ padding=["valid", "same"],
+ data_format=[None, "channels_first"],
+ dilation_rate=[1, 2])
+# Address pooling and conv layers having different default values for
+# 'data_format' for 1D layers.
+POOLING_1D_KWARGS_TO_VALUES = copy.deepcopy(POOLING_KWARGS_TO_VALUES)
+POOLING_1D_KWARGS_TO_VALUES.update(
+ {"data_format": ["channels_last", "channels_first"]})
+CONV_1D_KWARGS_TO_VALUES = copy.deepcopy(CONV_KWARGS_TO_VALUES)
+CONV_1D_KWARGS_TO_VALUES.update(
+ {"data_format": ["channels_last", "channels_first"]})
- configs[exported_name] = Config(kwargs, shapes)
- return configs
-
-
-# A dict mapping tf.keras.layers names to either a single Config (representing
-# the kwargs and shapes to use to test a Layer) or a dict mapping exported_names
-# to Configs. The latter case is usually automatically generated via
-# 'generate_configs', with the 'Config's in 'override_configs' specifying how
-# to modify the 'default_config's kwargs and shapes.
-#
-# Each entry will be normalized to be a dict mapping exported_names to Configs,
-# with a default exported_name of 'default'.
-LAYER_TO_UNITTEST_CONFIGURATIONS = {
- 'Activation':
- Config(dict(activation='relu'), [RANK_2_INPUT]),
- 'ActivityRegularization':
- Config(dict(l1=0.1, l2=0.1), shapes=[RANK_2_INPUT]),
- 'Add':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'AdditiveAttention':
- generate_configs(
- default_config=Config(
- shapes=[RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],),
- override_configs={
- 'causal': Config(dict(causal=True)),
- },
- ),
- 'AlphaDropout':
- Config(dict(rate=DROPOUT), [RANK_2_INPUT]),
- 'Attention':
- generate_configs(
- default_config=Config(
- shapes=[RANK_3_INPUT, RANK_3_INPUT, RANK_3_INPUT],),
- override_configs={
- 'causal': Config(dict(causal=True)),
- },
- ),
- 'Average':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'AveragePooling1D':
- generate_configs(
- default_config=Config(shapes=[CONV_1D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'AveragePooling2D':
- generate_configs(
- default_config=Config(shapes=[CONV_2D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Default AvgPoolingOp only supports NHWC on device type CPU
- # 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'AveragePooling3D':
- generate_configs(
- default_config=Config(shapes=[CONV_3D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'BatchNormalization':
- generate_configs(
- default_config=Config(shapes=[RANK_2_INPUT]),
- override_configs={'renorm': Config(dict(renorm=True))},
- ),
- 'Concatenate':
- Config(shapes=[RANK_4_INPUT, RANK_4_INPUT]),
- 'Conv1D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_1D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: The Conv2D op currently only supports the NHWC tensor
- # format on the CPU.
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv1DTranspose':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_1D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Conv2DCustomBackpropInputOp only supports NHWC
- # 'channels_first': Config(dict(data_format='channels_first')),
- # TF: Current libxsmm and customized CPU implementations do not
- # yet support dilation rates larger than 1.
- # 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: The Conv2D op currently only supports the NHWC tensor
- # format on the CPU.
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv2DTranspose':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv3D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_3D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: The Conv3D op currently only supports the NHWC tensor
- # format on the CPU.
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'Conv3DTranspose':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_3D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- },
- ),
- 'ConvLSTM2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3, return_state=True),
- shapes=[CONV_3D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'dilation_rate': Config(dict(dilation_rate=3)),
- 'go_backwards': Config(dict(go_backwards=True)),
- 'stateful': Config(dict(stateful=True)),
- },
- ),
- 'Cropping1D':
- Config(dict(cropping=2), [CONV_1D_INPUT]),
- 'Cropping2D':
- Config(dict(cropping=2), [CONV_2D_INPUT]),
- 'Cropping3D':
- Config(dict(cropping=2), [CONV_3D_INPUT]),
- 'Dense':
- Config(dict(units=4), [RANK_2_INPUT]),
- 'DepthwiseConv2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'depth_multiplier': Config(dict(depth_multiplier=2)),
- 'dilation_rate': Config(dict(dilation_rate=2)),
- },
- ),
- 'Dot':
- Config(dict(axes=(1, 2)), [RANK_3_INPUT, RANK_3_INPUT]),
- 'Dropout':
- Config(dict(rate=DROPOUT), [RANK_3_INPUT]),
- 'ELU':
- Config(shapes=[RANK_2_INPUT]),
- 'Embedding':
- Config(dict(input_dim=4, output_dim=2), [RANK_2_INPUT]),
- 'Flatten':
- Config(shapes=[RANK_2_INPUT]),
- 'GRU':
- generate_configs(
- default_config=Config(
- kwargs=dict(units=4, return_sequences=True),
- shapes=[RANK_3_INPUT],
- ),
- override_configs={
- 'implementation_1': Config(dict(implementation=1)),
- 'go_backwards': Config(dict(go_backwards=True)),
- 'time_major': Config(dict(time_major=True)),
- 'stateful': Config(dict(stateful=True)),
- },
- ),
- 'GRUCell':
- Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
- 'GaussianDropout':
- Config(dict(rate=DROPOUT), [RANK_2_INPUT]),
- 'GaussianNoise':
- Config(dict(stddev=1.0), [RANK_2_INPUT]),
- 'GlobalAveragePooling1D':
- generate_configs(
- default_config=Config(shapes=[CONV_1D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalAveragePooling2D':
- generate_configs(
- default_config=Config(shapes=[CONV_2D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalAveragePooling3D':
- generate_configs(
- default_config=Config(shapes=[CONV_3D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalMaxPool1D':
- generate_configs(
- default_config=Config(shapes=[CONV_1D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalMaxPool2D':
- generate_configs(
- default_config=Config(shapes=[CONV_2D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'GlobalMaxPool3D':
- generate_configs(
- default_config=Config(shapes=[CONV_3D_INPUT]),
- override_configs={
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'InputLayer':
- Config(shapes=[RANK_2_INPUT]),
- 'LSTM':
- generate_configs(
- default_config=Config(
- kwargs=dict(units=4, return_sequences=True),
- shapes=[RANK_3_INPUT],
- ),
- override_configs={
- 'implementation_1': Config(dict(implementation=1)),
- 'go_backwards': Config(dict(go_backwards=True)),
- 'time_major': Config(dict(time_major=True)),
- 'stateful': Config(dict(stateful=True)),
- },
- ),
- 'LSTMCell':
- Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
- 'Lambda':
- Config(dict(function=lambda x: x**2), [RANK_2_INPUT]),
- 'LayerNormalization':
- Config(shapes=[RANK_2_INPUT]),
- 'LeakyReLU':
- Config(shapes=[RANK_2_INPUT]),
- 'LocallyConnected1D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_1D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same', implementation=2)),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'sparse_implementation': Config(dict(implementation=3)),
- },
- ),
- 'LocallyConnected2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same', implementation=2)),
- 'channels_first': Config(dict(data_format='channels_first')),
- 'sparse_implementation': Config(dict(implementation=3)),
- },
- ),
- 'Masking':
- Config(shapes=[RANK_2_INPUT]),
- 'MaxPool1D':
- generate_configs(
- default_config=Config(shapes=[CONV_1D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'MaxPool2D':
- generate_configs(
- default_config=Config(shapes=[CONV_2D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Default MaxPoolingOp only supports NHWC on device type CPU
- # 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'MaxPool3D':
- generate_configs(
- default_config=Config(shapes=[CONV_3D_INPUT]),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- 'channels_first': Config(dict(data_format='channels_first')),
- },
- ),
- 'Maximum':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'Minimum':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'MultiHeadAttention':
- Config(dict(num_heads=2, key_dim=3), [RANK_3_INPUT, RANK_3_INPUT]),
- 'Multiply':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'PReLU':
- Config(shapes=[RANK_2_INPUT]),
- 'Permute':
- Config(dict(dims=(3, 1, 2)), [RANK_4_INPUT]),
- 'ReLU':
- Config(shapes=[RANK_2_INPUT]),
- 'RepeatVector':
- Config(dict(n=3), [RANK_2_INPUT]),
- 'Reshape':
- Config(dict(target_shape=[1, 1, 1] + RANK_3_INPUT[1:]), [RANK_3_INPUT]),
- 'SeparableConv1D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_1D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Depthwise convolution on CPU is only supported for NHWC
- # format
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'depth_multiplier': Config(dict(depth_multiplier=2)),
- 'dilation_rate': Config(dict(dilation_rate=2)),
- },
- ),
- 'SeparableConv2D':
- generate_configs(
- default_config=Config(
- kwargs=dict(filters=4, kernel_size=3),
- shapes=[CONV_2D_INPUT],
- ),
- override_configs={
- 'strides': Config(dict(strides=3)),
- 'padding_same': Config(dict(padding='same')),
- # TF: Depthwise convolution on CPU is only supported for NHWC
- # format
- # 'channels_first': Config(dict(data_format='channels_first')),
- 'depth_multiplier': Config(dict(depth_multiplier=2)),
- 'dilation_rate': Config(dict(dilation_rate=2)),
- },
- ),
- 'SimpleRNN':
- generate_configs(
- default_config=Config(
- kwargs=dict(units=4, return_sequences=True),
- shapes=[RANK_3_INPUT],
- ),
- override_configs={
- 'go_backwards': Config(dict(go_backwards=True)),
- 'stateful': Config(dict(stateful=True)),
- },
- ),
- 'SimpleRNNCell':
- Config(dict(units=4), [RANK_2_INPUT, RANK_2_INPUT]),
- 'Softmax':
- Config(shapes=[RANK_2_INPUT]),
- 'SpatialDropout1D':
- Config(dict(rate=DROPOUT), [CONV_1D_INPUT]),
- 'SpatialDropout2D':
- Config(dict(rate=DROPOUT), [CONV_2D_INPUT]),
- 'SpatialDropout3D':
- Config(dict(rate=DROPOUT), [CONV_3D_INPUT]),
- 'Subtract':
- Config(shapes=[RANK_2_INPUT, RANK_2_INPUT]),
- 'ThresholdedReLU':
- Config(shapes=[RANK_2_INPUT]),
- 'UpSampling1D':
- Config(shapes=[CONV_1D_INPUT]),
- 'UpSampling2D':
- Config(shapes=[CONV_2D_INPUT]),
- 'UpSampling3D':
- Config(shapes=[CONV_3D_INPUT]),
- 'ZeroPadding1D':
- Config(shapes=[CONV_1D_INPUT]),
- 'ZeroPadding2D':
- Config(shapes=[CONV_2D_INPUT]),
- 'ZeroPadding3D':
- Config(shapes=[CONV_3D_INPUT]),
+# Unsupported by TensorFlow (at least on CPU).
+LAYERS_TO_TF_UNSUPPORTED_NON_DEFAULT_KWARGS = {
+ "AveragePooling2D": ["data_format"],
+ "Conv1D": ["data_format"],
+ "Conv1DTranspose": ["data_format", "dilation_rate"],
+ "Conv2D": ["data_format"],
+ "Conv3D": ["data_format"],
+ "Conv3DTranspose": ["data_format"],
+ "LocallyConnected1D": ["padding"],
+ "LocallyConnected2D": ["padding"],
+ "MaxPool2D": ["data_format"],
}
-# Normalize LAYER_TO_UNITTEST_CONFIGURATIONS
-for key, value in LAYER_TO_UNITTEST_CONFIGURATIONS.items():
- if isinstance(value, Config):
- LAYER_TO_UNITTEST_CONFIGURATIONS[key] = {'default': value}
+# Some layers have kwargs which cannot both have non-default values.
+LAYERS_TO_MUTUALLY_EXCLUSIVE_KWARGS = {
+ "Conv1D": ["strides", "dilation_rate"],
+ "Conv2D": ["strides", "dilation_rate"],
+ "Conv2DTranspose": ["strides", "dilation_rate"],
+ "Conv3D": ["strides", "dilation_rate"],
+ "ConvLSTM2D": ["strides", "dilation_rate"],
+}
+
+
+def get_default_kwargs_values(layer: str) -> Dict[str, Any]:
+ """Gets the default kwargs for a tf.keras.layers layer."""
+ layer_class = getattr(tf.keras.layers, layer)
+ layer_parameters = inspect.signature(layer_class.__init__).parameters
+ kwargs_to_default_values = {
+ kwarg: value.default
+ for kwarg, value in layer_parameters.items()
+ if value.default is not inspect.Parameter.empty
+ }
+ return kwargs_to_default_values
+
+
+def _equal_or_splat_equal(value: Any, sequence: Any) -> bool:
+ """Returns True if value==sequence or value==(every element in sequence)."""
+ if value == sequence:
+ return True
+ elif isinstance(sequence, (list, tuple)):
+ for element in sequence:
+ if not _equal_or_splat_equal(value, element):
+ return False
+ return True
+ return False
+
+
+def get_non_default_kwargs(
+ layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> List[str]:
+ """Returns all non-default optional kwargs in unit_test_spec."""
+ kwargs_to_defaults = get_default_kwargs_values(layer)
+ non_default_kwargs = []
+ for kwarg, value in unit_test_spec.kwargs.items():
+ if (kwarg in kwargs_to_defaults and
+ not _equal_or_splat_equal(value, kwargs_to_defaults[kwarg])):
+ non_default_kwargs.append(kwarg)
+ return non_default_kwargs
+
+
+def unsupported_by_tf(layer: str,
+ unit_test_spec: tf_test_utils.UnitTestSpec) -> bool:
+ """True if unit_test_spec specifies tf-unsupported non-default kwargs."""
+ if layer in LAYERS_TO_TF_UNSUPPORTED_NON_DEFAULT_KWARGS:
+ unsupported_kwargs = LAYERS_TO_TF_UNSUPPORTED_NON_DEFAULT_KWARGS[layer]
+ non_default_kwargs = get_non_default_kwargs(layer, unit_test_spec)
+ return any(kwarg in unsupported_kwargs for kwarg in non_default_kwargs)
+ return False
+
+
+def has_mutually_exclusive_kwargs(
+ layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> bool:
+ """True if unit_test_spec specifies mutually exclusive non-default kwargs."""
+ if layer in LAYERS_TO_MUTUALLY_EXCLUSIVE_KWARGS:
+ mutually_exclusive_kwargs = LAYERS_TO_MUTUALLY_EXCLUSIVE_KWARGS[layer]
+ non_default_kwargs = get_non_default_kwargs(layer, unit_test_spec)
+ return set(mutually_exclusive_kwargs).issubset(set(non_default_kwargs))
+ return False
+
+
+# A dictionary mapping tf.keras.layers names to lists of UnitTestSpecs.
+# Each unit_test_name will have the tf.keras.layer name prepended to it.
+#
+# Each layer is required to have a UnitTestSpec with all-default values for
+# unrequired kwargs. This allows us to seperately test the basic api and the
+# full api.
+LAYERS_TO_UNIT_TEST_SPECS = {
+ "Activation":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(activation=["relu"])),
+ "ActivityRegularization":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(l1=[0.0, 0.1], l2=[0.0, 0.1])),
+ "Add":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "AdditiveAttention":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[(RANK_3_SHAPE, RANK_3_SHAPE, RANK_3_SHAPE)],
+ kwargs_to_values=dict(causal=[False, True])),
+ "AlphaDropout":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "Attention":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[(RANK_3_SHAPE, RANK_3_SHAPE, RANK_3_SHAPE)],
+ kwargs_to_values=dict(causal=[False, True])),
+ "Average":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "AveragePooling1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_1D_KWARGS_TO_VALUES),
+ "AveragePooling2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_KWARGS_TO_VALUES),
+ "AveragePooling3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_KWARGS_TO_VALUES),
+ "BatchNormalization":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(renorm=[False, True])),
+ "Concatenate":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(axis=[-1, 0])),
+ "Conv1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_1D_KWARGS_TO_VALUES),
+ "Conv1DTranspose":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "Conv2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "Conv2DTranspose":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "Conv3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "Conv3DTranspose":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=CONV_KWARGS_TO_VALUES),
+ "ConvLSTM2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ return_state=[False, True],
+ strides=[1, 2],
+ dilation_rate=[1, 2],
+ stateful=[False, True])),
+ "Cropping1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(cropping=[1, (1, 2)])),
+ "Cropping2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(cropping=[0, ((1, 2), (2, 1))])),
+ "Cropping3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(cropping=[1, ((1, 2), (2, 1), (1, 0))])),
+ "Dense":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(units=[8])),
+ "DepthwiseConv2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 2],
+ padding=["valid", "same"],
+ dilation_rate=[1, 2],
+ depth_multiplier=[1, 2])),
+ "Dot":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=BINARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(axes=[(1, 2)])),
+ "Dropout":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "ELU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "Embedding":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(input_dim=[4], output_dim=[2])),
+ "Flatten":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "GRU":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=RNN_KWARGS_TO_VALUES),
+ "GaussianDropout":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "GaussianNoise":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(stddev=[1.0])),
+ "GlobalAveragePooling1D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
+ "GlobalAveragePooling2D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
+ "GlobalAveragePooling3D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
+ "GlobalMaxPool1D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
+ "GlobalMaxPool2D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
+ "GlobalMaxPool3D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
+ "InputLayer":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "LSTM":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=RNN_KWARGS_TO_VALUES),
+ "Lambda":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(function=[lambda x: x**2])),
+ "LayerNormalization":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "LeakyReLU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "LocallyConnected1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 3],
+ padding=["valid", "same"],
+ implementation=[1, 3])),
+ "LocallyConnected2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 3],
+ padding=["valid", "same"],
+ implementation=[1, 3])),
+ "Masking":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "MaxPool1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_1D_KWARGS_TO_VALUES),
+ "MaxPool2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_KWARGS_TO_VALUES),
+ "MaxPool3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=POOLING_KWARGS_TO_VALUES),
+ "Maximum":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "Minimum":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "MultiHeadAttention":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[(RANK_3_SHAPE, RANK_3_SHAPE)],
+ kwargs_to_values=dict(num_heads=[2], key_dim=[3])),
+ "Multiply":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "PReLU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "Permute":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=UNARY_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(dims=[(3, 1, 4, 2)])),
+ "ReLU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "RepeatVector":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[((2, 2),)], kwargs_to_values=dict(n=[3])),
+ "Reshape":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=[((3, 2, 2, 2),)],
+ kwargs_to_values=dict(target_shape=[(2, 1, 4, 1)])),
+ "SeparableConv1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 2],
+ padding=["valid", "same"],
+ dilation_rate=[1, 2],
+ depth_multiplier=[1, 2])),
+ "SeparableConv2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(filters=[CONV_FILTERS],
+ kernel_size=[CONV_KERNEL_SIZE],
+ strides=[1, 2],
+ padding=["valid", "same"],
+ dilation_rate=[1, 2],
+ depth_multiplier=[1, 2])),
+ "SimpleRNN":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=RNN_SIGNATURE_SHAPES,
+ kwargs_to_values=RNN_KWARGS_TO_VALUES),
+ "Softmax":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "SpatialDropout1D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_1D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "SpatialDropout2D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_2D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "SpatialDropout3D":
+ tf_test_utils.unit_test_specs_from_signatures(
+ signature_shapes=CONV_3D_SIGNATURE_SHAPES,
+ kwargs_to_values=dict(rate=[DROPOUT])),
+ "Subtract":
+ tf_test_utils.unit_test_specs_from_signatures(BINARY_SIGNATURE_SHAPES),
+ "ThresholdedReLU":
+ tf_test_utils.unit_test_specs_from_signatures(UNARY_SIGNATURE_SHAPES),
+ "UpSampling1D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
+ "UpSampling2D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
+ "UpSampling3D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
+ "ZeroPadding1D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_1D_SIGNATURE_SHAPES),
+ "ZeroPadding2D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_2D_SIGNATURE_SHAPES),
+ "ZeroPadding3D":
+ tf_test_utils.unit_test_specs_from_signatures(CONV_3D_SIGNATURE_SHAPES),
+}
+
+for layer, specs in LAYERS_TO_UNIT_TEST_SPECS.items():
+ # Update using 'with_name' to avoid updating shared UnitTestSpecs.
+ specs = [spec.with_name(f"{layer}__{spec.unit_test_name}") for spec in specs]
+ LAYERS_TO_UNIT_TEST_SPECS[layer] = specs
+
+ # Validate that there are not multiple UnitTestSpecs with the same name.
+ seen_unit_test_names = set()
+ for spec in specs:
+ if spec.unit_test_name in seen_unit_test_names:
+ raise ValueError(
+ f"Found multiple UnitTestSpecs with the name '{spec.unit_test_name}'")
+ seen_unit_test_names.add(spec.unit_test_name)
+
+ # Validate that there is one spec that has default values for all unrequired
+ # kwargs.
+ has_default_unrequired_kwargs = False
+ for spec in specs:
+ if not get_non_default_kwargs(layer, spec):
+ has_default_unrequired_kwargs = True
+
+ if not has_default_unrequired_kwargs:
+ raise ValueError(
+ f"The configuration for '{layer}' did not have a UnitTestSpec with all "
+ "default kwargs.")
# Layers that allow specifying the 'dropout' kwarg.
DROPOUT_LAYERS = [
- 'AdditiveAttention', 'Attention', 'ConvLSTM2D', 'GRU', 'GRUCell', 'LSTM',
- 'LSTMCell', 'MultiHeadAttention', 'SimpleRNN', 'SimpleRNNCell'
+ "AdditiveAttention", "Attention", "ConvLSTM2D", "GRU", "LSTM",
+ "MultiHeadAttention", "SimpleRNN"
]
-flags.DEFINE_string('layer', 'Dense',
- f'One of {list(LAYER_TO_UNITTEST_CONFIGURATIONS.keys())}.')
+flags.DEFINE_string("layer", None,
+ f"One of {list(LAYERS_TO_UNIT_TEST_SPECS.keys())}.")
flags.DEFINE_bool(
- 'dynamic_batch', False,
- 'Whether or not to compile the layer with a dynamic batch size.')
-flags.DEFINE_bool('training', False,
- 'Whether or not to compile the layer in training mode.')
+ "dynamic_dims", False,
+ "Whether or not to compile the layer with a dynamic dimension sizes.")
+flags.DEFINE_bool("training", False,
+ "Whether or not to compile the layer in training mode.")
flags.DEFINE_bool(
- 'test_full_api', False,
- 'Whether or not to test multiple layer configurations using non-required '
- 'kwargs.')
+ "test_default_kwargs_only", True,
+ "Whether or not to test multiple layer configurations using non-required "
+ "kwargs.")
flags.DEFINE_bool(
- 'list_layers_with_full_api_tests', False,
- 'Whether or not to print out all layers with non-default configurations '
- '(and skip running the tests).')
-
-
-def get_configs() -> Dict[str, Config]:
- """Gets the configs that we want to test for FLAGS.layer."""
- configs = LAYER_TO_UNITTEST_CONFIGURATIONS[FLAGS.layer]
- if not FLAGS.test_full_api:
- return {'default': configs['default']}
- return configs # pytype: disable=bad-return-type
+ "list_layers_with_full_api_tests", False,
+ "Whether or not to print out all layers with non-default configurations "
+ "(and skip running the tests).")
def get_input(shape: Sequence[int]) -> tf.keras.layers.Input:
- """Gets the input shape(s) that we want to test."""
- batch_size = None if FLAGS.dynamic_batch else shape[0]
+ """Converts a shape into a tf.keras.Input."""
+ # Most keras layers are only compatible with dynamic batch sizes.
+ batch_size = None if FLAGS.dynamic_dims else shape[0]
return tf.keras.layers.Input(batch_size=batch_size, shape=shape[1:])
@@ -551,56 +485,76 @@
return list(args) if isinstance(args, tuple) else args
-def create_wrapped_keras_layer(config: Config) -> tf.keras.Model:
+def create_wrapped_keras_layer(
+ layer: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.keras.Model:
"""Wraps a keras layer in a model for compilation."""
- layer_class = getattr(tf.keras.layers, FLAGS.layer)
+ layer_class = getattr(tf.keras.layers, layer)
- if FLAGS.training and FLAGS.layer in DROPOUT_LAYERS:
- config.kwargs['dropout'] = DROPOUT
+ kwargs = copy.deepcopy(unit_test_spec.kwargs)
+ if FLAGS.training and layer in DROPOUT_LAYERS:
+ kwargs["dropout"] = DROPOUT
- inputs = keras_input_normalizer([get_input(shape) for shape in config.shapes])
- if FLAGS.layer == 'MultiHeadAttention':
+ if "dtype" not in unit_test_spec.kwargs:
+ kwargs["dtype"] = unit_test_spec.input_signature[0].dtype
+
+ inputs = keras_input_normalizer(
+ [get_input(spec.shape) for spec in unit_test_spec.input_signature])
+ if layer == "MultiHeadAttention":
# TODO(meadowlark): Remove specialization if API changes.
- outputs = layer_class(**config.kwargs)(*inputs)
+ outputs = layer_class(**kwargs)(*inputs)
else:
- outputs = layer_class(**config.kwargs)(inputs)
+ outputs = layer_class(**kwargs)(inputs)
return tf.keras.Model(inputs, outputs)
-def create_tf_function_unit_test(config: Config, exported_name: str,
- model: tf.keras.Model) -> tf.function:
+def create_layer_unit_test(
+ model: tf.keras.Model,
+ unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
"""Wrap the model's __call__ function in a tf.function for testing."""
- input_shapes = config.shapes
- if FLAGS.dynamic_batch:
- input_shapes = [[None] + shape[1:] for shape in input_shapes]
+ static_signature = unit_test_spec.input_signature
- input_signature = [tf.TensorSpec(shape) for shape in input_shapes]
- if len(input_signature) > 1:
- input_signature = [input_signature]
+ dynamic_signature = static_signature
+ if FLAGS.dynamic_dims:
+ dynamic_signature = tf_utils.apply_function(dynamic_signature,
+ tf_utils.make_dims_dynamic)
+
+ if len(static_signature) > 1:
+ static_signature = [static_signature]
+ dynamic_signature = [dynamic_signature]
call = lambda *args: model(keras_arg_wrapper(*args), training=FLAGS.training)
- return tf_test_utils.tf_function_unit_test(input_signature=input_signature,
- name=exported_name)(call)
+ return tf_test_utils.tf_function_unit_test(
+ input_signature=dynamic_signature,
+ static_signature=static_signature,
+ input_generator=unit_test_spec.input_generator,
+ input_args=unit_test_spec.input_args,
+ name=unit_test_spec.unit_test_name)(call)
class KerasLayersModule(tf_test_utils.TestModule):
- @classmethod
- def configure_class(cls):
- """Configure each tf_function_unit_test and define it on the cls."""
- for i, (exported_name, config) in enumerate(get_configs().items()):
- model = create_wrapped_keras_layer(config)
- setattr(cls, exported_name,
- create_tf_function_unit_test(config, exported_name, model))
-
def __init__(self):
super().__init__()
self.models = []
- for i, (exported_name, config) in enumerate(get_configs().items()):
- model = create_wrapped_keras_layer(config)
+ for unit_test_spec in LAYERS_TO_UNIT_TEST_SPECS[FLAGS.layer]:
+ if (FLAGS.test_default_kwargs_only and
+ get_non_default_kwargs(FLAGS.layer, unit_test_spec)):
+ # Skip all UnitTestSpecs with non-default unrequired kwargs.
+ continue
+
+ if (unsupported_by_tf(FLAGS.layer, unit_test_spec) or
+ has_mutually_exclusive_kwargs(FLAGS.layer, unit_test_spec)):
+ # Filter out UnitTestSpecs with kwargs that TensorFlow can't run on
+ # CPU or that are mutually exclusive. This allows us to take a product
+ # like that in CONV_KWARGS_TO_VALUE and filter out the configurations
+ # lacking support for particular layers.
+ continue
+
+ model = create_wrapped_keras_layer(FLAGS.layer, unit_test_spec)
+ # IREE requires that the models are stored on the module instance.
self.models.append(model)
- setattr(self, exported_name,
- create_tf_function_unit_test(config, exported_name, model))
+ layer_unit_test = create_layer_unit_test(model, unit_test_spec)
+ setattr(self, unit_test_spec.unit_test_name, layer_unit_test)
class KerasLayersTest(tf_test_utils.TracedModuleTestCase):
@@ -614,33 +568,36 @@
def main(argv):
del argv # Unused.
- if hasattr(tf, 'enable_v2_behavior'):
+ if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
- if FLAGS.layer not in LAYER_TO_UNITTEST_CONFIGURATIONS:
- raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'.")
-
if FLAGS.list_layers_with_full_api_tests:
- for layer, configs in sorted(LAYER_TO_UNITTEST_CONFIGURATIONS.items()):
- if len(configs) > 1:
+ for layer, unit_test_specs in sorted(LAYERS_TO_UNIT_TEST_SPECS.items()):
+ if len(unit_test_specs) > 1:
print(f' "{layer}",')
return
- # Set up name for saving artifacts.
- dynamic_batch_str = 'dynamic_batch' if FLAGS.dynamic_batch else 'static_batch'
- training_str = 'training' if FLAGS.training else 'non_training'
- full_api_str = 'full_api' if FLAGS.test_full_api else 'default_api'
- settings_str = f'{full_api_str}_{dynamic_batch_str}_{training_str}'
- KerasLayersModule.__name__ = os.path.join('keras_layers', FLAGS.layer,
- settings_str)
+ if FLAGS.layer not in LAYERS_TO_UNIT_TEST_SPECS:
+ raise ValueError(f"Unrecognized layer: '{FLAGS.layer}'")
- # Use the configurations for FLAGS.layer to add the tf.functions we wish
- # to test to the KerasLayersModule, and then generate unittests for each of
- # them.
- KerasLayersModule.configure_class()
+ # Set up name for saving artifacts.
+ dynamic_str = "dynamic" if FLAGS.dynamic_dims else "static"
+ training_str = "training" if FLAGS.training else "non_training"
+ full_api_str = "default_api" if FLAGS.test_default_kwargs_only else "full_api"
+ settings_str = f"{full_api_str}_{dynamic_str}_{training_str}"
+ relative_artifacts_dir = os.path.join("tf", "keras", "layers", FLAGS.layer,
+ settings_str)
+ # The relative artifacts directory path is calculated from the module name
+ # TODO(meadowlark): provide a better way of overridding this default.
+ KerasLayersModule.__name__ = relative_artifacts_dir
+
+ unit_tests = KerasLayersModule.get_tf_function_unit_tests()
+ logging.info("Testing the following %s functions: %s", len(unit_tests),
+ unit_tests)
+
KerasLayersTest.generate_unit_tests(KerasLayersModule)
tf.test.main()
-if __name__ == '__main__':
+if __name__ == "__main__":
app.run(main)
diff --git a/integrations/tensorflow/e2e/math/BUILD b/integrations/tensorflow/e2e/math/BUILD
index cdeed23..1a1c46d 100644
--- a/integrations/tensorflow/e2e/math/BUILD
+++ b/integrations/tensorflow/e2e/math/BUILD
@@ -13,8 +13,8 @@
# limitations under the License.
# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. A coverage table generated from this file can be viewed here:
-# https://google.github.io/iree/tf-e2e-coverage
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage/tf-base-coverage
# Updates made to test suite names should also be reflected here:
# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
index 881aff3..f871fe2 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/BUILD
+++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -13,8 +13,8 @@
# limitations under the License.
# Test coverage across backends for e2e tests is defined directly in the BUILD
-# files. A coverage table generated from this file can be viewed here:
-# https://google.github.io/iree/tf-e2e-coverage
+# files. Coverage tables generated from this file can be viewed here:
+# https://google.github.io/iree/tensorflow-coverage/vision-coverage
# Updates made to test suite names should also be reflected here:
# https://github.com/google/iree/blob/main/scripts/update_e2e_coverage.py
diff --git a/iree/base/BUILD b/iree/base/BUILD
index ead6d94..e6e634d 100644
--- a/iree/base/BUILD
+++ b/iree/base/BUILD
@@ -14,7 +14,7 @@
# Common types and utilities used in the IREE codebase.
-load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
+load("//build_tools/bazel:iree_flatcc.bzl", "iree_flatbuffer_c_library")
package(
default_visibility = ["//visibility:public"],
@@ -22,33 +22,55 @@
licenses = ["notice"], # Apache 2.0
)
-cc_library(
- name = "alignment",
- hdrs = ["alignment.h"],
- deps = [
- ":target_platform",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Public API
+#===------------------------------------------------------------------------===#
cc_library(
name = "api",
- srcs = [
- "api.c",
- ],
+ srcs = ["api.c"],
hdrs = ["api.h"],
visibility = ["//visibility:public"],
deps = [
- ":api_hdrs",
- ":target_platform",
+ ":core_headers",
":tracing",
],
)
+#===------------------------------------------------------------------------===#
+# Core headers (platform detection, compiler compat, etc)
+#===------------------------------------------------------------------------===#
+
cc_library(
- name = "api_hdrs",
- hdrs = ["api.h"],
+ name = "core_headers",
+ hdrs = [
+ "alignment.h",
+ "atomics.h",
+ "bitfield.h",
+ "math.h",
+ "memory.h",
+ "target_platform.h",
+ ],
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/types:span", # bitfield.h
+ ],
)
+cc_test(
+ name = "bitfield_test",
+ srcs = ["bitfield_test.cc"],
+ deps = [
+ ":core_headers",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+#===------------------------------------------------------------------------===#
+# Internal IREE C++ wrappers and utilities
+#===------------------------------------------------------------------------===#
+
cc_library(
name = "arena",
srcs = ["arena.cc"],
@@ -71,26 +93,20 @@
)
cc_library(
- name = "atomics",
- hdrs = ["atomics.h"],
+ name = "atomic_slist",
+ srcs = ["atomic_slist.c"],
+ hdrs = ["atomic_slist.h"],
deps = [
- ":target_platform",
- ],
-)
-
-cc_library(
- name = "bitfield",
- hdrs = ["bitfield.h"],
- deps = [
- "@com_google_absl//absl/types:span",
+ ":core_headers",
+ ":synchronization",
],
)
cc_test(
- name = "bitfield_test",
- srcs = ["bitfield_test.cc"],
+ name = "atomic_slist_test",
+ srcs = ["atomic_slist_test.cc"],
deps = [
- ":bitfield",
+ ":atomic_slist",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
],
@@ -103,61 +119,25 @@
"dynamic_library_win32.cc",
],
hdrs = ["dynamic_library.h"],
- linkopts = select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- "-ldl",
- ],
- }),
deps = [
+ ":core_headers",
":file_path",
":logging",
":status",
- ":target_platform",
":tracing",
+ "//build_tools:default_linkopts",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
-cc_binary(
- name = "dynamic_library_test_library.so",
- testonly = True,
- srcs = ["dynamic_library_test_library.cc"],
- linkshared = True,
-)
-
-cc_embed_data(
- name = "dynamic_library_test_library",
- testonly = True,
- srcs = [":dynamic_library_test_library.so"],
- cc_file_output = "dynamic_library_test_library_embed.cc",
- cpp_namespace = "iree",
- flatten = True,
- h_file_output = "dynamic_library_test_library_embed.h",
-)
-
-cc_test(
- name = "dynamic_library_test",
- srcs = ["dynamic_library_test.cc"],
- deps = [
- ":dynamic_library",
- ":dynamic_library_test_library",
- ":file_io",
- ":status",
- ":target_platform",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
cc_library(
name = "file_io",
hdrs = ["file_io.h"],
deps = [
+ ":core_headers",
":status",
- ":target_platform",
"//iree/base/internal:file_io_internal",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -233,39 +213,32 @@
)
cc_library(
- name = "flatbuffer_util",
- srcs = ["flatbuffer_util.cc"],
- hdrs = ["flatbuffer_util.h"],
+ name = "flags",
+ srcs = ["flags.cc"],
+ hdrs = ["flags.h"],
deps = [
- ":file_mapping",
- ":memory",
- ":ref_ptr",
- ":status",
- ":tracing",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "init",
- srcs = ["init.cc"],
- hdrs = ["init.h"],
- deps = [
- ":initializer",
+ ":api",
"@com_google_absl//absl/flags:parse",
],
)
cc_library(
- name = "initializer",
- srcs = ["initializer.cc"],
- hdrs = ["initializer.h"],
+ name = "flatcc",
+ hdrs = ["flatcc.h"],
deps = [
- ":target_platform",
+ ":flatcc_dummy",
+ "@com_github_dvidelabs_flatcc//:runtime",
+ ],
+)
+
+iree_flatbuffer_c_library(
+ name = "flatcc_dummy",
+ srcs = ["flatcc.fbs"],
+ flatcc_args = [
+ "--reader",
+ "--builder",
+ "--verifier",
+ "--json",
],
)
@@ -317,24 +290,8 @@
],
hdrs = ["main.h"],
deps = [
+ ":core_headers",
":logging",
- ":target_platform",
- ],
-)
-
-cc_library(
- name = "math",
- hdrs = ["math.h"],
- deps = [
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "memory",
- hdrs = ["memory.h"],
- deps = [
- "@com_google_absl//absl/types:span",
],
)
@@ -404,16 +361,9 @@
name = "synchronization",
srcs = ["synchronization.c"],
hdrs = ["synchronization.h"],
- linkopts = select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- "-lpthread",
- ],
- }),
deps = [
":api",
- ":atomics",
- ":target_platform",
+ ":core_headers",
":tracing",
],
)
@@ -421,9 +371,6 @@
cc_test(
name = "synchronization_benchmark",
srcs = ["synchronization_benchmark.cc"],
- tags = [
- "nowindows", # TODO(#3615) make this link on windows
- ],
deps = [
":synchronization",
"//iree/testing:benchmark_main",
@@ -434,9 +381,6 @@
cc_test(
name = "synchronization_test",
srcs = ["synchronization_test.cc"],
- tags = [
- "nowindows", # TODO(#3615) make this link on windows
- ],
deps = [
":synchronization",
"//iree/testing:gtest",
@@ -445,11 +389,6 @@
)
cc_library(
- name = "target_platform",
- hdrs = ["target_platform.h"],
-)
-
-cc_library(
name = "time",
hdrs = ["time.h"],
deps = [
@@ -477,31 +416,18 @@
"threading_win32.c",
],
hdrs = ["threading.h"],
- copts = [
- "-D_GNU_SOURCE=1",
- ],
- linkopts = select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- "-ldl",
- "-lpthread",
- ],
- }),
deps = [
":api",
- ":atomics",
+ ":core_headers",
":synchronization",
- ":target_platform",
":tracing",
+ "//build_tools:default_linkopts",
],
)
cc_test(
name = "threading_benchmark",
srcs = ["threading_benchmark.cc"],
- tags = [
- "nowindows", # TODO(#3615) make this link on windows
- ],
deps = [
":threading",
"//iree/testing:benchmark_main",
@@ -515,9 +441,6 @@
"threading_impl.h",
"threading_test.cc",
],
- tags = [
- "nowindows", # TODO(#3615) make this link on windows
- ],
deps = [
":synchronization",
":threading",
@@ -539,6 +462,7 @@
srcs = [
"wait_handle.c",
"wait_handle_epoll.c",
+ "wait_handle_impl.h",
"wait_handle_kqueue.c",
"wait_handle_poll.c",
"wait_handle_posix.c",
@@ -546,12 +470,9 @@
"wait_handle_win32.c",
],
hdrs = ["wait_handle.h"],
- copts = [
- "-D_GNU_SOURCE=1",
- ],
deps = [
":api",
- ":target_platform",
+ ":core_headers",
":tracing",
],
)
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt
index da6a9b0..c45233e 100644
--- a/iree/base/CMakeLists.txt
+++ b/iree/base/CMakeLists.txt
@@ -12,17 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-iree_add_all_subdirs()
+# bazel_to_cmake: DO NOT EDIT (tracing rule move)
-iree_cc_library(
- NAME
- alignment
- HDRS
- "alignment.h"
- DEPS
- ::target_platform
- PUBLIC
-)
+iree_add_all_subdirs()
iree_cc_library(
NAME
@@ -32,20 +24,38 @@
SRCS
"api.c"
DEPS
- ::api_hdrs
- ::target_platform
+ ::core_headers
::tracing
PUBLIC
)
iree_cc_library(
NAME
- api_hdrs
+ core_headers
HDRS
- "api.h"
+ "alignment.h"
+ "atomics.h"
+ "bitfield.h"
+ "math.h"
+ "memory.h"
+ "target_platform.h"
+ DEPS
+ absl::core_headers
+ absl::span
PUBLIC
)
+iree_cc_test(
+ NAME
+ bitfield_test
+ SRCS
+ "bitfield_test.cc"
+ DEPS
+ ::core_headers
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
+
iree_cc_library(
NAME
arena
@@ -73,41 +83,28 @@
iree_cc_library(
NAME
- atomics
+ atomic_slist
HDRS
- "atomics.h"
+ "atomic_slist.h"
+ SRCS
+ "atomic_slist.c"
DEPS
- ::target_platform
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- bitfield
- HDRS
- "bitfield.h"
- DEPS
- absl::span
+ ::core_headers
+ ::synchronization
PUBLIC
)
iree_cc_test(
NAME
- bitfield_test
+ atomic_slist_test
SRCS
- "bitfield_test.cc"
+ "atomic_slist_test.cc"
DEPS
- ::bitfield
- absl::core_headers
+ ::atomic_slist
iree::testing::gtest
iree::testing::gtest_main
)
-iree_select_compiler_opts(_DYNAMIC_LIBRARY_LINKOPTS
- CLANG_OR_GCC
- "-ldl"
-)
-
iree_cc_library(
NAME
dynamic_library
@@ -116,13 +113,11 @@
SRCS
"dynamic_library_posix.cc"
"dynamic_library_win32.cc"
- LINKOPTS
- ${_DYNAMIC_LIBRARY_LINKOPTS}
DEPS
+ ::core_headers
::file_path
::logging
::status
- ::target_platform
::tracing
absl::memory
absl::span
@@ -130,59 +125,14 @@
PUBLIC
)
-# TODO(scotttodd): clean up bazel_to_cmake handling here
-# * this is a cc_binary in Bazel, but `linkshared` fits iree_cc_library better
-# * the output file name is platform-specific, get it with $<TARGET_FILE:>
-iree_cc_library(
- NAME
- dynamic_library_test_library.so
- OUT
- dynamic_library_test_library.so
- SRCS
- "dynamic_library_test_library.cc"
- TESTONLY
- SHARED
-)
-
-iree_cc_embed_data(
- NAME
- dynamic_library_test_library
- GENERATED_SRCS
- "$<TARGET_FILE:iree::base::dynamic_library_test_library.so>"
- CC_FILE_OUTPUT
- "dynamic_library_test_library_embed.cc"
- H_FILE_OUTPUT
- "dynamic_library_test_library_embed.h"
- TESTONLY
- CPP_NAMESPACE
- "iree"
- FLATTEN
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- dynamic_library_test
- SRCS
- "dynamic_library_test.cc"
- DEPS
- ::dynamic_library
- ::dynamic_library_test_library
- ::file_io
- ::status
- ::target_platform
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
iree_cc_library(
NAME
file_io
HDRS
"file_io.h"
DEPS
+ ::core_headers
::status
- ::target_platform
absl::memory
absl::span
absl::strings
@@ -268,47 +218,38 @@
iree_cc_library(
NAME
- flatbuffer_util
+ flags
HDRS
- "flatbuffer_util.h"
+ "flags.h"
SRCS
- "flatbuffer_util.cc"
+ "flags.cc"
DEPS
- ::file_mapping
- ::memory
- ::ref_ptr
- ::status
- ::tracing
- absl::memory
- absl::optional
- absl::span
- absl::strings
- flatbuffers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- init
- HDRS
- "init.h"
- SRCS
- "init.cc"
- DEPS
+ ::api
absl::flags_parse
- ::initializer
PUBLIC
)
iree_cc_library(
NAME
- initializer
+ flatcc
HDRS
- "initializer.h"
- SRCS
- "initializer.cc"
+ "flatcc.h"
DEPS
- ::target_platform
+ ::flatcc_dummy
+ flatcc::runtime
+ PUBLIC
+)
+
+flatbuffer_c_library(
+ NAME
+ flatcc_dummy
+ SRCS
+ "flatcc.fbs"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
@@ -363,28 +304,8 @@
"main_posix.cc"
"main_win32.cc"
DEPS
+ ::core_headers
::logging
- ::target_platform
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- math
- HDRS
- "math.h"
- DEPS
- absl::core_headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- memory
- HDRS
- "memory.h"
- DEPS
- absl::span
PUBLIC
)
@@ -459,16 +380,6 @@
iree::testing::gtest_main
)
-if(NOT ANDROID)
- iree_select_compiler_opts(_SYNCHRONIZATION_LINKOPTS
- CLANG_OR_GCC
- "-lpthread"
- )
-else()
- # Android provides its own pthreads support with no linking required.
- set(_SYNCHRONIZATION_LINKOPTS "")
-endif()
-
iree_cc_library(
NAME
synchronization
@@ -476,12 +387,9 @@
"synchronization.h"
SRCS
"synchronization.c"
- LINKOPTS
- ${_SYNCHRONIZATION_LINKOPTS}
DEPS
::api
- ::atomics
- ::target_platform
+ ::core_headers
::tracing
PUBLIC
)
@@ -510,14 +418,6 @@
iree_cc_library(
NAME
- target_platform
- HDRS
- "target_platform.h"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
time
HDRS
"time.h"
@@ -537,20 +437,6 @@
iree::testing::gtest_main
)
-if(NOT ANDROID)
- iree_select_compiler_opts(_THREADING_LINKOPTS
- CLANG_OR_GCC
- "-ldl"
- "-lpthread"
- )
-else()
- iree_select_compiler_opts(_THREADING_LINKOPTS
- CLANG_OR_GCC
- "-ldl"
- # Android provides its own pthreads support with no linking required.
- )
-endif()
-
iree_cc_library(
NAME
threading
@@ -562,15 +448,10 @@
"threading_impl.h"
"threading_pthreads.c"
"threading_win32.c"
- COPTS
- "-D_GNU_SOURCE=1"
- LINKOPTS
- ${_THREADING_LINKOPTS}
DEPS
::api
- ::atomics
+ ::core_headers
::synchronization
- ::target_platform
::tracing
PUBLIC
)
@@ -599,11 +480,7 @@
iree::testing::gtest_main
)
-iree_select_compiler_opts(IREE_LINKOPTS_TRACING
- GCC_OR_CLANG
- -ldl
-)
-
+# TODO(benvanik): redirect to internal/tracing/ or something.
if(${IREE_ENABLE_RUNTIME_TRACING})
iree_cc_library(
NAME
@@ -614,10 +491,8 @@
"${IREE_ROOT_DIR}/third_party/tracy/TracyC.h"
SRCS
"tracing.cc"
- LINKOPTS
- ${IREE_LINKOPTS_TRACING}
DEPS
- ::target_platform
+ ::core_headers
absl::core_headers
DEFINES
# TODO(#2114): Change the mode to 2.
@@ -644,16 +519,15 @@
SRCS
"wait_handle.c"
"wait_handle_epoll.c"
+ "wait_handle_impl.h"
"wait_handle_kqueue.c"
"wait_handle_poll.c"
"wait_handle_posix.c"
"wait_handle_posix.h"
"wait_handle_win32.c"
- COPTS
- "-D_GNU_SOURCE=1"
DEPS
::api
- ::target_platform
+ ::core_headers
::tracing
PUBLIC
)
diff --git a/iree/base/alignment.h b/iree/base/alignment.h
index fe9cfaf..024a334 100644
--- a/iree/base/alignment.h
+++ b/iree/base/alignment.h
@@ -18,12 +18,22 @@
#ifndef IREE_BASE_ALIGNMENT_H_
#define IREE_BASE_ALIGNMENT_H_
+#include <stddef.h>
+
#include "iree/base/target_platform.h"
#ifdef __cplusplus
extern "C" {
#endif
+// https://en.cppreference.com/w/c/types/max_align_t
+#if defined(IREE_PLATFORM_WINDOWS)
+// NOTE: 16 is a specified Microsoft API requirement for some functions.
+#define iree_max_align_t 16
+#else
+#define iree_max_align_t sizeof(long double)
+#endif // IREE_PLATFORM_*
+
// https://en.cppreference.com/w/c/language/_Alignas
// https://en.cppreference.com/w/c/language/_Alignof
#if defined(IREE_COMPILER_MSVC)
diff --git a/iree/base/api.c b/iree/base/api.c
index 809e6d5..47c5d68 100644
--- a/iree/base/api.c
+++ b/iree/base/api.c
@@ -949,11 +949,6 @@
expected_version, actual_version);
}
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_api_init(int* argc,
- char*** argv) {
- return iree_ok_status();
-}
-
//===----------------------------------------------------------------------===//
// iree_time_t and iree_duration_t
//===----------------------------------------------------------------------===//
diff --git a/iree/base/api.h b/iree/base/api.h
index 561af88..3b02694 100644
--- a/iree/base/api.h
+++ b/iree/base/api.h
@@ -737,24 +737,6 @@
iree_api_version_check(iree_api_version_t expected_version,
iree_api_version_t* out_actual_version);
-// Initializes IREE for use within a binary.
-//
-// Specifically, this parses any command line flags and performs module
-// initialization (such as for tracing and dynamic driver registration). If
-// your application is certain it does not need this functionality, this call
-// may be skipped.
-//
-// |argc| and |argv| should contain any command line flags to parse.
-// If there are no flags to parse, nullptr may be passed, but this should still
-// be called so other initialization happens.
-//
-// This should typically be called early in some sort of main() function once,
-// before calling most other API functions. Certain core API functions here
-// such as iree_api_version_check, iree_allocator_malloc, and
-// iree_allocator_free are safe to call before this.
-IREE_API_EXPORT iree_status_t IREE_API_CALL iree_api_init(int* argc,
- char*** argv);
-
//===----------------------------------------------------------------------===//
// iree_time_t and iree_duration_t
//===----------------------------------------------------------------------===//
diff --git a/iree/base/atomic_slist.c b/iree/base/atomic_slist.c
new file mode 100644
index 0000000..c495c7e
--- /dev/null
+++ b/iree/base/atomic_slist.c
@@ -0,0 +1,114 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/atomic_slist.h"
+
+#include <assert.h>
+
+// TODO(benvanik): add TSAN annotations when switched to atomics:
+// https://github.com/gcc-mirror/gcc/blob/master/libsanitizer/include/sanitizer/tsan_interface_atomic.h
+// https://reviews.llvm.org/D18500
+
+void iree_atomic_slist_initialize(iree_atomic_slist_t* out_list) {
+ memset(out_list, 0, sizeof(*out_list));
+ iree_slim_mutex_initialize(&out_list->mutex);
+}
+
+void iree_atomic_slist_deinitialize(iree_atomic_slist_t* list) {
+ // TODO(benvanik): assert empty.
+ iree_slim_mutex_deinitialize(&list->mutex);
+ memset(list, 0, sizeof(*list));
+}
+
+void iree_atomic_slist_concat(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* head,
+ iree_atomic_slist_entry_t* tail) {
+ if (IREE_UNLIKELY(!head)) return;
+ iree_slim_mutex_lock(&list->mutex);
+ tail->next = list->head;
+ list->head = head;
+ iree_slim_mutex_unlock(&list->mutex);
+}
+
+void iree_atomic_slist_push(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* entry) {
+ iree_slim_mutex_lock(&list->mutex);
+ iree_atomic_slist_push_unsafe(list, entry);
+ iree_slim_mutex_unlock(&list->mutex);
+}
+
+void iree_atomic_slist_push_unsafe(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* entry) {
+ // NOTE: no lock is held here and no atomic operation will be used when this
+ // is actually made atomic.
+ entry->next = list->head;
+ list->head = entry;
+}
+
+iree_atomic_slist_entry_t* iree_atomic_slist_pop(iree_atomic_slist_t* list) {
+ iree_slim_mutex_lock(&list->mutex);
+ iree_atomic_slist_entry_t* entry = list->head;
+ list->head = entry ? entry->next : NULL;
+ iree_slim_mutex_unlock(&list->mutex);
+ return entry;
+}
+
+bool iree_atomic_slist_flush(iree_atomic_slist_t* list,
+ iree_atomic_slist_flush_order_t flush_order,
+ iree_atomic_slist_entry_t** out_head,
+ iree_atomic_slist_entry_t** out_tail) {
+ // Exchange list head with NULL to steal the entire list. The list will be in
+ // the native LIFO order of the slist.
+ iree_slim_mutex_lock(&list->mutex);
+ iree_atomic_slist_entry_t* head = list->head;
+ list->head = NULL;
+ iree_slim_mutex_unlock(&list->mutex);
+ if (!head) return false;
+
+ switch (flush_order) {
+ case IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_LIFO: {
+ // List is already in native LIFO order. If the user wants a tail we have
+ // to scan for it, though, which we really only want to do when required
+ // as it's a linked list pointer walk.
+ *out_head = head;
+ if (out_tail) {
+ iree_atomic_slist_entry_t* p = head;
+ while (p->next) p = p->next;
+ *out_tail = p;
+ }
+ break;
+ }
+ case IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_FIFO: {
+ // Reverse the list in a single scan. list_head is our tail, so scan
+ // forward to find our head. Since we have to walk the whole list anyway
+ // we can cheaply give both the head and tail to the caller.
+ iree_atomic_slist_entry_t* tail = head;
+ if (out_tail) *out_tail = tail;
+ iree_atomic_slist_entry_t* p = head;
+ do {
+ iree_atomic_slist_entry_t* next = p->next;
+ p->next = head;
+ head = p;
+ p = next;
+ } while (p != NULL);
+ tail->next = NULL;
+ *out_head = head;
+ break;
+ }
+ default:
+ return false;
+ }
+
+ return true;
+}
diff --git a/iree/base/atomic_slist.h b/iree/base/atomic_slist.h
new file mode 100644
index 0000000..f6fb25b
--- /dev/null
+++ b/iree/base/atomic_slist.h
@@ -0,0 +1,264 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// NOTE: the best kind of synchronization is no synchronization; always try to
+// design your algorithm so that you don't need anything from this file :)
+// See https://travisdowns.github.io/blog/2020/07/06/concurrency-costs.html
+
+#ifndef IREE_BASE_ATOMIC_SLIST_H_
+#define IREE_BASE_ATOMIC_SLIST_H_
+
+#include <stddef.h>
+
+#include "iree/base/alignment.h"
+#include "iree/base/atomics.h"
+#include "iree/base/synchronization.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// The embedded pointer to the next entry in the slist. This points to the
+// internal iree_atomic_slist_entry_t, *not* the user-provided pointer.
+typedef void* iree_atomic_slist_intrusive_ptr_t;
+
+// DO NOT USE: implementation detail.
+typedef struct iree_atomic_slist_entry_s {
+ struct iree_atomic_slist_entry_s* next;
+} iree_atomic_slist_entry_t;
+
+// Lightweight contention-avoiding singly linked list.
+// This models optimistically-ordered LIFO behavior (stack push/pop) using
+// atomic primitives.
+//
+// ***************************************************
+// ******** ONLY APPROXIMATE ORDER GUARANTEES ********
+// ***************************************************
+//
+// This makes it extremely efficient for when only eventual consistency across
+// producers and consumers is required. The most common example is free lists
+// where all that matters is that entries make it into the list and not that
+// they have any particular order between them. Work queues where all tasks
+// within the queue are able to execute in any order like with wavefront-style
+// scheduling can also benefit from this relaxed behavior.
+//
+// If a strict ordering is required this can be used as a primitive to construct
+// a flat-combining data structure where data structure change requests are
+// published to this list and a combiner is chosen to land the published data in
+// an appropriate order:
+// http://people.csail.mit.edu/shanir/publications/Flat%20Combining%20SPAA%2010.pdf
+//
+// There's often still a benefit in unordered scenarios of having LIFO behavior
+// as it promotes cache-friendly small linked lists when there is a small number
+// of producers and consumers (1:1 is the best case), though as the producer and
+// consumer count increases the LIFO behavior can pessimize performance as there
+// is more contention for the list head pointer. Prefer to shard across multiple
+// per-core/thread lists and use techniques like flat-combining for the
+// cross-core/thread aggregation/sequencing.
+//
+// This API modeled roughly on the Windows SList type:
+// https://docs.microsoft.com/en-us/windows/win32/sync/interlocked-singly-linked-lists
+// which is roughly compatible with the Apple OSAtomic queue:
+// https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man3/OSAtomicEnqueue.3.html
+// https://opensource.apple.com/source/libplatform/libplatform-125/include/libkern/OSAtomicQueue.h.auto.html
+//
+// Usage:
+// https://docs.microsoft.com/en-us/windows/win32/sync/using-singly-linked-lists
+//
+// WARNING: this is an extremely sharp pufferfish-esque API. Don't use it. 🐡
+//
+// TODO(benvanik): verify behavior (and worthwhileness) of supporting platform
+// primitives. The benefit of something like OSAtomicEnqueue/Dequeue is that it
+// may have better tooling (TSAN), special intrinsic handling in the compiler,
+// etc. That said, the Windows Interlocked* variants don't seem to. Having a
+// single heavily tested implementation seems more worthwhile than several.
+typedef iree_alignas(iree_max_align_t) struct {
+ // TODO(benvanik): spend some time golfing this. Unblocking myself for now :)
+ iree_slim_mutex_t mutex;
+ iree_atomic_slist_entry_t* head;
+} iree_atomic_slist_t;
+
+// Initializes an slist handle to an empty list.
+// Lists must be flushed to empty and deinitialized when no longer needed with
+// iree_atomic_slist_deinitialize.
+//
+// NOTE: not thread-safe; existing |out_list| contents are discarded.
+void iree_atomic_slist_initialize(iree_atomic_slist_t* out_list);
+
+// Deinitializes an slist.
+// The list must be empty; callers are expected to flush the list from the same
+// thread making this call when it is guaranteed no other thread may be trying
+// to use the list.
+//
+// NOTE: not thread-safe; |list| must not be used by any other thread.
+void iree_atomic_slist_deinitialize(iree_atomic_slist_t* list);
+
+// Concatenates a span of entries into the list in the order they are provided.
+//
+// Example:
+// existing slist: C B A
+// provided span: 1 2 3
+// resulting slist: 1 2 3 C B A
+void iree_atomic_slist_concat(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* head,
+ iree_atomic_slist_entry_t* tail);
+
+// Pushes an entry into the list.
+//
+// existing slist: C B A
+// provided entry: 1
+// resulting slist: 1 C B A
+void iree_atomic_slist_push(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* entry);
+
+// Pushes an entry into the list without using an atomic update.
+// This is useful for when |list| is known to be inaccessible to any other
+// thread, such as when populating a stack-local list prior to sharing it.
+void iree_atomic_slist_push_unsafe(iree_atomic_slist_t* list,
+ iree_atomic_slist_entry_t* entry);
+
+// Pops the most recently pushed entry from the list and returns it.
+// Returns NULL if the list was empty at the time it was queried.
+//
+// existing slist: C B A
+// resulting slist: B A
+// returned entry: C
+iree_atomic_slist_entry_t* iree_atomic_slist_pop(iree_atomic_slist_t* list);
+
+// Defines the approximate order in which a span of flushed entries is returned.
+typedef uint32_t iree_atomic_slist_flush_order_t;
+enum {
+ // |out_head| and |out_tail| will be set to a span of the entries roughly in
+ // the order they were pushed to the list in LIFO (stack) order.
+ //
+ // Example:
+ // slist: C B A
+ // result: C B A (or when contended possibly C A B)
+ IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_LIFO,
+ // |out_head| and |out_tail| will be set to the first and last entries
+ // pushed respectively, turning this LIFO slist into a FIFO queue.
+ //
+ // Example:
+ // slist: C B A
+ // result: A B C (or when contended possibly B A C)
+ IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_FIFO,
+};
+
+// Removes all items from the list and returns them in **APPROXIMATELY** the
+// |flush_order| requested. As there are no order guarantees there may be slight
+// transpositions of entries that were pushed from multiple processors or even
+// interleaved entries within spans of entries pushed with
+// iree_atomic_slist_concat.
+//
+// If |out_tail| is not required it can be omitted and this may avoid the
+// need for the flush to walk the list and touch each entry.
+//
+// Returns true if any items were present and false if the output list is empty.
+// Note that because atomic data structures can race it's possible for there to
+// both be something in the list prior to this call and something in the list
+// after the call and yet the return can still be false.
+bool iree_atomic_slist_flush(iree_atomic_slist_t* list,
+ iree_atomic_slist_flush_order_t flush_order,
+ iree_atomic_slist_entry_t** out_head,
+ iree_atomic_slist_entry_t** out_tail);
+
+//==============================================================================
+// Typed wrapper generator for iree_atomic_slist_t
+//==============================================================================
+
+// Typed and named wrappers for making atomic slists easier to work with.
+//
+// Usage:
+// typedef struct {
+// int some_fields;
+// iree_atomic_slist_intrusive_ptr_t slist_next;
+// int more_fields;
+// } my_type_t;
+// IREE_TYPED_ATOMIC_SLIST_WRAPPER(my_type, my_type_t,
+// offsetof(my_type_t, slist_next));
+//
+// my_type_slist_t list;
+// my_type_slist_initialize(&list);
+// my_type_t* entry = allocate_my_type(123);
+// my_type_slist_push(&list, entry);
+// entry = my_type_slist_pop(&list);
+#define IREE_TYPED_ATOMIC_SLIST_WRAPPER(name, type, next_offset) \
+ static inline iree_atomic_slist_entry_t* name##_slist_entry_from_ptr( \
+ type* entry) { \
+ return entry \
+ ? ((iree_atomic_slist_entry_t*)((uint8_t*)entry + next_offset)) \
+ : NULL; \
+ } \
+ static inline type* name##_slist_entry_to_ptr( \
+ iree_atomic_slist_entry_t* entry) { \
+ return entry ? (type*)(((uint8_t*)entry) - next_offset) : NULL; \
+ } \
+ \
+ static inline type* name##_slist_get_next(type* entry) { \
+ if (!entry) return NULL; \
+ return name##_slist_entry_to_ptr( \
+ ((iree_atomic_slist_entry_t*)((uint8_t*)entry + next_offset))->next); \
+ } \
+ static inline void name##_slist_set_next(type* entry, type* next) { \
+ name##_slist_entry_from_ptr(entry)->next = \
+ name##_slist_entry_from_ptr(next); \
+ } \
+ \
+ typedef iree_alignas(iree_max_align_t) struct { \
+ iree_atomic_slist_t impl; \
+ } name##_slist_t; \
+ \
+ static inline void name##_slist_initialize(name##_slist_t* out_list) { \
+ iree_atomic_slist_initialize(&out_list->impl); \
+ } \
+ static inline void name##_slist_deinitialize(name##_slist_t* list) { \
+ iree_atomic_slist_deinitialize(&list->impl); \
+ } \
+ \
+ static inline void name##_slist_push(name##_slist_t* list, type* entry) { \
+ iree_atomic_slist_push(&list->impl, name##_slist_entry_from_ptr(entry)); \
+ } \
+ static inline void name##_slist_push_unsafe(name##_slist_t* list, \
+ type* entry) { \
+ iree_atomic_slist_push_unsafe(&list->impl, \
+ name##_slist_entry_from_ptr(entry)); \
+ } \
+ static inline void name##_slist_concat(name##_slist_t* list, type* head, \
+ type* tail) { \
+ iree_atomic_slist_concat(&list->impl, name##_slist_entry_from_ptr(head), \
+ name##_slist_entry_from_ptr(tail)); \
+ } \
+ static inline type* name##_slist_pop(name##_slist_t* list) { \
+ return name##_slist_entry_to_ptr(iree_atomic_slist_pop(&list->impl)); \
+ } \
+ \
+ static inline bool name##_slist_flush( \
+ name##_slist_t* list, iree_atomic_slist_flush_order_t flush_order, \
+ type** out_head, type** out_tail) { \
+ iree_atomic_slist_entry_t* head = NULL; \
+ iree_atomic_slist_entry_t* tail = NULL; \
+ if (!iree_atomic_slist_flush(&list->impl, flush_order, &head, \
+ out_tail ? &tail : NULL)) { \
+ return false; /* empty list */ \
+ } \
+ *out_head = name##_slist_entry_to_ptr(head); \
+ if (out_tail) *out_tail = name##_slist_entry_to_ptr(tail); \
+ return true; \
+ }
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // IREE_BASE_ATOMIC_SLIST_H_
diff --git a/iree/base/atomic_slist_test.cc b/iree/base/atomic_slist_test.cc
new file mode 100644
index 0000000..8b32f7a
--- /dev/null
+++ b/iree/base/atomic_slist_test.cc
@@ -0,0 +1,191 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/atomic_slist.h"
+
+#include "iree/testing/gtest.h"
+
+namespace {
+
+struct dummy_entry_t {
+ // NOTE: we purposefully offset the entry pointer
+ size_t value = 0;
+ iree_atomic_slist_intrusive_ptr_t slist_next = NULL;
+};
+IREE_TYPED_ATOMIC_SLIST_WRAPPER(dummy, dummy_entry_t,
+ offsetof(dummy_entry_t, slist_next));
+
+std::vector<dummy_entry_t> MakeDummySListItems(size_t base_index,
+ size_t count) {
+ std::vector<dummy_entry_t> items(count);
+ for (size_t i = 0; i < count; ++i) {
+ items[i].value = base_index + i;
+ }
+ return items;
+}
+
+TEST(AtomicSList, Lifetime) {
+ iree_atomic_slist_t list; // NOTE: intentionally uninitialized.
+ iree_atomic_slist_initialize(&list);
+ iree_atomic_slist_deinitialize(&list);
+}
+
+TEST(AtomicSList, BasicUsage) {
+ dummy_slist_t list;
+ dummy_slist_initialize(&list);
+
+ // List starts empty.
+ EXPECT_EQ(NULL, dummy_slist_pop(&list));
+
+ // Push some items into the list (LIFO order).
+ // New contents: 5 4 3 2 1 0
+ auto item_storage = MakeDummySListItems(0, 6);
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ dummy_slist_push(&list, &item_storage[i]);
+ }
+
+ // Now pop them out - they should be in reverse order.
+ // New contents: e
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ dummy_entry_t* p = dummy_slist_pop(&list);
+ ASSERT_TRUE(p);
+ EXPECT_EQ(item_storage.size() - i - 1, p->value);
+ }
+
+ // List ends empty.
+ EXPECT_EQ(NULL, dummy_slist_pop(&list));
+
+ dummy_slist_deinitialize(&list);
+}
+
+TEST(AtomicSList, Concat) {
+ dummy_slist_t list;
+ dummy_slist_initialize(&list);
+
+ // Push some initial items into the list (LIFO order).
+ // New contents: 1 0
+ auto initial_item_storage = MakeDummySListItems(0, 2);
+ for (size_t i = 0; i < initial_item_storage.size(); ++i) {
+ dummy_slist_push(&list, &initial_item_storage[i]);
+ }
+
+ // Stitch items together modeling what a user may do when building the list
+ // themselves.
+ // Items: 2 3 4
+ auto span_item_storage = MakeDummySListItems(2, 3);
+ for (size_t i = 0; i < span_item_storage.size() - 1; ++i) {
+ dummy_slist_set_next(&span_item_storage[i], &span_item_storage[i + 1]);
+ }
+
+ // Push all of the items to the list at once.
+ // New contents: 2 3 4 1 0
+ dummy_slist_concat(&list, &span_item_storage.front(),
+ &span_item_storage.back());
+
+ // Pop the span items and verify they are in the correct order: we effectively
+ // pushed them such that popping is FIFO (2->4).
+ // New contents: 1 0
+ for (size_t i = 0; i < span_item_storage.size(); ++i) {
+ dummy_entry_t* p = dummy_slist_pop(&list);
+ ASSERT_TRUE(p);
+ EXPECT_EQ(/*base_index=*/2 + i, p->value);
+ }
+
+ // Pop the initial items and ensure they survived.
+ // New contents: e
+ for (size_t i = 0; i < initial_item_storage.size(); ++i) {
+ dummy_entry_t* p = dummy_slist_pop(&list);
+ ASSERT_TRUE(p);
+ EXPECT_EQ(initial_item_storage.size() - i - 1, p->value);
+ }
+
+ dummy_slist_deinitialize(&list);
+}
+
+TEST(AtomicSList, FlushLIFO) {
+ dummy_slist_t list;
+ dummy_slist_initialize(&list);
+
+ // Flushing when empty is ok.
+ dummy_entry_t* head = NULL;
+ dummy_entry_t* tail = NULL;
+ EXPECT_FALSE(dummy_slist_flush(
+ &list, IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_LIFO, &head, &tail));
+
+ // Push items into the list (LIFO order).
+ // New contents: 3 2 1 0
+ auto item_storage = MakeDummySListItems(0, 4);
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ dummy_slist_push(&list, &item_storage[i]);
+ }
+
+ // Flush in LIFO order and verify empty.
+ // New contents: e
+ EXPECT_TRUE(dummy_slist_flush(
+ &list, IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_LIFO, &head, &tail));
+ EXPECT_EQ(NULL, dummy_slist_pop(&list));
+
+ // Verify LIFO order and list pointer walking.
+ // Note that head and tail are reverse of item storage!
+ EXPECT_EQ(&item_storage.back(), head);
+ EXPECT_EQ(&item_storage.front(), tail);
+ dummy_entry_t* p = head;
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ ASSERT_TRUE(p);
+ EXPECT_EQ(item_storage.size() - i - 1, p->value);
+ p = dummy_slist_get_next(p);
+ }
+ EXPECT_EQ(NULL, p);
+
+ dummy_slist_deinitialize(&list);
+}
+
+TEST(AtomicSList, FlushFIFO) {
+ dummy_slist_t list;
+ dummy_slist_initialize(&list);
+
+ // Flushing when empty is ok.
+ dummy_entry_t* head = NULL;
+ dummy_entry_t* tail = NULL;
+ EXPECT_FALSE(dummy_slist_flush(
+ &list, IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_FIFO, &head, &tail));
+
+ // Push items into the list (LIFO order).
+ // New contents: 3 2 1 0
+ auto item_storage = MakeDummySListItems(0, 4);
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ dummy_slist_push(&list, &item_storage[i]);
+ }
+
+ // Flush in FIFO order and verify empty.
+ // New contents: e
+ EXPECT_TRUE(dummy_slist_flush(
+ &list, IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_FIFO, &head, &tail));
+ EXPECT_EQ(NULL, dummy_slist_pop(&list));
+
+ // Verify FIFO order and list pointer walking.
+ EXPECT_EQ(&item_storage.front(), head);
+ EXPECT_EQ(&item_storage.back(), tail);
+ dummy_entry_t* p = head;
+ for (size_t i = 0; i < item_storage.size(); ++i) {
+ ASSERT_TRUE(p);
+ EXPECT_EQ(i, p->value);
+ p = dummy_slist_get_next(p);
+ }
+ EXPECT_EQ(NULL, p);
+
+ dummy_slist_deinitialize(&list);
+}
+
+} // namespace
diff --git a/iree/base/flags.cc b/iree/base/flags.cc
new file mode 100644
index 0000000..4edd463
--- /dev/null
+++ b/iree/base/flags.cc
@@ -0,0 +1,54 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/flags.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+// TODO(#3814): replace abseil with pretty much anything else.
+#include "absl/flags/parse.h"
+
+iree_status_t iree_flags_parse(int* argc, char*** argv) {
+ if (argc == nullptr || argv == nullptr || *argc == 0) {
+ // No flags; that's fine - in some environments flags aren't supported.
+ return iree_ok_status();
+ }
+
+ auto positional_args = absl::ParseCommandLine(*argc, *argv);
+ if (positional_args.size() < *argc) {
+ // Edit the passed argument refs to only include positional args.
+ *argc = static_cast<int>(positional_args.size());
+ for (int i = 0; i < *argc; ++i) {
+ (*argv)[i] = positional_args[i];
+ }
+ (*argv)[*argc + 1] = nullptr;
+ }
+
+ return iree_ok_status();
+}
+
+void iree_flags_parse_checked(int* argc, char*** argv) {
+ iree_status_t status = iree_flags_parse(argc, argv);
+ if (iree_status_is_cancelled(status)) {
+ exit(EXIT_SUCCESS);
+ return;
+ }
+ if (!iree_status_is_ok(status)) {
+ // TODO(#2843): replace C++ logging.
+ iree_status_ignore(status);
+ exit(EXIT_FAILURE);
+ return;
+ }
+}
diff --git a/iree/base/flags.h b/iree/base/flags.h
new file mode 100644
index 0000000..9368f68
--- /dev/null
+++ b/iree/base/flags.h
@@ -0,0 +1,59 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_FLAGS_H_
+#define IREE_BASE_FLAGS_H_
+
+#include "iree/base/api.h"
+
+//===----------------------------------------------------------------------===//
+// Flag parsing
+//===----------------------------------------------------------------------===//
+
+// Parses flags from the given command line arguments.
+// All flag-style arguments ('--foo', '-f', etc) will be consumed and argc/argv
+// will be updated to contain only the program name (index 0) and any remaining
+// positional arguments.
+//
+// Returns success if all flags were parsed and execution should continue.
+// May return IREE_STATUS_CANCELLED if execution should be cancelled gracefully
+// such as when --help is used.
+//
+// Usage:
+// extern "C" int main(int argc, char** argv) {
+// iree_status_t status = iree_flags_parse(&argc, &argv);
+// if (iree_status_is_cancelled(status)) return 0;
+// if (!iree_status_is_ok(status)) {
+// // TODO(#2843): replace C++ logging.
+// LOG(ERROR) << status;
+// iree_status_ignore(status);
+// return 1;
+// }
+// consume_positional_args(argc, argv);
+// return 0;
+// }
+//
+// Example:
+// argc = 4, argv = ['program', 'abc', '--flag=2', '-p']
+// Results:
+// argc = 2, argv = ['program', 'abc']
+iree_status_t iree_flags_parse(int* argc, char*** argv);
+
+// Parses flags as with iree_flags_parse but will use exit() or abort().
+// WARNING: this almost always what you want in a command line tool and *never*
+// what you want when embedded in a host process. You don't want to have a flag
+// typo and shut down your entire server/sandbox/Android app/etc.
+void iree_flags_parse_checked(int* argc, char*** argv);
+
+#endif // IREE_BASE_FLAGS_H_
diff --git a/iree/base/flatbuffer_util.cc b/iree/base/flatbuffer_util.cc
deleted file mode 100644
index 117aea2..0000000
--- a/iree/base/flatbuffer_util.cc
+++ /dev/null
@@ -1,143 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/flatbuffer_util.h"
-
-#include <cerrno>
-#include <cstring>
-
-#include "absl/memory/memory.h"
-#include "iree/base/file_mapping.h"
-#include "iree/base/memory.h"
-#include "iree/base/status.h"
-#include "iree/base/tracing.h"
-
-namespace iree {
-
-FlatBufferFileBase::~FlatBufferFileBase() {
- if (deleter_) {
- deleter_();
- deleter_ = []() {};
- }
-}
-
-Status FlatBufferFileBase::Create(const void* root_ptr,
- std::function<void()> deleter) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::Create");
-
- root_ptr_ = root_ptr;
- deleter_ = std::move(deleter);
-
- return OkStatus();
-}
-
-Status FlatBufferFileBase::CreateWithBackingBuffer(
- const void* root_ptr, ::flatbuffers::DetachedBuffer backing_buffer) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::Create");
-
- root_ptr_ = root_ptr;
-
- // Pass along the buffer provided so we keep it alive until the
- // FlatBufferFileBase is destructed.
- auto backing_buffer_baton = IreeMoveToLambda(backing_buffer);
- deleter_ = [backing_buffer_baton]() { (void)backing_buffer_baton.value; };
-
- return OkStatus();
-}
-
-Status FlatBufferFileBase::Wrap(const void* root_ptr) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::Wrap");
- return Create(root_ptr, []() {});
-}
-
-Status FlatBufferFileBase::FromBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- IREE_TRACE_SCOPE();
-
- // Sanity check buffer for the minimum size as FlatBuffers doesn't.
- if (buffer_data.size() < 16) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Provided serialized flatbuffer buffer is too small to be legit "
- "at size="
- << buffer_data.size();
- }
-
- // Ensure the buffer has the BIPE magic bytes.
- if (identifier.has_value() && !::flatbuffers::BufferHasIdentifier(
- buffer_data.data(), identifier.value())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Provided serialized buffer does not contain the expected type; "
- "magic bytes mismatch (expected "
- << identifier.value() << ")";
- }
-
- // Verify the FlatBuffer contains valid offsets and won't try to read out of
- // bounds of the buffer. We inline a bit of VerifyBufferFromStart so this code
- // can stay generic.
- {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::FromBufferVerification");
- ::flatbuffers::Verifier verifier{buffer_data.data(), buffer_data.size()};
- if (!verifier_fn(identifier.value_or(nullptr), &verifier)) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "FlatBuffer failed to verify as expected type; possibly "
- "corrupt input";
- }
- }
-
- // Resolve the root pointer in the buffer.
- // This is GetMutableRoot such that we don't need to know T.
- root_ptr_ = buffer_data.data() +
- ::flatbuffers::EndianScalar(
- *reinterpret_cast<const ::flatbuffers::uoffset_t*>(
- buffer_data.data()));
- if (!root_ptr_) {
- return FailedPreconditionErrorBuilder(IREE_LOC)
- << "Unable to resolve root table";
- }
- deleter_ = std::move(deleter);
-
- return OkStatus();
-}
-
-Status FlatBufferFileBase::WrapBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::WrapBuffer");
- return FromBuffer(
- identifier, buffer_data, []() {}, root_type_size, verifier_fn);
-}
-
-Status FlatBufferFileBase::LoadFile(Identifier identifier, std::string path,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- IREE_TRACE_SCOPE0("FlatBufferFileBase::LoadFile");
-
- IREE_ASSIGN_OR_RETURN(auto file_mapping, FileMapping::OpenRead(path));
- auto buffer_data = file_mapping->data();
-
- auto handle_baton = IreeMoveToLambda(file_mapping);
- return FromBuffer(
- identifier, buffer_data,
- [handle_baton]() {
- // Keeping the mmap handle alive.
- (void)handle_baton.value;
- },
- root_type_size, verifier_fn);
-}
-
-} // namespace iree
diff --git a/iree/base/flatbuffer_util.h b/iree/base/flatbuffer_util.h
deleted file mode 100644
index 0e8c81f..0000000
--- a/iree/base/flatbuffer_util.h
+++ /dev/null
@@ -1,328 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_FLATBUFFER_UTIL_H_
-#define IREE_BASE_FLATBUFFER_UTIL_H_
-
-#include <cstddef>
-#include <cstdint>
-#include <functional>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "absl/strings/string_view.h"
-#include "absl/types/optional.h"
-#include "absl/types/span.h"
-#include "flatbuffers/flatbuffers.h"
-#include "iree/base/memory.h"
-#include "iree/base/ref_ptr.h"
-#include "iree/base/status.h"
-
-namespace iree {
-
-// A helper wrapper that moves the wrapped object on copy.
-// This is particularly handy for capturing unique_ptrs in lambdas.
-// Usage example:
-//
-// std::unique_ptr<Foo> foo_ptr(new Foo());
-// move_on_copy<std::unique_ptr<Foo>> foo(std::move(foo_ptr));
-// auto some_lambda = [bar]() { ... }
-//
-template <typename T>
-struct move_on_copy {
- explicit move_on_copy(T&& t) : value(std::move(t)) {}
-
- move_on_copy(move_on_copy const& other) : value(std::move(other.value)) {}
-
- move_on_copy(move_on_copy&& other) : value(std::move(other.value)) {}
-
- move_on_copy& operator=(move_on_copy const& other) {
- value = std::move(other.value);
- return *this;
- }
-
- move_on_copy& operator=(move_on_copy&& other) {
- value = std::move(other.value);
- return *this;
- }
-
- mutable T value;
-};
-
-// Utility to aid in moving ref_ptr's into closures.
-//
-// Usage:
-// auto baton = MoveToLambda(my_ref);
-// DoSomething([baton] () { baton.value; });
-#define IreeMoveToLambda(p) ::iree::move_on_copy<decltype(p)>(std::move(p))
-
-// Wraps a FlatBuffer String in an absl::string_view.
-// Returns empty-string ("") for nullptr values.
-inline absl::string_view WrapString(const ::flatbuffers::String* value) {
- return value ? absl::string_view{value->data(), value->size()} : "";
-}
-
-// Base type for FlatBufferFile<T>. See below.
-class FlatBufferFileBase : public RefObject<FlatBufferFileBase> {
- public:
- using Identifier = absl::optional<const char*>;
-
- virtual ~FlatBufferFileBase();
-
- protected:
- template <typename T>
- friend class FlatBufferFile;
-
- using VerifierFn = bool (*)(const char* identifier,
- ::flatbuffers::Verifier* verifier);
-
- FlatBufferFileBase() = default;
-
- const void* root_ptr() const { return root_ptr_; }
-
- // Redirections of template static methods on FlatBufferFile so we can put the
- // implementations in a shared compilation unit.
- // See FlatBufferFile<T> for doc comments.
- Status Create(const void* root_ptr, std::function<void()> deleter);
- Status CreateWithBackingBuffer(const void* root_ptr,
- ::flatbuffers::DetachedBuffer backing_buffer);
- Status Wrap(const void* root);
- Status FromBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter, size_t root_type_size,
- VerifierFn verifier_fn);
- // Initializes from an STL byte based container (string and vector of
- // char/byte should be compatible).
- template <typename Container>
- Status FromContainer(Identifier identifier, Container container,
- size_t root_type_size, VerifierFn verifier_fn);
- Status WrapBuffer(Identifier identifier,
- absl::Span<const uint8_t> buffer_data,
- size_t root_type_size, VerifierFn verifier_fn);
- Status LoadFile(Identifier identifier, std::string path,
- size_t root_type_size, VerifierFn verifier_fn);
-
- private:
- const void* root_ptr_ = nullptr;
- std::function<void()> deleter_;
-};
-
-// Immutable root FlatBuffer type wrapper with support for loading and backing
-// buffer management.
-//
-// Immutable and thread-safe.
-template <typename T>
-class FlatBufferFile final : public FlatBufferFileBase {
- public:
- // Creates a FlatBufferFile from an in-memory root pointer.
- // The provided |deleter| will be called when the FlatBufferFile is destructed
- // and can be used to deallocate/clean up resources.
- //
- // This assumes that the root pointer has already been verified as valid.
- // If verification is required instead use FromBuffer on the original buffer.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> Create(
- const T* root, std::function<void()> deleter);
-
- // Creates a FlatBufferFile from an in-memory root pointer and the detached
- // backing buffer storing it.
- //
- // Example:
- // FlatBufferBuilder fbb;
- // MyTypeBuilder mtb(fbb);
- // fbb.Finish(mtb.Finish());
- // auto my_type = FlatBufferFile<MyType>::CreateWithBackingBuffer(
- // fbb.Release());
- // my_type->foo();
- static StatusOr<ref_ptr<FlatBufferFile<T>>> CreateWithBackingBuffer(
- ::flatbuffers::DetachedBuffer backing_buffer);
-
- // Wraps a caller-owned in-memory root pointer.
- // The provided |root| must remain valid for the lifetime of the returned
- // FlatBufferFile.
- //
- // This assumes that the root pointer has already been verified as valid.
- // If verification is required instead use FromBuffer on the original buffer.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> Wrap(const T* root);
-
- // Creates a FlatBufferFile wrapping an external data buffer with a deleter
- // function that will be called when the FlatBufferFile is destructed.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter);
-
- // Creates a FlatBufferFile from a serialized data buffer.
- // The FlatBufferFile takes ownership of the vector.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromBuffer(
- Identifier identifier, std::vector<uint8_t> buffer_data);
-
- // Loads a FlatBufferFile from an external buffer owned by the caller.
- // The buffer must remain valid until the Pipeline is destroyed.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> WrapBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data);
-
- // Loads the FlatBufferFile from a serialized byte-based STL container.
- template <typename Container>
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromContainer(
- Identifier identifier, Container buffer_data);
-
- // Loads a FlatBufferFile from a serialized string.
- // The FlatBufferFile takes ownership of the string.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromString(
- Identifier identifier, std::string buffer_data) {
- return FromContainer(identifier, std::move(buffer_data));
- }
-
- // Loads a FlatBufferFile from a serialized byte vector.
- // The FlatBufferFile takes ownership of the vector.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> FromVector(
- Identifier identifier, std::vector<uint8_t> buffer_data) {
- return FromContainer(identifier, std::move(buffer_data));
- }
-
- // Loads a FlatBufferFile from a serialized file on the file system.
- // This will attempt to mmap the file and is the preferred way of loading as
- // only those pages that contain requested tables will be read.
- static StatusOr<ref_ptr<FlatBufferFile<T>>> LoadFile(Identifier identifier,
- std::string path);
-
- ~FlatBufferFile() override = default;
-
- // Typed root pointer of the file.
- const T* root() const { return reinterpret_cast<const T*>(root_ptr()); }
-
- private:
- FlatBufferFile() = default;
-
- // Conforms to VerifierFn.
- static bool VerifierFnT(const char* identifier,
- ::flatbuffers::Verifier* verifier) {
- return verifier->VerifyBuffer<T>(identifier);
- }
-};
-
-template <typename Container>
-Status FlatBufferFileBase::FromContainer(Identifier identifier,
- Container container,
- size_t root_type_size,
- VerifierFn verifier_fn) {
- static_assert(sizeof(*container.data()) == 1,
- "Expected container of byte sized elements");
- auto buffer_data = absl::MakeConstSpan(
- // Double static_cast through void is safer than reinterpret_cast.
- static_cast<const uint8_t*>(static_cast<const void*>(container.data())),
- container.size());
- // Use a baton to keep the container alive until the FlatBufferFileBase is
- // destroyed.
- auto buffer_data_baton = IreeMoveToLambda(container);
- return FromBuffer(
- identifier, buffer_data,
- [buffer_data_baton]() {
- // Keeping the container alive.
- (void)buffer_data_baton.value;
- },
- root_type_size, verifier_fn);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Create(
- const T* root, std::function<void()> deleter) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(base_file->Create(root, std::move(deleter)));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::CreateWithBackingBuffer(
- ::flatbuffers::DetachedBuffer backing_buffer) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- auto* root_ptr = ::flatbuffers::GetRoot<T>(backing_buffer.data());
- IREE_RETURN_IF_ERROR(
- base_file->CreateWithBackingBuffer(root_ptr, std::move(backing_buffer)));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Wrap(const T* root) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(base_file->Wrap(root));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data,
- std::function<void()> deleter) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(base_file->FromBuffer(
- identifier, buffer_data, std::move(deleter), sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer(
- Identifier identifier, std::vector<uint8_t> buffer_data) {
- auto* buffer_data_ptr = new decltype(buffer_data);
- (*buffer_data_ptr) = std::move(buffer_data);
- return FromBuffer(identifier, absl::MakeConstSpan(*buffer_data_ptr),
- [buffer_data_ptr]() { delete buffer_data_ptr; });
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::WrapBuffer(
- Identifier identifier, absl::Span<const uint8_t> buffer_data) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(
- base_file->WrapBuffer(identifier, buffer_data, sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-template <typename Container>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromContainer(
- Identifier identifier, Container buffer_data) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(base_file->FromContainer(
- identifier, std::move(buffer_data), sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-// static
-template <typename T>
-StatusOr<ref_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::LoadFile(
- Identifier identifier, std::string path) {
- ref_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>};
- auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get());
- IREE_RETURN_IF_ERROR(
- base_file->LoadFile(identifier, std::move(path), sizeof(T), VerifierFnT));
- return std::move(flat_buffer_file);
-}
-
-} // namespace iree
-
-#endif // IREE_BASE_FLATBUFFER_UTIL_H_
diff --git a/iree/base/flatcc.fbs b/iree/base/flatcc.fbs
new file mode 100644
index 0000000..590afb4
--- /dev/null
+++ b/iree/base/flatcc.fbs
@@ -0,0 +1,30 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace iree;
+
+// HACK: flatcc public API headers are incomplete and some things only exist
+// when pulled in via generated headers. So here we give ourselves something to
+// include that's always available and cheap.
+//
+// Instead of directly including this file use iree/base/flatcc.h.
+//
+// Normally including any generated file will include the appropriate headers in
+// the required order (as they are non-hermetic), but that requires that we have
+// a generated file. Though most of the API is exposed through the main includes
+// there are various types that only get generated and included by way of the
+// common headers that are not easily included.
+struct __IncludeWorkaround {
+ reserved:int;
+}
diff --git a/iree/base/flatcc.h b/iree/base/flatcc.h
new file mode 100644
index 0000000..0940517
--- /dev/null
+++ b/iree/base/flatcc.h
@@ -0,0 +1,50 @@
+#// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_FLATCC_H_
+#define IREE_BASE_FLATCC_H_
+
+//===----------------------------------------------------------------------===//
+// flatcc include order fixes
+//===----------------------------------------------------------------------===//
+//
+// This header merely wraps the flatcc headers that are generally useful to
+// include in various places that may not know the specific messages they are
+// working with.
+//
+// If using flatcc prefer to include this file over any hard-to-handle flatcc
+// file such as flatbuffers_common_reader.h or flatbuffers_common_builder.h.
+//
+// NOTE: order matters for these includes so stop clang from messing with it:
+// clang-format off
+
+#include "flatcc/reflection/flatbuffers_common_reader.h"
+#include "iree/base/flatcc_reader.h"
+
+#include "flatcc/flatcc_verifier.h"
+#include "iree/base/flatcc_verifier.h"
+
+#include "flatcc/flatcc_builder.h"
+#include "flatcc/reflection/flatbuffers_common_builder.h"
+#include "iree/base/flatcc_builder.h"
+
+#include "flatcc/flatcc_json_parser.h"
+#include "iree/base/flatcc_json_parser.h"
+
+#include "flatcc/flatcc_json_printer.h"
+#include "iree/base/flatcc_json_printer.h"
+
+// clang-format on
+
+#endif // IREE_BASE_FLATCC_H_
diff --git a/iree/base/init.cc b/iree/base/init.cc
deleted file mode 100644
index 33729c0..0000000
--- a/iree/base/init.cc
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/init.h"
-
-#include <string.h>
-
-#include <set>
-
-#include "absl/flags/parse.h"
-#include "iree/base/initializer.h"
-
-namespace iree {
-
-void InitializeEnvironment(int* argc, char*** argv) {
- if (argc != nullptr && argv != nullptr && *argc != 0) {
- auto positional_args = absl::ParseCommandLine(*argc, *argv);
- if (positional_args.size() < *argc) {
- // Edit the passed argument refs to only include positional args.
- *argc = positional_args.size();
- for (int i = 0; i < *argc; ++i) {
- (*argv)[i] = positional_args[i];
- }
- (*argv)[*argc + 1] = nullptr;
- }
- }
-
- IREE_RUN_MODULE_INITIALIZERS();
-}
-
-} // namespace iree
diff --git a/iree/base/init.h b/iree/base/init.h
deleted file mode 100644
index 9c93a93..0000000
--- a/iree/base/init.h
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INIT_H_
-#define IREE_BASE_INIT_H_
-
-// Initialization happens automatically during InitializeEnvironment(), which
-// should be called early in main(), before other code runs.
-
-namespace iree {
-
-// Initializes the system environment in a binary.
-//
-// This first parses command line flags, then resolves module initializers
-// by calling IREE_RUN_MODULE_INITIALIZERS().
-//
-// 'argc' and 'argv' are the command line flags to parse.
-//
-// This should typically be called early in main(), before other code runs.
-void InitializeEnvironment(int* argc, char*** argv);
-
-} // namespace iree
-
-#endif // IREE_BASE_INIT_H_
diff --git a/iree/base/initializer.cc b/iree/base/initializer.cc
deleted file mode 100644
index 5e1e21e..0000000
--- a/iree/base/initializer.cc
+++ /dev/null
@@ -1,100 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/base/initializer.h"
-
-#include <string.h>
-
-#include <mutex> // NOLINT
-#include <set>
-
-namespace iree {
-
-static Initializer::NameMap* static_name_map = nullptr;
-
-struct Initializer::InitializerData {
- Initializer* initializer_obj;
- std::set<std::string> dependency_names;
-
- InitializerData() : initializer_obj(nullptr) {}
- explicit InitializerData(Initializer* i) : initializer_obj(i) {}
-};
-
-Initializer::DependencyRegisterer::DependencyRegisterer(
- const char* name, Initializer* initializer, const Dependency& dependency) {
- NameMap* name_map = InitializerNameMap();
-
- // Insert 'dependency' into the 'dependency_names' set for 'initializer'.
- InitializerData* initializer_data = &(*name_map)[name];
- initializer_data->dependency_names.insert(dependency.name);
-
- // Ensure that 'dependency' exists in the map.
- InitializerData* dependency_data = &(*name_map)[dependency.name];
- dependency_data->initializer_obj = dependency.initializer;
-}
-
-Initializer::Initializer(const char* name, InitializerFunc function)
- : name_(name), function_(function), done_(false) {
- // Register this Initializer instance (wrapped by an InitializerData) within
- // the static name map.
- NameMap* name_map = InitializerNameMap();
- InitializerData* initializer_data = &(*name_map)[name];
- initializer_data->initializer_obj = this;
-}
-
-void Initializer::RunInitializers() {
- static std::once_flag init_once;
- std::call_once(init_once, &Initializer::InitializeOnceCallback);
-}
-
-void Initializer::InitializeOnceCallback() {
- // Run each registered Initializer, in lexicographic order of their names.
- // Initializer dependencies will be run first as needed.
- NameMap* name_map = InitializerNameMap();
- for (auto& p : *name_map) {
- RunInitializer(&p.second);
- }
-}
-
-void Initializer::Require() {
- NameMap* name_map = InitializerNameMap();
- InitializerData* initializer_data = &(name_map->find(name_)->second);
- RunInitializer(initializer_data);
-}
-
-Initializer::NameMap* Initializer::InitializerNameMap() {
- if (static_name_map == nullptr) {
- static_name_map = new Initializer::NameMap;
- }
- return static_name_map;
-}
-
-void Initializer::RunInitializer(InitializerData* initializer_data) {
- if (initializer_data->initializer_obj->done_) {
- return;
- }
-
- // Run Initializer dependencies first.
- NameMap* name_map = InitializerNameMap();
- for (const auto& dependency_name : initializer_data->dependency_names) {
- auto dep_init = name_map->find(dependency_name);
- RunInitializer(&dep_init->second);
- }
-
- // Finally run the Initializer itself.
- initializer_data->initializer_obj->function_();
- initializer_data->initializer_obj->done_ = true;
-}
-
-} // namespace iree
diff --git a/iree/base/initializer.h b/iree/base/initializer.h
deleted file mode 100644
index b31bc73..0000000
--- a/iree/base/initializer.h
+++ /dev/null
@@ -1,119 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_INITIALIZER_H_
-#define IREE_BASE_INITIALIZER_H_
-
-#include <map>
-#include <string>
-
-#include "iree/base/target_platform.h"
-
-namespace iree {
-
-// Initializer macros are defined in this files:
-// IREE_DECLARE_MODULE_INITIALIZER(name)
-// IREE_REGISTER_MODULE_INITIALIZER(name, body)
-// IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2)
-// IREE_REQUIRE_MODULE_INITIALIZED(name)
-// IREE_RUN_MODULE_INITIALIZERS()
-// IREE_REQUIRE_MODULE_LINKED(name)
-//
-// These macros allow for arranging pieces of initialization code to be
-// executed at a well-defined time and in a well-defined order.
-
-// A static instance of this class is declared for each piece of initialization
-// code using the initializer macros.
-class Initializer {
- public:
- typedef void (*InitializerFunc)();
-
- Initializer(const char* name, InitializerFunc function);
-
- // Runs all registered initializers that have not yet run.
- // The initializers are invoked in lexicographically increasing order by name,
- // except as necessary to satisfy dependencies.
- //
- // This is normally called by InitializeEnvironment(), so application code
- // typically should not call it directly.
- static void RunInitializers();
-
- // Runs this initializer if it has not yet run, including any dependencies.
- void Require();
-
- struct Dependency {
- Dependency(const char* n, Initializer* i) : name(n), initializer(i) {}
- const char* const name;
- Initializer* const initializer;
- };
-
- // A static instance of this class is declared for each piece of
- // initializer ordering definition.
- struct DependencyRegisterer {
- DependencyRegisterer(const char* name, Initializer* initializer,
- const Dependency& dependency);
- };
-
- struct InitializerData;
- typedef std::map<std::string, InitializerData> NameMap;
-
- private:
- static NameMap* InitializerNameMap();
- static void RunInitializer(InitializerData* initializer_data);
- static void InitializeOnceCallback();
-
- const std::string name_;
- InitializerFunc function_;
- bool done_;
-};
-
-} // namespace iree
-
-#define IREE_DECLARE_MODULE_INITIALIZER(name) \
- extern ::iree::Initializer iree_initializer_##name
-
-#define IREE_REGISTER_MODULE_INITIALIZER(name, body) \
- static void iree_init_##name() { body; } \
- ::iree::Initializer iree_initializer_##name(#name, iree_init_##name)
-
-#define IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) \
- namespace { \
- static ::iree::Initializer::DependencyRegisterer \
- iree_initializer_dependency_##name1##_##name2( \
- #name2, &iree_initializer_##name2, \
- ::iree::Initializer::Dependency(#name1, &iree_initializer_##name1)); \
- }
-
-#define IREE_REQUIRE_MODULE_INITIALIZED(name) \
- do { \
- IREE_DECLARE_MODULE_INITIALIZER(name); \
- iree_initializer_##name.Require(); \
- } while (0)
-
-#define IREE_RUN_MODULE_INITIALIZERS() \
- do { \
- ::iree::Initializer::RunInitializers(); \
- } while (0)
-
-#if !defined(IREE_COMPILER_MSVC)
-#define IREE_ATTRIBUTE_USED __attribute__((used))
-#else
-#define IREE_ATTRIBUTE_USED
-#endif // IREE_COMPILER_MSVC
-
-#define IREE_REQUIRE_MODULE_LINKED(name) \
- IREE_ATTRIBUTE_USED static ::iree::Initializer* iree_module_ref_##name = \
- &iree_initializer_##name
-
-#endif // IREE_BASE_INITIALIZER_H_
diff --git a/iree/base/internal/BUILD b/iree/base/internal/BUILD
index ae3f2fa..2fe99ee 100644
--- a/iree/base/internal/BUILD
+++ b/iree/base/internal/BUILD
@@ -25,8 +25,8 @@
srcs = ["file_handle_win32.cc"],
hdrs = ["file_handle_win32.h"],
deps = [
+ "//iree/base:core_headers",
"//iree/base:status",
- "//iree/base:target_platform",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -40,10 +40,10 @@
],
deps = [
":file_handle_win32",
+ "//iree/base:core_headers",
"//iree/base:file_io_hdrs",
"//iree/base:file_path",
"//iree/base:status",
- "//iree/base:target_platform",
"//iree/base:tracing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -58,8 +58,8 @@
],
deps = [
":file_handle_win32",
+ "//iree/base:core_headers",
"//iree/base:file_mapping_hdrs",
- "//iree/base:target_platform",
"//iree/base:tracing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -88,9 +88,10 @@
deps = [
":ostringstream",
"//iree/base:api",
+ "//iree/base:core_headers",
"//iree/base:logging",
- "//iree/base:target_platform",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/utility",
],
)
diff --git a/iree/base/internal/CMakeLists.txt b/iree/base/internal/CMakeLists.txt
index 973d297..f9975d7 100644
--- a/iree/base/internal/CMakeLists.txt
+++ b/iree/base/internal/CMakeLists.txt
@@ -24,8 +24,8 @@
DEPS
absl::memory
absl::strings
+ iree::base::core_headers
iree::base::status
- iree::base::target_platform
PUBLIC
)
@@ -39,10 +39,10 @@
::file_handle_win32
absl::memory
absl::strings
+ iree::base::core_headers
iree::base::file_io_hdrs
iree::base::file_path
iree::base::status
- iree::base::target_platform
iree::base::tracing
PUBLIC
)
@@ -57,8 +57,8 @@
::file_handle_win32
absl::memory
absl::strings
+ iree::base::core_headers
iree::base::file_mapping_hdrs
- iree::base::target_platform
iree::base::tracing
PUBLIC
)
@@ -88,8 +88,9 @@
::ostringstream
absl::core_headers
absl::strings
+ absl::utility
iree::base::api
+ iree::base::core_headers
iree::base::logging
- iree::base::target_platform
PUBLIC
)
diff --git a/iree/base/internal/file_io_win32.cc b/iree/base/internal/file_io_win32.cc
index 004f0f4..3bf4b0e 100644
--- a/iree/base/internal/file_io_win32.cc
+++ b/iree/base/internal/file_io_win32.cc
@@ -48,7 +48,8 @@
result.resize(file->size());
DWORD bytes_read = 0;
if (::ReadFile(file->handle(), const_cast<char*>(result.data()),
- result.size(), &bytes_read, nullptr) == FALSE) {
+ static_cast<DWORD>(result.size()), &bytes_read,
+ nullptr) == FALSE) {
return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
<< "Unable to read file span of " << result.size() << " bytes from '"
<< path << "'";
@@ -63,8 +64,8 @@
Status SetFileContents(const std::string& path, absl::string_view content) {
IREE_TRACE_SCOPE0("file_io::SetFileContents");
IREE_ASSIGN_OR_RETURN(auto file, FileHandle::OpenWrite(std::move(path), 0));
- if (::WriteFile(file->handle(), content.data(), content.size(), NULL, NULL) ==
- FALSE) {
+ if (::WriteFile(file->handle(), content.data(),
+ static_cast<DWORD>(content.size()), NULL, NULL) == FALSE) {
return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC)
<< "Unable to write file span of " << content.size() << " bytes to '"
<< path << "'";
@@ -103,7 +104,8 @@
std::string temp_path(64, '\0');
for (bool retry_query = true; retry_query;) {
- DWORD required_length = ::GetTempPathA(temp_path.size(), &temp_path[0]);
+ DWORD required_length =
+ ::GetTempPathA(static_cast<DWORD>(temp_path.size()), &temp_path[0]);
retry_query = required_length > temp_path.size();
temp_path.resize(required_length);
}
diff --git a/iree/base/internal/statusor.h b/iree/base/internal/statusor.h
index 93338a9..d2ecac7 100644
--- a/iree/base/internal/statusor.h
+++ b/iree/base/internal/statusor.h
@@ -15,6 +15,9 @@
#ifndef IREE_BASE_INTERNAL_STATUSOR_H_
#define IREE_BASE_INTERNAL_STATUSOR_H_
+#include <type_traits>
+#include <utility>
+
#include "absl/base/attributes.h"
#include "absl/utility/utility.h"
#include "iree/base/internal/status.h"
@@ -49,12 +52,11 @@
template <typename T, typename U>
struct IsAmbiguousStatusOrForInitialization
: // Strip const-value refs from type and check again, else false_type.
- public absl::conditional_t<
- std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>,
- U>::value,
+ public std::conditional_t<
+ std::is_same<std::remove_cv_t<std::remove_reference_t<U>>, U>::value,
std::false_type,
IsAmbiguousStatusOrForInitialization<
- T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
+ T, std::remove_cv_t<std::remove_reference_t<U>>>> {};
template <typename T, typename U>
struct IsAmbiguousStatusOrForInitialization<T, StatusOr<U>>
@@ -62,17 +64,17 @@
template <typename T, typename U>
using IsStatusOrDirectInitializationAmbiguous = absl::disjunction<
- std::is_same<StatusOr<T>, absl::remove_cv_t<absl::remove_reference_t<U>>>,
- std::is_same<Status, absl::remove_cv_t<absl::remove_reference_t<U>>>,
- std::is_same<StatusBuilder, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<StatusOr<T>, std::remove_cv_t<std::remove_reference_t<U>>>,
+ std::is_same<Status, std::remove_cv_t<std::remove_reference_t<U>>>,
+ std::is_same<StatusBuilder, std::remove_cv_t<std::remove_reference_t<U>>>,
std::is_same<absl::in_place_t,
- absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::remove_cv_t<std::remove_reference_t<U>>>,
IsAmbiguousStatusOrForInitialization<T, U>>;
template <typename T, typename U>
using IsStatusOrDirectInitializationValid = absl::disjunction<
// The is_same allows nested status ors to ignore this check iff same type.
- std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<T, std::remove_cv_t<std::remove_reference_t<U>>>,
absl::negation<IsStatusOrDirectInitializationAmbiguous<T, U>>>;
class Helper {
@@ -327,7 +329,7 @@
// explicit.
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>,
std::is_constructible<T, const U&>,
@@ -339,7 +341,7 @@
: Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>,
std::is_constructible<T, const U&>,
@@ -352,7 +354,7 @@
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
std::is_convertible<U&&, T>,
@@ -363,7 +365,7 @@
: Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
absl::negation<std::is_convertible<U&&, T>>,
@@ -378,7 +380,7 @@
// StatusOr<U>.
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>,
std::is_constructible<T, const U&>,
@@ -393,7 +395,7 @@
}
template <
typename U,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
std::is_assignable<T, U&&>,
@@ -445,7 +447,7 @@
// a StatusOr<J>, where J is convertible to T.
template <
typename U = T,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>,
std::is_constructible<T, U&&>,
@@ -456,7 +458,7 @@
template <
typename U = T,
- absl::enable_if_t<
+ std::enable_if_t<
absl::conjunction<
internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>,
std::is_constructible<T, U&&>,
diff --git a/iree/base/platform_headers.h b/iree/base/platform_headers.h
deleted file mode 100644
index 16ad04b..0000000
--- a/iree/base/platform_headers.h
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_BASE_PLATFORM_HEADERS_H_
-#define IREE_BASE_PLATFORM_HEADERS_H_
-
-#include "iree/base/target_platform.h"
-
-#if defined(IREE_PLATFORM_WINDOWS)
-
-#if defined(_MSC_VER)
-// Abseil compatibility: don't include incompatible winsock versions.
-#ifndef WIN32_LEAN_AND_MEAN
-#define WIN32_LEAN_AND_MEAN
-#endif // WIN32_LEAN_AND_MEAN
-// Abseil compatibility: don't define min and max macros.
-#ifndef NOMINMAX
-#define NOMINMAX
-#endif // NOMINMAX
-#endif // _MSC_VER
-
-#include <windows.h>
-
-// WinGDI.h defines `ERROR`, undef to avoid conflict naming.
-#undef ERROR
-
-#endif // IREE_PLATFORM_WINDOWS
-
-#endif // IREE_BASE_PLATFORM_HEADERS_H_
diff --git a/iree/base/synchronization.c b/iree/base/synchronization.c
index ed185955..8bc186e 100644
--- a/iree/base/synchronization.c
+++ b/iree/base/synchronization.c
@@ -95,7 +95,7 @@
#elif defined(IREE_PLATFORM_WINDOWS)
-#pragma comment(lib, "synchronization")
+#pragma comment(lib, "Synchronization.lib")
static inline iree_status_t iree_futex_wait(void* address,
uint32_t expected_value,
diff --git a/iree/base/testing/BUILD b/iree/base/testing/BUILD
new file mode 100644
index 0000000..43b4afe
--- /dev/null
+++ b/iree/base/testing/BUILD
@@ -0,0 +1,52 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_binary(
+ name = "dynamic_library_test_library.so",
+ testonly = True,
+ srcs = ["dynamic_library_test_library.cc"],
+ linkshared = True,
+)
+
+cc_embed_data(
+ name = "dynamic_library_test_library",
+ testonly = True,
+ srcs = [":dynamic_library_test_library.so"],
+ cc_file_output = "dynamic_library_test_library_embed.cc",
+ cpp_namespace = "iree",
+ flatten = True,
+ h_file_output = "dynamic_library_test_library_embed.h",
+)
+
+cc_test(
+ name = "dynamic_library_test",
+ srcs = ["dynamic_library_test.cc"],
+ deps = [
+ ":dynamic_library_test_library",
+ "//iree/base:core_headers",
+ "//iree/base:dynamic_library",
+ "//iree/base:file_io",
+ "//iree/base:status",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
diff --git a/iree/base/testing/CMakeLists.txt b/iree/base/testing/CMakeLists.txt
new file mode 100644
index 0000000..ba88fe8
--- /dev/null
+++ b/iree/base/testing/CMakeLists.txt
@@ -0,0 +1,60 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# bazel_to_cmake: DO NOT EDIT (some scotttodd todos remain)
+
+# TODO(scotttodd): clean up bazel_to_cmake handling here
+# * this is a cc_binary in Bazel, but `linkshared` fits iree_cc_library better
+# * the output file name is platform-specific, get it with $<TARGET_FILE:>
+iree_cc_library(
+ NAME
+ dynamic_library_test_library.so
+ OUT
+ dynamic_library_test_library.so
+ SRCS
+ "dynamic_library_test_library.cc"
+ TESTONLY
+ SHARED
+)
+
+iree_cc_embed_data(
+ NAME
+ dynamic_library_test_library
+ GENERATED_SRCS
+ "$<TARGET_FILE:iree::base::testing::dynamic_library_test_library.so>"
+ CC_FILE_OUTPUT
+ "dynamic_library_test_library_embed.cc"
+ H_FILE_OUTPUT
+ "dynamic_library_test_library_embed.h"
+ TESTONLY
+ CPP_NAMESPACE
+ "iree"
+ FLATTEN
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ dynamic_library_test
+ SRCS
+ "dynamic_library_test.cc"
+ DEPS
+ ::dynamic_library_test_library
+ iree::base::core_headers
+ iree::base::dynamic_library
+ iree::base::file_io
+ iree::base::status
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
diff --git a/iree/base/dynamic_library_test.cc b/iree/base/testing/dynamic_library_test.cc
similarity index 96%
rename from iree/base/dynamic_library_test.cc
rename to iree/base/testing/dynamic_library_test.cc
index f607c22..406a3c5 100644
--- a/iree/base/dynamic_library_test.cc
+++ b/iree/base/testing/dynamic_library_test.cc
@@ -16,10 +16,10 @@
#include <string>
-#include "iree/base/dynamic_library_test_library_embed.h"
#include "iree/base/file_io.h"
#include "iree/base/status.h"
#include "iree/base/target_platform.h"
+#include "iree/base/testing/dynamic_library_test_library_embed.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
@@ -64,7 +64,6 @@
TEST_F(DynamicLibraryTest, LoadLibrarySuccess) {
IREE_ASSERT_OK_AND_ASSIGN(auto library,
DynamicLibrary::Load(library_temp_path_.c_str()));
- EXPECT_EQ(library_temp_path_, library->file_name());
}
TEST_F(DynamicLibraryTest, LoadLibraryFailure) {
diff --git a/iree/base/dynamic_library_test_library.cc b/iree/base/testing/dynamic_library_test_library.cc
similarity index 100%
rename from iree/base/dynamic_library_test_library.cc
rename to iree/base/testing/dynamic_library_test_library.cc
diff --git a/iree/base/threading_darwin.c b/iree/base/threading_darwin.c
index d2e6820..c0ba729 100644
--- a/iree/base/threading_darwin.c
+++ b/iree/base/threading_darwin.c
@@ -12,10 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/base/atomics.h"
-#include "iree/base/threading.h"
+// NOTE: must be first to ensure that we can define settings for all includes.
#include "iree/base/threading_impl.h"
-#include "iree/base/tracing.h"
#if defined(IREE_PLATFORM_APPLE)
@@ -25,6 +23,10 @@
#include <pthread.h>
#include <string.h>
+#include "iree/base/atomics.h"
+#include "iree/base/threading.h"
+#include "iree/base/tracing.h"
+
// Useful to see how pthreads is implemented on (old) darwin:
// https://opensource.apple.com/source/Libc/Libc-825.40.1/pthreads/pthread.c.auto.html
diff --git a/iree/base/threading_impl.h b/iree/base/threading_impl.h
index b6794f2..40de48f 100644
--- a/iree/base/threading_impl.h
+++ b/iree/base/threading_impl.h
@@ -15,6 +15,18 @@
#ifndef IREE_BASE_THREADING_IMPL_H_
#define IREE_BASE_THREADING_IMPL_H_
+// Ensure that any posix header we include exposes GNU stuff. Ignored on
+// platforms where we either don't have the GNU stuff or don't have posix
+// headers at all.
+//
+// Note that this does not need to be the same for all compilation units, only
+// those we want to access the non-portable features in. It *must* be defined
+// prior to including any of the files, though, as otherwise header-guards will
+// cause the setting at the time of first inclusion to win.
+//
+// https://stackoverflow.com/a/5583764
+#define _GNU_SOURCE 1
+
#include <assert.h>
#include <errno.h>
#include <stddef.h>
@@ -22,6 +34,7 @@
#include "iree/base/api.h"
#include "iree/base/synchronization.h"
+#include "iree/base/target_platform.h"
#include "iree/base/threading.h"
#ifdef __cplusplus
diff --git a/iree/base/threading_pthreads.c b/iree/base/threading_pthreads.c
index 8726653..8d2f14c 100644
--- a/iree/base/threading_pthreads.c
+++ b/iree/base/threading_pthreads.c
@@ -12,11 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/base/atomics.h"
-#include "iree/base/synchronization.h"
-#include "iree/base/threading.h"
+// NOTE: must be first to ensure that we can define settings for all includes.
#include "iree/base/threading_impl.h"
-#include "iree/base/tracing.h"
#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_EMSCRIPTEN) || \
defined(IREE_PLATFORM_LINUX)
@@ -31,6 +28,11 @@
#include <time.h>
#include <unistd.h>
+#include "iree/base/atomics.h"
+#include "iree/base/synchronization.h"
+#include "iree/base/threading.h"
+#include "iree/base/tracing.h"
+
// Older glibc doesn't have a gettid wrapper:
// https://stackoverflow.com/a/63494768
#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30
diff --git a/iree/base/threading_win32.c b/iree/base/threading_win32.c
index b42639b..ff7356b 100644
--- a/iree/base/threading_win32.c
+++ b/iree/base/threading_win32.c
@@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/threading_impl.h"
+
+#if defined(IREE_PLATFORM_WINDOWS)
+
#include "iree/base/atomics.h"
#include "iree/base/threading.h"
#include "iree/base/threading_impl.h"
#include "iree/base/tracing.h"
-#if defined(IREE_PLATFORM_WINDOWS)
-
// Great documentation:
// https://www.microsoftpressstore.com/articles/article.aspx?p=2233328
diff --git a/iree/base/wait_handle_epoll.c b/iree/base/wait_handle_epoll.c
index 1dcb207..1249d99 100644
--- a/iree/base/wait_handle_epoll.c
+++ b/iree/base/wait_handle_epoll.c
@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/base/tracing.h"
-#include "iree/base/wait_handle.h"
-#include "iree/base/wait_handle_posix.h"
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
#if IREE_WAIT_API == IREE_WAIT_API_EPOLL
+#include "iree/base/tracing.h"
+#include "iree/base/wait_handle_posix.h"
+
//===----------------------------------------------------------------------===//
// iree_wait_set_t
//===----------------------------------------------------------------------===//
diff --git a/iree/base/wait_handle_impl.h b/iree/base/wait_handle_impl.h
new file mode 100644
index 0000000..38f82e5
--- /dev/null
+++ b/iree/base/wait_handle_impl.h
@@ -0,0 +1,69 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_BASE_WAIT_HANDLE_IMPL_H_
+#define IREE_BASE_WAIT_HANDLE_IMPL_H_
+
+//===----------------------------------------------------------------------===//
+// Platform overrides
+//===----------------------------------------------------------------------===//
+// NOTE: this must come first prior to any local/system includes!
+
+// Ensure that any posix header we include exposes GNU stuff. Ignored on
+// platforms where we either don't have the GNU stuff or don't have posix
+// headers at all.
+//
+// Note that this does not need to be the same for all compilation units, only
+// those we want to access the non-portable features in. It *must* be defined
+// prior to including any of the files, though, as otherwise header-guards will
+// cause the setting at the time of first inclusion to win.
+//
+// https://stackoverflow.com/a/5583764
+#define _GNU_SOURCE 1
+
+//===----------------------------------------------------------------------===//
+// Active wait API implementation selection (wait_handle_*.c)
+//===----------------------------------------------------------------------===//
+
+#include "iree/base/target_platform.h"
+
+// Priorities are (kqueue|epoll) > ppoll > poll
+#define IREE_WAIT_API_POLL 1
+#define IREE_WAIT_API_PPOLL 2
+#define IREE_WAIT_API_EPOLL 3
+#define IREE_WAIT_API_KQUEUE 4
+
+// NOTE: we could be tighter here, but we today only have win32 or not-win32.
+#if defined(IREE_PLATFORM_WINDOWS)
+#define IREE_WAIT_API 0 // WFMO used in wait_handle_win32.c
+#else
+
+// TODO(benvanik): EPOLL on android/linux/bsd/etc.
+// TODO(benvanik): KQUEUE on mac/ios.
+// KQUEUE is not implemented yet. Use POLL for mac/ios
+#if !defined(IREE_PLATFORM_APPLE) && !defined(__EMSCRIPTEN__)
+#define IREE_WAIT_API IREE_WAIT_API_PPOLL
+#else
+#define IREE_WAIT_API IREE_WAIT_API_POLL
+#endif // insanity
+
+#endif // IREE_PLATFORM_WINDOWS
+
+//===----------------------------------------------------------------------===//
+// Wait handle included with options set
+//===----------------------------------------------------------------------===//
+
+#include "iree/base/wait_handle.h"
+
+#endif // IREE_BASE_WAIT_HANDLE_IMPL_H_
diff --git a/iree/base/wait_handle_kqueue.c b/iree/base/wait_handle_kqueue.c
index 92d1392..5e03f2c 100644
--- a/iree/base/wait_handle_kqueue.c
+++ b/iree/base/wait_handle_kqueue.c
@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/base/tracing.h"
-#include "iree/base/wait_handle.h"
-#include "iree/base/wait_handle_posix.h"
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
#if IREE_WAIT_API == IREE_WAIT_API_KQUEUE
+#include "iree/base/tracing.h"
+#include "iree/base/wait_handle_posix.h"
+
//===----------------------------------------------------------------------===//
// iree_wait_set_t
//===----------------------------------------------------------------------===//
diff --git a/iree/base/wait_handle_poll.c b/iree/base/wait_handle_poll.c
index 64b7b73..2ba2eb1 100644
--- a/iree/base/wait_handle_poll.c
+++ b/iree/base/wait_handle_poll.c
@@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/base/tracing.h"
-#include "iree/base/wait_handle.h"
-#include "iree/base/wait_handle_posix.h"
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
#if IREE_WAIT_API == IREE_WAIT_API_POLL || IREE_WAIT_API == IREE_WAIT_API_PPOLL
@@ -22,6 +21,9 @@
#include <poll.h>
#include <time.h>
+#include "iree/base/tracing.h"
+#include "iree/base/wait_handle_posix.h"
+
//===----------------------------------------------------------------------===//
// Platform utilities
//===----------------------------------------------------------------------===//
diff --git a/iree/base/wait_handle_posix.h b/iree/base/wait_handle_posix.h
index 7078859..63e1ded 100644
--- a/iree/base/wait_handle_posix.h
+++ b/iree/base/wait_handle_posix.h
@@ -12,29 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/base/wait_handle.h"
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
#ifndef IREE_BASE_WAIT_HANDLE_POSIX_H_
#define IREE_BASE_WAIT_HANDLE_POSIX_H_
-// Priorities are (kqueue|epoll) > ppoll > poll
-#define IREE_WAIT_API_POLL 1
-#define IREE_WAIT_API_PPOLL 2
-#define IREE_WAIT_API_EPOLL 3
-#define IREE_WAIT_API_KQUEUE 4
-
// NOTE: we could be tighter here, but we today only have win32 or not-win32.
-#if defined(IREE_PLATFORM_WINDOWS)
-#define IREE_WAIT_API 0 // WFMO used in wait_handle_win32.c
-#else
-
-// TODO(benvanik): EPOLL on android/linux/bsd/etc.
-// TODO(benvanik): KQUEUE on mac/ios.
-#if !defined(OS_IOS) && !defined(__EMSCRIPTEN__)
-#define IREE_WAIT_API IREE_WAIT_API_PPOLL
-#else
-#define IREE_WAIT_API IREE_WAIT_API_POLL
-#endif // insanity
+#if !defined(IREE_PLATFORM_WINDOWS)
#ifdef __cplusplus
extern "C" {
@@ -96,6 +81,6 @@
} // extern "C"
#endif // __cplusplus
-#endif // IREE_PLATFORM_WINDOWS
+#endif // !IREE_PLATFORM_WINDOWS
#endif // IREE_BASE_WAIT_HANDLE_POSIX_H_
diff --git a/iree/base/wait_handle_win32.c b/iree/base/wait_handle_win32.c
index 8f550fd..cf91401 100644
--- a/iree/base/wait_handle_win32.c
+++ b/iree/base/wait_handle_win32.c
@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/base/tracing.h"
-#include "iree/base/wait_handle.h"
+// NOTE: must be first to ensure that we can define settings for all includes.
+#include "iree/base/wait_handle_impl.h"
#if defined(IREE_PLATFORM_WINDOWS)
+#include "iree/base/tracing.h"
+
//===----------------------------------------------------------------------===//
// Platform utilities
//===----------------------------------------------------------------------===//
diff --git a/iree/build_defs.oss.bzl b/iree/build_defs.oss.bzl
index 4d359e7..33fa984 100644
--- a/iree/build_defs.oss.bzl
+++ b/iree/build_defs.oss.bzl
@@ -14,9 +14,6 @@
"""Common Bazel definitions for IREE."""
-load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
-load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library")
-
# Target to the FileCheck binary.
INTREE_FILECHECK_TARGET = "@llvm-project//llvm:FileCheck"
@@ -40,27 +37,6 @@
"//iree/%s/internal:%s_internal" % (path, basename),
]
-# A platform-sensitive list of dependencies for non-test targets using Vulkan.
-PLATFORM_VULKAN_DEPS = select({
- "//iree/hal/vulkan:native_vk": [],
- "//iree/hal/vulkan:swiftshader_vk": [],
- "//conditions:default": [],
-})
-
-# A platform-sensitive list of dependencies for tests using Vulkan.
-PLATFORM_VULKAN_TEST_DEPS = []
-
-# Driver modules that register themselves at link time.
-IREE_DRIVER_MODULES = [
- "//iree/hal/dylib:dylib_driver_module",
- "//iree/hal/vmla:vmla_driver_module",
- "//iree/hal/vulkan:vulkan_driver_module",
- "//iree/hal/llvmjit:llvmjit_driver_module",
-]
-
-# Aliases to the Starlark cc rules.
-cc_library = _cc_library
-
def iree_build_test(name, targets):
"""Dummy rule to ensure that targets build.
@@ -68,34 +44,6 @@
"""
pass
-# The OSS build currently has issues with generating flatbuffer reflections.
-# It is hard-coded to disabled here (and in iree_flatbuffer_cc_library) until triaged/fixed.
-FLATBUFFER_SUPPORTS_REFLECTIONS = False
-
-def iree_flatbuffer_cc_library(**kwargs):
- """Wrapper for the flatbuffer_cc_library."""
-
- # TODO(laurenzo): The bazel rule for reflections seems broken in OSS
- # builds. Fix it and enable by default.
- flatbuffer_cc_library(
- gen_reflections = False,
- **kwargs
- )
-
-def cc_binary(linkopts = [], **kwargs):
- """Wrapper around low-level cc_binary that adds flags."""
- _cc_binary(
- linkopts = linkopts + select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- # Just include libraries that should be presumed in 2020.
- "-ldl",
- "-lpthread",
- ],
- }),
- **kwargs
- )
-
def iree_cmake_extra_content(content = "", inline = False):
"""Tool for inserting arbitrary content during Bazel->CMake conversion.
diff --git a/iree/compiler/BUILD b/iree/compiler/BUILD
index 01d988d..3d97d17 100644
--- a/iree/compiler/BUILD
+++ b/iree/compiler/BUILD
@@ -17,3 +17,9 @@
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)
+
+#===------------------------------------------------------------------------===#
+# Public API
+#===------------------------------------------------------------------------===#
+# TODO(#3817): expose :compiler as the public C API.
+# TODO(#3817): expose :cc as the public C++ wrapper API.
diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
index 448d86e..7e888ed 100644
--- a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.cpp
@@ -11,12 +11,14 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
+
+#include "iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h"
+
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
index ddda183..0e0f638 100644
--- a/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
+++ b/iree/compiler/Conversion/CodegenUtils/ForOpCanonicalization.h
@@ -17,8 +17,9 @@
#include <memory>
+#include "mlir/Pass/Pass.h"
+
namespace mlir {
-class FunctionPass;
namespace iree_compiler {
/// An ad-hoc pass to canonicalize selected loop carried dependencies on
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
index 44dfac9..452b662 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.cpp
@@ -12,18 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
+
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
-#include "mlir/IR/PatternMatch.h"
#define DEBUG_TYPE "workgroup-calculation"
diff --git a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
index 5495c87..4e28779 100644
--- a/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
+++ b/iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h
@@ -18,33 +18,17 @@
#include <array>
#include <cstdint>
-namespace llvm {
-class StringRef;
-template <typename T>
-class ArrayRef;
-template <typename T>
-class Optional;
-} // namespace llvm
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
namespace mlir {
-class Location;
-class FuncOp;
-struct LogicalResult;
-class PatternRewriter;
-class ConversionPatternRewriter;
-class Value;
-namespace linalg {
-class LinalgOp;
-} // namespace linalg
-
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-class InterfaceOp;
-class TensorRewriteAdaptor;
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
namespace iree_compiler {
/// Generates a function that computes the number of workgroups as
diff --git a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
index b185ae7..5072293 100644
--- a/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
@@ -23,11 +23,10 @@
#define IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
#include "llvm/ADT/ArrayRef.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
-
-class Operation;
namespace iree_compiler {
/// Marker to denote that a linalg operation has been partitioned to
diff --git a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h
index 6048bea..7c8536c 100644
--- a/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h
+++ b/iree/compiler/Conversion/CodegenUtils/MatmulCodegenStrategy.h
@@ -24,12 +24,11 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/Function.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
-class FuncOp;
-
/// Abstract Transformation class applied in a sequence that also handles state
/// through markers.
struct Transformation {
diff --git a/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
index 8a143f5..a8f64c1 100644
--- a/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
+++ b/iree/compiler/Conversion/HLOToHLO/DecomposeHLOClamp.cpp
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 7381818..d178038 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -32,6 +32,8 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -44,8 +46,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index 9ae86c7..48707d4 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -22,6 +22,8 @@
#include <memory>
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
@@ -35,8 +37,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp b/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp
index 749ada4..f4f41d9 100644
--- a/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp
+++ b/iree/compiler/Conversion/LLVMToLLVM/FastExpConversion.cpp
@@ -31,13 +31,13 @@
LogicalResult matchAndRewrite(LLVM::ExpOp op,
PatternRewriter &rewriter) const override {
- constexpr float ln2Const = 0.693147181;
- constexpr float ln2InvConst = 1.44269504;
+ constexpr float ln2Const = 0.693147181f;
+ constexpr float ln2InvConst = 1.44269504f;
// Least squares polynomial fit computed :
// cValues = np.polyfit(np.linspace(0, math.log(2), 10000), np.exp(x), 4)
- constexpr float cValues[5] = {0.05924867, 0.15514645, 0.50308552,
- 0.99968939, 1.00000721531};
+ constexpr float cValues[5] = {0.05924867f, 0.15514645f, 0.50308552f,
+ 0.99968939f, 1.00000721531f};
auto loc = op.getLoc();
Value x = op.getOperand();
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
index 793c6f0..f2d7711 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h
@@ -15,13 +15,11 @@
#include <cstdint>
#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
namespace mlir {
-class Operation;
-class Value;
-class OpBuilder;
-class Operation;
-
namespace iree_compiler {
enum class TilingLevel {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h b/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h
index f75133e..8ee737a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h
@@ -23,13 +23,14 @@
// or not.
//
//===----------------------------------------------------------------------===//
+
#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_COOPERATIVEMATRIXANALYSIS_H_
#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_COOPERATIVEMATRIXANALYSIS_H_
+
#include "llvm/ADT/DenseSet.h"
+#include "mlir/IR/Operation.h"
namespace mlir {
-class Operation;
-
namespace iree_compiler {
class CooperativeMatrixAnalysis {
@@ -45,6 +46,8 @@
private:
llvm::DenseSet<mlir::Operation *> usesCooperativeMatrix;
};
+
} // namespace iree_compiler
} // namespace mlir
+
#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_COOPERATIVEMATRIXANALYSIS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
index fd3c0a2..72ee6d2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -24,28 +24,21 @@
#include <array>
+#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
namespace mlir {
-class FuncOp;
-struct LogicalResult;
-class Operation;
-class PatternRewriter;
-class ShapedType;
-class Value;
-
-namespace linalg {
-class LinalgDependenceGraph;
-class LinalgOp;
-} // namespace linalg
-namespace iree_compiler {
-struct SPIRVCodegenOptions;
-}
-
namespace iree_compiler {
/// Store the tile sizes to use at different levels of tiling as a vector of
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
index ddaecf3..ea1316c 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
@@ -21,19 +21,19 @@
#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/FoldUtils.h"
namespace mlir {
-class FuncOp;
-class Value;
-class SubViewOp;
-class OperationFolder;
-class OpBuilder;
-struct LogicalResult;
-
namespace iree_compiler {
static constexpr int kNumGPUDims = 3;
+
/// Allocation callback for allocation workgroup local memory.
Optional<Value> allocateWorkgroupMemory(OpBuilder &b, SubViewOp subview,
ArrayRef<Value> boundingSubViewSize,
@@ -63,6 +63,7 @@
/// Updates the workgroup size used for the dispatch region.
LogicalResult updateWorkGroupSize(FuncOp funcOp,
ArrayRef<int64_t> workGroupSize);
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
index 16a05ef..c3dd2f4 100644
--- a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
+++ b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
@@ -18,10 +18,10 @@
#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h"
#include "llvm/ADT/SetVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
index 0eb58e5..650f19e 100644
--- a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
@@ -18,12 +18,12 @@
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
index 9964eae..049b5c1 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchConfig.cpp
@@ -18,8 +18,8 @@
#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/Support/Debug.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#define DEBUG_TYPE "iree-detail"
diff --git a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
index d6bfd26..86f886a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -22,7 +23,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index 9565747..0982a12 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -30,7 +31,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#define DEBUG_TYPE "iree-dispatch"
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index 039c829..db6856c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -16,6 +16,9 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
+#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
@@ -24,9 +27,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
index d636867..50593ff 100644
--- a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
@@ -28,6 +28,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -40,7 +41,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#define DEBUG_TYPE "iree-dispatch-detail"
diff --git a/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp b/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp
index 0b5d964..3dd2e56 100644
--- a/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/LegalizeInputTypes.cpp
@@ -13,6 +13,7 @@
// limitations under the License.
#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -26,7 +27,6 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index cf093c5..d012ba5 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -18,10 +18,10 @@
#include "iree/compiler/Dialect/Shape/Conversion/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
index f25f67e..d62568e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PrePostPartitioningConversion.cpp
@@ -19,14 +19,14 @@
#include "iree/compiler/Dialect/Flow/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Shape/Transforms/Patterns.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
index fc9d792..431f0e3 100644
--- a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
@@ -16,6 +16,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -28,7 +29,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
index 506d2d2..eaf4516 100644
--- a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
@@ -17,10 +17,10 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "llvm/ADT/SetVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
index 3b11700..5e19ea9 100644
--- a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
@@ -20,12 +20,12 @@
#include "iree/compiler/Dialect/Shape/IR/Builders.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LLVM.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 8c337a0..6b7e466 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -1392,6 +1392,14 @@
data));
}
+void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state,
+ uint32_t format, DenseIntElementsAttr data) {
+ ensureTerminator(*state.addRegion(), builder, state.location);
+ state.addAttribute(
+ "format", builder.getIntegerAttr(builder.getIntegerType(32), format));
+ state.addAttribute("data", data);
+}
+
static ParseResult parseExecutableBinaryOp(OpAsmParser &parser,
OperationState *result) {
auto *body = result->addRegion();
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index e294c8d..529cb5f 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1931,6 +1931,7 @@
let skipDefaultBuilders = 1;
let builders = [
OpBuilderDAG<(ins "uint32_t":$format, "std::vector<uint8_t>":$data)>,
+ OpBuilderDAG<(ins "uint32_t":$format, "DenseIntElementsAttr":$data)>,
];
let extraClassDeclaration = [{
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
index 18e5566..51983ae 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/BUILD
@@ -39,11 +39,13 @@
],
deps = [
":LinkerTool",
+ "//iree/base:flatcc",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMBaseTarget",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMIRPasses",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
- "//iree/schemas:dylib_executable_def_cc_fbs",
+ "//iree/compiler/Utils",
+ "//iree/schemas:dylib_executable_def_c_fbs",
"@llvm-project//llvm:AArch64AsmParser",
"@llvm-project//llvm:AArch64CodeGen",
"@llvm-project//llvm:ARMAsmParser",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
index f144761..0a97912 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/CMakeLists.txt
@@ -37,11 +37,13 @@
LLVMX86CodeGen
MLIRLLVMIR
MLIRTargetLLVMIR
+ iree::base::flatcc
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Target::LLVM::LLVMBaseTarget
iree::compiler::Dialect::HAL::Target::LLVM::LLVMIRPasses
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
- iree::schemas::dylib_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::dylib_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
index 337ba7c..03224a8 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LLVMAOTTarget.cpp
@@ -20,7 +20,8 @@
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/schemas/dylib_executable_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/dylib_executable_def_builder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TargetSelect.h"
@@ -70,24 +71,13 @@
// At this moment we are leaving MLIR LLVM dialect land translating module
// into target independent LLVMIR.
- auto llvmModule =
- mlir::translateModuleToLLVMIR(targetOp.getInnerModule(), context);
+ auto llvmModule = mlir::translateModuleToLLVMIR(targetOp.getInnerModule(),
+ context, libraryName);
if (!llvmModule) {
return targetOp.emitError() << "failed to translate the MLIR LLVM "
"dialect to the native llvm::Module";
}
- // Export all entry points such that they are accessible on the dynamic
- // libraries we generate.
- iree::DyLibExecutableDefT dyLibExecutableDef;
- SmallVector<StringRef, 8> entryPointNames;
- for (auto entryPointOp :
- targetOp.getBlock().getOps<ExecutableEntryPointOp>()) {
- dyLibExecutableDef.entry_points.push_back(
- std::string(entryPointOp.sym_name()));
- entryPointNames.push_back(entryPointOp.sym_name());
- }
-
// Try to grab a linker tool based on the options (and target environment).
auto linkerTool = LinkerTool::getForTarget(targetTriple, options_);
if (!linkerTool) {
@@ -98,6 +88,9 @@
// Configure the module with any code generation options required later by
// linking (such as initializer functions).
+ auto entryPointNames = llvm::to_vector<8>(
+ llvm::map_range(targetOp.getBlock().getOps<ExecutableEntryPointOp>(),
+ [&](auto op) { return op.getName(); }));
if (failed(
linkerTool->configureModule(llvmModule.get(), entryPointNames))) {
return targetOp.emitError()
@@ -152,21 +145,6 @@
<< linkerTool->getToolPath();
}
auto &linkArtifacts = linkArtifactsOr.getValue();
- dyLibExecutableDef.library_embedded =
- linkArtifacts.libraryFile.read().getValueOr(std::vector<int8_t>());
- if (dyLibExecutableDef.library_embedded.empty()) {
- return targetOp.emitError() << "failed to read back dylib temp file at "
- << linkArtifacts.libraryFile.path;
- }
-
- if (options_.debugSymbols && linkArtifacts.debugFile.outputFile) {
- dyLibExecutableDef.debug_database_embedded =
- linkArtifacts.debugFile.read().getValue();
- assert(!dyLibExecutableDef.debug_database_embedded.empty());
- dyLibExecutableDef.debug_database_filename =
- llvm::sys::path::filename(linkArtifacts.debugFile.path).str();
- }
-
if (options_.keepLinkerArtifacts) {
mlir::emitRemark(targetOp.getLoc())
<< "Linker artifacts for " << targetOp.getName() << " preserved:\n"
@@ -174,20 +152,49 @@
linkArtifacts.keepAllFiles();
}
- ::flatbuffers::FlatBufferBuilder fbb;
- auto executableOffset =
- iree::DyLibExecutableDef::Pack(fbb, &dyLibExecutableDef);
- iree::FinishDyLibExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ // Embed debug symbols at the end of the flatbuffer by adding first in the
+ // bottoms-up builder.
+ FlatbufferBuilder builder;
+ flatbuffers_uint8_vec_ref_t debugDatabaseRef = 0;
+ flatbuffers_string_ref_t debugDatabaseFilenameRef = 0;
+ if (options_.debugSymbols && linkArtifacts.debugFile.outputFile) {
+ debugDatabaseRef = builder.streamUint8Vec([&](raw_ostream &stream) {
+ return linkArtifacts.debugFile.readInto(stream);
+ });
+ debugDatabaseFilenameRef = builder.createString(
+ llvm::sys::path::filename(linkArtifacts.debugFile.path));
+ }
+
+ // Embed entire dynamic library output.
+ flatbuffers_uint8_vec_ref_t libraryEmbeddedRef =
+ builder.streamUint8Vec([&](raw_ostream &stream) {
+ return linkArtifacts.libraryFile.readInto(stream);
+ });
+ if (!libraryEmbeddedRef) {
+ return targetOp.emitError() << "failed to read back dylib temp file at "
+ << linkArtifacts.libraryFile.path;
+ }
+
+ // Entry point names up from.
+ // TODO(#3580): these won't be needed in the executable_library world.
+ auto entryPointsRef = builder.createStringVec(llvm::map_range(
+ targetOp.getBlock().getOps<ExecutableEntryPointOp>(),
+ [&](ExecutableEntryPointOp op) { return op.getName(); }));
+
+ iree_DyLibExecutableDef_start_as_root(builder);
+ iree_DyLibExecutableDef_entry_points_add(builder, entryPointsRef);
+ iree_DyLibExecutableDef_library_embedded_add(builder, libraryEmbeddedRef);
+ iree_DyLibExecutableDef_debug_database_filename_add(
+ builder, debugDatabaseFilenameRef);
+ iree_DyLibExecutableDef_debug_database_embedded_add(builder,
+ debugDatabaseRef);
+ iree_DyLibExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::DyLib),
- std::move(bytes));
-
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
};
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp
index 083d281..16a8526 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.cpp
@@ -67,6 +67,22 @@
return resultBuffer;
}
+bool Artifact::readInto(raw_ostream &targetStream) const {
+ // NOTE: we could make this much more efficient if we read in the file a
+ // chunk at a time and piped it along to targetStream. I couldn't find
+ // anything in LLVM that did this, for some crazy reason, but since we are
+ // dealing with binaries that can be 10+MB here it'd be nice if we could avoid
+ // reading them all into memory.
+ auto fileData = llvm::MemoryBuffer::getFile(path);
+ if (!fileData) {
+ llvm::errs() << "failed to load library output file '" << path << "'";
+ return false;
+ }
+ auto sourceBuffer = fileData.get()->getBuffer();
+ targetStream.write(sourceBuffer.data(), sourceBuffer.size());
+ return true;
+}
+
void Artifact::close() { outputFile->os().close(); }
void Artifacts::keepAllFiles() {
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h
index 6f3c39a..8cf4484 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/LinkerTool.h
@@ -50,6 +50,9 @@
// Reads the artifact file contents as bytes.
Optional<std::vector<int8_t>> read() const;
+ // Reads the artifact file and writes it into the given |stream|.
+ bool readInto(raw_ostream& targetStream) const;
+
// Closes the ostream of the file while preserving the temporary entry on
// disk. Use this if files need to be modified by external tools that may
// require exclusive access.
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp
index fcd1986..5e14e8a 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/AOT/internal/UnixLinkerTool.cpp
@@ -36,7 +36,18 @@
LogicalResult configureModule(llvm::Module *llvmModule,
ArrayRef<StringRef> entryPointNames) override {
- // Possibly a no-op in ELF files; needs to be verified.
+ // Enable frame pointers to ensure that stack unwinding works, e.g. in
+ // Tracy. In principle this could also be achieved by enabling unwind
+ // tables, but we tried that and that didn't work in Tracy (which uses
+ // libbacktrace), while enabling frame pointers worked.
+ // https://github.com/google/iree/issues/3957
+ for (auto &func : *llvmModule) {
+ auto attrs = func.getAttributes();
+ attrs = attrs.addAttribute(llvmModule->getContext(),
+ llvm::AttributeList::FunctionIndex,
+ "frame-pointer", "all");
+ func.setAttributes(attrs);
+ }
return success();
}
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
index eed06b7..10b20b5 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/BUILD
@@ -37,12 +37,14 @@
"LLVMIRTarget.h",
],
deps = [
+ "//iree/base:flatcc",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMBaseTarget",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMIRPasses",
"//iree/compiler/Dialect/HAL/Target/LLVM:LLVMTargetOptions",
"//iree/compiler/Dialect/Shape/IR",
- "//iree/schemas:llvmir_executable_def_cc_fbs",
+ "//iree/compiler/Utils",
+ "//iree/schemas:llvmir_executable_def_c_fbs",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:TargetLLVMIR",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
index c490eb0..650ce95 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/CMakeLists.txt
@@ -29,11 +29,13 @@
LLVMCore
LLVMSupport
MLIRTargetLLVMIR
+ iree::base::flatcc
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::HAL::Target::LLVM::LLVMBaseTarget
iree::compiler::Dialect::HAL::Target::LLVM::LLVMIRPasses
iree::compiler::Dialect::HAL::Target::LLVM::LLVMTargetOptions
iree::compiler::Dialect::Shape::IR
- iree::schemas::llvmir_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::llvmir_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
index b18ecbf..9567eae 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/IR/LLVMIRTarget.cpp
@@ -17,7 +17,8 @@
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMBaseTarget.h"
#include "iree/compiler/Dialect/HAL/Target/LLVM/LLVMIRPasses.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/schemas/llvmir_executable_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/llvmir_executable_def_builder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TargetSelect.h"
@@ -63,15 +64,6 @@
return targetOp.emitError("Failed to translate executable to LLVM IR");
}
- // Create invocation function an populate entry_points.
- iree::LLVMIRExecutableDefT llvmIrExecutableDef;
- auto entryPointOps =
- targetOp.getBlock().getOps<IREE::HAL::ExecutableEntryPointOp>();
- for (auto entryPointOp : entryPointOps) {
- llvmIrExecutableDef.entry_points.push_back(
- std::string(entryPointOp.sym_name()));
- }
-
// LLVMIR opt passes.
auto targetMachine = createTargetMachine(options_);
if (!targetMachine) {
@@ -86,30 +78,27 @@
"Can't build LLVMIR opt passes for ExecutableOp module");
}
- // Serialize LLVM module.
- std::string bufferString;
- llvm::raw_string_ostream ostream(bufferString);
- llvmModule->print(ostream, nullptr);
- ostream.flush();
+ // Serialize LLVM module directly into flatbuffer.
+ FlatbufferBuilder builder;
+ auto bitcodeModuleRef = builder.streamUint8Vec([&](raw_ostream &stream) {
+ llvmModule->print(stream, nullptr);
+ return true;
+ });
- // Creates executable bytes.
- llvmIrExecutableDef.llvmir_module = {bufferString.begin(),
- bufferString.end()};
+ auto entryPointsRef = builder.createStringVec(llvm::map_range(
+ targetOp.getBlock().getOps<ExecutableEntryPointOp>(),
+ [&](ExecutableEntryPointOp op) { return op.sym_name(); }));
- ::flatbuffers::FlatBufferBuilder fbb;
- auto executableOffset =
- iree::LLVMIRExecutableDef::Pack(fbb, &llvmIrExecutableDef);
- iree::FinishLLVMIRExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ iree_LLVMIRExecutableDef_start_as_root(builder);
+ iree_LLVMIRExecutableDef_entry_points_add(builder, entryPointsRef);
+ iree_LLVMIRExecutableDef_bitcode_module_add(builder, bitcodeModuleRef);
+ iree_LLVMIRExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::LLVM),
- std::move(bytes));
-
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
};
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt
index 60352f0..2934ce9 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt
@@ -32,10 +32,11 @@
MLIRSPIRV
MLIRSPIRVSerialization
MLIRSupport
- flatbuffers
+ iree::base::flatcc
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target::SPIRVCommon
- iree::schemas::metal_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::metal_executable_def_c_fbs
PUBLIC
)
@@ -48,8 +49,7 @@
"SPIRVToMSL.cpp"
DEPS
LLVMSupport
+ MLIRSupport
spirv-cross-msl
- INCLUDES
- ${PROJECT_SOURCE_DIR}/third_party
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
index 612d93b..58e2172 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -14,12 +14,12 @@
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/schemas/metal_executable_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/metal_executable_def_builder.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
@@ -50,16 +50,6 @@
spirv::getDefaultResourceLimits(context));
}
-// Returns a list of entry point names matching the expected export ordinals.
-static std::vector<std::string> populateEntryPointNames(
- spirv::ModuleOp spvModuleOp) {
- std::vector<std::string> entryPointNames;
- spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
- entryPointNames.push_back(std::string(entryPointOp.fn()));
- });
- return entryPointNames;
-}
-
class MetalSPIRVTargetBackend : public SPIRVTargetBackend {
public:
MetalSPIRVTargetBackend(MetalSPIRVTargetOptions options)
@@ -88,17 +78,18 @@
// The runtime use ordinals instead of names but Metal requires function
// names for constructing pipeline states. Get an ordered list of the entry
// point names.
- std::vector<std::string> entryPoints;
+ SmallVector<StringRef, 8> entryPointNames;
if (auto scheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
iree_compiler::getEntryPointScheduleAttrName())) {
// We have multiple entry points in this module. Make sure the order
// specified in the schedule attribute is respected.
for (Attribute entryPoint : scheduleAttr) {
- entryPoints.emplace_back(
- entryPoint.cast<StringAttr>().getValue().str());
+ entryPointNames.push_back(entryPoint.cast<StringAttr>().getValue());
}
} else {
- entryPoints = populateEntryPointNames(spvModuleOp);
+ spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
+ entryPointNames.push_back(entryPointOp.fn());
+ });
}
// 1. Serialize the spirv::ModuleOp into binary format.
@@ -109,7 +100,7 @@
// 2. Cross compile SPIR-V to MSL source code.
llvm::SmallVector<MetalShader, 2> mslShaders;
- for (const std::string &entryPoint : entryPoints) {
+ for (const auto &entryPoint : entryPointNames) {
llvm::Optional<MetalShader> mslShader = crossCompileSPIRVToMSL(
// We can use ArrayRef here given spvBinary reserves 0 bytes on stack.
llvm::makeArrayRef(spvBinary.data(), spvBinary.size()), entryPoint);
@@ -129,30 +120,32 @@
// to invoke them in C++.
// 4. Pack the MTLLibrary and metadata into a flatbuffer.
- iree::MetalExecutableDefT metalExecutableDef;
- metalExecutableDef.entry_points = entryPoints;
- for (auto &shader : mslShaders) {
- metalExecutableDef.shader_sources.push_back(std::move(shader.source));
- const auto &sizes = shader.threadgroupSize;
- metalExecutableDef.threadgroup_sizes.push_back(
- {sizes.x, sizes.y, sizes.z});
- }
+ FlatbufferBuilder builder;
- // Pack the executable definition and get the bytes with the proper header.
- // The header is used to verify the contents at runtime.
- ::flatbuffers::FlatBufferBuilder fbb;
- auto executableOffset =
- iree::MetalExecutableDef::Pack(fbb, &metalExecutableDef);
- iree::FinishMetalExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ auto shaderSourcesRef = builder.createStringVec(llvm::map_range(
+ mslShaders, [&](const MetalShader &shader) { return shader.source; }));
+
+ iree_MetalThreadgroupSize_vec_start(builder);
+ for (auto &shader : mslShaders) {
+ iree_MetalThreadgroupSize_vec_push_create(
+ builder, shader.threadgroupSize.x, shader.threadgroupSize.y,
+ shader.threadgroupSize.z);
+ }
+ auto threadgroupSizesRef = iree_MetalThreadgroupSize_vec_end(builder);
+
+ auto entryPointNamesRef = builder.createStringVec(entryPointNames);
+
+ iree_MetalExecutableDef_start_as_root(builder);
+ iree_MetalExecutableDef_entry_points_add(builder, entryPointNamesRef);
+ iree_MetalExecutableDef_threadgroup_sizes_add(builder, threadgroupSizesRef);
+ iree_MetalExecutableDef_shader_sources_add(builder, shaderSourcesRef);
+ iree_MetalExecutableDef_end_as_root(builder);
// 5. Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::Metal),
- std::move(bytes));
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
index 5d89aa3..0d23fd6 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
@@ -19,7 +19,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "spirv_cross/spirv_msl.hpp"
+#include "third_party/spirv_cross/spirv_msl.hpp"
#define DEBUG_TYPE "spirv-to-msl"
@@ -32,9 +32,9 @@
using CompilerMSL::CompilerMSL;
MetalShader::ThreadGroupSize getWorkgroupSizeForEntryPoint(
- const std::string& entryName) {
+ StringRef entryName) {
const auto& entryPoint = get_entry_point(
- entryName, spv::ExecutionModel::ExecutionModelGLCompute);
+ entryName.str(), spv::ExecutionModel::ExecutionModelGLCompute);
const auto& workgroupSize = entryPoint.workgroup_size;
// TODO(antiagainst): support specialization constant.
if (workgroupSize.constant != 0) return {0, 0, 0};
@@ -104,13 +104,13 @@
} // namespace
llvm::Optional<MetalShader> crossCompileSPIRVToMSL(
- llvm::ArrayRef<uint32_t> spvBinary, const std::string& entryPoint) {
+ llvm::ArrayRef<uint32_t> spvBinary, StringRef entryPoint) {
SPIRVToMSLCompiler spvCrossCompiler(spvBinary.data(), spvBinary.size());
// All spirv-cross operations work on the current entry point. It should be
// set right after the cross compiler construction.
spvCrossCompiler.set_entry_point(
- entryPoint, spv::ExecutionModel::ExecutionModelGLCompute);
+ entryPoint.str(), spv::ExecutionModel::ExecutionModelGLCompute);
// Explicitly set the argument buffer index for each SPIR-V resource variable.
auto descriptors = spvCrossCompiler.getBufferSetBindingPairs();
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h
index ab805d1..79cb609 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h
@@ -21,6 +21,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/StringRef.h"
+#include "mlir/Support/LLVM.h"
namespace mlir {
namespace iree_compiler {
@@ -37,7 +38,7 @@
// Cross compiles SPIR-V into Metal Shading Language source code for the
// compute shader with |entryPoint|. Returns llvm::None on failure.
llvm::Optional<MetalShader> crossCompileSPIRVToMSL(
- llvm::ArrayRef<uint32_t> spvBinary, const std::string& entryPoint);
+ llvm::ArrayRef<uint32_t> spvBinary, StringRef entryPoint);
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD b/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
index f3f0f08..7276df9 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
@@ -37,6 +37,7 @@
"VMLATarget.h",
],
deps = [
+ "//iree/base:flatcc",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/Target",
"//iree/compiler/Dialect/VM/Conversion",
@@ -45,8 +46,8 @@
"//iree/compiler/Dialect/VM/Transforms",
"//iree/compiler/Dialect/VMLA/IR:VMLADialect",
"//iree/compiler/Dialect/VMLA/Transforms",
- "//iree/schemas:vmla_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/compiler/Utils",
+ "//iree/schemas:vmla_executable_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
index 43de52b..8f6194d 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
@@ -30,7 +30,7 @@
MLIRIR
MLIRPass
MLIRSupport
- flatbuffers
+ iree::base::flatcc
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::VM::Conversion
@@ -39,6 +39,7 @@
iree::compiler::Dialect::VM::Transforms
iree::compiler::Dialect::VMLA::IR::VMLADialect
iree::compiler::Dialect::VMLA::Transforms
- iree::schemas::vmla_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::vmla_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index c1afd0b..1f40ef5 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
@@ -23,11 +22,10 @@
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
-#include "iree/schemas/vmla_executable_def_generated.h"
-#include "llvm/ADT/ScopeExit.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/vmla_executable_def_builder.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
@@ -216,35 +214,30 @@
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
- // Serialize the VM module to bytes.
- std::string byteStreamValue;
- llvm::raw_string_ostream byte_stream(byteStreamValue);
+ // Serialize the VM module to bytes directly into a flatbuffer.
+ FlatbufferBuilder builder;
IREE::VM::BytecodeTargetOptions bytecodeOptions;
- if (failed(translateModuleToBytecode(targetOp.getInnerModule(),
- bytecodeOptions, byte_stream))) {
+ auto dataRef = builder.streamUint8Vec([&](raw_ostream &stream) {
+ return succeeded(translateModuleToBytecode(targetOp.getInnerModule(),
+ bytecodeOptions, stream));
+ });
+ if (!dataRef) {
return targetOp.emitError() << "failed to serialize converted VM module";
}
// Pack the executable definition and get the bytes with the proper header.
// The header is used to verify the contents at runtime.
- ::flatbuffers::FlatBufferBuilder fbb;
- iree::VMLAExecutableDefT vmlaExecutableDef;
- vmlaExecutableDef.bytecode_module.resize(byteStreamValue.size());
- std::memcpy(vmlaExecutableDef.bytecode_module.data(),
- byteStreamValue.data(), byteStreamValue.size());
- auto executableOffset =
- iree::VMLAExecutableDef::Pack(fbb, &vmlaExecutableDef);
- iree::FinishVMLAExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ iree_VMLAExecutableDef_start_as_root(builder);
+ iree_VMLAExecutableDef_bytecode_module_add(builder, dataRef);
+ iree_VMLAExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
+ // NOTE: this snapshots the flatbuffer builder data at the time it is called
+ // and future changes will not be observed.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::VMLA),
- std::move(bytes));
-
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
index db09599..4062aa6 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
@@ -37,6 +37,7 @@
"VulkanSPIRVTarget.h",
],
deps = [
+ "//iree/base:flatcc",
"//iree/compiler/Conversion/Common",
"//iree/compiler/Conversion/LinalgToSPIRV",
"//iree/compiler/Conversion/LinalgToSPIRV:CodeGenOptionUtils",
@@ -45,8 +46,8 @@
"//iree/compiler/Dialect/HAL/Target/SPIRVCommon",
"//iree/compiler/Dialect/Vulkan/IR",
"//iree/compiler/Dialect/Vulkan/Utils",
- "//iree/schemas:spirv_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/compiler/Utils",
+ "//iree/schemas:spirv_executable_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:GPUDialect",
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
index 8fba6f5..c230f04 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
@@ -36,7 +36,7 @@
MLIRSPIRVSerialization
MLIRSupport
MLIRVector
- flatbuffers
+ iree::base::flatcc
iree::compiler::Conversion::Common
iree::compiler::Conversion::LinalgToSPIRV
iree::compiler::Conversion::LinalgToSPIRV::CodeGenOptionUtils
@@ -45,6 +45,7 @@
iree::compiler::Dialect::HAL::Target::SPIRVCommon
iree::compiler::Dialect::Vulkan::IR
iree::compiler::Dialect::Vulkan::Utils
- iree::schemas::spirv_executable_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::spirv_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index b678f36..166f979 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
@@ -23,7 +22,8 @@
#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
#include "iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/spirv_executable_def_builder.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -86,16 +86,6 @@
return {};
}
-// Returns a list of entry point names matching the expected export ordinals.
-static std::vector<std::string> populateEntryPointNames(
- spirv::ModuleOp spvModuleOp) {
- std::vector<std::string> entryPointNames;
- spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
- entryPointNames.push_back(std::string(entryPointOp.fn()));
- });
- return entryPointNames;
-}
-
class VulkanSPIRVTargetBackend : public SPIRVTargetBackend {
public:
VulkanSPIRVTargetBackend(VulkanSPIRVTargetOptions options)
@@ -135,53 +125,47 @@
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
- iree::SpirVExecutableDefT spirvExecutableDef;
-
ModuleOp innerModuleOp = targetOp.getInnerModule();
auto spvModuleOp = *innerModuleOp.getOps<spirv::ModuleOp>().begin();
+ // Serialize the spirv::ModuleOp into the binary that we will embed in the
+ // final flatbuffer.
+ FlatbufferBuilder builder;
+ SmallVector<uint32_t, 256> spvBinary;
+ if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) {
+ return targetOp.emitError() << "failed to serialize spv.module";
+ }
+ auto spvCodeRef = flatbuffers_uint32_vec_create(builder, spvBinary.data(),
+ spvBinary.size());
+
// The sequencer and runtime use ordinals instead of names. We provide the
// list of entry point names here that are then passed in
// VkShaderModuleCreateInfo.
+ SmallVector<StringRef, 8> entryPointNames;
if (auto scheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
iree_compiler::getEntryPointScheduleAttrName())) {
// We have multiple entry points in this module. Make sure the order
// specified in the schedule attribute is respected.
for (Attribute entryPoint : scheduleAttr) {
- spirvExecutableDef.entry_points.emplace_back(
- entryPoint.cast<StringAttr>().getValue().str());
+ entryPointNames.push_back(entryPoint.cast<StringAttr>().getValue());
}
} else {
- spirvExecutableDef.entry_points = populateEntryPointNames(spvModuleOp);
+ spvModuleOp.walk([&](spirv::EntryPointOp entryPointOp) {
+ entryPointNames.push_back(entryPointOp.fn());
+ });
}
+ auto entryPointsRef = builder.createStringVec(entryPointNames);
- // Serialize the spirv::ModuleOp into the binary that we will embed in the
- // final flatbuffer.
- SmallVector<uint32_t, 256> spvBinary;
- if (failed(spirv::serialize(spvModuleOp, spvBinary))) {
- return targetOp.emitError() << "failed to serialize spv.module";
- }
- spirvExecutableDef.code = {spvBinary.begin(), spvBinary.end()};
- if (spirvExecutableDef.code.empty()) {
- return targetOp.emitError()
- << "failed to translate and serialize SPIR-V executable";
- }
-
- // Pack the executable definition and get the bytes with the proper header.
- // The header is used to verify the contents at runtime.
- ::flatbuffers::FlatBufferBuilder fbb;
- auto executableOffset =
- iree::SpirVExecutableDef::Pack(fbb, &spirvExecutableDef);
- iree::FinishSpirVExecutableDefBuffer(fbb, executableOffset);
- std::vector<uint8_t> bytes;
- bytes.resize(fbb.GetSize());
- std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size());
+ iree_SpirVExecutableDef_start_as_root(builder);
+ iree_SpirVExecutableDef_entry_points_add(builder, entryPointsRef);
+ iree_SpirVExecutableDef_code_add(builder, spvCodeRef);
+ iree_SpirVExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
targetOp.getLoc(),
static_cast<uint32_t>(IREE::HAL::ExecutableFormat::SpirV),
- std::move(bytes));
+ builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
index 18dc06c..65e4663 100644
--- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
+++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
@@ -21,13 +21,13 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/Optional.h"
#include "llvm/Support/Debug.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/IREE/Tools/BUILD b/iree/compiler/Dialect/IREE/Tools/BUILD
index f046f8e..0411cff 100644
--- a/iree/compiler/Dialect/IREE/Tools/BUILD
+++ b/iree/compiler/Dialect/IREE/Tools/BUILD
@@ -18,16 +18,9 @@
licenses = ["notice"], # Apache 2.0
)
-cc_library(
- name = "Tools",
+filegroup(
+ name = "GenSrcs",
srcs = [
"StructAttrGen.cpp",
],
- deps = [
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:TableGen",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TableGen",
- ],
- alwayslink = 1,
)
diff --git a/iree/compiler/Dialect/IREE/Tools/CMakeLists.txt b/iree/compiler/Dialect/IREE/Tools/CMakeLists.txt
index 620c3a1..8b864e5 100644
--- a/iree/compiler/Dialect/IREE/Tools/CMakeLists.txt
+++ b/iree/compiler/Dialect/IREE/Tools/CMakeLists.txt
@@ -13,17 +13,3 @@
# limitations under the License.
iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- Tools
- SRCS
- "StructAttrGen.cpp"
- DEPS
- LLVMSupport
- LLVMTableGen
- MLIRSupport
- MLIRTableGen
- ALWAYSLINK
- PUBLIC
-)
diff --git a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
index 644f7ce..19d3812 100644
--- a/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
+++ b/iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.cpp
@@ -19,9 +19,9 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Optional.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
using llvm::None;
using llvm::Optional;
diff --git a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
index 1b06f2d..d4a8ebe 100644
--- a/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/ConvertHLOToShapeDialectPass.cpp
@@ -18,13 +18,13 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/Shape/Transforms/Passes.h b/iree/compiler/Dialect/Shape/Transforms/Passes.h
index 21541ac..765a99e 100644
--- a/iree/compiler/Dialect/Shape/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Shape/Transforms/Passes.h
@@ -19,11 +19,9 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
namespace mlir {
-
-class OpPassManager;
-
namespace iree_compiler {
namespace Shape {
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BUILD b/iree/compiler/Dialect/VM/Target/Bytecode/BUILD
index 8a1f4db..c525f57 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/BUILD
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/BUILD
@@ -25,8 +25,8 @@
"//iree/compiler/Dialect/VM/Analysis",
"//iree/compiler/Dialect/VM/IR",
"//iree/compiler/Dialect/VM/Transforms",
- "//iree/schemas:bytecode_module_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/compiler/Utils",
+ "//iree/schemas:bytecode_module_def_c_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
index ecfee52f..d63d70f 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
@@ -16,8 +16,6 @@
#include <algorithm>
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/minireflect.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
@@ -28,7 +26,9 @@
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
-#include "iree/schemas/bytecode_module_def_generated.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/bytecode_module_def_builder.h"
+#include "iree/schemas/bytecode_module_def_json_printer.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
@@ -48,10 +48,6 @@
namespace {
-using flatbuffers::FlatBufferBuilder;
-using flatbuffers::Offset;
-using flatbuffers::Vector;
-
struct ModuleCounts {
int importFuncs = 0;
int exportFuncs = 0;
@@ -204,20 +200,6 @@
return success();
}
-// Returns a vector of tables of type T or None if |contents| is empty.
-template <typename T>
-static Optional<Offset<Vector<Offset<T>>>> createOptionalVector(
- const std::vector<Offset<T>> &contents, FlatBufferBuilder &fbb) {
- if (contents.empty()) return llvm::None;
- return fbb.CreateVector(contents);
-}
-template <typename T>
-static Optional<Offset<Vector<T>>> createOptionalVector(
- const std::vector<T> &contents, FlatBufferBuilder &fbb) {
- if (contents.empty()) return llvm::None;
- return fbb.CreateVector(contents);
-}
-
// Encodes a type (or a tuple of nested types) to a calling convention string.
//
// Examples:
@@ -331,89 +313,103 @@
return std::string(s.data(), s.size());
}
-// Populates common fields for FunctionSignatureDefs of all function types.
-static void populateFunctionSignatureDef(FunctionType functionType,
- llvm::DenseMap<Type, int> &typeTable,
- iree::vm::FunctionSignatureDefT &fsd) {
- for (auto type : functionType.getInputs()) {
- if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) {
- type = refPtrType.getObjectType();
- }
- fsd.argument_types.push_back(typeTable.lookup(type));
+// Creates a FunctionSignatureDef based on the given function metadata.
+// Some fields are not used on all signature defs and added only when present on
+// the argument objects/attrs.
+static iree_vm_FunctionSignatureDef_ref_t createFunctionSignatureDef(
+ FunctionType functionType, llvm::DenseMap<Type, int> &typeTable,
+ StringRef callingConvention,
+ iree_vm_ReflectionAttrDef_vec_ref_t reflectionAttrsRef,
+ FlatbufferBuilder &fbb) {
+ auto resultTypesRef = fbb.createInt32Vec(
+ llvm::map_range(functionType.getResults(), [&](Type type) {
+ if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) {
+ type = refPtrType.getObjectType();
+ }
+ return typeTable.lookup(type);
+ }));
+ auto argumentTypesRef = fbb.createInt32Vec(
+ llvm::map_range(functionType.getInputs(), [&](Type type) {
+ if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) {
+ type = refPtrType.getObjectType();
+ }
+ return typeTable.lookup(type);
+ }));
+
+ auto callingConventionRef = fbb.createString(callingConvention);
+
+ // If the signature would be empty then let's avoid writing the empty table.
+ if (!argumentTypesRef && !resultTypesRef && !callingConventionRef &&
+ !reflectionAttrsRef) {
+ return 0;
}
- for (auto type : functionType.getResults()) {
- if (auto refPtrType = type.dyn_cast<IREE::VM::RefType>()) {
- type = refPtrType.getObjectType();
- }
- fsd.result_types.push_back(typeTable.lookup(type));
- }
+
+ iree_vm_FunctionSignatureDef_start(fbb);
+ iree_vm_FunctionSignatureDef_argument_types_add(fbb, argumentTypesRef);
+ iree_vm_FunctionSignatureDef_result_types_add(fbb, resultTypesRef);
+ iree_vm_FunctionSignatureDef_calling_convention_add(fbb,
+ callingConventionRef);
+ iree_vm_FunctionSignatureDef_reflection_attrs_add(fbb, reflectionAttrsRef);
+ return iree_vm_FunctionSignatureDef_end(fbb);
}
// Returns a serialized function signature.
-static Offset<iree::vm::FunctionSignatureDef> makeImportFunctionSignatureDef(
+static iree_vm_FunctionSignatureDef_ref_t makeImportFunctionSignatureDef(
IREE::VM::ImportOp importOp, llvm::DenseMap<Type, int> &typeTable,
- FlatBufferBuilder &fbb) {
- // Common attributes.
- iree::vm::FunctionSignatureDefT fsd;
- populateFunctionSignatureDef(importOp.getType(), typeTable, fsd);
-
+ FlatbufferBuilder &fbb) {
// Generate the signature calling convention string based on types.
auto cconv = makeImportCallingConventionString(importOp);
if (!cconv.hasValue()) return {};
- fsd.calling_convention = cconv.getValue();
-
- return iree::vm::FunctionSignatureDef::Pack(fbb, &fsd);
+ return createFunctionSignatureDef(importOp.getType(), typeTable,
+ cconv.getValue(), /*reflectionAttrsRef=*/0,
+ fbb);
}
// Returns a serialized function signature.
-static Offset<iree::vm::FunctionSignatureDef> makeExportFunctionSignatureDef(
+static iree_vm_FunctionSignatureDef_ref_t makeExportFunctionSignatureDef(
IREE::VM::ExportOp exportOp, IREE::VM::FuncOp funcOp,
- llvm::DenseMap<Type, int> &typeTable, FlatBufferBuilder &fbb) {
- // Common attributes.
- iree::vm::FunctionSignatureDefT fsd;
- populateFunctionSignatureDef(funcOp.getType(), typeTable, fsd);
-
+ llvm::DenseMap<Type, int> &typeTable, FlatbufferBuilder &fbb) {
// Generate the signature calling convention string based on types.
auto cconv = makeCallingConventionString(funcOp);
if (!cconv.hasValue()) return {};
- fsd.calling_convention = cconv.getValue();
-
- return iree::vm::FunctionSignatureDef::Pack(fbb, &fsd);
+ return createFunctionSignatureDef(funcOp.getType(), typeTable,
+ cconv.getValue(), /*reflectionAttrsRef=*/0,
+ fbb);
}
// Returns a serialized function signature.
-static Offset<iree::vm::FunctionSignatureDef> makeInternalFunctionSignatureDef(
+static iree_vm_FunctionSignatureDef_ref_t makeInternalFunctionSignatureDef(
IREE::VM::FuncOp funcOp, llvm::DenseMap<Type, int> &typeTable,
- FlatBufferBuilder &fbb) {
- // Common attributes.
- iree::vm::FunctionSignatureDefT fsd;
- populateFunctionSignatureDef(funcOp.getType(), typeTable, fsd);
-
+ FlatbufferBuilder &fbb) {
// Generate the signature calling convention string based on types.
// TODO(benvanik): only do this on exports. The runtime currently looks on
// internal functions, though, so we have to have it here.
auto cconv = makeCallingConventionString(funcOp);
if (!cconv.hasValue()) return {};
- fsd.calling_convention = cconv.getValue();
// Reflection attributes.
// TODO(benvanik): move these to exports (or remove entirely).
+ iree_vm_ReflectionAttrDef_vec_ref_t reflectionAttrsRef = 0;
if (auto reflectionAttrs =
funcOp.getAttrOfType<DictionaryAttr>("iree.reflection")) {
- llvm::SmallVector<Offset<iree::vm::ReflectionAttrDef>, 4>
- reflectionAttrItems;
+ SmallVector<iree_vm_ReflectionAttrDef_ref_t, 4> reflectionAttrRefs;
for (auto reflectionAttr : reflectionAttrs) {
auto key = reflectionAttr.first.strref();
auto value = reflectionAttr.second.dyn_cast<StringAttr>();
if (!value || key.empty()) continue;
- auto rattr = std::make_unique<iree::vm::ReflectionAttrDefT>();
- rattr->key = key.str();
- rattr->value = value.getValue().str();
- fsd.reflection_attrs.push_back(std::move(rattr));
+ // NOTE: if we actually want to keep these we should dedupe them (as the
+ // keys and likely several of the values are shared across all functions).
+ auto valueRef = fbb.createString(value.getValue());
+ auto keyRef = fbb.createString(key);
+ reflectionAttrRefs.push_back(
+ iree_vm_ReflectionAttrDef_create(fbb, keyRef, valueRef));
}
+ reflectionAttrsRef = iree_vm_ReflectionAttrDef_vec_create(
+ fbb, reflectionAttrRefs.data(), reflectionAttrRefs.size());
}
- return iree::vm::FunctionSignatureDef::Pack(fbb, &fsd);
+ return createFunctionSignatureDef(funcOp.getType(), typeTable,
+ cconv.getValue(), reflectionAttrsRef, fbb);
}
// Builds a complete BytecodeModuleDef FlatBuffer object in |fbb|.
@@ -426,9 +422,9 @@
// has been packed into the top-level table. This results in a messier function
// here during serialization but a much more trivial (and cache-friendly)
// representation at runtime.
-static Offset<iree::vm::BytecodeModuleDef> buildFlatBufferModule(
- BytecodeTargetOptions targetOptions, IREE::VM::ModuleOp moduleOp,
- FlatBufferBuilder &fbb) {
+static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions,
+ IREE::VM::ModuleOp moduleOp,
+ FlatbufferBuilder &fbb) {
SymbolTable symbolTable(moduleOp);
auto symbolCounts = computeModuleSymbolCounts(moduleOp);
@@ -456,17 +452,24 @@
// Serialize read-only data first so that it ends up at the end of the file.
// This is where large things like parameters live and we don't want that to
// get paged in until it is needed.
- std::vector<Offset<Vector<uint8_t>>> rodataContentOffsets;
- rodataContentOffsets.reserve(rodataOps.size());
- for (auto rodataOp : rodataOps) {
- auto dataOffset =
+ //
+ // NOTE: flatbuffers are built bottom-up; after each rodata we serialize we
+ // move *backward* in the file and prepend the next, meaning that if we
+ // were to serialize all rodata we'd have it in the opposite order as we do
+ // in the IR. Though this it isn't required for correctness, enabling file
+ // layout planning by preserving the order in the IR is useful.
+ SmallVector<flatbuffers_uint8_vec_ref_t, 8> rodataContentRefs;
+ rodataContentRefs.reserve(rodataOps.size());
+ for (auto rodataOp : llvm::reverse(rodataOps)) {
+ auto rodataRef =
serializeConstant(rodataOp.getLoc(), rodataOp.value(), fbb);
- if (dataOffset.IsNull()) {
- rodataOp.emitOpError() << "failed to encode";
- return {};
+ if (!rodataRef) {
+ return rodataOp.emitOpError() << "failed to encode";
}
- rodataContentOffsets.push_back(dataOffset);
+ rodataContentRefs.push_back(rodataRef);
}
+ // List of references needs to be swapped forward (we wrote backward).
+ std::reverse(rodataContentRefs.begin(), rodataContentRefs.end());
// Find all types in the module to build the type table.
// Note that we don't emit it yet as we want to keep it near the top of the
@@ -478,28 +481,32 @@
}
// Serialize function bytecode one at a time and then merge at the end.
- std::vector<std::vector<uint8_t>> bytecodeDataParts;
- std::vector<iree::vm::FunctionDescriptor> functionDescriptors;
- bytecodeDataParts.reserve(internalFuncOps.size());
- functionDescriptors.reserve(internalFuncOps.size());
+ SmallVector<std::vector<uint8_t>, 8> bytecodeDataParts;
+ SmallVector<iree_vm_FunctionDescriptor_t, 8> functionDescriptors;
+ bytecodeDataParts.resize(internalFuncOps.size());
+ functionDescriptors.resize(internalFuncOps.size());
size_t totalBytecodeLength = 0;
- for (auto funcOp : internalFuncOps) {
- auto encodedFunction =
- BytecodeEncoder::encodeFunction(funcOp, typeOrdinalMap, symbolTable);
+ for (auto funcOp : llvm::enumerate(internalFuncOps)) {
+ auto encodedFunction = BytecodeEncoder::encodeFunction(
+ funcOp.value(), typeOrdinalMap, symbolTable);
if (!encodedFunction) {
- funcOp.emitError() << "failed to encode function bytecode";
- return {};
+ return funcOp.value().emitError() << "failed to encode function bytecode";
}
- functionDescriptors.push_back(iree::vm::FunctionDescriptor(
- totalBytecodeLength, encodedFunction->bytecodeData.size(),
- encodedFunction->i32RegisterCount, encodedFunction->refRegisterCount));
+ iree_vm_FunctionDescriptor_assign(
+ &functionDescriptors[funcOp.index()], totalBytecodeLength,
+ encodedFunction->bytecodeData.size(), encodedFunction->i32RegisterCount,
+ encodedFunction->refRegisterCount);
totalBytecodeLength += encodedFunction->bytecodeData.size();
- bytecodeDataParts.push_back(std::move(encodedFunction->bytecodeData));
+ bytecodeDataParts[funcOp.index()] =
+ std::move(encodedFunction->bytecodeData);
}
- // TODO(benvanik): compression? deduping?
- uint8_t *bytecodeDataPtr = nullptr;
- auto bytecodeDataOffset = fbb.CreateUninitializedVector<uint8_t>(
- totalBytecodeLength, &bytecodeDataPtr);
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytecodeDataPtr =
+ flatbuffers_uint8_vec_extend(fbb, totalBytecodeLength);
+ // NOTE: we need to ensure we clear the output data in case we have gaps for
+ // alignment (where otherwise uninitialized memory might sneak in and be bad
+ // for both security and determinism).
+ memset(bytecodeDataPtr, 0, totalBytecodeLength);
size_t currentBytecodeOffset = 0;
for (const auto &it : llvm::enumerate(internalFuncOps)) {
int ordinal = it.index();
@@ -508,102 +515,109 @@
data.size());
currentBytecodeOffset += data.size();
}
+ auto bytecodeDataRef = flatbuffers_uint8_vec_end(fbb);
+
+ // Encode the function descriptors adjacent to the bytcode data; they are
+ // always accessed together. Descriptor 0 is likely within a few hundred bytes
+ // of the referenced bytecode data offset 0, and from there we are at least
+ // able to hope sequential readahead caching helps; if not, at least we
+ // hopefully don't fault on the first function call every time.
+ auto functionDescriptorsRef = iree_vm_FunctionDescriptor_vec_create(
+ fbb, functionDescriptors.data(), functionDescriptors.size());
// Serialize metadata that should be near the front of the file.
- std::vector<Offset<iree::vm::RodataSegmentDef>> rodataSegmentOffsets;
- rodataSegmentOffsets.reserve(rodataOps.size());
- for (auto rodataContentOffset : rodataContentOffsets) {
- iree::vm::RodataSegmentDefBuilder rsd(fbb);
- rsd.add_data(rodataContentOffset);
- rodataSegmentOffsets.push_back(rsd.Finish());
- }
- std::vector<Offset<iree::vm::RwdataSegmentDef>> rwdataSegmentOffsets;
- std::vector<Offset<iree::vm::TypeDef>> typeOffsets;
- typeOffsets.reserve(typeTable.size());
- for (auto &typeDef : typeTable) {
- auto nameOffset = fbb.CreateString(typeDef.full_name);
- iree::vm::TypeDefBuilder tdb(fbb);
- tdb.add_full_name(nameOffset);
- typeOffsets.push_back(tdb.Finish());
- }
- std::vector<Offset<iree::vm::ImportFunctionDef>> importFuncOffsets;
- importFuncOffsets.reserve(importFuncOps.size());
- for (auto importOp : importFuncOps) {
- auto nameOffset = fbb.CreateString(importOp.getName().str());
- auto signatureOffset =
- makeImportFunctionSignatureDef(importOp, typeOrdinalMap, fbb);
- iree::vm::ImportFunctionDefBuilder ifd(fbb);
- ifd.add_full_name(nameOffset);
- ifd.add_signature(signatureOffset);
- importFuncOffsets.push_back(ifd.Finish());
- }
- std::vector<Offset<iree::vm::ExportFunctionDef>> exportFuncOffsets;
- exportFuncOffsets.reserve(exportFuncOps.size());
- for (auto exportOp : exportFuncOps) {
- auto nameOffset = fbb.CreateString(exportOp.export_name().str());
- auto funcOp = symbolTable.lookup<IREE::VM::FuncOp>(exportOp.function_ref());
- auto signatureOffset =
- makeExportFunctionSignatureDef(exportOp, funcOp, typeOrdinalMap, fbb);
- iree::vm::ExportFunctionDefBuilder efd(fbb);
- efd.add_local_name(nameOffset);
- efd.add_signature(signatureOffset);
- efd.add_internal_ordinal(funcOp.ordinal().getValue().getLimitedValue());
- exportFuncOffsets.push_back(efd.Finish());
- }
- std::vector<Offset<iree::vm::InternalFunctionDef>> internalFuncOffsets;
+ auto rodataSegmentRefs = llvm::to_vector<8>(
+ llvm::map_range(rodataContentRefs, [&](auto rodataContentRef) {
+ iree_vm_RodataSegmentDef_start(fbb);
+ iree_vm_RodataSegmentDef_data_add(fbb, rodataContentRef);
+ return iree_vm_RodataSegmentDef_end(fbb);
+ }));
+ SmallVector<iree_vm_RwdataSegmentDef_ref_t, 8> rwdataSegmentRefs;
+ // NOTE: rwdata current unused.
+ auto typeRefs =
+ llvm::to_vector<8>(llvm::map_range(typeTable, [&](auto typeDef) {
+ auto fullNameRef = fbb.createString(typeDef.full_name);
+ iree_vm_TypeDef_start(fbb);
+ iree_vm_TypeDef_full_name_add(fbb, fullNameRef);
+ return iree_vm_TypeDef_end(fbb);
+ }));
+ auto importFuncRefs =
+ llvm::to_vector<8>(llvm::map_range(importFuncOps, [&](auto importOp) {
+ auto fullNameRef = fbb.createString(importOp.getName());
+ auto signatureRef =
+ makeImportFunctionSignatureDef(importOp, typeOrdinalMap, fbb);
+ iree_vm_ImportFunctionDef_start(fbb);
+ iree_vm_ImportFunctionDef_full_name_add(fbb, fullNameRef);
+ iree_vm_ImportFunctionDef_signature_add(fbb, signatureRef);
+ return iree_vm_ImportFunctionDef_end(fbb);
+ }));
+ auto exportFuncRefs =
+ llvm::to_vector<8>(llvm::map_range(exportFuncOps, [&](auto exportOp) {
+ auto localNameRef = fbb.createString(exportOp.export_name());
+ auto funcOp =
+ symbolTable.lookup<IREE::VM::FuncOp>(exportOp.function_ref());
+ auto signatureRef = makeExportFunctionSignatureDef(exportOp, funcOp,
+ typeOrdinalMap, fbb);
+ iree_vm_ExportFunctionDef_start(fbb);
+ iree_vm_ExportFunctionDef_local_name_add(fbb, localNameRef);
+ iree_vm_ExportFunctionDef_signature_add(fbb, signatureRef);
+ iree_vm_ExportFunctionDef_internal_ordinal_add(
+ fbb, funcOp.ordinal().getValue().getLimitedValue());
+ return iree_vm_ExportFunctionDef_end(fbb);
+ }));
+ SmallVector<iree_vm_InternalFunctionDef_ref_t, 8> internalFuncRefs;
if (!targetOptions.stripSymbols) {
- internalFuncOffsets.reserve(internalFuncOps.size());
+ internalFuncRefs.reserve(internalFuncOps.size());
for (auto funcOp : internalFuncOps) {
- auto nameOffset = fbb.CreateString(funcOp.getName().str());
- auto signatureOffset =
+ auto localNameRef = fbb.createString(funcOp.getName());
+ auto signatureRef =
makeInternalFunctionSignatureDef(funcOp, typeOrdinalMap, fbb);
- iree::vm::InternalFunctionDefBuilder ifd(fbb);
- ifd.add_local_name(nameOffset);
- ifd.add_signature(signatureOffset);
- internalFuncOffsets.push_back(ifd.Finish());
+ iree_vm_InternalFunctionDef_start(fbb);
+ iree_vm_InternalFunctionDef_local_name_add(fbb, localNameRef);
+ iree_vm_InternalFunctionDef_signature_add(fbb, signatureRef);
+ internalFuncRefs.push_back(iree_vm_InternalFunctionDef_end(fbb));
}
}
- auto functionDescriptorsOffset =
- fbb.CreateVectorOfStructs(functionDescriptors);
- auto rodataSegmentsOffset = createOptionalVector(rodataSegmentOffsets, fbb);
- auto rwdataSegmentsOffset = createOptionalVector(rwdataSegmentOffsets, fbb);
- auto internalFuncsOffset = fbb.CreateVector(internalFuncOffsets);
- auto exportFuncsOffset = fbb.CreateVector(exportFuncOffsets);
- auto importFuncsOffset = createOptionalVector(importFuncOffsets, fbb);
- auto typesOffset = fbb.CreateVector(typeOffsets);
+ // NOTE: we keep the vectors clustered here so that we can hopefully keep the
+ // pages mapped at runtime; vector dereferences in flatbuffers require
+ // touching these structs to get length/etc and as such we don't want to be
+ // gathering from all over the file (with giant rodata chunks and such
+ // inbetween) just to perform a bounds check and deference into another part
+ // of the file.
+ auto rodataSegmentsRef = fbb.createOffsetVecDestructive(rodataSegmentRefs);
+ auto rwdataSegmentsRef = fbb.createOffsetVecDestructive(rwdataSegmentRefs);
+ auto internalFuncsRef = fbb.createOffsetVecDestructive(internalFuncRefs);
+ auto exportFuncsOffset = fbb.createOffsetVecDestructive(exportFuncRefs);
+ auto importFuncsRef = fbb.createOffsetVecDestructive(importFuncRefs);
+ auto typesRef = fbb.createOffsetVecDestructive(typeRefs);
- Optional<Offset<iree::vm::ModuleStateDef>> moduleStateDef;
+ iree_vm_ModuleStateDef_ref_t moduleStateDef = 0;
if (symbolCounts.globalBytes || symbolCounts.globalRefs) {
- iree::vm::ModuleStateDefBuilder msd(fbb);
- msd.add_global_bytes_capacity(symbolCounts.globalBytes);
- msd.add_global_ref_count(symbolCounts.globalRefs);
- moduleStateDef = msd.Finish();
+ iree_vm_ModuleStateDef_start(fbb);
+ iree_vm_ModuleStateDef_global_bytes_capacity_add(fbb,
+ symbolCounts.globalBytes);
+ iree_vm_ModuleStateDef_global_ref_count_add(fbb, symbolCounts.globalRefs);
+ moduleStateDef = iree_vm_ModuleStateDef_end(fbb);
}
- auto nameOffset = fbb.CreateString(
- moduleOp.sym_name().empty() ? "module" : moduleOp.sym_name().str());
+ auto moduleNameRef = fbb.createString(
+ moduleOp.sym_name().empty() ? "module" : moduleOp.sym_name());
- iree::vm::BytecodeModuleDefBuilder bmd(fbb);
- bmd.add_name(nameOffset);
- bmd.add_types(typesOffset);
- if (importFuncsOffset) {
- bmd.add_imported_functions(importFuncsOffset.getValue());
- }
- bmd.add_exported_functions(exportFuncsOffset);
- bmd.add_internal_functions(internalFuncsOffset);
- if (moduleStateDef) {
- bmd.add_module_state(moduleStateDef.getValue());
- }
- if (rwdataSegmentsOffset) {
- bmd.add_rwdata_segments(rwdataSegmentsOffset.getValue());
- }
- if (rodataSegmentsOffset) {
- bmd.add_rodata_segments(rodataSegmentsOffset.getValue());
- }
- bmd.add_function_descriptors(functionDescriptorsOffset);
- bmd.add_bytecode_data(bytecodeDataOffset);
- return bmd.Finish();
+ iree_vm_BytecodeModuleDef_start_as_root(fbb);
+ iree_vm_BytecodeModuleDef_name_add(fbb, moduleNameRef);
+ iree_vm_BytecodeModuleDef_types_add(fbb, typesRef);
+ iree_vm_BytecodeModuleDef_imported_functions_add(fbb, importFuncsRef);
+ iree_vm_BytecodeModuleDef_exported_functions_add(fbb, exportFuncsOffset);
+ iree_vm_BytecodeModuleDef_internal_functions_add(fbb, internalFuncsRef);
+ iree_vm_BytecodeModuleDef_module_state_add(fbb, moduleStateDef);
+ iree_vm_BytecodeModuleDef_rodata_segments_add(fbb, rodataSegmentsRef);
+ iree_vm_BytecodeModuleDef_rwdata_segments_add(fbb, rwdataSegmentsRef);
+ iree_vm_BytecodeModuleDef_function_descriptors_add(fbb,
+ functionDescriptorsRef);
+ iree_vm_BytecodeModuleDef_bytecode_data_add(fbb, bytecodeDataRef);
+ iree_vm_BytecodeModuleDef_end_as_root(fbb);
+ return success();
}
LogicalResult translateModuleToBytecode(IREE::VM::ModuleOp moduleOp,
@@ -639,28 +653,30 @@
// the module header in memory. This ensures that when we map the file only
// the first few pages need to be accessed to get the metadata and the rest
// can be large bulk data.
- FlatBufferBuilder fbb;
- auto moduleDef = buildFlatBufferModule(targetOptions, moduleOp, fbb);
- if (moduleDef.IsNull()) {
+ FlatbufferBuilder fbb;
+ if (failed(buildFlatBufferModule(targetOptions, moduleOp, fbb))) {
return moduleOp.emitError()
<< "failed to build FlatBuffer BytecodeModuleDef";
}
- iree::vm::FinishBytecodeModuleDefBuffer(fbb, moduleDef);
- const uint8_t *flatbufferBytes = fbb.GetBufferPointer();
- size_t flatbufferByteSize = fbb.GetSize();
-
switch (targetOptions.outputFormat) {
case BytecodeOutputFormat::kFlatBufferBinary:
- output.write(reinterpret_cast<const char *>(flatbufferBytes),
- flatbufferByteSize);
+ if (failed(fbb.copyToStream(output))) {
+ return moduleOp.emitError()
+ << "failed to copy flatbuffer emitter contents to output stream "
+ "- possibly out of memory";
+ }
break;
case BytecodeOutputFormat::kFlatBufferText: {
- flatbuffers::ToStringVisitor toStringVisitor("\n", false, " ", false);
- flatbuffers::IterateFlatBuffer(flatbufferBytes,
- iree::vm::BytecodeModuleDefTypeTable(),
- &toStringVisitor);
- output << toStringVisitor.s << "\n";
+ if (failed(fbb.printJsonToStream(/*pretty=*/true,
+ /*includeDefaults=*/false,
+ bytecode_module_def_print_json,
+ output))) {
+ return moduleOp.emitError()
+ << "failed to print flatbuffer emitter contents to output "
+ "stream - possibly out of memory, possibly unprintable "
+ "structure";
+ }
break;
}
default:
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt b/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt
index 73aeb14..24a13c5 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/CMakeLists.txt
@@ -35,12 +35,12 @@
MLIRSupport
MLIRTransforms
MLIRTranslation
- flatbuffers
iree::compiler::Dialect::IREE::IR
iree::compiler::Dialect::IREE::Transforms
iree::compiler::Dialect::VM::Analysis
iree::compiler::Dialect::VM::IR
iree::compiler::Dialect::VM::Transforms
- iree::schemas::bytecode_module_def_cc_fbs
+ iree::compiler::Utils
+ iree::schemas::bytecode_module_def_c_fbs
PUBLIC
)
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
index 180535a..64e1079 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp
@@ -14,7 +14,6 @@
#include "iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h"
-#include "flatbuffers/flatbuffers.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/StandardTypes.h"
@@ -24,24 +23,16 @@
namespace IREE {
namespace VM {
-namespace {
-
-using flatbuffers::FlatBufferBuilder;
-using flatbuffers::Offset;
-using flatbuffers::Vector;
-
-} // namespace
-
// TODO(benvanik): switch to LLVM's BinaryStreamWriter to handle endianness.
-static Offset<Vector<uint8_t>> serializeConstantI8Array(
- DenseIntElementsAttr attr, FlatBufferBuilder &fbb) {
+static flatbuffers_uint8_vec_ref_t serializeConstantI8Array(
+ DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
// vm.rodata and other very large constants end up as this; since i8 is i8
// everywhere (endianness doesn't matter when you have one byte :) we can
// directly access the data and memcpy.
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 1, &bytePtr);
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr =
+ flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(int8_t));
if (attr.isSplat()) {
// NOTE: this is a slow path and we should have eliminated it earlier on
// during constant op conversion.
@@ -52,72 +43,72 @@
auto rawData = attr.getRawData();
std::memcpy(bytePtr, rawData.data(), rawData.size());
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantI16Array(
- DenseIntElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 2, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantI16Array(
+ DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
+ fbb, attr.getNumElements() * sizeof(int16_t));
uint16_t *nativePtr = reinterpret_cast<uint16_t *>(bytePtr);
for (const APInt &value : attr.getIntValues()) {
*(nativePtr++) = value.extractBitsAsZExtValue(16, 0) & UINT16_MAX;
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantI32Array(
- DenseIntElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 4, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantI32Array(
+ DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
+ fbb, attr.getNumElements() * sizeof(int32_t));
uint32_t *nativePtr = reinterpret_cast<uint32_t *>(bytePtr);
for (const APInt &value : attr.getIntValues()) {
*(nativePtr++) = value.extractBitsAsZExtValue(32, 0) & UINT32_MAX;
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantI64Array(
- DenseIntElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 8, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantI64Array(
+ DenseIntElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr = flatbuffers_uint8_vec_extend(
+ fbb, attr.getNumElements() * sizeof(int64_t));
uint64_t *nativePtr = reinterpret_cast<uint64_t *>(bytePtr);
for (const APInt &value : attr.getIntValues()) {
*(nativePtr++) = value.extractBitsAsZExtValue(64, 0) & UINT64_MAX;
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantF32Array(
- DenseFPElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 4, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantF32Array(
+ DenseFPElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr =
+ flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(float));
float *nativePtr = reinterpret_cast<float *>(bytePtr);
for (const APFloat &value : attr.getFloatValues()) {
*(nativePtr++) = value.convertToFloat();
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-static Offset<Vector<uint8_t>> serializeConstantF64Array(
- DenseFPElementsAttr attr, FlatBufferBuilder &fbb) {
- uint8_t *bytePtr = nullptr;
- auto byteVector =
- fbb.CreateUninitializedVector(attr.getNumElements() * 8, &bytePtr);
+static flatbuffers_uint8_vec_ref_t serializeConstantF64Array(
+ DenseFPElementsAttr attr, FlatbufferBuilder &fbb) {
+ flatbuffers_uint8_vec_start(fbb);
+ uint8_t *bytePtr =
+ flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(double));
double *nativePtr = reinterpret_cast<double *>(bytePtr);
for (const APFloat &value : attr.getFloatValues()) {
*(nativePtr++) = value.convertToDouble();
}
- return byteVector;
+ return flatbuffers_uint8_vec_end(fbb);
}
-Offset<Vector<uint8_t>> serializeConstant(Location loc,
- ElementsAttr elementsAttr,
- FlatBufferBuilder &fbb) {
+flatbuffers_uint8_vec_ref_t serializeConstant(Location loc,
+ ElementsAttr elementsAttr,
+ FlatbufferBuilder &fbb) {
if (auto attr = elementsAttr.dyn_cast<DenseIntElementsAttr>()) {
switch (attr.getType().getElementTypeBitWidth()) {
case 8:
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h
index 04d62a3..56471a6 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h
@@ -15,7 +15,8 @@
#ifndef IREE_COMPILER_DIALECT_VM_TARGET_BYTECODE_CONSTANTENCODER_H_
#define IREE_COMPILER_DIALECT_VM_TARGET_BYTECODE_CONSTANTENCODER_H_
-#include "flatbuffers/flatbuffers.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/bytecode_module_def_builder.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Location.h"
@@ -25,9 +26,9 @@
namespace VM {
// Serializes a constant attribute to the FlatBuffer as a binary blob.
-flatbuffers::Offset<flatbuffers::Vector<uint8_t>> serializeConstant(
- Location loc, ElementsAttr elementsAttr,
- flatbuffers::FlatBufferBuilder &fbb);
+flatbuffers_uint8_vec_ref_t serializeConstant(Location loc,
+ ElementsAttr elementsAttr,
+ FlatbufferBuilder &fbb);
} // namespace VM
} // namespace IREE
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir b/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir
index 4389809..fc4d1a2 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/constant_encoding.mlir
@@ -1,20 +1,50 @@
// RUN: iree-translate -split-input-file -iree-vm-ir-to-bytecode-module -iree-vm-bytecode-module-output-format=flatbuffer-text %s | IreeFileCheck %s
-// CHECK: name: "constants"
+// CHECK: "name": "constants"
vm.module @constants {
vm.export @func
vm.func @func() {
vm.return
}
- // CHECK: rodata_segments: [ {
+ // CHECK: "rodata_segments": [{
- // CHECK: data: [ 1, 2, 3 ]
+ // CHECK: "data": [
+ // CHECK-NEXT: 1,
+ // CHECK-NEXT: 2,
+ // CHECK-NEXT: 3
+ // CHECK-NEXT: ]
vm.rodata @dense_i8s dense<[1, 2, 3]> : tensor<3xi8>
- // CHECK: data: [ 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64 ]
+ // CHECK: "data": [
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 128,
+ // CHECK-NEXT: 63,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 64,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 64,
+ // CHECK-NEXT: 64
+ // CHECK-NEXT: ]
vm.rodata @dense_float32s dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>
- // CHECK: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
+ // CHECK: "data": [
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 128,
+ // CHECK-NEXT: 63,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 128,
+ // CHECK-NEXT: 63,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 128,
+ // CHECK-NEXT: 63
+ // CHECK-NEXT: ]
vm.rodata @splat_float32s dense<1.000000e+00> : tensor<3xf32>
}
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir b/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir
index 18caab9..6c5e100 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir
@@ -1,24 +1,32 @@
// RUN: iree-translate -split-input-file -iree-vm-ir-to-bytecode-module -iree-vm-bytecode-module-output-format=flatbuffer-text %s | IreeFileCheck %s
-// CHECK: name: "simple_module"
+// CHECK: "name": "simple_module"
vm.module @simple_module {
- // CHECK: types: [ {
- // CHECK: full_name: "i32"
+ // CHECK: "types": [{
+ // CHECK: "full_name": "i32"
- // CHECK: exported_functions:
- // CHECK: local_name: "func"
+ // CHECK: "exported_functions":
+ // CHECK: "local_name": "func"
vm.export @func
- // CHECK: internal_functions:
- // CHECK: local_name: "func"
+ // CHECK: "internal_functions":
+ // CHECK: "local_name": "func"
vm.func @func(%arg0 : i32) -> i32 {
vm.return %arg0 : i32
}
- // CHECK: function_descriptors:
- // CHECK-NEXT: bytecode_offset: 0
- // CHECK-NEXT: bytecode_length: 5
- // CHECK-NEXT: i32_register_count: 1
- // CHECK-NEXT: ref_register_count: 0
- // CHECK: bytecode_data: [ 84, 1, 0, 0, 0 ]
+ // CHECK: "function_descriptors":
+ // CHECK-NEXT: {
+ // CHECK-NEXT: "bytecode_offset": 0
+ // CHECK-NEXT: "bytecode_length": 5
+ // CHECK-NEXT: "i32_register_count": 1
+ // CHECK-NEXT: "ref_register_count": 0
+ // CHECK-NEXT: }
+ // CHECK: "bytecode_data": [
+ // CHECK-NEXT: 84,
+ // CHECK-NEXT: 1,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: ]
}
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir b/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
index 30c0d8b..0c69043 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
@@ -3,12 +3,12 @@
// CHECK-LABEL: simple_module
vm.module @simple_module {
vm.export @func
- // CHECK: internal_functions:
- // CHECK: reflection_attrs:
- // CHECK: key: "f"
- // CHECK: value: "FOOBAR"
- // CHECK: key: "fv"
- // CHECK: value: "INFINITY"
+ // CHECK: "internal_functions":
+ // CHECK: "reflection_attrs":
+ // CHECK: "key": "f"
+ // CHECK: "value": "FOOBAR"
+ // CHECK: "key": "fv"
+ // CHECK: "value": "INFINITY"
vm.func @func(%arg0 : i32) -> i32
attributes { iree.reflection = { f = "FOOBAR", fv = "INFINITY" } }
{
diff --git a/iree/compiler/Dialect/VM/Tools/BUILD b/iree/compiler/Dialect/VM/Tools/BUILD
index b31b9f9..1b0c741 100644
--- a/iree/compiler/Dialect/VM/Tools/BUILD
+++ b/iree/compiler/Dialect/VM/Tools/BUILD
@@ -18,17 +18,10 @@
licenses = ["notice"], # Apache 2.0
)
-cc_library(
- name = "Tools",
+filegroup(
+ name = "GenSrcs",
srcs = [
"VMOpEncoderGen.cpp",
"VMOpTableGen.cpp",
],
- deps = [
- "@llvm-project//llvm:Support",
- "@llvm-project//llvm:TableGen",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TableGen",
- ],
- alwayslink = 1,
)
diff --git a/iree/compiler/Dialect/VM/Tools/CMakeLists.txt b/iree/compiler/Dialect/VM/Tools/CMakeLists.txt
index d13afd5..8b864e5 100644
--- a/iree/compiler/Dialect/VM/Tools/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Tools/CMakeLists.txt
@@ -13,18 +13,3 @@
# limitations under the License.
iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- Tools
- SRCS
- "VMOpEncoderGen.cpp"
- "VMOpTableGen.cpp"
- DEPS
- LLVMSupport
- LLVMTableGen
- MLIRSupport
- MLIRTableGen
- ALWAYSLINK
- PUBLIC
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp
index f5bc7c1..3640212 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp
@@ -21,6 +21,7 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -31,7 +32,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 93b73f6..c3d2a91 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -24,6 +24,8 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "llvm/ADT/STLExtras.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -34,8 +36,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
index 4e7eb20..60fb6ff 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
@@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -29,7 +30,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
index 5c5ce94..45d36f7 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
@@ -25,13 +25,13 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
index d29b463..8061fca 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
@@ -18,9 +18,9 @@
#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
index fef7f26..3097f96 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -21,6 +21,8 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitVector.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -34,8 +36,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp b/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp
index 289ce23..2305342 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Translation/test/smoketest.mlir b/iree/compiler/Translation/test/smoketest.mlir
index 69a8d97..3b19456 100644
--- a/iree/compiler/Translation/test/smoketest.mlir
+++ b/iree/compiler/Translation/test/smoketest.mlir
@@ -1,30 +1,36 @@
// RUN: iree-translate -split-input-file -iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module -iree-vm-bytecode-module-output-format=flatbuffer-text %s | IreeFileCheck %s
-// CHECK-LABEL: name: "simple_module"
+// CHECK-LABEL: "name": "simple_module"
module @simple_module {
-// CHECK: exported_functions:
-// CHECK: local_name: "func"
+// CHECK: "exported_functions":
+// CHECK: "local_name": "func"
-// CHECK: internal_functions:
-// CHECK: local_name: "func"
+// CHECK: "internal_functions":
+// CHECK: "local_name": "func"
func @func(%arg0 : i32) -> i32 attributes { iree.module.export } {
return %arg0 : i32
}
-// CHECK: function_descriptors:
-// CHECK-NEXT: bytecode_offset: 0
-// CHECK-NEXT: bytecode_length: 5
-// CHECK-NEXT: i32_register_count: 1
-// CHECK-NEXT: ref_register_count: 0
-// CHECK: bytecode_data: [ 84, 1, 0, 0,
+// CHECK: "function_descriptors":
+// CHECK-NEXT: {
+// CHECK-NEXT: "bytecode_offset": 0
+// CHECK-NEXT: "bytecode_length": 5
+// CHECK-NEXT: "i32_register_count": 1
+// CHECK-NEXT: "ref_register_count": 0
+// CHECK-NEXT: }
+// CHECK: "bytecode_data": [
+// CHECK-NEXT: 84,
+// CHECK-NEXT: 1,
+// CHECK-NEXT: 0,
+// CHECK-NEXT: 0,
}
// -----
-// CHECK-LABEL: name: "do_not_optimize_module"
+// CHECK-LABEL: "name": "do_not_optimize_module"
module @do_not_optimize_module {
-// CHECK: exported_functions:
-// CHECK: local_name: "add"
+// CHECK: "exported_functions":
+// CHECK: "local_name": "add"
func @add() -> i32 attributes { iree.module.export } {
%c1 = constant 1 : i32
%unf_c1 = iree.do_not_optimize(%c1) : i32
@@ -36,12 +42,12 @@
// -----
-// CHECK-LABEL: name: "hal_usage"
+// CHECK-LABEL: "name": "hal_usage"
module @hal_usage {
-// CHECK: imported_functions:
-// CHECK: full_name: "hal.command_buffer.dispatch"
-// CHECK: exported_functions:
-// CHECK: local_name: "hloElementwiseOps"
+// CHECK: "imported_functions":
+// CHECK: "full_name": "hal.command_buffer.dispatch"
+// CHECK: "exported_functions":
+// CHECK: "local_name": "hloElementwiseOps"
func @hloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes {iree.module.export} {
%0 = mhlo.add %arg0, %arg0 : tensor<4xf32>
%1 = mhlo.subtract %0, %arg0 : tensor<4xf32>
diff --git a/iree/compiler/Utils/BUILD b/iree/compiler/Utils/BUILD
index ee02d94..3f6aa07 100644
--- a/iree/compiler/Utils/BUILD
+++ b/iree/compiler/Utils/BUILD
@@ -23,13 +23,16 @@
cc_library(
name = "Utils",
srcs = [
+ "FlatbufferUtils.cpp",
"GraphUtils.cpp",
],
hdrs = [
+ "FlatbufferUtils.h",
"GraphUtils.h",
"PatternUtils.h",
],
deps = [
+ "//iree/base:flatcc",
"//iree/compiler/Dialect/IREE/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Utils/CMakeLists.txt b/iree/compiler/Utils/CMakeLists.txt
index 6e7286e..6db54f5 100644
--- a/iree/compiler/Utils/CMakeLists.txt
+++ b/iree/compiler/Utils/CMakeLists.txt
@@ -18,9 +18,11 @@
NAME
Utils
HDRS
+ "FlatbufferUtils.h"
"GraphUtils.h"
"PatternUtils.h"
SRCS
+ "FlatbufferUtils.cpp"
"GraphUtils.cpp"
DEPS
LLVMSupport
@@ -30,6 +32,7 @@
MLIRSupport
MLIRTransformUtils
MLIRTransforms
+ iree::base::flatcc
iree::compiler::Dialect::IREE::IR
tensorflow::mlir_hlo
PUBLIC
diff --git a/iree/compiler/Utils/FlatbufferUtils.cpp b/iree/compiler/Utils/FlatbufferUtils.cpp
new file mode 100644
index 0000000..487a053
--- /dev/null
+++ b/iree/compiler/Utils/FlatbufferUtils.cpp
@@ -0,0 +1,127 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+
+#include <vector>
+
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Combines all pages of the flatbuffer builder into a single contiguous byte
+// buffer and returns the result.
+//
+// NOTE: this is a alloc/copy. We need to have a single contiguous buffer to
+// pass into the elements factory function and the data we have in the
+// builder is paged. If we end up with a custom attribute type for this that
+// does not support storage uniquing then we can directly allocate and copy
+// the pages into the buffer without the extra copy.
+static SmallVector<uint8_t, 32> cloneBufferIntoContiguousBytes(
+ FlatbufferBuilder &fbb) {
+ size_t packedSize = flatcc_builder_get_buffer_size(fbb);
+ SmallVector<uint8_t, 32> packedData(packedSize);
+ void *result =
+ flatcc_builder_copy_buffer(fbb, packedData.data(), packedData.size());
+ assert(result && "flatcc_emitter_t impl failed (non-default?)");
+ return packedData;
+}
+
+FlatbufferBuilder::FlatbufferBuilder() { flatcc_builder_init(&builder); }
+
+FlatbufferBuilder::~FlatbufferBuilder() { flatcc_builder_clear(&builder); }
+
+flatbuffers_uint8_vec_ref_t FlatbufferBuilder::streamUint8Vec(
+ std::function<bool(raw_ostream &stream)> fn) {
+ flatbuffers_uint8_vec_start(*this);
+ raw_flatbuffer_uint8_vec_ostream stream(*this);
+ if (!fn(stream)) {
+ return 0;
+ }
+ stream.flush();
+ return flatbuffers_uint8_vec_end(*this);
+}
+
+DenseIntElementsAttr FlatbufferBuilder::getBufferAttr(MLIRContext *context) {
+ // We require direct access to the flatbuffer bytes so we can pass them to
+ // the attribute constructor (which needs to inspect them all for uniquing).
+ auto bufferData = cloneBufferIntoContiguousBytes(*this);
+
+ // NOTE: ew. OpaqueAttr may be better? It does equality checks but won't try
+ // to unique and would let us get a mutable buffer out.
+ return DenseIntElementsAttr::get(
+ VectorType::get({static_cast<int64_t>(bufferData.size())},
+ IntegerType::get(8, context)),
+ std::move(bufferData));
+}
+
+LogicalResult FlatbufferBuilder::copyToStream(llvm::raw_ostream &output) {
+ // NOTE: expected to be the default emitter.
+ auto *E = reinterpret_cast<flatcc_emitter_t *>(
+ flatcc_builder_get_emit_context(*this));
+
+ if (!E->front) {
+ return failure();
+ }
+ if (E->front == E->back) {
+ output.write(reinterpret_cast<char *>(E->front_cursor), E->used);
+ return success();
+ }
+ size_t len = FLATCC_EMITTER_PAGE_SIZE - E->front_left;
+ output.write(reinterpret_cast<char *>(E->front_cursor), len);
+ flatcc_emitter_page_t *p = E->front->next;
+ while (p != E->back) {
+ output.write(reinterpret_cast<char *>(p->page), FLATCC_EMITTER_PAGE_SIZE);
+ p = p->next;
+ }
+ output.write(reinterpret_cast<char *>(p->page),
+ FLATCC_EMITTER_PAGE_SIZE - E->back_left);
+ return success();
+}
+
+LogicalResult FlatbufferBuilder::printJsonToStream(
+ bool pretty, bool includeDefaults, print_json_fn_t print_json_fn,
+ llvm::raw_ostream &output) {
+ // The printer requires direct access to the flatbuffer bytes so clone here.
+ auto bufferData = cloneBufferIntoContiguousBytes(*this);
+
+ flatcc_json_printer_t printer;
+ flatcc_json_printer_init_dynamic_buffer(&printer, /*buffer_size=*/0);
+ flatcc_json_printer_set_indent(&printer, pretty ? 2 : 0);
+ flatcc_json_printer_set_skip_default(&printer, !includeDefaults);
+ flatcc_json_printer_set_force_default(&printer, includeDefaults);
+
+ // Print into the dynamically-resizing buffer. May fail if OOM.
+ int rv =
+ print_json_fn(&printer, reinterpret_cast<const char *>(bufferData.data()),
+ bufferData.size());
+ if (rv == -1) {
+ flatcc_json_printer_clear(&printer);
+ return failure();
+ }
+
+ // Take the buffer from the printer; note that it is 0 terminated and can be
+ // used directly as a cstr if needed.
+ size_t outputSize = 0;
+ char *outputBytes = reinterpret_cast<char *>(
+ flatcc_json_printer_finalize_dynamic_buffer(&printer, &outputSize));
+ output.write(outputBytes, outputSize);
+ free(outputBytes);
+
+ return success();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Utils/FlatbufferUtils.h b/iree/compiler/Utils/FlatbufferUtils.h
new file mode 100644
index 0000000..301634d
--- /dev/null
+++ b/iree/compiler/Utils/FlatbufferUtils.h
@@ -0,0 +1,184 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_UTILS_FLATBUFFERUTILS_H_
+#define IREE_COMPILER_UTILS_FLATBUFFERUTILS_H_
+
+#include <functional>
+
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+
+// NOTE: order matters here as some of the LLVM includes conflict.
+#include "iree/base/flatcc.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+// RAII wrapper for flatcc_builder_t; pass to functions requiring a builder.
+//
+// Usage:
+// FlatbufferBuilder builder;
+// // NOTE: flatbuffers are built bottoms-up so we first generate our [uint8]:
+// auto dataRef = builder.streamUint8Vec(...);
+// // ... and then start the table that references it:
+// my_type_start_as_root(builder);
+// my_type_uint8_vec_field_add(builder, dataRef);
+// my_type_end_as_root(builder);
+// // ... and finally capture the results as an mlir::Attribute.
+// auto attr = builder.getBufferAttr(mlirContext);
+class FlatbufferBuilder {
+ public:
+ FlatbufferBuilder();
+ ~FlatbufferBuilder();
+
+ operator flatcc_builder_t *() { return &builder; }
+
+ // Creates a string with the given string contents (including zeros).
+ flatbuffers_string_ref_t createString(StringRef value) {
+ if (value.empty()) return 0;
+ return flatbuffers_string_create(*this, value.data(), value.size());
+ }
+
+ // Creates a string vector containing all strings in the given range.
+ template <typename RangeTy>
+ flatbuffers_string_vec_ref_t createStringVec(RangeTy &&Range) {
+ auto stringRefs =
+ llvm::to_vector<8>(llvm::map_range(Range, [&](StringRef value) {
+ return flatbuffers_string_create(*this, value.data(), value.size());
+ }));
+ if (stringRefs.empty()) return 0;
+ return flatbuffers_string_vec_create(*this, stringRefs.data(),
+ stringRefs.size());
+ }
+
+ // Creates an offset vector with the given values. The source values will not
+ // be modified.
+ flatbuffers_vec_ref_t createOffsetVec(ArrayRef<flatcc_builder_ref_t> values) {
+ if (values.empty()) return 0;
+ return flatcc_builder_create_offset_vector(*this, values.data(),
+ values.size());
+ }
+
+ // Creates an offset vector with the given values.
+ // Unlike createOffsetVec this will destroy the input values array during
+ // serialization but be much faster.
+ flatbuffers_vec_ref_t createOffsetVecDestructive(
+ SmallVectorImpl<flatcc_builder_ref_t> &values) {
+ if (values.empty()) return 0;
+ return flatcc_builder_create_offset_vector_direct(*this, values.data(),
+ values.size());
+ }
+
+ // Creates an [int32] vec with the contents of the given range.
+ template <typename RangeTy>
+ flatbuffers_int32_vec_ref_t createInt32Vec(RangeTy &&Range) {
+ if (llvm::empty(Range)) return 0;
+ flatbuffers_int32_vec_start(*this);
+ for (int32_t v : Range) {
+ flatbuffers_int32_vec_push_create(*this, v);
+ }
+ return flatbuffers_int32_vec_end(*this);
+ }
+
+ // Provides a raw_ostream that |fn| can use to directly stream into a [uint8]
+ // in the flatbuffer builder.
+ //
+ // Usage:
+ // auto ref = builder.streamUint8Vec([&](llvm::raw_ostream &stream) {
+ // stream << "foo";
+ // return true;
+ // });
+ // ...
+ // my_type_uint8_vec_field_add(builder, ref); // use vec reference
+ // ...
+ flatbuffers_uint8_vec_ref_t streamUint8Vec(
+ std::function<bool(raw_ostream &stream)> fn);
+
+ // Captures the current contents of the flatbuffer builder and returns them
+ // as a shaped `vector<SIZExi8>` dense attr. The builder is left unmodified.
+ DenseIntElementsAttr getBufferAttr(MLIRContext *context);
+
+ // Copies the current contents of the flatbuffer builder to the target output
+ // stream. The builder is left unmodified.
+ //
+ // This is reduces a significant large allocation that can happen when trying
+ // to stitch together all of the pages that were allocated in the emitter as
+ // the flatbuffer was constructed; here we can just walk over each page and
+ // write it out in order without any allocations.
+ LogicalResult copyToStream(llvm::raw_ostream &output);
+
+ using print_json_fn_t = int (*)(flatcc_json_printer_t *ctx, const char *buf,
+ size_t bufsiz);
+
+ // Prints the flatbuffer in its canonical JSON format to the given stream.
+ // The builder is left unmodified.
+ //
+ // |pretty| enables newlines and indentation; somewhat useful for lit testing
+ // (as large byte buffers end up with a byte per line!).
+ //
+ // |includeDefaults| will force all values, including those that would not
+ // be serialized to the binary format due to the default value (0, etc) being
+ // omitted.
+ //
+ // NOTE: JSON representations will also differ structurally from the binary
+ // format as reused tables are printed wherever they are used as opposed to
+ // referencing the same bytes; meaning that this can't be used to verify that
+ // we are correctly memoizing strings/structures/etc.
+ LogicalResult printJsonToStream(bool pretty, bool includeDefaults,
+ print_json_fn_t print_json_fn,
+ llvm::raw_ostream &output);
+
+ private:
+ flatcc_builder_t builder;
+};
+
+// Allows streaming bytes directly into a flatbuffer `[uint8]` field.
+// The ostream runs in buffered mode and routes all writes into pages
+// allocated by the flatbuffer builder as we grow the output.
+//
+// Usage:
+// flatbuffers_uint8_vec_start(builder);
+// raw_flatbuffer_uint8_vec_ostream stream(builder);
+// stream << "foo";
+// stream.flush(); // *********** IMPORTANT ***********
+// flatbuffers_uint8_vec_ref_t ref = flatbuffers_uint8_vec_end(builder);
+class raw_flatbuffer_uint8_vec_ostream : public llvm::raw_ostream {
+ public:
+ explicit raw_flatbuffer_uint8_vec_ostream(flatcc_builder_t *builder)
+ : raw_ostream(/*unbuffered=*/true), builder(builder) {}
+
+ ~raw_flatbuffer_uint8_vec_ostream() override { flush(); }
+
+ private:
+ void write_impl(const char *Ptr, size_t Size) override {
+ flatbuffers_uint8_vec_append(builder,
+ reinterpret_cast<const uint8_t *>(Ptr), Size);
+ }
+
+ uint64_t current_pos() const override {
+ return tell() - GetNumBytesInBuffer();
+ }
+
+ flatcc_builder_t *builder;
+};
+
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_UTILS_FLATBUFFERUTILS_H_
diff --git a/iree/hal/BUILD b/iree/hal/BUILD
index 39c8c8f..fbe7237 100644
--- a/iree/hal/BUILD
+++ b/iree/hal/BUILD
@@ -22,18 +22,10 @@
licenses = ["notice"], # Apache 2.0
)
-cc_library(
- name = "allocator",
- srcs = ["allocator.cc"],
- hdrs = ["allocator.h"],
- deps = [
- ":buffer",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:tracing",
- "@com_google_absl//absl/types:span",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Public API
+#===------------------------------------------------------------------------===#
+# TODO(benvanik): rename to :hal
cc_library(
name = "api",
@@ -44,16 +36,11 @@
],
visibility = ["//visibility:public"],
deps = [
- ":api_hdrs",
- ":buffer",
- ":command_buffer",
- ":device",
- ":driver",
":driver_registry",
+ ":hal",
":heap_buffer",
- ":semaphore",
"//iree/base:api",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:ref_ptr",
"//iree/base:tracing",
"//iree/hal/host:host_local_allocator",
@@ -64,20 +51,12 @@
],
)
-cc_library(
- name = "api_hdrs",
- hdrs = ["api.h"],
- deps = [
- "//iree/base:api_hdrs",
- ],
-)
-
cc_test(
name = "api_string_util_test",
srcs = ["api_string_util_test.cc"],
deps = [
":api",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:status",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
@@ -86,18 +65,52 @@
],
)
+#===------------------------------------------------------------------------===#
+# Implementation
+#===------------------------------------------------------------------------===#
+# TODO(benvanik): rename to :cc and expose via an api_cc.h.
+
cc_library(
- name = "buffer",
- srcs = ["buffer.cc"],
- hdrs = ["buffer.h"],
+ name = "hal",
+ srcs = [
+ "allocator.cc",
+ "buffer.cc",
+ "command_buffer.cc",
+ "deferred_buffer.cc",
+ "executable_cache.cc",
+ ],
+ hdrs = [
+ "allocator.h",
+ "buffer.h",
+ "command_buffer.h",
+ "command_queue.h",
+ "debug_capture_manager.h",
+ "deferred_buffer.h",
+ "descriptor_set.h",
+ "descriptor_set_layout.h",
+ "device.h",
+ "device_info.h",
+ "device_placement.h",
+ "driver.h",
+ "event.h",
+ "executable.h",
+ "executable_cache.h",
+ "executable_format.h",
+ "executable_layout.h",
+ "executable_spec.h",
+ "resource.h",
+ "semaphore.h",
+ "stack_trace.h",
+ ],
deps = [
- ":resource",
- "//iree/base:bitfield",
+ "//iree/base:core_headers",
"//iree/base:logging",
+ "//iree/base:ref_ptr",
"//iree/base:status",
+ "//iree/base:time",
+ "//iree/base:tracing",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
- "@com_google_absl//absl/types:variant",
],
)
@@ -108,7 +121,7 @@
"buffer_test.cc",
],
deps = [
- ":buffer",
+ ":hal",
":heap_buffer",
"//iree/base:status",
"//iree/testing:gtest",
@@ -117,72 +130,11 @@
],
)
-cc_library(
- name = "command_buffer",
- srcs = ["command_buffer.cc"],
- hdrs = ["command_buffer.h"],
- deps = [
- ":buffer",
- ":descriptor_set",
- ":event",
- ":executable",
- ":executable_layout",
- ":resource",
- "//iree/base:bitfield",
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "command_buffer_validation",
- srcs = ["command_buffer_validation.cc"],
- hdrs = ["command_buffer_validation.h"],
- deps = [
- ":allocator",
- ":command_buffer",
- "//iree/base:logging",
- "//iree/base:status",
- "@com_google_absl//absl/strings",
- ],
-)
-
-cc_library(
- name = "command_queue",
- hdrs = ["command_queue.h"],
- deps = [
- ":command_buffer",
- ":semaphore",
- "//iree/base:bitfield",
- "//iree/base:status",
- "//iree/base:time",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "debug_capture_manager",
- hdrs = ["debug_capture_manager.h"],
- deps = [
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "deferred_buffer",
- srcs = ["deferred_buffer.cc"],
- hdrs = ["deferred_buffer.h"],
- deps = [
- ":allocator",
- ":buffer",
- "//iree/base:status",
- ],
-)
-
cc_test(
name = "deferred_buffer_test",
srcs = ["deferred_buffer_test.cc"],
deps = [
- ":deferred_buffer",
+ ":hal",
":heap_buffer",
"//iree/hal/testing:mock_allocator",
"//iree/testing:gtest",
@@ -191,68 +143,34 @@
],
)
-cc_library(
- name = "descriptor_set",
- hdrs = ["descriptor_set.h"],
- deps = [
- ":buffer",
- ":resource",
- "@com_google_absl//absl/strings",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Debugging utilities and tools
+#===------------------------------------------------------------------------===#
cc_library(
- name = "descriptor_set_layout",
- hdrs = ["descriptor_set_layout.h"],
+ name = "command_buffer_validation",
+ srcs = ["command_buffer_validation.cc"],
+ hdrs = ["command_buffer_validation.h"],
deps = [
- ":buffer",
- ":resource",
- ],
-)
-
-cc_library(
- name = "device",
- hdrs = ["device.h"],
- deps = [
- ":allocator",
- ":buffer",
- ":command_queue",
- ":descriptor_set",
- ":descriptor_set_layout",
- ":device_info",
- ":event",
- ":executable_cache",
- ":executable_layout",
- ":semaphore",
- "//iree/base:ref_ptr",
+ ":hal",
+ "//iree/base:logging",
"//iree/base:status",
- "//iree/base:target_platform",
- "//iree/base:time",
- ],
-)
-
-cc_library(
- name = "device_info",
- hdrs = ["device_info.h"],
- deps = [
- "//iree/base:bitfield",
"@com_google_absl//absl/strings",
],
)
+#===------------------------------------------------------------------------===#
+# Internal device management and driver registry
+#===------------------------------------------------------------------------===#
+# TODO(benvanik): port these to C and merge into main API.
+
cc_library(
name = "device_manager",
srcs = ["device_manager.cc"],
hdrs = ["device_manager.h"],
deps = [
- ":allocator",
- ":buffer",
- ":command_queue",
- ":device",
- ":device_placement",
- ":executable_format",
+ ":hal",
":heap_buffer",
- ":semaphore",
"//iree/base:status",
"//iree/base:time",
"//iree/base:tracing",
@@ -262,29 +180,11 @@
)
cc_library(
- name = "device_placement",
- hdrs = ["device_placement.h"],
-)
-
-cc_library(
- name = "driver",
- hdrs = ["driver.h"],
- deps = [
- ":debug_capture_manager",
- ":device",
- ":device_info",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- ],
-)
-
-cc_library(
name = "driver_registry",
srcs = ["driver_registry.cc"],
hdrs = ["driver_registry.h"],
deps = [
- ":driver",
- "//iree/base:initializer",
+ ":hal",
"//iree/base:ref_ptr",
"//iree/base:status",
"@com_google_absl//absl/base:core_headers",
@@ -292,87 +192,18 @@
],
)
-cc_library(
- name = "event",
- hdrs = ["event.h"],
- deps = [
- ":resource",
- ],
-)
-
-cc_library(
- name = "executable",
- hdrs = ["executable.h"],
- deps = [":resource"],
-)
-
-cc_library(
- name = "executable_cache",
- srcs = ["executable_cache.cc"],
- hdrs = ["executable_cache.h"],
- deps = [
- ":executable",
- ":executable_format",
- ":executable_layout",
- ":executable_spec",
- "//iree/base:bitfield",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- ],
-)
-
-cc_library(
- name = "executable_format",
- hdrs = ["executable_format.h"],
-)
-
-cc_library(
- name = "executable_layout",
- hdrs = ["executable_layout.h"],
- deps = [":resource"],
-)
-
-cc_library(
- name = "executable_spec",
- hdrs = ["executable_spec.h"],
- deps = [
- ":executable_format",
- "@com_google_absl//absl/types:span",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Internal implementation details
+#===------------------------------------------------------------------------===#
cc_library(
name = "heap_buffer",
srcs = ["heap_buffer.cc"],
hdrs = ["heap_buffer.h"],
deps = [
- ":allocator",
- ":buffer",
+ ":hal",
"//iree/base:status",
"//iree/base:tracing",
"//iree/hal/host:host_buffer",
],
)
-
-cc_library(
- name = "resource",
- hdrs = ["resource.h"],
- deps = [
- "//iree/base:ref_ptr",
- ],
-)
-
-cc_library(
- name = "semaphore",
- hdrs = ["semaphore.h"],
- deps = [
- ":resource",
- "//iree/base:status",
- "//iree/base:time",
- ],
-)
-
-cc_library(
- name = "stack_trace",
- hdrs = ["stack_trace.h"],
-)
diff --git a/iree/hal/CMakeLists.txt b/iree/hal/CMakeLists.txt
index 3804321..9d88801 100644
--- a/iree/hal/CMakeLists.txt
+++ b/iree/hal/CMakeLists.txt
@@ -16,22 +16,6 @@
iree_cc_library(
NAME
- allocator
- HDRS
- "allocator.h"
- SRCS
- "allocator.cc"
- DEPS
- ::buffer
- absl::span
- iree::base::ref_ptr
- iree::base::status
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
api
HDRS
"api.h"
@@ -39,35 +23,20 @@
SRCS
"api.cc"
DEPS
- ::api_hdrs
- ::buffer
- ::command_buffer
- ::device
- ::driver
::driver_registry
+ ::hal
::heap_buffer
- ::semaphore
absl::inlined_vector
absl::span
absl::strings
iree::base::api
- iree::base::memory
+ iree::base::core_headers
iree::base::ref_ptr
iree::base::tracing
iree::hal::host::host_local_allocator
PUBLIC
)
-iree_cc_library(
- NAME
- api_hdrs
- HDRS
- "api.h"
- DEPS
- iree::base::api_hdrs
- PUBLIC
-)
-
iree_cc_test(
NAME
api_string_util_test
@@ -77,7 +46,7 @@
::api
absl::inlined_vector
absl::strings
- iree::base::memory
+ iree::base::core_headers
iree::base::status
iree::testing::gtest
iree::testing::gtest_main
@@ -85,19 +54,44 @@
iree_cc_library(
NAME
- buffer
+ hal
HDRS
+ "allocator.h"
"buffer.h"
+ "command_buffer.h"
+ "command_queue.h"
+ "debug_capture_manager.h"
+ "deferred_buffer.h"
+ "descriptor_set.h"
+ "descriptor_set_layout.h"
+ "device.h"
+ "device_info.h"
+ "device_placement.h"
+ "driver.h"
+ "event.h"
+ "executable.h"
+ "executable_cache.h"
+ "executable_format.h"
+ "executable_layout.h"
+ "executable_spec.h"
+ "resource.h"
+ "semaphore.h"
+ "stack_trace.h"
SRCS
+ "allocator.cc"
"buffer.cc"
+ "command_buffer.cc"
+ "deferred_buffer.cc"
+ "executable_cache.cc"
DEPS
- ::resource
absl::span
absl::strings
- absl::variant
- iree::base::bitfield
+ iree::base::core_headers
iree::base::logging
+ iree::base::ref_ptr
iree::base::status
+ iree::base::time
+ iree::base::tracing
PUBLIC
)
@@ -108,7 +102,7 @@
"buffer_mapping_test.cc"
"buffer_test.cc"
DEPS
- ::buffer
+ ::hal
::heap_buffer
absl::span
iree::base::status
@@ -116,23 +110,18 @@
iree::testing::gtest_main
)
-iree_cc_library(
+iree_cc_test(
NAME
- command_buffer
- HDRS
- "command_buffer.h"
+ deferred_buffer_test
SRCS
- "command_buffer.cc"
+ "deferred_buffer_test.cc"
DEPS
- ::buffer
- ::descriptor_set
- ::event
- ::executable
- ::executable_layout
- ::resource
- iree::base::bitfield
- iree::base::status
- PUBLIC
+ ::hal
+ ::heap_buffer
+ absl::memory
+ iree::hal::testing::mock_allocator
+ iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_library(
@@ -143,8 +132,7 @@
SRCS
"command_buffer_validation.cc"
DEPS
- ::allocator
- ::command_buffer
+ ::hal
absl::strings
iree::base::logging
iree::base::status
@@ -153,130 +141,14 @@
iree_cc_library(
NAME
- command_queue
- HDRS
- "command_queue.h"
- DEPS
- ::command_buffer
- ::semaphore
- absl::span
- iree::base::bitfield
- iree::base::status
- iree::base::time
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- debug_capture_manager
- HDRS
- "debug_capture_manager.h"
- DEPS
- iree::base::status
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- deferred_buffer
- HDRS
- "deferred_buffer.h"
- SRCS
- "deferred_buffer.cc"
- DEPS
- ::allocator
- ::buffer
- iree::base::status
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- deferred_buffer_test
- SRCS
- "deferred_buffer_test.cc"
- DEPS
- ::deferred_buffer
- ::heap_buffer
- absl::memory
- iree::hal::testing::mock_allocator
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- descriptor_set
- HDRS
- "descriptor_set.h"
- DEPS
- ::buffer
- ::resource
- absl::strings
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- descriptor_set_layout
- HDRS
- "descriptor_set_layout.h"
- DEPS
- ::buffer
- ::resource
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- device
- HDRS
- "device.h"
- DEPS
- ::allocator
- ::buffer
- ::command_queue
- ::descriptor_set
- ::descriptor_set_layout
- ::device_info
- ::event
- ::executable_cache
- ::executable_layout
- ::semaphore
- iree::base::ref_ptr
- iree::base::status
- iree::base::target_platform
- iree::base::time
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- device_info
- HDRS
- "device_info.h"
- DEPS
- absl::strings
- iree::base::bitfield
- PUBLIC
-)
-
-iree_cc_library(
- NAME
device_manager
HDRS
"device_manager.h"
SRCS
"device_manager.cc"
DEPS
- ::allocator
- ::buffer
- ::command_queue
- ::device
- ::device_placement
- ::executable_format
+ ::hal
::heap_buffer
- ::semaphore
absl::span
absl::synchronization
iree::base::status
@@ -287,38 +159,15 @@
iree_cc_library(
NAME
- device_placement
- HDRS
- "device_placement.h"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- driver
- HDRS
- "driver.h"
- DEPS
- ::debug_capture_manager
- ::device
- ::device_info
- iree::base::ref_ptr
- iree::base::status
- PUBLIC
-)
-
-iree_cc_library(
- NAME
driver_registry
HDRS
"driver_registry.h"
SRCS
"driver_registry.cc"
DEPS
- ::driver
+ ::hal
absl::core_headers
absl::synchronization
- iree::base::initializer
iree::base::ref_ptr
iree::base::status
PUBLIC
@@ -326,113 +175,15 @@
iree_cc_library(
NAME
- event
- HDRS
- "event.h"
- DEPS
- ::resource
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable
- HDRS
- "executable.h"
- DEPS
- ::resource
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable_cache
- HDRS
- "executable_cache.h"
- SRCS
- "executable_cache.cc"
- DEPS
- ::executable
- ::executable_format
- ::executable_layout
- ::executable_spec
- iree::base::bitfield
- iree::base::ref_ptr
- iree::base::status
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable_format
- HDRS
- "executable_format.h"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable_layout
- HDRS
- "executable_layout.h"
- DEPS
- ::resource
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- executable_spec
- HDRS
- "executable_spec.h"
- DEPS
- ::executable_format
- absl::span
- PUBLIC
-)
-
-iree_cc_library(
- NAME
heap_buffer
HDRS
"heap_buffer.h"
SRCS
"heap_buffer.cc"
DEPS
- ::allocator
- ::buffer
+ ::hal
iree::base::status
iree::base::tracing
iree::hal::host::host_buffer
PUBLIC
)
-
-iree_cc_library(
- NAME
- resource
- HDRS
- "resource.h"
- DEPS
- iree::base::ref_ptr
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- semaphore
- HDRS
- "semaphore.h"
- DEPS
- ::resource
- iree::base::status
- iree::base::time
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- stack_trace
- HDRS
- "stack_trace.h"
- PUBLIC
-)
diff --git a/iree/hal/api.cc b/iree/hal/api.cc
index 7ec77af..956ac7b 100644
--- a/iree/hal/api.cc
+++ b/iree/hal/api.cc
@@ -429,7 +429,7 @@
break;
default: {
// Treat any unknown format as binary.
- n = 2 * element_size;
+ n = 2 * (int)element_size;
if (buffer && buffer_capacity > n) {
iree_hal_bytes_to_hex_string(data.data, buffer, element_size);
buffer[n] = 0;
@@ -1707,19 +1707,19 @@
// We need to allocate storage to marshal in the semaphores. Ideally we'd
// change the C++ API to make this 1:1 with a reinterpret_cast, however that
// makes the C API more difficult. Bleh.
- int total_semaphore_count = 0;
- for (int i = 0; i < batch_count; ++i) {
+ iree_host_size_t total_semaphore_count = 0;
+ for (iree_host_size_t i = 0; i < batch_count; ++i) {
total_semaphore_count += batches[i].wait_semaphores.count;
total_semaphore_count += batches[i].signal_semaphores.count;
}
absl::InlinedVector<SemaphoreValue, 4> semaphore_values(
total_semaphore_count);
absl::InlinedVector<SubmissionBatch, 2> dst_batches(batch_count);
- int base_semaphore_index = 0;
- for (int i = 0; i < batch_count; ++i) {
+ iree_host_size_t base_semaphore_index = 0;
+ for (iree_host_size_t i = 0; i < batch_count; ++i) {
const auto& src_batch = batches[i];
auto& dst_batch = dst_batches[i];
- for (int j = 0; j < src_batch.wait_semaphores.count; ++j) {
+ for (iree_host_size_t j = 0; j < src_batch.wait_semaphores.count; ++j) {
semaphore_values[base_semaphore_index + j] = {
reinterpret_cast<Semaphore*>(src_batch.wait_semaphores.semaphores[j]),
src_batch.wait_semaphores.payload_values[j]};
@@ -1731,7 +1731,7 @@
dst_batch.command_buffers =
iree::ReinterpretSpan<CommandBuffer*>(absl::MakeConstSpan(
src_batch.command_buffers, src_batch.command_buffer_count));
- for (int j = 0; j < src_batch.signal_semaphores.count; ++j) {
+ for (iree_host_size_t j = 0; j < src_batch.signal_semaphores.count; ++j) {
semaphore_values[base_semaphore_index + j] = {
reinterpret_cast<Semaphore*>(
src_batch.signal_semaphores.semaphores[j]),
@@ -1901,6 +1901,11 @@
}
*out_driver_count = available_drivers.size();
+ *out_driver_names = NULL;
+ if (available_drivers.empty()) {
+ return iree_ok_status();
+ }
+
iree_string_view_t* driver_name_storage = nullptr;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
allocator,
diff --git a/iree/hal/api.h b/iree/hal/api.h
index 05fd0f4..4e0295f 100644
--- a/iree/hal/api.h
+++ b/iree/hal/api.h
@@ -598,7 +598,7 @@
// with real device allocators and will likely incur a copy if used.
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_allocator_create_host_local(iree_allocator_t allocator,
- iree_hal_allocator** out_allocator);
+ iree_hal_allocator_t** out_allocator);
// Retains the given |allocator| for the caller.
IREE_API_EXPORT void IREE_API_CALL
diff --git a/iree/hal/buffer.cc b/iree/hal/buffer.cc
index 13c756f..f798952 100644
--- a/iree/hal/buffer.cc
+++ b/iree/hal/buffer.cc
@@ -22,7 +22,6 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
-#include "absl/types/variant.h"
#include "iree/base/status.h"
namespace iree {
diff --git a/iree/hal/buffer.h b/iree/hal/buffer.h
index 3db906b..a7e8c01 100644
--- a/iree/hal/buffer.h
+++ b/iree/hal/buffer.h
@@ -59,7 +59,6 @@
#include <utility>
#include "absl/types/span.h"
-#include "absl/types/variant.h"
#include "iree/base/bitfield.h"
#include "iree/base/logging.h"
#include "iree/base/status.h"
diff --git a/iree/hal/cts/BUILD b/iree/hal/cts/BUILD
index cbee0cf..b915ee1 100644
--- a/iree/hal/cts/BUILD
+++ b/iree/hal/cts/BUILD
@@ -14,8 +14,6 @@
# Conformance Test Suite (CTS) for HAL implementations.
-load("//iree:build_defs.oss.bzl", "PLATFORM_VULKAN_TEST_DEPS")
-
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
@@ -33,14 +31,9 @@
deps = [
"//iree/base:status",
"//iree/hal:driver_registry",
+ "//iree/hal/drivers",
"//iree/testing:gtest",
-
- # HAL driver modules.
- "//iree/hal/dylib:dylib_driver_module", # build-cleaner: keep
- "//iree/hal/llvmjit:llvmjit_driver_module", # build-cleaner: keep
- "//iree/hal/vmla:vmla_driver_module", # build-cleaner: keep
- "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep
- ] + PLATFORM_VULKAN_TEST_DEPS,
+ ],
)
cc_test(
diff --git a/iree/hal/cts/CMakeLists.txt b/iree/hal/cts/CMakeLists.txt
index 1271a0c..f48aee9 100644
--- a/iree/hal/cts/CMakeLists.txt
+++ b/iree/hal/cts/CMakeLists.txt
@@ -14,8 +14,6 @@
iree_add_all_subdirs()
-# bazel_to_cmake: DO NOT EDIT, IREE_HAL_DRIVER_MODULES is custom logic
-
iree_cc_library(
NAME
cts_test_base
@@ -26,9 +24,8 @@
DEPS
iree::base::status
iree::hal::driver_registry
+ iree::hal::drivers
iree::testing::gtest
- iree::testing::gtest_main
- ${IREE_HAL_DRIVER_MODULES}
TESTONLY
PUBLIC
)
@@ -43,6 +40,7 @@
iree::base::status
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -55,6 +53,7 @@
iree::base::status
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -67,6 +66,7 @@
iree::base::status
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -79,6 +79,7 @@
iree::base::status
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -90,6 +91,7 @@
::cts_test_base
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_test(
@@ -101,4 +103,5 @@
::cts_test_base
iree::hal::driver_registry
iree::testing::gtest
+ iree::testing::gtest_main
)
diff --git a/iree/hal/cts/allocator_test.cc b/iree/hal/cts/allocator_test.cc
index 1608c62..f8e20b3 100644
--- a/iree/hal/cts/allocator_test.cc
+++ b/iree/hal/cts/allocator_test.cc
@@ -77,10 +77,10 @@
allocator_->CanUseBufferLike(allocator_, memory_type, usage, usage));
}
-INSTANTIATE_TEST_SUITE_P(AllDrivers, AllocatorTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, AllocatorTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/cts/buffer_test.cc b/iree/hal/cts/buffer_test.cc
index 26eecc7..dc63dd0 100644
--- a/iree/hal/cts/buffer_test.cc
+++ b/iree/hal/cts/buffer_test.cc
@@ -375,10 +375,10 @@
EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xCC, 0xDD, 4, 5, 6));
}
-INSTANTIATE_TEST_SUITE_P(AllDrivers, BufferTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, BufferTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/cts/command_buffer_test.cc b/iree/hal/cts/command_buffer_test.cc
index 5b6e4d6..234e755 100644
--- a/iree/hal/cts/command_buffer_test.cc
+++ b/iree/hal/cts/command_buffer_test.cc
@@ -237,10 +237,10 @@
// TODO(scotttodd): UpdateBuffer, Dispatch, Sync, etc.
-INSTANTIATE_TEST_SUITE_P(AllDrivers, CommandBufferTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, CommandBufferTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/cts/command_queue_test.cc b/iree/hal/cts/command_queue_test.cc
index 411cd58..84feaae 100644
--- a/iree/hal/cts/command_queue_test.cc
+++ b/iree/hal/cts/command_queue_test.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <mutex>
+
#include "iree/base/status.h"
#include "iree/hal/cts/cts_test_base.h"
#include "iree/hal/driver_registry.h"
@@ -142,6 +144,10 @@
}
std::vector<std::string> GetSupportedDrivers() {
+ static std::once_flag register_once;
+ std::call_once(register_once, [] {
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
+ });
auto drivers = DriverRegistry::shared_registry()->EnumerateAvailableDrivers();
auto it = drivers.begin();
while (it != drivers.end()) {
diff --git a/iree/hal/cts/cts_test_base.h b/iree/hal/cts/cts_test_base.h
index 5d1fde9..fc2a94a 100644
--- a/iree/hal/cts/cts_test_base.h
+++ b/iree/hal/cts/cts_test_base.h
@@ -16,10 +16,12 @@
#define IREE_HAL_CTS_CTS_TEST_BASE_H_
#include <map>
+#include <mutex>
#include <set>
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
+#include "iree/hal/drivers/init.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
@@ -29,6 +31,15 @@
// Common setup for tests parameterized across all registered drivers.
class CtsTestBase : public ::testing::TestWithParam<std::string> {
+ public:
+ static std::vector<std::string> EnumerateAvailableDrivers() {
+ static std::once_flag register_once;
+ std::call_once(register_once, [] {
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
+ });
+ return DriverRegistry::shared_registry()->EnumerateAvailableDrivers();
+ }
+
protected:
// Per-test-suite set-up. This is called before the first test in this test
// suite. We use it to set up drivers that must be reused between test cases
diff --git a/iree/hal/cts/driver_test.cc b/iree/hal/cts/driver_test.cc
index 9b85ccb..2fea7d8 100644
--- a/iree/hal/cts/driver_test.cc
+++ b/iree/hal/cts/driver_test.cc
@@ -30,16 +30,16 @@
TEST_P(DriverTest, EnumerateAndCreateAvailableDevices) {
IREE_ASSERT_OK_AND_ASSIGN(auto devices, driver_->EnumerateAvailableDevices());
- for (int i = 0; i < devices.size(); ++i) {
+ for (iree_host_size_t i = 0; i < devices.size(); ++i) {
IREE_ASSERT_OK_AND_ASSIGN(auto device, driver_->CreateDevice(devices[i]));
IREE_LOG(INFO) << "Device #" << i << " details:\n" << device->DebugString();
}
}
-INSTANTIATE_TEST_SUITE_P(AllDrivers, DriverTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, DriverTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/cts/semaphore_test.cc b/iree/hal/cts/semaphore_test.cc
index c4299ca..727596e 100644
--- a/iree/hal/cts/semaphore_test.cc
+++ b/iree/hal/cts/semaphore_test.cc
@@ -141,10 +141,10 @@
thread.join();
}
-INSTANTIATE_TEST_SUITE_P(AllDrivers, SemaphoreTest,
- ::testing::ValuesIn(DriverRegistry::shared_registry()
- ->EnumerateAvailableDrivers()),
- GenerateTestName());
+INSTANTIATE_TEST_SUITE_P(
+ AllDrivers, SemaphoreTest,
+ ::testing::ValuesIn(CtsTestBase::EnumerateAvailableDrivers()),
+ GenerateTestName());
} // namespace cts
} // namespace hal
diff --git a/iree/hal/driver_registry.cc b/iree/hal/driver_registry.cc
index 81322ea..79fd907 100644
--- a/iree/hal/driver_registry.cc
+++ b/iree/hal/driver_registry.cc
@@ -82,6 +82,3 @@
} // namespace hal
} // namespace iree
-
-IREE_REGISTER_MODULE_INITIALIZER(
- iree_hal, ::iree::hal::DriverRegistry::shared_registry());
diff --git a/iree/hal/driver_registry.h b/iree/hal/driver_registry.h
index 8d182ec..f0675c2 100644
--- a/iree/hal/driver_registry.h
+++ b/iree/hal/driver_registry.h
@@ -20,7 +20,6 @@
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
-#include "iree/base/initializer.h"
#include "iree/base/ref_ptr.h"
#include "iree/base/status.h"
#include "iree/hal/driver.h"
@@ -78,7 +77,4 @@
} // namespace hal
} // namespace iree
-IREE_DECLARE_MODULE_INITIALIZER(iree_hal);
-IREE_REQUIRE_MODULE_LINKED(iree_hal);
-
#endif // IREE_HAL_DRIVER_REGISTRY_H_
diff --git a/iree/hal/drivers/BUILD b/iree/hal/drivers/BUILD
new file mode 100644
index 0000000..8c04a1d
--- /dev/null
+++ b/iree/hal/drivers/BUILD
@@ -0,0 +1,35 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["-layering_check"], # buildozer: disable=no-layering-check, allow indirect headers
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "drivers",
+ srcs = ["init.c"],
+ hdrs = ["init.h"],
+ deps = [
+ "//iree/base:api",
+ "//iree/base:tracing",
+ ] + [
+ # TODO(*): select() and only pull in based on build configuration.
+ "//iree/hal/dylib/registration",
+ "//iree/hal/llvmjit/registration",
+ "//iree/hal/vmla/registration",
+ "//iree/hal/vulkan/registration/google_internal:registration",
+ ],
+)
diff --git a/iree/hal/drivers/CMakeLists.txt b/iree/hal/drivers/CMakeLists.txt
new file mode 100644
index 0000000..388f199
--- /dev/null
+++ b/iree/hal/drivers/CMakeLists.txt
@@ -0,0 +1,46 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# bazel_to_cmake: DO NOT EDIT (custom configuration vars)
+
+set(IREE_HAL_DRIVER_MODULES)
+if(${IREE_HAL_DRIVER_DYLIB})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::registration)
+endif()
+if(${IREE_HAL_DRIVER_LLVM})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::llvmjit::registration)
+endif()
+if(${IREE_HAL_DRIVER_METAL})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::metal::registration)
+endif()
+if(${IREE_HAL_DRIVER_VMLA})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vmla::registration)
+endif()
+if(${IREE_HAL_DRIVER_VULKAN})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vulkan::registration)
+endif()
+
+iree_cc_library(
+ NAME
+ drivers
+ HDRS
+ "init.h"
+ SRCS
+ "init.c"
+ DEPS
+ iree::base::api
+ iree::base::tracing
+ ${IREE_HAL_DRIVER_MODULES}
+ PUBLIC
+)
diff --git a/iree/hal/drivers/init.c b/iree/hal/drivers/init.c
new file mode 100644
index 0000000..38482f4
--- /dev/null
+++ b/iree/hal/drivers/init.c
@@ -0,0 +1,69 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/drivers/init.h"
+
+#include "iree/base/tracing.h"
+
+#if defined(IREE_HAL_HAVE_DYLIB_DRIVER_MODULE)
+#include "iree/hal/dylib/registration/driver_module.h"
+#endif // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE)
+#include "iree/hal/llvmjit/registration/driver_module.h"
+#endif // IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_METAL_DRIVER_MODULE)
+#include "iree/hal/metal/registration/driver_module.h"
+#endif // IREE_HAL_HAVE_METAL_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE)
+#include "iree/hal/vmla/registration/driver_module.h"
+#endif // IREE_HAL_HAVE_VMLA_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_VULKAN_DRIVER_MODULE)
+#include "iree/hal/vulkan/registration/google_internal/driver_module.h"
+#endif // IREE_HAL_HAVE_VULKAN_DRIVER_MODULE
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_register_all_available_drivers() {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+#if defined(IREE_HAL_HAVE_DYLIB_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+ iree_hal_dylib_driver_module_register());
+#endif // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+ iree_hal_llvmjit_driver_module_register());
+#endif // IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_METAL_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+ iree_hal_metal_driver_module_register());
+#endif // IREE_HAL_HAVE_METAL_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_vmla_driver_module_register());
+#endif // IREE_HAL_HAVE_VMLA_DRIVER_MODULE
+
+#if defined(IREE_HAL_HAVE_VULKAN_DRIVER_MODULE)
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
+ iree_hal_vulkan_driver_module_register());
+#endif // IREE_HAL_HAVE_VULKAN_DRIVER_MODULE
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
diff --git a/iree/hal/drivers/init.h b/iree/hal/drivers/init.h
new file mode 100644
index 0000000..6ab9c40
--- /dev/null
+++ b/iree/hal/drivers/init.h
@@ -0,0 +1,38 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DRIVERS_INIT_H_
+#define IREE_HAL_DRIVERS_INIT_H_
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+#include "iree/base/api.h"
+
+// Registers all drivers that were linked into the current binary based on the
+// build configuration. Note that there may be no drivers available.
+//
+// This only registers IREE core drivers (those under iree/hal/). User-provided
+// drivers must be directly registered or directly created, though a user could
+// create their own user_register_all_available_drivers() that calls this as
+// well as registering their drivers.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_register_all_available_drivers();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_DRIVERS_INIT_H_
diff --git a/iree/hal/dylib/BUILD b/iree/hal/dylib/BUILD
index 9c19b61..49a5711 100644
--- a/iree/hal/dylib/BUILD
+++ b/iree/hal/dylib/BUILD
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# HAL implementation for executing functions provided by dynamic libraries.
-
load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
package(
@@ -31,70 +29,32 @@
)
cc_library(
- name = "dylib_device",
- srcs = ["dylib_device.cc"],
- hdrs = ["dylib_device.h"],
- deps = [
- ":dylib_executable_cache",
- "//iree/base:tracing",
- "//iree/hal/host:host_local_device",
+ name = "dylib",
+ srcs = [
+ "dylib_device.cc",
+ "dylib_driver.cc",
+ "dylib_executable.cc",
+ "dylib_executable_cache.cc",
],
-)
-
-cc_library(
- name = "dylib_driver",
- srcs = ["dylib_driver.cc"],
- hdrs = ["dylib_driver.h"],
- deps = [
- ":dylib_device",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "//iree/hal/host/serial:serial_scheduling_model",
+ hdrs = [
+ "dylib_device.h",
+ "dylib_driver.h",
+ "dylib_executable.h",
+ "dylib_executable_cache.h",
],
-)
-
-cc_library(
- name = "dylib_driver_module",
- srcs = ["dylib_driver_module.cc"],
- deps = [
- ":dylib_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/hal:driver_registry",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "dylib_executable",
- srcs = ["dylib_executable.cc"],
- hdrs = ["dylib_executable.h"],
deps = [
"//iree/base:dynamic_library",
"//iree/base:file_io",
"//iree/base:file_path",
+ "//iree/base:flatcc",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_spec",
+ "//iree/hal",
"//iree/hal/host:host_executable",
- "//iree/schemas:dylib_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/hal/host:host_local_device",
+ "//iree/hal/host/serial:serial_scheduling_model",
+ "//iree/schemas:dylib_executable_def_c_fbs",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:span",
],
)
-
-cc_library(
- name = "dylib_executable_cache",
- srcs = ["dylib_executable_cache.cc"],
- hdrs = ["dylib_executable_cache.h"],
- deps = [
- ":dylib_executable",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- ],
-)
diff --git a/iree/hal/dylib/CMakeLists.txt b/iree/hal/dylib/CMakeLists.txt
index b057837..03d278d 100644
--- a/iree/hal/dylib/CMakeLists.txt
+++ b/iree/hal/dylib/CMakeLists.txt
@@ -20,83 +20,30 @@
iree_cc_library(
NAME
- dylib_device
+ dylib
HDRS
"dylib_device.h"
+ "dylib_driver.h"
+ "dylib_executable.h"
+ "dylib_executable_cache.h"
SRCS
"dylib_device.cc"
- DEPS
- ::dylib_executable_cache
- iree::base::tracing
- iree::hal::host::host_local_device
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dylib_driver
- HDRS
- "dylib_driver.h"
- SRCS
"dylib_driver.cc"
- DEPS
- ::dylib_device
- iree::hal::device_info
- iree::hal::driver
- iree::hal::host::serial::serial_scheduling_model
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dylib_driver_module
- SRCS
- "dylib_driver_module.cc"
- DEPS
- ::dylib_driver
- iree::base::init
- iree::base::status
- iree::hal::driver_registry
- ALWAYSLINK
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dylib_executable
- HDRS
- "dylib_executable.h"
- SRCS
"dylib_executable.cc"
+ "dylib_executable_cache.cc"
DEPS
absl::inlined_vector
absl::span
- flatbuffers
iree::base::dynamic_library
iree::base::file_io
iree::base::file_path
+ iree::base::flatcc
iree::base::status
iree::base::tracing
- iree::hal::executable
- iree::hal::executable_spec
+ iree::hal
iree::hal::host::host_executable
- iree::schemas::dylib_executable_def_cc_fbs
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dylib_executable_cache
- HDRS
- "dylib_executable_cache.h"
- SRCS
- "dylib_executable_cache.cc"
- DEPS
- ::dylib_executable
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
+ iree::hal::host::host_local_device
+ iree::hal::host::serial::serial_scheduling_model
+ iree::schemas::dylib_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/hal/dylib/dylib_executable.cc b/iree/hal/dylib/dylib_executable.cc
index 3d3ff01..92a8ea7 100644
--- a/iree/hal/dylib/dylib_executable.cc
+++ b/iree/hal/dylib/dylib_executable.cc
@@ -14,10 +14,62 @@
#include "iree/hal/dylib/dylib_executable.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/base/file_io.h"
#include "iree/base/file_path.h"
-#include "iree/schemas/dylib_executable_def_generated.h"
+
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/dylib_executable_def_reader.h"
+#include "iree/schemas/dylib_executable_def_verifier.h"
+
+// NOTE: starting to port this to C.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_dylib_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret = iree_DyLibExecutableDef_verify_as_root(
+ flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_DyLibExecutableDef_table_t executable_def =
+ iree_DyLibExecutableDef_as_root(flatbuffer_data.data);
+
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_DyLibExecutableDef_entry_points_get(executable_def);
+ size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
+ for (size_t i = 0; i < entry_point_count; ++i) {
+ if (!flatbuffers_string_len(
+ flatbuffers_string_vec_at(entry_points_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable entry point %zu has no name", i);
+ }
+ }
+
+ if (!flatbuffers_uint8_vec_len(
+ iree_DyLibExecutableDef_library_embedded_get(executable_def))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable library_embedded is missing/empty");
+ }
+
+ return iree_ok_status();
+}
namespace iree {
namespace hal {
@@ -49,21 +101,22 @@
Status DyLibExecutable::Initialize(ExecutableSpec spec) {
IREE_TRACE_SCOPE0("DyLibExecutable::Initialize");
- auto dylib_executable_def =
- ::flatbuffers::GetRoot<DyLibExecutableDef>(spec.executable_data.data());
-
- if (!dylib_executable_def->entry_points() ||
- dylib_executable_def->entry_points()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
- }
- if (!dylib_executable_def->library_embedded() ||
- dylib_executable_def->library_embedded()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No embedded library";
- }
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data = iree_make_const_byte_span(
+ spec.executable_data.data(), spec.executable_data.size());
+ IREE_RETURN_IF_ERROR(
+ iree_hal_dylib_executable_flatbuffer_verify(executable_data));
+ iree_DyLibExecutableDef_table_t executable_def =
+ iree_DyLibExecutableDef_as_root(executable_data.data);
// Write the embedded library out to a temp file, since all of the dynamic
// library APIs work with files. We could instead use in-memory files on
// platforms where that is convenient.
+ //
+ // TODO(#3845): use dlopen on an fd with either dlopen(/proc/self/fd/NN),
+ // fdlopen, or android_dlopen_ext to avoid needing to write the file to disk.
+ // Can fallback to memfd_create + dlopen where available, and fallback from
+ // that to disk (maybe just windows/mac).
std::string base_name = "dylib_executable";
IREE_ASSIGN_OR_RETURN(auto library_temp_path,
file_io::GetTempFile(base_name));
@@ -77,46 +130,51 @@
library_temp_path += ".so";
#endif
- absl::string_view embedded_library_data(
- reinterpret_cast<const char*>(
- dylib_executable_def->library_embedded()->data()),
- dylib_executable_def->library_embedded()->size());
- IREE_RETURN_IF_ERROR(
- file_io::SetFileContents(library_temp_path, embedded_library_data));
+ flatbuffers_uint8_vec_t embedded_library_vec =
+ iree_DyLibExecutableDef_library_embedded_get(executable_def);
+ IREE_RETURN_IF_ERROR(file_io::SetFileContents(
+ library_temp_path,
+ absl::string_view(reinterpret_cast<const char*>(embedded_library_vec),
+ flatbuffers_uint8_vec_len(embedded_library_vec))));
IREE_ASSIGN_OR_RETURN(executable_library_,
DynamicLibrary::Load(library_temp_path.c_str()));
- if (dylib_executable_def->debug_database_filename() &&
- dylib_executable_def->debug_database_embedded()) {
+ flatbuffers_string_t debug_database_filename =
+ iree_DyLibExecutableDef_debug_database_filename_get(executable_def);
+ flatbuffers_uint8_vec_t debug_database_embedded_vec =
+ iree_DyLibExecutableDef_debug_database_embedded_get(executable_def);
+ if (flatbuffers_string_len(debug_database_filename) &&
+ flatbuffers_uint8_vec_len(debug_database_embedded_vec)) {
IREE_TRACE_SCOPE0("DyLibExecutable::AttachDebugDatabase");
- absl::string_view debug_database_filename(
- dylib_executable_def->debug_database_filename()->data(),
- dylib_executable_def->debug_database_filename()->size());
- absl::string_view debug_database_data(
- reinterpret_cast<const char*>(
- dylib_executable_def->debug_database_embedded()->data()),
- dylib_executable_def->debug_database_embedded()->size());
auto debug_database_path = file_path::JoinPaths(
- file_path::DirectoryName(library_temp_path), debug_database_filename);
+ file_path::DirectoryName(library_temp_path),
+ absl::string_view(debug_database_filename,
+ flatbuffers_string_len(debug_database_filename)));
temp_file_paths_.push_back(debug_database_path);
- IREE_IGNORE_ERROR(
- file_io::SetFileContents(debug_database_path, debug_database_data));
+ IREE_IGNORE_ERROR(file_io::SetFileContents(
+ debug_database_path,
+ absl::string_view(
+ reinterpret_cast<const char*>(debug_database_embedded_vec),
+ flatbuffers_uint8_vec_len(debug_database_embedded_vec))));
executable_library_->AttachDebugDatabase(debug_database_path.c_str());
}
- const auto& entry_points = *dylib_executable_def->entry_points();
- entry_functions_.resize(entry_points.size());
- IREE_TRACE(entry_names_.resize(entry_points.size()));
- for (int i = 0; i < entry_functions_.size(); ++i) {
- void* symbol = executable_library_->GetSymbol(entry_points[i]->c_str());
+ flatbuffers_string_vec_t entry_points =
+ iree_DyLibExecutableDef_entry_points_get(executable_def);
+ entry_functions_.resize(flatbuffers_string_vec_len(entry_points));
+ IREE_TRACE(entry_names_.resize(flatbuffers_string_vec_len(entry_points)));
+ for (size_t i = 0; i < entry_functions_.size(); ++i) {
+ flatbuffers_string_t entry_point =
+ flatbuffers_string_vec_at(entry_points, i);
+ void* symbol = executable_library_->GetSymbol(entry_point);
if (!symbol) {
return NotFoundErrorBuilder(IREE_LOC)
- << "Could not find symbol: " << entry_points[i];
+ << "Could not find symbol: " << entry_point;
}
entry_functions_[i] = symbol;
- IREE_TRACE(entry_names_[i] = entry_points[i]->c_str());
+ IREE_TRACE(entry_names_[i] = entry_point);
}
return OkStatus();
diff --git a/iree/hal/dylib/registration/BUILD b/iree/hal/dylib/registration/BUILD
new file mode 100644
index 0000000..ec475fa
--- /dev/null
+++ b/iree/hal/dylib/registration/BUILD
@@ -0,0 +1,46 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(NOT ${IREE_HAL_DRIVER_DYLIB})
+ return()
+endif()
+""",
+ inline = True,
+)
+
+cc_library(
+ name = "registration",
+ srcs = ["driver_module.cc"],
+ hdrs = ["driver_module.h"],
+ defines = [
+ "IREE_HAL_HAVE_DYLIB_DRIVER_MODULE=1",
+ ],
+ deps = [
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/hal:api",
+ "//iree/hal:driver_registry",
+ "//iree/hal/dylib",
+ ],
+)
diff --git a/iree/hal/dylib/registration/CMakeLists.txt b/iree/hal/dylib/registration/CMakeLists.txt
new file mode 100644
index 0000000..07fd286
--- /dev/null
+++ b/iree/hal/dylib/registration/CMakeLists.txt
@@ -0,0 +1,37 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(NOT ${IREE_HAL_DRIVER_DYLIB})
+ return()
+endif()
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ iree::base::flags
+ iree::base::status
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::dylib
+ DEFINES
+ "IREE_HAL_HAVE_DYLIB_DRIVER_MODULE=1"
+ PUBLIC
+)
diff --git a/iree/hal/dylib/dylib_driver_module.cc b/iree/hal/dylib/registration/driver_module.cc
similarity index 74%
rename from iree/hal/dylib/dylib_driver_module.cc
rename to iree/hal/dylib/registration/driver_module.cc
index d71d3fa..b9a920f 100644
--- a/iree/hal/dylib/dylib_driver_module.cc
+++ b/iree/hal/dylib/registration/driver_module.cc
@@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/dylib/registration/driver_module.h"
-#include "iree/base/init.h"
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/dylib/dylib_driver.h"
@@ -31,8 +30,8 @@
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_dylib_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "dylib", ::iree::hal::dylib::CreateDyLibDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_dylib_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_dylib_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "dylib", ::iree::hal::dylib::CreateDyLibDriver);
+}
diff --git a/iree/hal/dylib/registration/driver_module.h b/iree/hal/dylib/registration/driver_module.h
new file mode 100644
index 0000000..230a69d
--- /dev/null
+++ b/iree/hal/dylib/registration/driver_module.h
@@ -0,0 +1,33 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_DYLIB_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_DYLIB_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// DEPRECATED: this entire driver will be removed soon.
+// TODO(#3580): remove this entire driver w/ iree_hal_executable_library_t.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_dylib_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_DYLIB_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/host/BUILD b/iree/hal/host/BUILD
index a26fa0f..a968eed 100644
--- a/iree/hal/host/BUILD
+++ b/iree/hal/host/BUILD
@@ -28,7 +28,7 @@
deps = [
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:semaphore",
+ "//iree/hal",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/synchronization",
@@ -56,7 +56,7 @@
"//iree/base:logging",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:buffer",
+ "//iree/hal",
],
)
@@ -65,8 +65,7 @@
srcs = ["host_descriptor_set.cc"],
hdrs = ["host_descriptor_set.h"],
deps = [
- "//iree/hal:descriptor_set",
- "//iree/hal:descriptor_set_layout",
+ "//iree/hal",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -76,8 +75,7 @@
hdrs = ["host_executable.h"],
deps = [
"//iree/base:status",
- "//iree/hal:descriptor_set",
- "//iree/hal:executable",
+ "//iree/hal",
],
)
@@ -86,9 +84,8 @@
srcs = ["host_executable_layout.cc"],
hdrs = ["host_executable_layout.h"],
deps = [
- "//iree/base:memory",
- "//iree/hal:descriptor_set_layout",
- "//iree/hal:executable_layout",
+ "//iree/base:core_headers",
+ "//iree/hal",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -101,8 +98,7 @@
":host_buffer",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:buffer",
+ "//iree/hal",
],
)
@@ -115,12 +111,11 @@
":host_executable_layout",
":host_local_allocator",
":scheduling_model",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:allocator",
+ "//iree/hal",
"//iree/hal:command_buffer_validation",
- "//iree/hal:device",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
@@ -136,7 +131,7 @@
"//iree/base:intrusive_list",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:command_buffer",
+ "//iree/hal",
],
)
@@ -145,7 +140,7 @@
srcs = ["nop_event.cc"],
hdrs = ["nop_event.h"],
deps = [
- "//iree/hal:event",
+ "//iree/hal",
],
)
@@ -153,6 +148,6 @@
name = "scheduling_model",
hdrs = ["scheduling_model.h"],
deps = [
- "//iree/hal:command_queue",
+ "//iree/hal",
],
)
diff --git a/iree/hal/host/CMakeLists.txt b/iree/hal/host/CMakeLists.txt
index 953a777..5c926c6 100644
--- a/iree/hal/host/CMakeLists.txt
+++ b/iree/hal/host/CMakeLists.txt
@@ -28,7 +28,7 @@
absl::synchronization
iree::base::status
iree::base::tracing
- iree::hal::semaphore
+ iree::hal
PUBLIC
)
@@ -56,7 +56,7 @@
iree::base::logging
iree::base::status
iree::base::tracing
- iree::hal::buffer
+ iree::hal
PUBLIC
)
@@ -69,8 +69,7 @@
"host_descriptor_set.cc"
DEPS
absl::inlined_vector
- iree::hal::descriptor_set
- iree::hal::descriptor_set_layout
+ iree::hal
PUBLIC
)
@@ -81,8 +80,7 @@
"host_executable.h"
DEPS
iree::base::status
- iree::hal::descriptor_set
- iree::hal::executable
+ iree::hal
PUBLIC
)
@@ -95,9 +93,8 @@
"host_executable_layout.cc"
DEPS
absl::inlined_vector
- iree::base::memory
- iree::hal::descriptor_set_layout
- iree::hal::executable_layout
+ iree::base::core_headers
+ iree::hal
PUBLIC
)
@@ -112,8 +109,7 @@
::host_buffer
iree::base::status
iree::base::tracing
- iree::hal::allocator
- iree::hal::buffer
+ iree::hal
PUBLIC
)
@@ -132,12 +128,11 @@
absl::core_headers
absl::memory
absl::span
- iree::base::memory
+ iree::base::core_headers
iree::base::status
iree::base::tracing
- iree::hal::allocator
+ iree::hal
iree::hal::command_buffer_validation
- iree::hal::device
PUBLIC
)
@@ -153,7 +148,7 @@
iree::base::intrusive_list
iree::base::status
iree::base::tracing
- iree::hal::command_buffer
+ iree::hal
PUBLIC
)
@@ -165,7 +160,7 @@
SRCS
"nop_event.cc"
DEPS
- iree::hal::event
+ iree::hal
PUBLIC
)
@@ -175,6 +170,6 @@
HDRS
"scheduling_model.h"
DEPS
- iree::hal::command_queue
+ iree::hal
PUBLIC
)
diff --git a/iree/hal/host/host_executable.h b/iree/hal/host/host_executable.h
index 4885885..cb2e6c9 100644
--- a/iree/hal/host/host_executable.h
+++ b/iree/hal/host/host_executable.h
@@ -41,7 +41,7 @@
// Grid parameters shared for all tiles within a dispatch.
struct DispatchParams {
// Entry point within the executable.
- int32_t entry_point = 0;
+ size_t entry_point = 0;
// Total workgroup XYZ count for the grid.
std::array<uint32_t, 3> workgroup_count;
diff --git a/iree/hal/host/serial/BUILD b/iree/hal/host/serial/BUILD
index 6bba64a..9228d98 100644
--- a/iree/hal/host/serial/BUILD
+++ b/iree/hal/host/serial/BUILD
@@ -28,8 +28,7 @@
deps = [
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:command_queue",
- "//iree/hal:semaphore",
+ "//iree/hal",
"//iree/hal/host/serial:serial_submission_queue",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
@@ -43,7 +42,7 @@
":async_command_queue",
"//iree/base:status",
"//iree/base:time",
- "//iree/hal:command_queue",
+ "//iree/hal",
"//iree/hal/host/serial:serial_submission_queue",
"//iree/hal/testing:mock_command_buffer",
"//iree/hal/testing:mock_command_queue",
@@ -60,7 +59,7 @@
deps = [
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:command_buffer",
+ "//iree/hal",
"//iree/hal/host:host_descriptor_set",
"//iree/hal/host:host_executable",
"//iree/hal/host:host_executable_layout",
@@ -76,7 +75,7 @@
":async_command_queue",
":serial_command_processor",
":serial_submission_queue",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:status",
"//iree/base:tracing",
"//iree/hal/host:condvar_semaphore",
@@ -95,7 +94,7 @@
"//iree/base:intrusive_list",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:command_queue",
+ "//iree/hal",
"//iree/hal/host:condvar_semaphore",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
diff --git a/iree/hal/host/serial/CMakeLists.txt b/iree/hal/host/serial/CMakeLists.txt
index 1077854..86b7da5 100644
--- a/iree/hal/host/serial/CMakeLists.txt
+++ b/iree/hal/host/serial/CMakeLists.txt
@@ -26,9 +26,8 @@
absl::synchronization
iree::base::status
iree::base::tracing
- iree::hal::command_queue
+ iree::hal
iree::hal::host::serial::serial_submission_queue
- iree::hal::semaphore
PUBLIC
)
@@ -42,7 +41,7 @@
absl::memory
iree::base::status
iree::base::time
- iree::hal::command_queue
+ iree::hal
iree::hal::host::serial::serial_submission_queue
iree::hal::testing::mock_command_buffer
iree::hal::testing::mock_command_queue
@@ -61,7 +60,7 @@
absl::inlined_vector
iree::base::status
iree::base::tracing
- iree::hal::command_buffer
+ iree::hal
iree::hal::host::host_descriptor_set
iree::hal::host::host_executable
iree::hal::host::host_executable_layout
@@ -80,7 +79,7 @@
::serial_command_processor
::serial_submission_queue
absl::inlined_vector
- iree::base::memory
+ iree::base::core_headers
iree::base::status
iree::base::tracing
iree::hal::host::condvar_semaphore
@@ -104,7 +103,7 @@
iree::base::intrusive_list
iree::base::status
iree::base::tracing
- iree::hal::command_queue
+ iree::hal
iree::hal::host::condvar_semaphore
PUBLIC
)
diff --git a/iree/hal/llvmjit/BUILD b/iree/hal/llvmjit/BUILD
index 2e9d275..d7f5535 100644
--- a/iree/hal/llvmjit/BUILD
+++ b/iree/hal/llvmjit/BUILD
@@ -31,75 +31,33 @@
)
cc_library(
- name = "llvmjit_device",
- srcs = ["llvmjit_device.cc"],
- hdrs = ["llvmjit_device.h"],
- deps = [
- ":llvmjit_executable_cache",
- "//iree/base:tracing",
- "//iree/hal/host:host_local_device",
+ name = "llvmjit",
+ srcs = [
+ "llvmjit_device.cc",
+ "llvmjit_driver.cc",
+ "llvmjit_executable.cc",
+ "llvmjit_executable_cache.cc",
],
-)
-
-cc_library(
- name = "llvmjit_driver",
- srcs = ["llvmjit_driver.cc"],
- hdrs = ["llvmjit_driver.h"],
- deps = [
- ":llvmjit_device",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "//iree/hal/host/serial:serial_scheduling_model",
- "@llvm-project//llvm:ExecutionEngine",
+ hdrs = [
+ "llvmjit_device.h",
+ "llvmjit_driver.h",
+ "llvmjit_executable.h",
+ "llvmjit_executable_cache.h",
],
-)
-
-cc_library(
- name = "llvmjit_driver_module",
- srcs = ["llvmjit_driver_module.cc"],
deps = [
- ":llvmjit_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/hal:driver_registry",
- "@llvm-project//llvm:Support",
- #TODO(ataei): Link with native target dep.
- "@llvm-project//llvm:X86CodeGen",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "llvmjit_executable",
- srcs = ["llvmjit_executable.cc"],
- hdrs = ["llvmjit_executable.h"],
- deps = [
+ "//iree/base:flatcc",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:buffer",
- "//iree/hal:executable",
- "//iree/hal:executable_spec",
+ "//iree/hal",
"//iree/hal/host:host_executable",
- "//iree/schemas:llvmir_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/hal/host:host_local_device",
+ "//iree/hal/host/serial:serial_scheduling_model",
+ "//iree/schemas:llvmir_executable_def_c_fbs",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:Core",
+ "@llvm-project//llvm:ExecutionEngine",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:Support",
],
)
-
-cc_library(
- name = "llvmjit_executable_cache",
- srcs = ["llvmjit_executable_cache.cc"],
- hdrs = ["llvmjit_executable_cache.h"],
- deps = [
- ":llvmjit_executable",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- ],
-)
diff --git a/iree/hal/llvmjit/CMakeLists.txt b/iree/hal/llvmjit/CMakeLists.txt
index 6faf869..ea303d1 100644
--- a/iree/hal/llvmjit/CMakeLists.txt
+++ b/iree/hal/llvmjit/CMakeLists.txt
@@ -20,87 +20,31 @@
iree_cc_library(
NAME
- llvmjit_device
+ llvmjit
HDRS
"llvmjit_device.h"
+ "llvmjit_driver.h"
+ "llvmjit_executable.h"
+ "llvmjit_executable_cache.h"
SRCS
"llvmjit_device.cc"
- DEPS
- ::llvmjit_executable_cache
- iree::base::tracing
- iree::hal::host::host_local_device
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- llvmjit_driver
- HDRS
- "llvmjit_driver.h"
- SRCS
"llvmjit_driver.cc"
- DEPS
- ::llvmjit_device
- LLVMExecutionEngine
- iree::hal::device_info
- iree::hal::driver
- iree::hal::host::serial::serial_scheduling_model
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- llvmjit_driver_module
- SRCS
- "llvmjit_driver_module.cc"
- DEPS
- ::llvmjit_driver
- LLVMSupport
- LLVMX86CodeGen
- iree::base::init
- iree::base::status
- iree::hal::driver_registry
- ALWAYSLINK
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- llvmjit_executable
- HDRS
- "llvmjit_executable.h"
- SRCS
"llvmjit_executable.cc"
+ "llvmjit_executable_cache.cc"
DEPS
LLVMAsmParser
LLVMCore
+ LLVMExecutionEngine
LLVMOrcJIT
LLVMSupport
absl::span
- flatbuffers
+ iree::base::flatcc
iree::base::status
iree::base::tracing
- iree::hal::buffer
- iree::hal::executable
- iree::hal::executable_spec
+ iree::hal
iree::hal::host::host_executable
- iree::schemas::llvmir_executable_def_cc_fbs
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- llvmjit_executable_cache
- HDRS
- "llvmjit_executable_cache.h"
- SRCS
- "llvmjit_executable_cache.cc"
- DEPS
- ::llvmjit_executable
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
+ iree::hal::host::host_local_device
+ iree::hal::host::serial::serial_scheduling_model
+ iree::schemas::llvmir_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/hal/llvmjit/llvmjit_executable.cc b/iree/hal/llvmjit/llvmjit_executable.cc
index 391a217..5dd1a82 100644
--- a/iree/hal/llvmjit/llvmjit_executable.cc
+++ b/iree/hal/llvmjit/llvmjit_executable.cc
@@ -17,11 +17,9 @@
#include <iostream>
#include <memory>
-#include "flatbuffers/flatbuffers.h"
#include "iree/base/tracing.h"
#include "iree/hal/buffer.h"
#include "iree/hal/executable.h"
-#include "iree/schemas/llvmir_executable_def_generated.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/AsmParser/Parser.h"
@@ -31,6 +29,60 @@
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/llvmir_executable_def_reader.h"
+#include "iree/schemas/llvmir_executable_def_verifier.h"
+
+// NOTE: starting to port this to C.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_llvmir_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret = iree_LLVMIRExecutableDef_verify_as_root(
+ flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_LLVMIRExecutableDef_table_t executable_def =
+ iree_LLVMIRExecutableDef_as_root(flatbuffer_data.data);
+
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_LLVMIRExecutableDef_entry_points_get(executable_def);
+ size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
+ for (size_t i = 0; i < entry_point_count; ++i) {
+ if (!flatbuffers_string_len(
+ flatbuffers_string_vec_at(entry_points_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable entry point %zu has no name", i);
+ }
+ }
+
+ if (!flatbuffers_uint8_vec_len(
+ iree_LLVMIRExecutableDef_bitcode_module_get(executable_def))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable bitcode_module is missing/empty");
+ }
+
+ return iree_ok_status();
+}
+
namespace iree {
namespace hal {
namespace llvmjit {
@@ -40,13 +92,20 @@
ExecutableSpec spec, bool allow_aliasing_data) {
IREE_TRACE_SCOPE0("LLVMJITExecutable::Load");
- auto module_def =
- ::flatbuffers::GetRoot<LLVMIRExecutableDef>(spec.executable_data.data());
- auto data =
- reinterpret_cast<const char*>(module_def->llvmir_module()->data());
- const int size = module_def->llvmir_module()->size();
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data = iree_make_const_byte_span(
+ spec.executable_data.data(), spec.executable_data.size());
+ IREE_RETURN_IF_ERROR(
+ iree_hal_llvmir_executable_flatbuffer_verify(executable_data));
+ iree_LLVMIRExecutableDef_table_t executable_def =
+ iree_LLVMIRExecutableDef_as_root(executable_data.data);
+
+ flatbuffers_uint8_vec_t bitcode_module_vec =
+ iree_LLVMIRExecutableDef_bitcode_module_get(executable_def);
auto mem_buffer = llvm::MemoryBuffer::getMemBufferCopy(
- llvm::StringRef(data, size), "llvm-ir");
+ llvm::StringRef(reinterpret_cast<const char*>(bitcode_module_vec),
+ flatbuffers_uint8_vec_len(bitcode_module_vec)),
+ "llvm-ir");
auto llvm_context = std::make_unique<llvm::LLVMContext>();
llvm::SMDiagnostic sm_diagnostic;
auto module = llvm::parseAssembly(*mem_buffer, sm_diagnostic, *llvm_context);
@@ -55,7 +114,6 @@
<< "Can't parse LLVMIR Module: " << sm_diagnostic.getMessage().str();
}
auto dataLayout = module->getDataLayout();
- const auto entry_points = module_def->entry_points();
llvm::orc::ThreadSafeModule thread_safe_module(std::move(module),
std::move(llvm_context));
auto ll_jit = llvm::cantFail(llvm::orc::LLJITBuilder().create());
@@ -67,29 +125,35 @@
<< llvm::toString(std::move(err));
}
- auto dylib_serarch_generator =
+ auto llvmjit_serarch_generator =
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
dataLayout.getGlobalPrefix());
- if (!dylib_serarch_generator) {
+ if (!llvmjit_serarch_generator) {
return UnavailableErrorBuilder(IREE_LOC)
<< "Can't resolve symbols in current process: "
- << llvm::toString(dylib_serarch_generator.takeError());
+ << llvm::toString(llvmjit_serarch_generator.takeError());
}
- auto& main_jitdylib = ll_jit->getMainJITDylib();
- main_jitdylib.addGenerator(std::move(dylib_serarch_generator.get()));
+ auto& main_jitllvmjit = ll_jit->getMainJITDylib();
+ main_jitllvmjit.addGenerator(std::move(llvmjit_serarch_generator.get()));
auto executable =
make_ref<LLVMJITExecutable>(spec, std::move(ll_jit), allow_aliasing_data);
- for (const auto func_name : *entry_points) {
- auto func_symbol = executable->ll_jit_->lookup(func_name->str());
+ flatbuffers_string_vec_t entry_points =
+ iree_LLVMIRExecutableDef_entry_points_get(executable_def);
+ executable->symbols_.resize(flatbuffers_string_vec_len(entry_points));
+ for (size_t i = 0; i < flatbuffers_string_vec_len(entry_points); ++i) {
+ flatbuffers_string_t entry_point =
+ flatbuffers_string_vec_at(entry_points, i);
+ auto func_symbol = executable->ll_jit_->lookup(
+ llvm::StringRef(entry_point, flatbuffers_string_len(entry_point)));
if (!func_symbol) {
return NotFoundErrorBuilder(IREE_LOC)
- << "Can't JIT compile function '" << func_name
+ << "Can't JIT compile function '" << entry_point
<< "': " << llvm::toString(func_symbol.takeError());
}
- executable->symbols_.push_back(func_symbol.get());
+ executable->symbols_[i] = func_symbol.get();
}
return executable;
diff --git a/iree/hal/llvmjit/llvmjit_executable.h b/iree/hal/llvmjit/llvmjit_executable.h
index f672381..6c5b3bc 100644
--- a/iree/hal/llvmjit/llvmjit_executable.h
+++ b/iree/hal/llvmjit/llvmjit_executable.h
@@ -20,7 +20,6 @@
#include "iree/base/status.h"
#include "iree/hal/executable_spec.h"
#include "iree/hal/host/host_executable.h"
-#include "iree/schemas/llvmir_executable_def_generated.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
diff --git a/iree/hal/llvmjit/registration/BUILD b/iree/hal/llvmjit/registration/BUILD
new file mode 100644
index 0000000..b484138
--- /dev/null
+++ b/iree/hal/llvmjit/registration/BUILD
@@ -0,0 +1,54 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(${IREE_HAL_DRIVER_LLVM})
+""",
+ inline = True,
+)
+
+cc_library(
+ name = "registration",
+ srcs = ["driver_module.cc"],
+ hdrs = ["driver_module.h"],
+ defines = [
+ "IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE=1",
+ ],
+ deps = [
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/hal:api",
+ "//iree/hal:driver_registry",
+ "//iree/hal/llvmjit",
+ "@llvm-project//llvm:Support",
+ # TODO(ataei): Link with native target dep.
+ "@llvm-project//llvm:X86CodeGen",
+ ],
+)
+
+iree_cmake_extra_content(
+ content = """
+endif()
+""",
+ inline = True,
+)
diff --git a/iree/hal/llvmjit/registration/CMakeLists.txt b/iree/hal/llvmjit/registration/CMakeLists.txt
new file mode 100644
index 0000000..af7d688
--- /dev/null
+++ b/iree/hal/llvmjit/registration/CMakeLists.txt
@@ -0,0 +1,39 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_LLVM})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ LLVMSupport
+ LLVMX86CodeGen
+ iree::base::flags
+ iree::base::status
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::llvmjit
+ DEFINES
+ "IREE_HAL_HAVE_LLVMJIT_DRIVER_MODULE=1"
+ PUBLIC
+)
+
+endif()
diff --git a/iree/hal/llvmjit/llvmjit_driver_module.cc b/iree/hal/llvmjit/registration/driver_module.cc
similarity index 76%
rename from iree/hal/llvmjit/llvmjit_driver_module.cc
rename to iree/hal/llvmjit/registration/driver_module.cc
index db1837f..544467c 100644
--- a/iree/hal/llvmjit/llvmjit_driver_module.cc
+++ b/iree/hal/llvmjit/registration/driver_module.cc
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/llvmjit/registration/driver_module.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/llvmjit/llvmjit_driver.h"
@@ -34,8 +34,8 @@
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_llvm_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "llvm", ::iree::hal::llvmjit::CreateLLVMJITDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_llvm_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_llvmjit_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "llvm", ::iree::hal::llvmjit::CreateLLVMJITDriver);
+}
diff --git a/iree/hal/llvmjit/registration/driver_module.h b/iree/hal/llvmjit/registration/driver_module.h
new file mode 100644
index 0000000..f7a7c85
--- /dev/null
+++ b/iree/hal/llvmjit/registration/driver_module.h
@@ -0,0 +1,31 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_LLVMJIT_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_LLVMJIT_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_llvmjit_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_LLVMJIT_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/metal/CMakeLists.txt b/iree/hal/metal/CMakeLists.txt
index 7cdb1f5..b8b5f0e 100644
--- a/iree/hal/metal/CMakeLists.txt
+++ b/iree/hal/metal/CMakeLists.txt
@@ -16,215 +16,51 @@
return()
endif()
+iree_add_all_subdirs()
+
iree_cc_library(
NAME
- metal_capture_manager
+ metal
HDRS
+ "metal_buffer.h"
"metal_capture_manager.h"
+ "metal_command_buffer.h"
+ "metal_command_queue.h"
+ "metal_device.h"
+ "metal_direct_allocator.h"
+ "metal_driver.h"
+ "metal_kernel_library.h"
+ "metal_pipeline_argument_buffer.h"
+ "metal_pipeline_cache.h"
+ "metal_shared_event.h"
SRCS
+ "metal_buffer.mm"
"metal_capture_manager.mm"
+ "metal_command_buffer.mm"
+ "metal_command_queue.mm"
+ "metal_device.mm"
+ "metal_direct_allocator.mm"
+ "metal_driver.mm"
+ "metal_kernel_library.mm"
+ "metal_pipeline_argument_buffer.cc"
+ "metal_pipeline_cache.mm"
+ "metal_shared_event.mm"
DEPS
+ absl::flat_hash_map
+ absl::inlined_vector
+ absl::memory
+ absl::span
+ absl::strings
+ iree::base::flatcc
iree::base::file_io
iree::base::logging
iree::base::status
- iree::base::tracing
- iree::hal::debug_capture_manager
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_command_buffer
- HDRS
- "metal_command_buffer.h"
- SRCS
- "metal_command_buffer.mm"
- DEPS
- ::metal_kernel_library
- ::metal_pipeline_argument_buffer
- absl::flat_hash_map
- absl::inlined_vector
- iree::base::status
- iree::base::tracing
- iree::hal::command_buffer
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_command_queue
- HDRS
- "metal_command_queue.h"
- SRCS
- "metal_command_queue.mm"
- DEPS
- ::metal_command_buffer
- ::metal_shared_event
- iree::base::status
iree::base::time
iree::base::tracing
- iree::hal::command_queue
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_device
- HDRS
- "metal_device.h"
- SRCS
- "metal_device.mm"
- DEPS
- ::metal_capture_manager
- ::metal_command_buffer
- ::metal_command_queue
- ::metal_direct_allocator
- ::metal_pipeline_argument_buffer
- ::metal_pipeline_cache
- ::metal_shared_event
- absl::strings
- absl::span
- iree::base::status
- iree::base::time
- iree::base::tracing
- iree::hal::allocator
- iree::hal::command_queue
- iree::hal::device
- iree::hal::driver
- iree::hal::semaphore
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_direct_allocator
- HDRS
- "metal_buffer.h"
- "metal_direct_allocator.h"
- SRCS
- "metal_buffer.mm"
- "metal_direct_allocator.mm"
- DEPS
- absl::memory
- iree::base::logging
- iree::base::status
- iree::base::tracing
- iree::hal::allocator
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_driver
- HDRS
- "metal_driver.h"
- SRCS
- "metal_driver.mm"
- DEPS
- ::metal_capture_manager
- ::metal_device
- iree::base::status
- iree::base::tracing
- iree::hal::device_info
- iree::hal::driver
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_driver_module
- SRCS
- "metal_driver_module.cc"
- DEPS
- ::metal_driver
- absl::flags
- iree::base::init
- iree::base::status
- iree::hal::driver_registry
- ALWAYSLINK
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_kernel_library
- HDRS
- "metal_kernel_library.h"
- SRCS
- "metal_kernel_library.mm"
- DEPS
- absl::inlined_vector
- iree::base::memory
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_spec
- iree::schemas::metal_executable_def_cc_fbs
+ iree::hal
+ iree::schemas::metal_executable_def_c_fbs
LINKOPTS
"-framework Foundation"
"-framework Metal"
PUBLIC
)
-
-iree_cc_library(
- NAME
- metal_pipeline_argument_buffer
- HDRS
- "metal_pipeline_argument_buffer.h"
- SRCS
- "metal_pipeline_argument_buffer.cc"
- DEPS
- absl::inlined_vector
- absl::span
- iree::hal::descriptor_set_layout
- iree::hal::executable_layout
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_pipeline_cache
- HDRS
- "metal_pipeline_cache.h"
- SRCS
- "metal_pipeline_cache.mm"
- DEPS
- ::metal_kernel_library
- flatbuffers
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
- iree::schemas::metal_executable_def_cc_fbs
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- metal_shared_event
- HDRS
- "metal_shared_event.h"
- SRCS
- "metal_shared_event.mm"
- DEPS
- iree::base::tracing
- iree::hal::semaphore
- LINKOPTS
- "-framework Metal"
- PUBLIC
-)
diff --git a/iree/hal/metal/metal_command_buffer.mm b/iree/hal/metal/metal_command_buffer.mm
index 34ab86f..b40369f 100644
--- a/iree/hal/metal/metal_command_buffer.mm
+++ b/iree/hal/metal/metal_command_buffer.mm
@@ -368,11 +368,11 @@
}
IREE_DVLOG(2) << "Dispatch workgroup count: (" << workgroups[0] << ", " << workgroups[1] << ", "
- << workgroups[2] << "), workgroup size: (" << workgroup_size.x() << ", "
- << workgroup_size.y() << ", " << workgroup_size.z() << ")";
- [compute_encoder dispatchThreadgroups:MTLSizeMake(workgroups[0], workgroups[1], workgroups[2])
- threadsPerThreadgroup:MTLSizeMake(workgroup_size.x(), workgroup_size.y(),
- workgroup_size.z())];
+ << workgroups[2] << "), workgroup size: (" << workgroup_size.x << ", "
+ << workgroup_size.y << ", " << workgroup_size.z << ")";
+ [compute_encoder
+ dispatchThreadgroups:MTLSizeMake(workgroups[0], workgroups[1], workgroups[2])
+ threadsPerThreadgroup:MTLSizeMake(workgroup_size.x, workgroup_size.y, workgroup_size.z)];
return OkStatus();
}
diff --git a/iree/hal/metal/metal_kernel_library.h b/iree/hal/metal/metal_kernel_library.h
index 2fbf2f4..d68bc4e 100644
--- a/iree/hal/metal/metal_kernel_library.h
+++ b/iree/hal/metal/metal_kernel_library.h
@@ -24,7 +24,11 @@
#include "iree/hal/executable.h"
#include "iree/hal/executable_cache.h"
#include "iree/hal/executable_spec.h"
-#include "iree/schemas/metal_executable_def_generated.h"
+
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/metal_executable_def_reader.h"
+#include "iree/schemas/metal_executable_def_verifier.h"
namespace iree {
namespace hal {
@@ -41,7 +45,7 @@
public:
static StatusOr<ref_ptr<MetalKernelLibrary>> Create(
id<MTLDevice> device, ExecutableCachingModeBitfield mode,
- const MetalExecutableDef& metal_executable_def);
+ const ExecutableSpec& spec);
~MetalKernelLibrary() override;
bool supports_debugging() const override { return false; }
@@ -50,7 +54,7 @@
StatusOr<id<MTLFunction>> GetKernelForEntryPoint(int ordinal) const;
// Returns the threadgroup size for the entry point with the given |ordinal|.
- StatusOr<MetalThreadgroupSize> GetThreadgroupSizeForEntryPoint(
+ StatusOr<iree_MetalThreadgroupSize_t> GetThreadgroupSizeForEntryPoint(
int ordinal) const;
// Returns the pipeline state object for the entry point with the given
@@ -61,7 +65,7 @@
private:
struct KernelObjects {
id<MTLFunction> function;
- MetalThreadgroupSize threadgroup_size;
+ iree_MetalThreadgroupSize_t threadgroup_size;
// Baked pipeline state object.
id<MTLComputePipelineState> pipeline_state;
};
@@ -69,17 +73,13 @@
// Creates a MetalKernelLibrary assuming all Metal objects are already
// retained before passing in.
MetalKernelLibrary(id<MTLDevice> device,
- absl::InlinedVector<id<MTLLibrary>, 1> libraries,
- absl::InlinedVector<KernelObjects, 1> kernel_objects,
- std::string tag);
-
- // Tag coming from Metal executable FlatBuffer.
- std::string tag_;
+ absl::InlinedVector<id<MTLLibrary>, 4> libraries,
+ absl::InlinedVector<KernelObjects, 4> kernel_objects);
id<MTLDevice> device_;
- absl::InlinedVector<id<MTLLibrary>, 1> libraries_;
- absl::InlinedVector<KernelObjects, 1> kernel_objects_;
+ absl::InlinedVector<id<MTLLibrary>, 4> libraries_;
+ absl::InlinedVector<KernelObjects, 4> kernel_objects_;
};
} // namespace metal
diff --git a/iree/hal/metal/metal_kernel_library.mm b/iree/hal/metal/metal_kernel_library.mm
index 5e60ebe..8ea7e6a 100644
--- a/iree/hal/metal/metal_kernel_library.mm
+++ b/iree/hal/metal/metal_kernel_library.mm
@@ -18,42 +18,99 @@
#include "iree/base/status.h"
#include "iree/base/tracing.h"
+// NOTE: starting to port this to ObjC.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_metal_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret =
+ iree_MetalExecutableDef_verify_as_root(flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_MetalExecutableDef_table_t executable_def =
+ iree_MetalExecutableDef_as_root(flatbuffer_data.data);
+
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_MetalExecutableDef_entry_points_get(executable_def);
+ size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
+ for (size_t i = 0; i < entry_point_count; ++i) {
+ if (!flatbuffers_string_len(flatbuffers_string_vec_at(entry_points_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable entry point %zu has no name", i);
+ }
+ }
+
+ iree_MetalThreadgroupSize_vec_t threadgroup_sizes_vec =
+ iree_MetalExecutableDef_threadgroup_sizes(executable_def);
+ size_t threadgroup_size_count = iree_MetalThreadgroupSize_vec_len(threadgroup_sizes_vec);
+ if (!threadgroup_size_count) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "No threadgroup sizes present";
+ }
+
+ flatbuffers_string_vec_t shader_sources_vec =
+ iree_MetalExecutableDef_shader_sources_get(executable_def);
+ size_t shader_source_count = flatbuffers_string_vec_len(shader_sources_vec);
+ for (size_t i = 0; i < shader_source_count; ++i) {
+ if (!flatbuffers_string_len(flatbuffers_string_vec_at(shader_sources_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "executable shader source %zu is empty",
+ i);
+ }
+ }
+
+ if (entry_point_count != threadgroup_size_count || entry_point_count != shader_source_count) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "mismatch among the numbers of entry points (%zu), thread group sizes "
+ "(%zu), and source strings (%zu)",
+ entry_point_count, threadgroup_size_count, shader_source_count);
+ }
+
+ return iree_ok_status();
+}
+
namespace iree {
namespace hal {
namespace metal {
// static
-StatusOr<ref_ptr<MetalKernelLibrary>> MetalKernelLibrary::Create(
- id<MTLDevice> device, ExecutableCachingModeBitfield mode,
- const MetalExecutableDef& metal_executable_def) {
+StatusOr<ref_ptr<MetalKernelLibrary>> MetalKernelLibrary::Create(id<MTLDevice> device,
+ ExecutableCachingModeBitfield mode,
+ const ExecutableSpec& spec) {
IREE_TRACE_SCOPE0("MetalKernelLibrary::Create");
- if (!metal_executable_def.entry_points() || metal_executable_def.entry_points()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
- }
- if (!metal_executable_def.threadgroup_sizes() ||
- metal_executable_def.threadgroup_sizes()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No threadgroup sizes present";
- }
- if (!metal_executable_def.shader_sources() ||
- metal_executable_def.shader_sources()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No MSL source string present";
- }
- const auto& entry_points = *metal_executable_def.entry_points();
- const auto& threadgroup_sizes = *metal_executable_def.threadgroup_sizes();
- const auto& msl_sources = *metal_executable_def.shader_sources();
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data =
+ iree_make_const_byte_span(spec.executable_data.data(), spec.executable_data.size());
+ IREE_RETURN_IF_ERROR(iree_hal_metal_executable_flatbuffer_verify(executable_data));
+ iree_MetalExecutableDef_table_t executable_def =
+ iree_MetalExecutableDef_as_root(executable_data.data);
- if (entry_points.size() != threadgroup_sizes.size() ||
- entry_points.size() != msl_sources.size()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Mismatch among the numbers of entry points, thread group sizes, and source strings";
- }
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_MetalExecutableDef_entry_points_get(executable_def);
+ iree_MetalThreadgroupSize_vec_t threadgroup_sizes_vec =
+ iree_MetalExecutableDef_threadgroup_sizes(executable_def);
+ flatbuffers_string_vec_t shader_sources_vec =
+ iree_MetalExecutableDef_shader_sources_get(executable_def);
// Compile each MSL source string into a MTLLibrary and get the MTLFunction for the entry point to
// build the pipeline state object.
- absl::InlinedVector<id<MTLLibrary>, 1> libraries;
- absl::InlinedVector<KernelObjects, 1> kernel_objects;
+ absl::InlinedVector<id<MTLLibrary>, 4> libraries;
+ absl::InlinedVector<KernelObjects, 4> kernel_objects;
MTLCompileOptions* msl_compile_options = [MTLCompileOptions new];
msl_compile_options.languageVersion = MTLLanguageVersion2_0;
@@ -71,12 +128,15 @@
// debugging purposes but bad for performance. Enable offline compilation and make that as the
// default.
- for (int i = 0; i < msl_sources.size(); ++i) {
+ for (size_t entry_ordinal = 0; entry_ordinal < flatbuffers_string_vec_len(shader_sources_vec);
+ ++entry_ordinal) {
+ flatbuffers_string_t entry_point = flatbuffers_string_vec_at(entry_points_vec, entry_ordinal);
@autoreleasepool {
NSError* error = nil;
- NSString* shader_source = [NSString stringWithCString:msl_sources[i]->c_str()
- encoding:[NSString defaultCStringEncoding]];
+ NSString* shader_source =
+ [NSString stringWithCString:flatbuffers_string_vec_at(shader_sources_vec, entry_ordinal)
+ encoding:[NSString defaultCStringEncoding]];
id<MTLLibrary> library = [device newLibraryWithSource:shader_source
options:msl_compile_options
error:&error];
@@ -89,17 +149,17 @@
}
libraries.push_back(library);
- NSString* entry_point = [NSString stringWithCString:entry_points[i]->c_str()
- encoding:[NSString defaultCStringEncoding]];
- id<MTLFunction> function = [library newFunctionWithName:entry_point];
+ id<MTLFunction> function = [library
+ newFunctionWithName:[NSString stringWithCString:entry_point
+ encoding:[NSString defaultCStringEncoding]]];
if (!function) {
NSLog(@"Failed to create MTLFunction");
#ifndef NDEBUG
NSLog(@"Original MSL source: %@", shader_source);
#endif
return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Cannot find entry point '" << entry_points[i] << "' in shader source index "
- << i;
+ << "Cannot find entry point '" << entry_point << "' in shader source index "
+ << entry_ordinal;
}
id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:function
@@ -112,21 +172,19 @@
return InvalidArgumentErrorBuilder(IREE_LOC) << "Invalid MSL source";
}
- kernel_objects.push_back(KernelObjects{function, *threadgroup_sizes[i], pso});
+ kernel_objects.push_back(
+ KernelObjects{function, {static_cast<uint32_t>(iree_MetalThreadgroupSize__size())}, pso});
}
}
- std::string tag = metal_executable_def.tag() ? metal_executable_def.tag()->str() : "";
- return assign_ref(new MetalKernelLibrary([device retain], std::move(libraries),
- std::move(kernel_objects), std::move(tag)));
+ return assign_ref(
+ new MetalKernelLibrary([device retain], std::move(libraries), std::move(kernel_objects)));
}
MetalKernelLibrary::MetalKernelLibrary(id<MTLDevice> device,
- absl::InlinedVector<id<MTLLibrary>, 1> libraries,
- absl::InlinedVector<KernelObjects, 1> kernel_objects,
- std::string tag)
- : tag_(std::move(tag)),
- device_(device),
+ absl::InlinedVector<id<MTLLibrary>, 4> libraries,
+ absl::InlinedVector<KernelObjects, 4> kernel_objects)
+ : device_(device),
libraries_(std::move(libraries)),
kernel_objects_(std::move(kernel_objects)) {}
@@ -146,7 +204,7 @@
return kernel_objects_[ordinal].function;
}
-StatusOr<MetalThreadgroupSize> MetalKernelLibrary::GetThreadgroupSizeForEntryPoint(
+StatusOr<iree_MetalThreadgroupSize_t> MetalKernelLibrary::GetThreadgroupSizeForEntryPoint(
int ordinal) const {
if (ordinal < 0 || ordinal >= kernel_objects_.size()) {
return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal: " << ordinal;
diff --git a/iree/hal/metal/metal_pipeline_cache.mm b/iree/hal/metal/metal_pipeline_cache.mm
index 2c44b0b..1e98770 100644
--- a/iree/hal/metal/metal_pipeline_cache.mm
+++ b/iree/hal/metal/metal_pipeline_cache.mm
@@ -14,12 +14,10 @@
#include "iree/hal/metal/metal_pipeline_cache.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/executable_format.h"
#include "iree/hal/metal/metal_kernel_library.h"
-#include "iree/schemas/metal_executable_def_generated.h"
namespace iree {
namespace hal {
@@ -37,19 +35,9 @@
ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode,
const ExecutableSpec& spec) {
IREE_TRACE_SCOPE0("MetalPipelineCache::PrepareExecutable");
- if (spec.executable_data.size() <= 4 ||
- !MetalExecutableDefBufferHasIdentifier(spec.executable_data.data())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Supplied executable data does not contain a MetalExecutableDef";
- }
-
- // Get the Metal executable def flatbuffer.
- const auto& metal_executable_def =
- *::flatbuffers::GetRoot<MetalExecutableDef>(spec.executable_data.data());
// Create the Metal library (which may itself own many pipeline states).
- IREE_ASSIGN_OR_RETURN(auto executable,
- MetalKernelLibrary::Create(metal_device_, mode, metal_executable_def));
+ IREE_ASSIGN_OR_RETURN(auto executable, MetalKernelLibrary::Create(metal_device_, mode, spec));
return executable;
}
diff --git a/iree/hal/metal/registration/CMakeLists.txt b/iree/hal/metal/registration/CMakeLists.txt
new file mode 100644
index 0000000..9a5e57c
--- /dev/null
+++ b/iree/hal/metal/registration/CMakeLists.txt
@@ -0,0 +1,38 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_METAL})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ absl::flags
+ iree::base::flags
+ iree::base::status
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::metal
+ DEFINES
+ "IREE_HAL_HAVE_METAL_DRIVER_MODULE=1"
+ PUBLIC
+)
+
+endif()
diff --git a/iree/hal/metal/metal_driver_module.cc b/iree/hal/metal/registration/driver_module.cc
similarity index 78%
rename from iree/hal/metal/metal_driver_module.cc
rename to iree/hal/metal/registration/driver_module.cc
index 48308ae..2b445f3 100644
--- a/iree/hal/metal/metal_driver_module.cc
+++ b/iree/hal/metal/registration/driver_module.cc
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/metal/registration/driver_module.h"
#include "absl/flags/flag.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/metal/metal_driver.h"
@@ -28,7 +28,6 @@
namespace iree {
namespace hal {
namespace metal {
-namespace {
StatusOr<ref_ptr<Driver>> CreateMetalDriver() {
MetalDriverOptions options;
@@ -37,13 +36,12 @@
return MetalDriver::Create(options);
}
-} // namespace
} // namespace metal
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_metal_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "metal", ::iree::hal::metal::CreateMetalDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_metal_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_metal_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "metal", ::iree::hal::metal::CreateMetalDriver);
+}
diff --git a/iree/hal/metal/registration/driver_module.h b/iree/hal/metal/registration/driver_module.h
new file mode 100644
index 0000000..1b5c9c0
--- /dev/null
+++ b/iree/hal/metal/registration/driver_module.h
@@ -0,0 +1,31 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_METAL_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_metal_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_METAL_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/testing/BUILD b/iree/hal/testing/BUILD
index d69e68c..e66768f 100644
--- a/iree/hal/testing/BUILD
+++ b/iree/hal/testing/BUILD
@@ -25,7 +25,7 @@
testonly = True,
hdrs = ["mock_allocator.h"],
deps = [
- "//iree/hal:allocator",
+ "//iree/hal",
"//iree/testing:gtest",
],
)
@@ -35,7 +35,7 @@
testonly = True,
hdrs = ["mock_command_buffer.h"],
deps = [
- "//iree/hal:command_buffer",
+ "//iree/hal",
"//iree/testing:gtest",
],
)
@@ -45,7 +45,7 @@
testonly = True,
hdrs = ["mock_command_queue.h"],
deps = [
- "//iree/hal:command_queue",
+ "//iree/hal",
"//iree/testing:gtest",
],
)
diff --git a/iree/hal/testing/CMakeLists.txt b/iree/hal/testing/CMakeLists.txt
index e925403..53fa5a8 100644
--- a/iree/hal/testing/CMakeLists.txt
+++ b/iree/hal/testing/CMakeLists.txt
@@ -20,7 +20,7 @@
HDRS
"mock_allocator.h"
DEPS
- iree::hal::allocator
+ iree::hal
iree::testing::gtest
TESTONLY
PUBLIC
@@ -32,7 +32,7 @@
HDRS
"mock_command_buffer.h"
DEPS
- iree::hal::command_buffer
+ iree::hal
iree::testing::gtest
TESTONLY
PUBLIC
@@ -44,7 +44,7 @@
HDRS
"mock_command_queue.h"
DEPS
- iree::hal::command_queue
+ iree::hal
iree::testing::gtest
TESTONLY
PUBLIC
diff --git a/iree/hal/vmla/BUILD b/iree/hal/vmla/BUILD
index d029bbf..cc7204d 100644
--- a/iree/hal/vmla/BUILD
+++ b/iree/hal/vmla/BUILD
@@ -57,7 +57,7 @@
srcs = ["op_kernels_test.cc"],
deps = [
":op_kernels",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"@com_google_absl//absl/container:inlined_vector",
@@ -65,108 +65,55 @@
)
cc_library(
- name = "vmla_cache",
- srcs = ["vmla_cache.cc"],
- hdrs = ["vmla_cache.h"],
- deps = [
- ":vmla_executable",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- "//iree/vm:instance",
- "//iree/vm:module",
- ],
-)
-
-cc_library(
- name = "vmla_device",
- srcs = ["vmla_device.cc"],
- hdrs = ["vmla_device.h"],
- deps = [
- ":vmla_cache",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_queue",
- "//iree/hal:device",
- "//iree/hal/host:host_local_device",
- "//iree/vm:instance",
- "//iree/vm:module",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "vmla_driver",
- srcs = ["vmla_driver.cc"],
- hdrs = ["vmla_driver.h"],
- deps = [
- ":vmla_device",
- ":vmla_module",
- "//iree/base:tracing",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "//iree/hal/host/serial:serial_scheduling_model",
- "//iree/vm:instance",
- "//iree/vm:module",
- ],
-)
-
-cc_library(
- name = "vmla_driver_module",
- srcs = ["vmla_driver_module.cc"],
- deps = [
- ":vmla_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/hal:driver_registry",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "vmla_executable",
- srcs = ["vmla_executable.cc"],
- hdrs = ["vmla_executable.h"],
- deps = [
- ":vmla_module",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_spec",
- "//iree/hal/host:host_buffer",
- "//iree/hal/host:host_executable",
- "//iree/schemas:vmla_executable_def_cc_fbs",
- "//iree/vm:bytecode_module",
- "//iree/vm:context",
- "//iree/vm:instance",
- "//iree/vm:invocation",
- "//iree/vm:list",
- "//iree/vm:module",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "vmla_module",
- srcs = ["vmla_module.cc"],
- hdrs = ["vmla_module.h"],
+ name = "op_module",
+ srcs = ["op_module.cc"],
+ hdrs = ["op_module.h"],
deps = [
":op_kernels",
"//iree/base:api",
- "//iree/base:memory",
+ "//iree/base:core_headers",
"//iree/base:ref_ptr",
"//iree/base:status",
"//iree/base:tracing",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
+ name = "vmla",
+ srcs = [
+ "vmla_cache.cc",
+ "vmla_device.cc",
+ "vmla_driver.cc",
+ "vmla_executable.cc",
+ ],
+ hdrs = [
+ "vmla_cache.h",
+ "vmla_device.h",
+ "vmla_driver.h",
+ "vmla_executable.h",
+ ],
+ deps = [
+ ":op_module",
+ "//iree/base:api",
+ "//iree/base:core_headers",
+ "//iree/base:flatcc",
+ "//iree/base:ref_ptr",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/hal",
+ "//iree/hal/host:host_buffer",
+ "//iree/hal/host:host_executable",
+ "//iree/hal/host:host_local_device",
+ "//iree/hal/host/serial:serial_scheduling_model",
+ "//iree/schemas:vmla_executable_def_c_fbs",
+ "//iree/vm",
+ "//iree/vm:bytecode_module",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
diff --git a/iree/hal/vmla/CMakeLists.txt b/iree/hal/vmla/CMakeLists.txt
index eb22436..b25a0dc 100644
--- a/iree/hal/vmla/CMakeLists.txt
+++ b/iree/hal/vmla/CMakeLists.txt
@@ -47,131 +47,63 @@
DEPS
::op_kernels
absl::inlined_vector
- iree::base::memory
+ iree::base::core_headers
iree::testing::gtest
iree::testing::gtest_main
)
iree_cc_library(
NAME
- vmla_cache
+ op_module
HDRS
- "vmla_cache.h"
+ "op_module.h"
SRCS
- "vmla_cache.cc"
- DEPS
- ::vmla_executable
- iree::base::status
- iree::base::tracing
- iree::hal::allocator
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
- iree::vm::instance
- iree::vm::module
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_device
- HDRS
- "vmla_device.h"
- SRCS
- "vmla_device.cc"
- DEPS
- ::vmla_cache
- absl::inlined_vector
- absl::memory
- absl::span
- absl::strings
- iree::base::memory
- iree::base::status
- iree::base::tracing
- iree::hal::command_queue
- iree::hal::device
- iree::hal::host::host_local_device
- iree::vm::instance
- iree::vm::module
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_driver
- HDRS
- "vmla_driver.h"
- SRCS
- "vmla_driver.cc"
- DEPS
- ::vmla_device
- ::vmla_module
- iree::base::tracing
- iree::hal::device_info
- iree::hal::driver
- iree::hal::host::serial::serial_scheduling_model
- iree::vm::instance
- iree::vm::module
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_driver_module
- SRCS
- "vmla_driver_module.cc"
- DEPS
- ::vmla_driver
- iree::base::init
- iree::base::status
- iree::hal::driver_registry
- ALWAYSLINK
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_executable
- HDRS
- "vmla_executable.h"
- SRCS
- "vmla_executable.cc"
- DEPS
- ::vmla_module
- absl::inlined_vector
- absl::span
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_spec
- iree::hal::host::host_buffer
- iree::hal::host::host_executable
- iree::schemas::vmla_executable_def_cc_fbs
- iree::vm::bytecode_module
- iree::vm::context
- iree::vm::instance
- iree::vm::invocation
- iree::vm::list
- iree::vm::module
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vmla_module
- HDRS
- "vmla_module.h"
- SRCS
- "vmla_module.cc"
+ "op_module.cc"
DEPS
::op_kernels
absl::span
iree::base::api
- iree::base::memory
+ iree::base::core_headers
iree::base::ref_ptr
iree::base::status
iree::base::tracing
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ vmla
+ HDRS
+ "vmla_cache.h"
+ "vmla_device.h"
+ "vmla_driver.h"
+ "vmla_executable.h"
+ SRCS
+ "vmla_cache.cc"
+ "vmla_device.cc"
+ "vmla_driver.cc"
+ "vmla_executable.cc"
+ DEPS
+ ::op_module
+ absl::inlined_vector
+ absl::memory
+ absl::span
+ absl::strings
+ iree::base::api
+ iree::base::core_headers
+ iree::base::flatcc
+ iree::base::ref_ptr
+ iree::base::status
+ iree::base::tracing
+ iree::hal
+ iree::hal::host::host_buffer
+ iree::hal::host::host_executable
+ iree::hal::host::host_local_device
+ iree::hal::host::serial::serial_scheduling_model
+ iree::schemas::vmla_executable_def_c_fbs
+ iree::vm
+ iree::vm::bytecode_module
PUBLIC
)
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index a84e2cf..f3ec435 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -96,7 +96,7 @@
size_t element_size) {
absl::InlinedVector<size_t, 6> strides(shape.size());
strides.back() = element_size;
- for (int i = shape.size() - 2; i >= 0; --i) {
+ for (int i = static_cast<int>(shape.size()) - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
@@ -110,7 +110,7 @@
absl::Span<const int32_t> dst_indices,
absl::Span<const int32_t> lengths) {
if (lengths.size() > 1) {
- for (int i = 0; i < lengths[0]; ++i) {
+ for (int32_t i = 0; i < lengths[0]; ++i) {
size_t src_offset = src_strides[0] * (src_indices[0] + i);
size_t dst_offset = dst_strides[0] * (dst_indices[0] + i);
CopyRegion(src_buffer.subspan(src_offset), src_strides.subspan(1),
@@ -1067,12 +1067,12 @@
absl::Span<const int> src_indices,
ShapeSpan src_shape, T init_value,
ShapeSpan window_dimensions, T* dst_value) {
- int rank = src_shape.size();
+ size_t rank = src_shape.size();
absl::InlinedVector<int, 8> window_indices(rank, 0);
auto getSrcValue = [&]() -> T {
- int flat_idx = 0;
- for (int i = 0; i < rank; ++i) {
- int idx = src_indices[i] + window_indices[i];
+ size_t flat_idx = 0;
+ for (size_t i = 0; i < rank; ++i) {
+ size_t idx = src_indices[i] + window_indices[i];
if (idx < 0 || idx >= src_shape[i]) return init_value;
flat_idx = flat_idx * src_shape[i] + idx;
}
@@ -1080,7 +1080,7 @@
};
*dst_value = init_value;
- for (int i = 0, e = GetElementCount(window_dimensions); i < e; ++i) {
+ for (size_t i = 0, e = GetElementCount(window_dimensions); i < e; ++i) {
KernelImpl()(dst_value, getSrcValue());
IncrementShapeIndex(absl::MakeSpan(window_indices), window_dimensions);
}
@@ -1092,11 +1092,11 @@
ShapeSpan src_shape, ShapeSpan dst_shape,
ShapeSpan window_dimensions, ShapeSpan strides,
ShapeSpan pad_low) {
- int rank = src_shape.size();
+ size_t rank = src_shape.size();
absl::InlinedVector<int, 8> src_indices(rank, 0);
absl::InlinedVector<int, 8> dst_indices(rank, 0);
- for (int i = 0, e = GetElementCount(dst_shape); i < e; ++i) {
- for (int j = 0; j < rank; ++j) {
+ for (size_t i = 0, e = GetElementCount(dst_shape); i < e; ++i) {
+ for (size_t j = 0; j < rank; ++j) {
src_indices[j] = dst_indices[j] * strides[j] - pad_low[j];
}
ComputePoolingWindow<T, KernelImpl>(src_buffer, src_indices, src_shape,
diff --git a/iree/hal/vmla/op_kernels_test.cc b/iree/hal/vmla/op_kernels_test.cc
index 877457b..feb1dee 100644
--- a/iree/hal/vmla/op_kernels_test.cc
+++ b/iree/hal/vmla/op_kernels_test.cc
@@ -39,7 +39,7 @@
}
template <typename T>
-std::vector<T> MakeIota(int size) {
+std::vector<T> MakeIota(size_t size) {
std::vector<T> v(size);
std::iota(v.begin(), v.end(), static_cast<T>(1));
return v;
@@ -390,7 +390,7 @@
absl::MakeSpan(dst_buffer),
dimension, src_shape, dst_shape));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
@@ -409,7 +409,7 @@
absl::MakeSpan(dst_buffer),
dimension, src_shape, dst_shape));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
@@ -421,8 +421,8 @@
Shape strides = {1, 2, 3, 1};
Shape pad_low = {0, 0, 0, 0};
std::vector<int> src_buffer = MakeIota<int>(GetShapeElementCount(src_shape));
- std::vector<int> init_buffer(1, 0.0f);
- std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
+ std::vector<int> init_buffer(1, 0);
+ std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0);
std::vector<int> expected_dst = {9, 12, 21, 24};
IREE_EXPECT_OK(PoolingMax::Execute<int>(
@@ -442,8 +442,8 @@
Shape strides = {1, 1};
Shape pad_low = {1, 1};
std::vector<int> src_buffer = MakeIota<int>(GetShapeElementCount(src_shape));
- std::vector<int> init_buffer(1, 100.0);
- std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
+ std::vector<int> init_buffer(1, 100);
+ std::vector<int> dst_buffer(GetShapeElementCount(dst_shape), 0);
std::vector<int> expected_dst = {1, 1, 2, 1, 1, 2};
IREE_EXPECT_OK(PoolingMin::Execute<int>(
@@ -467,7 +467,7 @@
IREE_EXPECT_OK(PoolingSum::Execute<float>(
src_buffer, init_buffer, absl::MakeSpan(dst_buffer), src_shape, dst_shape,
window_sizes, strides, pad_low));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
@@ -485,10 +485,10 @@
std::vector<float> filter_buffer(GetShapeElementCount(filter_shape));
std::vector<float> expected_dst = {1310, 1466, 1622, 1778,
2090, 2246, 2402, 2558};
- for (int i = 0; i < GetShapeElementCount(input_shape); ++i) {
- input_buffer[i] = i + 1;
+ for (size_t i = 0; i < GetShapeElementCount(input_shape); ++i) {
+ input_buffer[i] = static_cast<float>(i + 1);
if (i < GetShapeElementCount(filter_shape)) {
- filter_buffer[i] = i + 1;
+ filter_buffer[i] = static_cast<float>(i + 1);
}
}
std::vector<float> dst_buffer(GetShapeElementCount(dst_shape), 0.0f);
@@ -498,7 +498,7 @@
absl::MakeSpan(dst_buffer), dst_shape, strides, pad_h, pad_w,
lhs_dilation, rhs_dilation, 1));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
@@ -518,7 +518,7 @@
1124, 1196, 1346, 1424, 1256, 1340, 1502, 1592, 1388, 1484, 1658,
1760, 1520, 1628, 1814, 1928, 1784, 1916, 2126, 2264, 1916, 2060,
2282, 2432, 2048, 2204, 2438, 2600, 2180, 2348, 2594, 2768};
- for (int i = 0; i < GetShapeElementCount(input_shape); ++i) {
+ for (size_t i = 0; i < GetShapeElementCount(input_shape); ++i) {
input_buffer[i] = i + 1;
if (i < GetShapeElementCount(filter_shape)) {
filter_buffer[i] = i + 1;
@@ -531,7 +531,7 @@
absl::MakeSpan(dst_buffer), dst_shape, strides, pad_h, pad_w,
lhs_dilation, rhs_dilation, 2));
- for (int i = 0; i < dst_buffer.size(); ++i) {
+ for (size_t i = 0; i < dst_buffer.size(); ++i) {
EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon);
}
}
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/op_module.cc
similarity index 99%
rename from iree/hal/vmla/vmla_module.cc
rename to iree/hal/vmla/op_module.cc
index 76b9fce..ff7ead3 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/op_module.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/hal/vmla/vmla_module.h"
+#include "iree/hal/vmla/op_module.h"
#include <cstdint>
@@ -129,8 +129,8 @@
constexpr int Interface::kMaxBindings;
void Interface::Reset() {
- for (int i = 0; i < bindings_.size(); ++i) {
- for (int j = 0; j < bindings_[i].size(); ++j) {
+ for (size_t i = 0; i < bindings_.size(); ++i) {
+ for (size_t j = 0; j < bindings_[i].size(); ++j) {
bindings_[i][j] = {};
}
}
@@ -150,7 +150,7 @@
<< "Constant value overflow; have " << values.size()
<< " but max is " << kMaxConstants;
}
- for (int i = 0; i < values.size(); ++i) {
+ for (size_t i = 0; i < values.size(); ++i) {
constants_[i] = values[i];
}
return OkStatus();
diff --git a/iree/hal/vmla/vmla_module.h b/iree/hal/vmla/op_module.h
similarity index 97%
rename from iree/hal/vmla/vmla_module.h
rename to iree/hal/vmla/op_module.h
index 9bb3860..f94a671 100644
--- a/iree/hal/vmla/vmla_module.h
+++ b/iree/hal/vmla/op_module.h
@@ -16,8 +16,8 @@
// linked into the same library, because of this we can avoid the C shims and
// directly use C++ types.
-#ifndef IREE_HAL_VMLA_VMLA_MODULE_H_
-#define IREE_HAL_VMLA_VMLA_MODULE_H_
+#ifndef IREE_HAL_VMLA_OP_MODULE_H_
+#define IREE_HAL_VMLA_OP_MODULE_H_
#include <cstdint>
@@ -136,4 +136,4 @@
IREE_VM_DECLARE_TYPE_ADAPTERS(Buffer, iree::hal::vmla::Buffer);
IREE_VM_DECLARE_TYPE_ADAPTERS(Interface, iree::hal::vmla::Interface);
-#endif // IREE_HAL_VMLA_VMLA_MODULE_H_
+#endif // IREE_HAL_VMLA_OP_MODULE_H_
diff --git a/iree/hal/vmla/registration/BUILD b/iree/hal/vmla/registration/BUILD
new file mode 100644
index 0000000..c37ffce
--- /dev/null
+++ b/iree/hal/vmla/registration/BUILD
@@ -0,0 +1,51 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(${IREE_HAL_DRIVER_VMLA})
+""",
+ inline = True,
+)
+
+cc_library(
+ name = "registration",
+ srcs = ["driver_module.cc"],
+ hdrs = ["driver_module.h"],
+ defines = [
+ "IREE_HAL_HAVE_VMLA_DRIVER_MODULE=1",
+ ],
+ deps = [
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/hal:api",
+ "//iree/hal:driver_registry",
+ "//iree/hal/vmla",
+ ],
+)
+
+iree_cmake_extra_content(
+ content = """
+endif()
+""",
+ inline = True,
+)
diff --git a/iree/hal/vmla/registration/CMakeLists.txt b/iree/hal/vmla/registration/CMakeLists.txt
new file mode 100644
index 0000000..04aee07
--- /dev/null
+++ b/iree/hal/vmla/registration/CMakeLists.txt
@@ -0,0 +1,37 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_VMLA})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ iree::base::flags
+ iree::base::status
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::vmla
+ DEFINES
+ "IREE_HAL_HAVE_VMLA_DRIVER_MODULE=1"
+ PUBLIC
+)
+
+endif()
diff --git a/iree/hal/vmla/vmla_driver_module.cc b/iree/hal/vmla/registration/driver_module.cc
similarity index 63%
rename from iree/hal/vmla/vmla_driver_module.cc
rename to iree/hal/vmla/registration/driver_module.cc
index b051db6..94877a1 100644
--- a/iree/hal/vmla/vmla_driver_module.cc
+++ b/iree/hal/vmla/registration/driver_module.cc
@@ -1,4 +1,4 @@
-// Copyright 2019 Google LLC
+// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/vmla/registration/driver_module.h"
-#include "iree/base/init.h"
#include "iree/base/status.h"
#include "iree/hal/driver_registry.h"
#include "iree/hal/vmla/vmla_driver.h"
@@ -22,17 +21,17 @@
namespace iree {
namespace hal {
namespace vmla {
-namespace {
-StatusOr<ref_ptr<Driver>> CreateVMLADriver() { return VMLADriver::Create(); }
+static StatusOr<ref_ptr<Driver>> CreateVMLADriver() {
+ return VMLADriver::Create();
+}
-} // namespace
} // namespace vmla
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_vmla_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "vmla", ::iree::hal::vmla::CreateVMLADriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_vmla_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_vmla_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "vmla", ::iree::hal::vmla::CreateVMLADriver);
+}
diff --git a/iree/hal/vmla/registration/driver_module.h b/iree/hal/vmla/registration/driver_module.h
new file mode 100644
index 0000000..ec7ff1a
--- /dev/null
+++ b/iree/hal/vmla/registration/driver_module.h
@@ -0,0 +1,33 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VMLA_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_VMLA_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// DEPRECATED: this entire driver will be removed soon.
+// TODO(#3580): remove this entire driver w/ iree_hal_executable_library_t.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_vmla_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_VMLA_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/vmla/vmla_cache.h b/iree/hal/vmla/vmla_cache.h
index 62650ab..560d4bd 100644
--- a/iree/hal/vmla/vmla_cache.h
+++ b/iree/hal/vmla/vmla_cache.h
@@ -18,8 +18,7 @@
#include "iree/hal/allocator.h"
#include "iree/hal/executable.h"
#include "iree/hal/executable_cache.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vmla/vmla_device.h b/iree/hal/vmla/vmla_device.h
index 78f23e6..3d3f3aa 100644
--- a/iree/hal/vmla/vmla_device.h
+++ b/iree/hal/vmla/vmla_device.h
@@ -17,8 +17,7 @@
#include "iree/base/memory.h"
#include "iree/hal/host/host_local_device.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vmla/vmla_driver.cc b/iree/hal/vmla/vmla_driver.cc
index c518efa..54696b1 100644
--- a/iree/hal/vmla/vmla_driver.cc
+++ b/iree/hal/vmla/vmla_driver.cc
@@ -19,9 +19,8 @@
#include "iree/base/tracing.h"
#include "iree/hal/device_info.h"
#include "iree/hal/host/serial/serial_scheduling_model.h"
+#include "iree/hal/vmla/op_module.h"
#include "iree/hal/vmla/vmla_device.h"
-#include "iree/hal/vmla/vmla_module.h"
-#include "iree/vm/module.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vmla/vmla_driver.h b/iree/hal/vmla/vmla_driver.h
index 3d8cb37..c701f4d 100644
--- a/iree/hal/vmla/vmla_driver.h
+++ b/iree/hal/vmla/vmla_driver.h
@@ -16,8 +16,7 @@
#define IREE_HAL_VMLA_VMLA_DRIVER_H_
#include "iree/hal/driver.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vmla/vmla_executable.cc b/iree/hal/vmla/vmla_executable.cc
index 4cc447a..da0cff2 100644
--- a/iree/hal/vmla/vmla_executable.cc
+++ b/iree/hal/vmla/vmla_executable.cc
@@ -17,12 +17,53 @@
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/host/host_buffer.h"
-#include "iree/hal/vmla/vmla_module.h"
-#include "iree/schemas/vmla_executable_def_generated.h"
+#include "iree/hal/vmla/op_module.h"
#include "iree/vm/bytecode_module.h"
-#include "iree/vm/invocation.h"
-#include "iree/vm/list.h"
-#include "iree/vm/module.h"
+
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/vmla_executable_def_reader.h"
+#include "iree/schemas/vmla_executable_def_verifier.h"
+
+// NOTE: starting to port this to C.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_vmla_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret = iree_VMLAExecutableDef_verify_as_root(
+ flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_VMLAExecutableDef_table_t executable_def =
+ iree_VMLAExecutableDef_as_root(flatbuffer_data.data);
+
+ if (flatbuffers_uint8_vec_len(
+ iree_VMLAExecutableDef_bytecode_module_get(executable_def)) < 0) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable bytecode_module is missing/empty");
+ }
+
+ // NOTE: we don't check the actual bytecode module contents here; it's opaque
+ // to us and passed on to the VM.
+ return iree_ok_status();
+}
namespace iree {
namespace hal {
@@ -61,34 +102,28 @@
iree_vm_module_t* vmla_module) {
IREE_TRACE_SCOPE0("VMLAExecutable::Initialize");
- if (spec_.executable_data.size() < 16) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Flatbuffer data is not present or less than 16 bytes";
- } else if (!iree::VMLAExecutableDefBufferHasIdentifier(
- spec_.executable_data.data())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Flatbuffer data does not have bytecode module identifier";
- }
-
- const auto* executable_def = ::flatbuffers::GetRoot<iree::VMLAExecutableDef>(
- spec_.executable_data.data());
- if (!executable_def || !executable_def->bytecode_module()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Failed getting root from flatbuffer data";
- }
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data = iree_make_const_byte_span(
+ spec_.executable_data.data(), spec_.executable_data.size());
+ IREE_RETURN_IF_ERROR(
+ iree_hal_vmla_executable_flatbuffer_verify(executable_data));
+ iree_VMLAExecutableDef_table_t executable_def =
+ iree_VMLAExecutableDef_as_root(executable_data.data);
// Load bytecode module from the executable spec.
+ flatbuffers_uint8_vec_t bytecode_module_vec =
+ iree_VMLAExecutableDef_bytecode_module_get(executable_def);
+ iree_const_byte_span_t bytecode_module_data = iree_make_const_byte_span(
+ bytecode_module_vec, flatbuffers_uint8_vec_len(bytecode_module_vec));
iree_vm_module_t* bytecode_module = nullptr;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
- iree_const_byte_span_t{reinterpret_cast<const uint8_t*>(
- executable_def->bytecode_module()->data()),
- executable_def->bytecode_module()->size()},
- iree_allocator_null(), iree_allocator_system(), &bytecode_module))
+ bytecode_module_data, iree_allocator_null(), iree_allocator_system(),
+ &bytecode_module))
<< "Failed to load executable bytecode module";
entry_functions_.resize(
iree_vm_module_signature(bytecode_module).export_function_count);
- for (int i = 0; i < entry_functions_.size(); ++i) {
+ for (size_t i = 0; i < entry_functions_.size(); ++i) {
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i,
&entry_functions_[i], nullptr));
@@ -134,7 +169,7 @@
auto* interface = &dispatch_state->interface;
IREE_RETURN_IF_ERROR(interface->SetConstants(params.push_constants->values));
- for (int set_ordinal = 0; set_ordinal < params.set_bindings.size();
+ for (size_t set_ordinal = 0; set_ordinal < params.set_bindings.size();
++set_ordinal) {
for (const auto& binding : params.set_bindings[set_ordinal]) {
// TODO(benvanik): plumb binding directly into VMLA to avoid this.
@@ -167,7 +202,7 @@
/*element_type=*/nullptr,
/*interface*/ 1 + /*workgroup_xyz[3]*/ 3, &input_list));
iree_vm_list_push_ref_retain(input_list, &dispatch_state->interface_ref);
- for (int i = 0; i < workgroup_xyz.size(); ++i) {
+ for (size_t i = 0; i < workgroup_xyz.size(); ++i) {
iree_vm_value_t value = iree_vm_value_make_i32(workgroup_xyz[i]);
iree_vm_list_push_value(input_list, &value);
}
diff --git a/iree/hal/vmla/vmla_executable.h b/iree/hal/vmla/vmla_executable.h
index 0e79b36..7eb3c27 100644
--- a/iree/hal/vmla/vmla_executable.h
+++ b/iree/hal/vmla/vmla_executable.h
@@ -22,9 +22,7 @@
#include "iree/base/status.h"
#include "iree/hal/executable_spec.h"
#include "iree/hal/host/host_executable.h"
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vulkan/BUILD b/iree/hal/vulkan/BUILD
index a00ccef..7c44f47 100644
--- a/iree/hal/vulkan/BUILD
+++ b/iree/hal/vulkan/BUILD
@@ -14,7 +14,7 @@
# HAL implementation using Vulkan and (likely) SPIR-V executables.
-load("//iree:build_defs.oss.bzl", "PLATFORM_VULKAN_TEST_DEPS")
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
package(
default_visibility = ["//visibility:public"],
@@ -22,20 +22,12 @@
licenses = ["notice"], # Apache 2.0
)
-# --define=IREE_VK=native to use the native Vulkan drivers (and real hardware).
-config_setting(
- name = "native_vk",
- values = {
- "define": "IREE_VK=native",
- },
-)
-
-# --define=IREE_VK=swiftshader to use SwiftShader.
-config_setting(
- name = "swiftshader_vk",
- values = {
- "define": "IREE_VK=swiftshader",
- },
+iree_cmake_extra_content(
+ content = """
+if(NOT ${IREE_HAL_DRIVER_VULKAN})
+ return()
+endif()
+""",
)
cc_library(
@@ -44,360 +36,66 @@
hdrs = ["api.h"],
visibility = ["//visibility:public"],
deps = [
- ":dynamic_symbols",
- ":extensibility_util",
- ":vulkan_device",
- ":vulkan_driver",
+ ":utils",
+ ":vulkan",
"//iree/base:api",
"//iree/base:tracing",
"//iree/hal:api",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
],
)
cc_library(
- name = "debug_reporter",
- srcs = ["debug_reporter.cc"],
- hdrs = ["debug_reporter.h"],
- deps = [
- ":dynamic_symbols",
- ":status_util",
- "//iree/base:status",
- "//iree/base:tracing",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
+ name = "utils",
+ srcs = [
+ "debug_reporter.cc",
+ "dynamic_symbols.cc",
+ "extensibility_util.cc",
+ "renderdoc_capture_manager.cc",
+ "status_util.cc",
+ "timepoint_util.cc",
],
-)
-
-cc_library(
- name = "descriptor_pool_cache",
- srcs = ["descriptor_pool_cache.cc"],
- hdrs = ["descriptor_pool_cache.h"],
- deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
- "//iree/base:ref_ptr",
- "//iree/base:tracing",
- "@com_google_absl//absl/container:inlined_vector",
- ],
-)
-
-cc_library(
- name = "descriptor_set_arena",
- srcs = ["descriptor_set_arena.cc"],
- hdrs = ["descriptor_set_arena.h"],
- deps = [
- ":descriptor_pool_cache",
- ":pipeline_executable",
- ":status_util",
- ":vma_allocator",
- "//iree/base:alignment",
- "//iree/base:arena",
- "//iree/base:math",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- ],
-)
-
-cc_library(
- name = "direct_command_buffer",
- srcs = ["direct_command_buffer.cc"],
- hdrs = ["direct_command_buffer.h"],
- deps = [
- ":descriptor_pool_cache",
- ":descriptor_set_arena",
- ":dynamic_symbols",
- ":handle_util",
- ":native_descriptor_set",
- ":native_event",
- ":pipeline_executable",
- ":pipeline_executable_layout",
- ":status_util",
- ":vma_allocator",
- "//iree/base:math",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "direct_command_queue",
- srcs = ["direct_command_queue.cc"],
- hdrs = ["direct_command_queue.h"],
- deps = [
- ":direct_command_buffer",
- ":dynamic_symbols",
- ":handle_util",
- ":native_timeline_semaphore",
- ":status_util",
- "//iree/base:api",
- "//iree/base:arena",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_queue",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "dynamic_symbols",
- srcs = ["dynamic_symbols.cc"],
hdrs = [
+ "debug_reporter.h",
"dynamic_symbol_tables.h",
"dynamic_symbols.h",
+ "extensibility_util.h",
+ "handle_util.h",
+ "renderdoc_capture_manager.h",
+ "status_util.h",
+ "timepoint_util.h",
+ "vulkan_headers.h",
],
deps = [
+ "//iree/base:core_headers",
"//iree/base:dynamic_library",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:target_platform",
- "//iree/base:tracing",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_test(
- name = "dynamic_symbols_test",
- srcs = ["dynamic_symbols_test.cc"],
- tags = ["driver=vulkan"],
- deps = [
- ":status_util",
- ":dynamic_symbols",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ] + PLATFORM_VULKAN_TEST_DEPS,
-)
-
-cc_library(
- name = "emulated_timeline_semaphore",
- srcs = ["emulated_timeline_semaphore.cc"],
- hdrs = ["emulated_timeline_semaphore.h"],
- deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
- ":timepoint_util",
- "//iree/base:api",
"//iree/base:intrusive_list",
+ "//iree/base:logging",
"//iree/base:ref_ptr",
"//iree/base:status",
"//iree/base:time",
"//iree/base:tracing",
- "//iree/hal:semaphore",
+ "//iree/hal",
"@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/utility",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "extensibility_util",
- srcs = ["extensibility_util.cc"],
- hdrs = ["extensibility_util.h"],
- deps = [
- ":dynamic_symbols",
- ":status_util",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
"@com_google_absl//absl/types:span",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "handle_util",
- hdrs = ["handle_util.h"],
- deps = [
- ":dynamic_symbols",
- ":extensibility_util",
- "//iree/base:ref_ptr",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/utility",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "native_event",
- srcs = ["native_event.cc"],
- hdrs = ["native_event.h"],
- deps = [
- ":handle_util",
- "//iree/hal:event",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "native_descriptor_set",
- srcs = ["native_descriptor_set.cc"],
- hdrs = ["native_descriptor_set.h"],
- deps = [
- ":handle_util",
- "//iree/hal:descriptor_set",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "native_timeline_semaphore",
- srcs = ["native_timeline_semaphore.cc"],
- hdrs = ["native_timeline_semaphore.h"],
- deps = [
- ":handle_util",
- ":status_util",
- "//iree/base:tracing",
- "//iree/hal:semaphore",
- "@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "pipeline_cache",
- srcs = ["pipeline_cache.cc"],
- hdrs = ["pipeline_cache.h"],
- deps = [
- ":handle_util",
- ":pipeline_executable",
- ":status_util",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_format",
- "//iree/schemas:spirv_executable_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "pipeline_executable",
- srcs = ["pipeline_executable.cc"],
- hdrs = ["pipeline_executable.h"],
- deps = [
- ":handle_util",
- ":native_descriptor_set",
- ":pipeline_executable_layout",
- ":status_util",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:executable",
- "//iree/hal:executable_cache",
- "//iree/hal:executable_layout",
- "//iree/hal:executable_spec",
- "//iree/schemas:spirv_executable_def_cc_fbs",
- "@com_google_absl//absl/container:inlined_vector",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "pipeline_executable_layout",
- srcs = ["pipeline_executable_layout.cc"],
- hdrs = ["pipeline_executable_layout.h"],
- deps = [
- ":handle_util",
- "//iree/hal:descriptor_set_layout",
- "//iree/hal:executable_layout",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/types:span",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "renderdoc_capture_manager",
- srcs = ["renderdoc_capture_manager.cc"],
- hdrs = ["renderdoc_capture_manager.h"],
- deps = [
- "//iree/base:dynamic_library",
- "//iree/base:logging",
- "//iree/base:status",
- "//iree/base:target_platform",
- "//iree/base:tracing",
- "//iree/hal:debug_capture_manager",
- "@com_google_absl//absl/types:span",
+ "@iree_vulkan_headers//:vulkan_headers",
"@renderdoc_api//:renderdoc_app",
],
)
-cc_library(
- name = "serializing_command_queue",
- srcs = ["serializing_command_queue.cc"],
- hdrs = ["serializing_command_queue.h"],
- deps = [
- ":direct_command_buffer",
- ":dynamic_symbols",
- ":emulated_timeline_semaphore",
- ":handle_util",
- ":status_util",
- ":timepoint_util",
- "//iree/base:api",
- "//iree/base:intrusive_list",
- "//iree/base:memory",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:command_buffer",
- "//iree/hal:command_queue",
- "//iree/hal:semaphore",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "status_util",
- srcs = ["status_util.cc"],
- hdrs = ["status_util.h"],
- deps = [
- "//iree/base:status",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "timepoint_util",
- srcs = ["timepoint_util.cc"],
- hdrs = ["timepoint_util.h"],
- deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
- "//iree/base:intrusive_list",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "//iree/base:time",
- "//iree/base:tracing",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/utility",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
+# TODO(suderman): Re-enable once testing solution is found.
+# cc_test(
+# name = "dynamic_symbols_test",
+# srcs = ["dynamic_symbols_test.cc"],
+# tags = ["driver=vulkan"],
+# deps = [
+# ":utils",
+# "//iree/testing:gtest",
+# "//iree/testing:gtest_main",
+# ],
+# )
cc_library(
name = "vma_allocator",
@@ -411,106 +109,74 @@
"vma_allocator.h",
"vma_buffer.h",
],
- copts = [
- # Only needed in the implementation cc and not by external users.
- "-DVMA_STATIC_VULKAN_FUNCTIONS=0",
- ] + select({
- "//iree:iree_is_msvc": [],
- "//conditions:default": [
- "-Wno-thread-safety-attributes", # External code.
- ],
- }),
deps = [
- ":dynamic_symbols",
- ":handle_util",
- ":status_util",
+ ":utils",
+ "//iree/base:core_headers",
"//iree/base:logging",
"//iree/base:status",
"//iree/base:tracing",
- "//iree/hal:allocator",
- "//iree/hal:buffer",
+ "//iree/hal",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
"@vulkan_memory_allocator//:impl_header_only",
],
)
cc_library(
- name = "vulkan_device",
- srcs = ["vulkan_device.cc"],
- hdrs = ["vulkan_device.h"],
+ name = "vulkan",
+ srcs = [
+ "descriptor_pool_cache.cc",
+ "descriptor_set_arena.cc",
+ "direct_command_buffer.cc",
+ "direct_command_queue.cc",
+ "emulated_timeline_semaphore.cc",
+ "native_descriptor_set.cc",
+ "native_event.cc",
+ "native_timeline_semaphore.cc",
+ "pipeline_cache.cc",
+ "pipeline_executable.cc",
+ "pipeline_executable_layout.cc",
+ "serializing_command_queue.cc",
+ "vulkan_device.cc",
+ "vulkan_driver.cc",
+ ],
+ hdrs = [
+ "descriptor_pool_cache.h",
+ "descriptor_set_arena.h",
+ "direct_command_buffer.h",
+ "direct_command_queue.h",
+ "emulated_timeline_semaphore.h",
+ "native_descriptor_set.h",
+ "native_event.h",
+ "native_timeline_semaphore.h",
+ "pipeline_cache.h",
+ "pipeline_executable.h",
+ "pipeline_executable_layout.h",
+ "serializing_command_queue.h",
+ "vulkan_device.h",
+ "vulkan_driver.h",
+ ],
deps = [
- ":descriptor_pool_cache",
- ":direct_command_buffer",
- ":direct_command_queue",
- ":dynamic_symbols",
- ":emulated_timeline_semaphore",
- ":extensibility_util",
- ":handle_util",
- ":native_descriptor_set",
- ":native_event",
- ":native_timeline_semaphore",
- ":pipeline_cache",
- ":pipeline_executable_layout",
- ":serializing_command_queue",
- ":status_util",
+ ":utils",
":vma_allocator",
- "//iree/base:math",
- "//iree/base:memory",
+ "//iree/base:api",
+ "//iree/base:arena",
+ "//iree/base:core_headers",
+ "//iree/base:flatcc",
+ "//iree/base:intrusive_list",
+ "//iree/base:ref_ptr",
"//iree/base:status",
"//iree/base:time",
"//iree/base:tracing",
- "//iree/hal:allocator",
+ "//iree/hal",
"//iree/hal:command_buffer_validation",
- "//iree/hal:command_queue",
- "//iree/hal:debug_capture_manager",
- "//iree/hal:device",
- "//iree/hal:driver",
- "//iree/hal:semaphore",
+ "//iree/schemas:spirv_executable_def_c_fbs",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
],
)
-
-cc_library(
- name = "vulkan_driver",
- srcs = ["vulkan_driver.cc"],
- hdrs = ["vulkan_driver.h"],
- deps = [
- ":debug_reporter",
- ":dynamic_symbols",
- ":extensibility_util",
- ":renderdoc_capture_manager",
- ":status_util",
- ":vulkan_device",
- "//iree/base:memory",
- "//iree/base:status",
- "//iree/base:target_platform",
- "//iree/base:tracing",
- "//iree/hal:device_info",
- "//iree/hal:driver",
- "@com_google_absl//absl/container:inlined_vector",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
- ],
-)
-
-cc_library(
- name = "vulkan_driver_module",
- srcs = ["vulkan_driver_module.cc"],
- deps = [
- ":dynamic_symbols",
- ":vulkan_driver",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/hal:driver_registry",
- "@com_google_absl//absl/flags:flag",
- ],
- alwayslink = 1,
-)
diff --git a/iree/hal/vulkan/CMakeLists.txt b/iree/hal/vulkan/CMakeLists.txt
index 19e3685..3bdeeac 100644
--- a/iree/hal/vulkan/CMakeLists.txt
+++ b/iree/hal/vulkan/CMakeLists.txt
@@ -16,9 +16,7 @@
return()
endif()
-set(VMA_SRC_ROOT
- "${IREE_ROOT_DIR}/third_party/vulkan_memory_allocator/src/"
-)
+iree_add_all_subdirs()
iree_cc_library(
NAME
@@ -27,159 +25,52 @@
"api.h"
SRCS
"api.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
DEPS
+ ::utils
+ ::vulkan
iree::base::api
iree::base::tracing
iree::hal::api
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::extensibility_util
- iree::hal::vulkan::vulkan_device
- iree::hal::vulkan::vulkan_driver
- Vulkan::Headers
PUBLIC
)
iree_cc_library(
NAME
- debug_reporter
+ utils
HDRS
"debug_reporter.h"
- SRCS
- "debug_reporter.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- iree::base::status
- iree::base::tracing
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::status_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- descriptor_pool_cache
- HDRS
- "descriptor_pool_cache.h"
- SRCS
- "descriptor_pool_cache.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- iree::base::ref_ptr
- iree::base::tracing
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::status_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- descriptor_set_arena
- HDRS
- "descriptor_set_arena.h"
- SRCS
- "descriptor_set_arena.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- iree::base::alignment
- iree::base::arena
- iree::base::math
- iree::base::status
- iree::base::tracing
- iree::hal::command_buffer
- iree::hal::vulkan::descriptor_pool_cache
- iree::hal::vulkan::pipeline_executable
- iree::hal::vulkan::status_util
- iree::hal::vulkan::vma_allocator
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- direct_command_buffer
- HDRS
- "direct_command_buffer.h"
- SRCS
- "direct_command_buffer.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::core_headers
- absl::inlined_vector
- absl::synchronization
- iree::base::math
- iree::base::status
- iree::base::tracing
- iree::hal::command_buffer
- iree::hal::vulkan::descriptor_pool_cache
- iree::hal::vulkan::descriptor_set_arena
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::native_event
- iree::hal::vulkan::pipeline_executable
- iree::hal::vulkan::status_util
- iree::hal::vulkan::vma_allocator
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- direct_command_queue
- HDRS
- "direct_command_queue.h"
- SRCS
- "direct_command_queue.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::core_headers
- absl::synchronization
- iree::base::arena
- iree::base::memory
- iree::base::status
- iree::base::time
- iree::base::tracing
- iree::hal::command_queue
- iree::hal::vulkan::direct_command_buffer
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::native_timeline_semaphore
- iree::hal::vulkan::status_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- dynamic_symbols
- HDRS
"dynamic_symbol_tables.h"
"dynamic_symbols.h"
+ "extensibility_util.h"
+ "handle_util.h"
+ "renderdoc_capture_manager.h"
+ "status_util.h"
+ "timepoint_util.h"
+ "vulkan_headers.h"
SRCS
+ "debug_reporter.cc"
"dynamic_symbols.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
+ "extensibility_util.cc"
+ "renderdoc_capture_manager.cc"
+ "status_util.cc"
+ "timepoint_util.cc"
DEPS
+ Vulkan::Headers
absl::core_headers
absl::memory
absl::span
absl::strings
+ absl::synchronization
+ iree::base::core_headers
iree::base::dynamic_library
+ iree::base::intrusive_list
+ iree::base::logging
iree::base::ref_ptr
iree::base::status
- iree::base::target_platform
+ iree::base::time
iree::base::tracing
- Vulkan::Headers
+ iree::hal
+ renderdoc_api::renderdoc_app
PUBLIC
)
@@ -188,278 +79,12 @@
dynamic_symbols_test
SRCS
"dynamic_symbols_test.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
DEPS
- iree::hal::vulkan::status_util
- iree::hal::vulkan::dynamic_symbols
+ ::utils
iree::testing::gtest
iree::testing::gtest_main
LABELS
- "nokokoro"
-)
-
-iree_cc_library(
- NAME
- emulated_timeline_semaphore
- HDRS
- "emulated_timeline_semaphore.h"
- SRCS
- "emulated_timeline_semaphore.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- ::status_util
- ::timepoint_util
- absl::inlined_vector
- absl::synchronization
- iree::base::intrusive_list
- iree::base::status
- iree::base::time
- iree::base::tracing
- iree::hal::semaphore
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- extensibility_util
- HDRS
- "extensibility_util.h"
- SRCS
- "extensibility_util.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::span
- iree::base::memory
- iree::base::status
- iree::base::tracing
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::status_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- handle_util
- HDRS
- "handle_util.h"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- absl::synchronization
- absl::utility
- iree::base::ref_ptr
- iree::hal::command_queue
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::extensibility_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_descriptor_set
- HDRS
- native_descriptor_set.h
- SRCS
- native_descriptor_set.cc
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- iree::hal::descriptor_set
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_event
- HDRS
- "native_event.h"
- SRCS
- "native_event.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- iree::hal::event
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_timeline_semaphore
- HDRS
- "native_timeline_semaphore.h"
- SRCS
- "native_timeline_semaphore.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- ::status_util
- iree::base::tracing
- iree::hal::semaphore
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- pipeline_cache
- HDRS
- "pipeline_cache.h"
- SRCS
- "pipeline_cache.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- absl::synchronization
- flatbuffers
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_format
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::pipeline_executable
- iree::hal::vulkan::status_util
- iree::schemas::spirv_executable_def_cc_fbs
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- pipeline_executable
- HDRS
- "pipeline_executable.h"
- SRCS
- "pipeline_executable.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- ::native_descriptor_set
- ::pipeline_executable_layout
- ::status_util
- absl::inlined_vector
- iree::base::memory
- iree::base::status
- iree::base::tracing
- iree::hal::executable
- iree::hal::executable_cache
- iree::hal::executable_layout
- iree::hal::executable_spec
- iree::schemas::spirv_executable_def_cc_fbs
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- pipeline_executable_layout
- HDRS
- "pipeline_executable_layout.h"
- SRCS
- "pipeline_executable_layout.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- absl::span
- iree::hal::descriptor_set_layout
- iree::hal::executable_layout
- iree::hal::vulkan::handle_util
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- renderdoc_capture_manager
- HDRS
- "renderdoc_capture_manager.h"
- SRCS
- "renderdoc_capture_manager.cc"
- DEPS
- iree::base::dynamic_library
- iree::base::logging
- iree::base::status
- iree::base::target_platform
- iree::base::tracing
- iree::hal::debug_capture_manager
- renderdoc_api::renderdoc_app
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- serializing_command_queue
- HDRS
- "serializing_command_queue.h"
- SRCS
- "serializing_command_queue.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::direct_command_buffer
- ::emulated_timeline_semaphore
- ::handle_util
- ::status_util
- ::timepoint_util
- absl::inlined_vector
- absl::synchronization
- iree::base::status
- iree::base::tracing
- iree::hal::command_queue
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- status_util
- HDRS
- "status_util.h"
- SRCS
- "status_util.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- iree::base::status
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- timepoint_util
- HDRS
- "timepoint_util.h"
- SRCS
- "timepoint_util.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- ::handle_util
- absl::synchronization
- iree::base::intrusive_list
- iree::base::ref_ptr
- iree::base::status
- iree::base::time
- iree::base::tracing
- Vulkan::Headers
- PUBLIC
+ "driver=vulkan"
)
iree_cc_library(
@@ -473,120 +98,73 @@
"internal_vk_mem_alloc.h"
"vma_allocator.cc"
"vma_buffer.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- # Only needed in the implementation cc and not by external users.
- "-DVMA_STATIC_VULKAN_FUNCTIONS=0"
- # Silence some warnings.
- "-Wno-nullability-completeness"
- # Note: IREE_DEFAULT_COPTS sets -Wthread-safety-analysis, so we can't
- # disable here without switching off of iree_cc_library or refactoring
- # "-Wno-thread-safety-analysis"
- INCLUDES
- ${VMA_SRC_ROOT}
DEPS
+ ::utils
absl::flat_hash_map
absl::memory
absl::synchronization
+ iree::base::core_headers
iree::base::logging
iree::base::status
iree::base::tracing
- iree::hal::allocator
- iree::hal::buffer
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::handle_util
- iree::hal::vulkan::status_util
- Vulkan::Headers
+ iree::hal
+ vulkan_memory_allocator
PUBLIC
)
iree_cc_library(
NAME
- vulkan_device
+ vulkan
HDRS
+ "descriptor_pool_cache.h"
+ "descriptor_set_arena.h"
+ "direct_command_buffer.h"
+ "direct_command_queue.h"
+ "emulated_timeline_semaphore.h"
+ "native_descriptor_set.h"
+ "native_event.h"
+ "native_timeline_semaphore.h"
+ "pipeline_cache.h"
+ "pipeline_executable.h"
+ "pipeline_executable_layout.h"
+ "serializing_command_queue.h"
"vulkan_device.h"
+ "vulkan_driver.h"
SRCS
+ "descriptor_pool_cache.cc"
+ "descriptor_set_arena.cc"
+ "direct_command_buffer.cc"
+ "direct_command_queue.cc"
+ "emulated_timeline_semaphore.cc"
+ "native_descriptor_set.cc"
+ "native_event.cc"
+ "native_timeline_semaphore.cc"
+ "pipeline_cache.cc"
+ "pipeline_executable.cc"
+ "pipeline_executable_layout.cc"
+ "serializing_command_queue.cc"
"vulkan_device.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
+ "vulkan_driver.cc"
DEPS
- ::descriptor_pool_cache
- ::direct_command_buffer
- ::direct_command_queue
- ::dynamic_symbols
- ::emulated_timeline_semaphore
- ::extensibility_util
- ::handle_util
- ::native_descriptor_set
- ::native_timeline_semaphore
- ::pipeline_cache
- ::pipeline_executable_layout
- ::serializing_command_queue
- ::status_util
+ ::utils
::vma_allocator
+ absl::core_headers
absl::inlined_vector
absl::memory
+ absl::span
absl::strings
absl::synchronization
- absl::span
- iree::base::math
- iree::base::memory
+ iree::base::api
+ iree::base::arena
+ iree::base::core_headers
+ iree::base::flatcc
+ iree::base::intrusive_list
+ iree::base::ref_ptr
iree::base::status
iree::base::time
iree::base::tracing
- iree::hal::allocator
+ iree::hal
iree::hal::command_buffer_validation
- iree::hal::command_queue
- iree::hal::debug_capture_manager
- iree::hal::device
- iree::hal::driver
- iree::hal::semaphore
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vulkan_driver
- HDRS
- "vulkan_driver.h"
- SRCS
- "vulkan_driver.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::inlined_vector
- iree::base::memory
- iree::base::status
- iree::base::target_platform
- iree::base::tracing
- iree::hal::device_info
- iree::hal::driver
- iree::hal::vulkan::debug_reporter
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::extensibility_util
- iree::hal::vulkan::renderdoc_capture_manager
- iree::hal::vulkan::status_util
- iree::hal::vulkan::vulkan_device
- Vulkan::Headers
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vulkan_driver_module
- SRCS
- "vulkan_driver_module.cc"
- COPTS
- "-DVK_NO_PROTOTYPES"
- DEPS
- absl::flags
- iree::base::init
- iree::base::status
- iree::base::tracing
- iree::hal::driver_registry
- iree::hal::vulkan::dynamic_symbols
- iree::hal::vulkan::vulkan_driver
- ALWAYSLINK
+ iree::schemas::spirv_executable_def_c_fbs
PUBLIC
)
diff --git a/iree/hal/vulkan/api.h b/iree/hal/vulkan/api.h
index c3abe0c..6ad56a4 100644
--- a/iree/hal/vulkan/api.h
+++ b/iree/hal/vulkan/api.h
@@ -17,7 +17,7 @@
#ifndef IREE_HAL_VULKAN_API_H_
#define IREE_HAL_VULKAN_API_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
diff --git a/iree/hal/vulkan/debug_reporter.h b/iree/hal/vulkan/debug_reporter.h
index ffc594b..edcdbb3 100644
--- a/iree/hal/vulkan/debug_reporter.h
+++ b/iree/hal/vulkan/debug_reporter.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_DEBUG_REPORTER_H_
#define IREE_HAL_VULKAN_DEBUG_REPORTER_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/base/status.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
diff --git a/iree/hal/vulkan/descriptor_pool_cache.cc b/iree/hal/vulkan/descriptor_pool_cache.cc
index 32723f5..6feea16 100644
--- a/iree/hal/vulkan/descriptor_pool_cache.cc
+++ b/iree/hal/vulkan/descriptor_pool_cache.cc
@@ -65,7 +65,7 @@
std::array<VkDescriptorPoolSize, 1> pool_sizes;
pool_sizes[0].type = descriptor_type;
pool_sizes[0].descriptorCount = max_descriptor_count;
- create_info.poolSizeCount = pool_sizes.size();
+ create_info.poolSizeCount = static_cast<uint32_t>(pool_sizes.size());
create_info.pPoolSizes = pool_sizes.data();
DescriptorPool descriptor_pool;
diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc
index 97b8b86..238384f 100644
--- a/iree/hal/vulkan/descriptor_set_arena.cc
+++ b/iree/hal/vulkan/descriptor_set_arena.cc
@@ -131,7 +131,7 @@
// Pick a bucket based on the number of descriptors required.
// NOTE: right now we are 1:1 with bindings.
- int required_descriptor_count = bindings.size() * 1;
+ int required_descriptor_count = static_cast<int>(bindings.size() * 1);
int max_descriptor_count =
std::max(8, RoundUpToNearestPow2(required_descriptor_count));
int bucket = TrailingZeros(max_descriptor_count >> 3);
@@ -194,7 +194,8 @@
// descriptor sets we will need and what buffers they will point to (without
// doing just as much work as actually recording the buffer to try to find
// out).
- syms().vkUpdateDescriptorSets(*logical_device_, write_infos.size(),
+ syms().vkUpdateDescriptorSets(*logical_device_,
+ static_cast<uint32_t>(write_infos.size()),
write_infos.data(), 0, nullptr);
// Bind the descriptor set.
@@ -219,7 +220,8 @@
// command buffer and prevent the need for our own pooling mechanisms.
syms().vkCmdPushDescriptorSetKHR(
command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE,
- executable_layout->handle(), set, write_infos.size(), write_infos.data());
+ executable_layout->handle(), set,
+ static_cast<uint32_t>(write_infos.size()), write_infos.data());
return OkStatus();
}
diff --git a/iree/hal/vulkan/direct_command_buffer.cc b/iree/hal/vulkan/direct_command_buffer.cc
index 0222c73..b5a9aab 100644
--- a/iree/hal/vulkan/direct_command_buffer.cc
+++ b/iree/hal/vulkan/direct_command_buffer.cc
@@ -230,8 +230,10 @@
syms()->vkCmdPipelineBarrier(
command_buffer_, ConvertPipelineStageFlags(source_stage_mask),
ConvertPipelineStageFlags(target_stage_mask), /*dependencyFlags=*/0,
- memory_barrier_infos.size(), memory_barrier_infos.data(),
- buffer_barrier_infos.size(), buffer_barrier_infos.data(), 0, nullptr);
+ static_cast<uint32_t>(memory_barrier_infos.size()),
+ memory_barrier_infos.data(),
+ static_cast<uint32_t>(buffer_barrier_infos.size()),
+ buffer_barrier_infos.data(), 0, nullptr);
return OkStatus();
}
@@ -296,12 +298,14 @@
info.size = buffer_barrier.length;
}
- syms()->vkCmdWaitEvents(
- command_buffer_, event_handles.size(), event_handles.data(),
- ConvertPipelineStageFlags(source_stage_mask),
- ConvertPipelineStageFlags(target_stage_mask), memory_barrier_infos.size(),
- memory_barrier_infos.data(), buffer_barrier_infos.size(),
- buffer_barrier_infos.data(), 0, nullptr);
+ syms()->vkCmdWaitEvents(command_buffer_, event_handles.size(),
+ event_handles.data(),
+ ConvertPipelineStageFlags(source_stage_mask),
+ ConvertPipelineStageFlags(target_stage_mask),
+ static_cast<uint32_t>(memory_barrier_infos.size()),
+ memory_barrier_infos.data(),
+ static_cast<uint32_t>(buffer_barrier_infos.size()),
+ buffer_barrier_infos.data(), 0, nullptr);
return OkStatus();
}
@@ -385,8 +389,9 @@
syms()->vkCmdPushConstants(
command_buffer_, device_executable_layout->handle(),
- VK_SHADER_STAGE_COMPUTE_BIT, offset * sizeof(uint32_t),
- values.size() * sizeof(uint32_t), values.data());
+ VK_SHADER_STAGE_COMPUTE_BIT,
+ static_cast<uint32_t>(offset * sizeof(uint32_t)),
+ static_cast<uint32_t>(values.size() * sizeof(uint32_t)), values.data());
return OkStatus();
}
@@ -424,8 +429,9 @@
device_descriptor_set->handle()};
syms()->vkCmdBindDescriptorSets(
command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
- device_executable_layout->handle(), set, descriptor_sets.size(),
- descriptor_sets.data(), dynamic_offsets_i32.size(),
+ device_executable_layout->handle(), set,
+ static_cast<uint32_t>(descriptor_sets.size()), descriptor_sets.data(),
+ static_cast<uint32_t>(dynamic_offsets_i32.size()),
dynamic_offsets_i32.data());
return OkStatus();
diff --git a/iree/hal/vulkan/direct_command_buffer.h b/iree/hal/vulkan/direct_command_buffer.h
index 20f637d..8808c6e 100644
--- a/iree/hal/vulkan/direct_command_buffer.h
+++ b/iree/hal/vulkan/direct_command_buffer.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
#define IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/hal/command_buffer.h"
#include "iree/hal/vulkan/descriptor_pool_cache.h"
diff --git a/iree/hal/vulkan/direct_command_queue.cc b/iree/hal/vulkan/direct_command_queue.cc
index 7523b09..71e8e33 100644
--- a/iree/hal/vulkan/direct_command_queue.cc
+++ b/iree/hal/vulkan/direct_command_queue.cc
@@ -87,21 +87,25 @@
submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
submit_info->pNext = timeline_submit_info;
- submit_info->waitSemaphoreCount = wait_semaphore_handles.size();
+ submit_info->waitSemaphoreCount =
+ static_cast<uint32_t>(wait_semaphore_handles.size());
submit_info->pWaitSemaphores = wait_semaphore_handles.data();
submit_info->pWaitDstStageMask = wait_dst_stage_masks.data();
- submit_info->commandBufferCount = command_buffer_handles.size();
+ submit_info->commandBufferCount =
+ static_cast<uint32_t>(command_buffer_handles.size());
submit_info->pCommandBuffers = command_buffer_handles.data();
- submit_info->signalSemaphoreCount = signal_semaphore_handles.size();
+ submit_info->signalSemaphoreCount =
+ static_cast<uint32_t>(signal_semaphore_handles.size());
submit_info->pSignalSemaphores = signal_semaphore_handles.data();
timeline_submit_info->sType =
VK_STRUCTURE_TYPE_TIMELINE_SEMAPHORE_SUBMIT_INFO;
timeline_submit_info->pNext = nullptr;
- timeline_submit_info->waitSemaphoreValueCount = wait_semaphore_values.size();
+ timeline_submit_info->waitSemaphoreValueCount =
+ static_cast<uint32_t>(wait_semaphore_values.size());
timeline_submit_info->pWaitSemaphoreValues = wait_semaphore_values.data();
timeline_submit_info->signalSemaphoreValueCount =
- signal_semaphore_values.size();
+ static_cast<uint32_t>(signal_semaphore_values.size());
timeline_submit_info->pSignalSemaphoreValues = signal_semaphore_values.data();
return OkStatus();
@@ -125,7 +129,8 @@
{
absl::MutexLock lock(&queue_mutex_);
VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(
- queue_, submit_infos.size(), submit_infos.data(), VK_NULL_HANDLE));
+ queue_, static_cast<uint32_t>(submit_infos.size()), submit_infos.data(),
+ VK_NULL_HANDLE));
}
return OkStatus();
diff --git a/iree/hal/vulkan/direct_command_queue.h b/iree/hal/vulkan/direct_command_queue.h
index 1921bfa..b0d3611 100644
--- a/iree/hal/vulkan/direct_command_queue.h
+++ b/iree/hal/vulkan/direct_command_queue.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
#define IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <cstdint>
#include <string>
diff --git a/iree/hal/vulkan/dynamic_symbols.h b/iree/hal/vulkan/dynamic_symbols.h
index c7651d2..8e74875 100644
--- a/iree/hal/vulkan/dynamic_symbols.h
+++ b/iree/hal/vulkan/dynamic_symbols.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
#define IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <cstdint>
#include <functional>
diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.cc b/iree/hal/vulkan/emulated_timeline_semaphore.cc
index fb34bcc..475aa33 100644
--- a/iree/hal/vulkan/emulated_timeline_semaphore.cc
+++ b/iree/hal/vulkan/emulated_timeline_semaphore.cc
@@ -16,7 +16,6 @@
#include "absl/container/inlined_vector.h"
#include "absl/synchronization/mutex.h"
-#include "absl/utility/utility.h"
#include "iree/base/time.h"
#include "iree/base/tracing.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.h b/iree/hal/vulkan/emulated_timeline_semaphore.h
index 4280d33..ccf0029 100644
--- a/iree/hal/vulkan/emulated_timeline_semaphore.h
+++ b/iree/hal/vulkan/emulated_timeline_semaphore.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_
#define IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <atomic>
#include <vector>
diff --git a/iree/hal/vulkan/extensibility_util.h b/iree/hal/vulkan/extensibility_util.h
index b39b2fc..5aa6998 100644
--- a/iree/hal/vulkan/extensibility_util.h
+++ b/iree/hal/vulkan/extensibility_util.h
@@ -17,7 +17,7 @@
#ifndef IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
#define IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <vector>
diff --git a/iree/hal/vulkan/handle_util.h b/iree/hal/vulkan/handle_util.h
index 147acaa..efe463af 100644
--- a/iree/hal/vulkan/handle_util.h
+++ b/iree/hal/vulkan/handle_util.h
@@ -24,11 +24,11 @@
#ifndef IREE_HAL_VULKAN_HANDLE_UTIL_H_
#define IREE_HAL_VULKAN_HANDLE_UTIL_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "absl/synchronization/mutex.h"
-#include "absl/utility/utility.h"
#include "iree/base/ref_ptr.h"
+#include "iree/base/status.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
#include "iree/hal/vulkan/extensibility_util.h"
@@ -50,7 +50,7 @@
VkDeviceHandle(const VkDeviceHandle&) = delete;
VkDeviceHandle& operator=(const VkDeviceHandle&) = delete;
VkDeviceHandle(VkDeviceHandle&& other) noexcept
- : value_(absl::exchange(other.value_,
+ : value_(iree::exchange(other.value_,
static_cast<VkDevice>(VK_NULL_HANDLE))),
syms_(std::move(other.syms_)),
enabled_extensions_(other.enabled_extensions_),
@@ -94,7 +94,7 @@
VkCommandPoolHandle& operator=(const VkCommandPoolHandle&) = delete;
VkCommandPoolHandle(VkCommandPoolHandle&& other) noexcept
: logical_device_(std::move(other.logical_device_)),
- value_(absl::exchange(other.value_,
+ value_(iree::exchange(other.value_,
static_cast<VkCommandPool>(VK_NULL_HANDLE))) {}
VkCommandPoolHandle& operator=(VkCommandPoolHandle&& other) {
std::swap(logical_device_, other.logical_device_);
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.cc b/iree/hal/vulkan/internal_vk_mem_alloc.cc
index 95747a7..3899f7f 100644
--- a/iree/hal/vulkan/internal_vk_mem_alloc.cc
+++ b/iree/hal/vulkan/internal_vk_mem_alloc.cc
@@ -68,9 +68,7 @@
};
#define VMA_RW_MUTEX AbslVmaRWMutex
-#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0
-
#define VMA_IMPLEMENTATION
-#include "vk_mem_alloc.h"
+#include "iree/hal/vulkan/internal_vk_mem_alloc.h"
-#endif
+#endif // !VULKAN_MEMORY_ALLOCATOR_EXTERNAL_IMPL
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.h b/iree/hal/vulkan/internal_vk_mem_alloc.h
index 591ce7e..4541e9f 100644
--- a/iree/hal/vulkan/internal_vk_mem_alloc.h
+++ b/iree/hal/vulkan/internal_vk_mem_alloc.h
@@ -15,6 +15,17 @@
#ifndef IREE_HAL_VULKAN_INTERNAL_VK_MEM_ALLOC_H_
#define IREE_HAL_VULKAN_INTERNAL_VK_MEM_ALLOC_H_
-#include "vk_mem_alloc.h"
+// NOTE: ensure our vulkan headers are used (as we define VK_NO_PROTOTYPES).
+#include "iree/hal/vulkan/vulkan_headers.h"
+// Force all Vulkan calls to go through an indirect pVulkanFunctions interface.
+// https://gpuopen-librariesandsdks.github.io/VulkanMemoryAllocator/html/configuration.html
+#define VMA_STATIC_VULKAN_FUNCTIONS 0
+
+// Prevent VMA from querying for dynamic functions we may not have provided.
+// We want to be able to print nice errors or decide whether something is ok
+// to be omitted and not have VMA poking around where it shouldn't.
+#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0
+
+#include "vk_mem_alloc.h"
#endif // IREE_HAL_VULKAN_INTERNAL_VK_MEM_ALLOC_H_
diff --git a/iree/hal/vulkan/native_descriptor_set.h b/iree/hal/vulkan/native_descriptor_set.h
index 9bcb6f1..b251c3f 100644
--- a/iree/hal/vulkan/native_descriptor_set.h
+++ b/iree/hal/vulkan/native_descriptor_set.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_
#define IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/hal/descriptor_set.h"
#include "iree/hal/vulkan/handle_util.h"
diff --git a/iree/hal/vulkan/native_event.h b/iree/hal/vulkan/native_event.h
index 691ef6d..5db3f5d 100644
--- a/iree/hal/vulkan/native_event.h
+++ b/iree/hal/vulkan/native_event.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_NATIVE_EVENT_H_
#define IREE_HAL_VULKAN_NATIVE_EVENT_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/hal/event.h"
#include "iree/hal/vulkan/handle_util.h"
diff --git a/iree/hal/vulkan/native_timeline_semaphore.h b/iree/hal/vulkan/native_timeline_semaphore.h
index a5d3a93..0a22510 100644
--- a/iree/hal/vulkan/native_timeline_semaphore.h
+++ b/iree/hal/vulkan/native_timeline_semaphore.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_
#define IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "absl/synchronization/mutex.h"
#include "iree/hal/semaphore.h"
diff --git a/iree/hal/vulkan/pipeline_cache.cc b/iree/hal/vulkan/pipeline_cache.cc
index dd59678..5404cf9 100644
--- a/iree/hal/vulkan/pipeline_cache.cc
+++ b/iree/hal/vulkan/pipeline_cache.cc
@@ -14,13 +14,10 @@
#include "iree/hal/vulkan/pipeline_cache.h"
-#include "absl/synchronization/mutex.h"
-#include "flatbuffers/flatbuffers.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/executable_format.h"
#include "iree/hal/vulkan/status_util.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
namespace iree {
namespace hal {
@@ -39,15 +36,6 @@
ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode,
const ExecutableSpec& spec) {
IREE_TRACE_SCOPE0("PipelineCache::PrepareExecutable");
- if (spec.executable_data.size() <= 4 ||
- !SpirVExecutableDefBufferHasIdentifier(spec.executable_data.data())) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Supplied executable data does not contain a SpirVExecutableDef";
- }
-
- // Get the SPIR-V executable def flatbuffer.
- const auto& spirv_executable_def =
- *::flatbuffers::GetRoot<SpirVExecutableDef>(spec.executable_data.data());
// Create the executable (which may itself own many pipelines).
IREE_ASSIGN_OR_RETURN(
@@ -56,7 +44,7 @@
add_ref(logical_device_),
/*pipeline_cache=*/VK_NULL_HANDLE,
static_cast<PipelineExecutableLayout*>(executable_layout), mode,
- spirv_executable_def));
+ spec));
return executable;
}
diff --git a/iree/hal/vulkan/pipeline_cache.h b/iree/hal/vulkan/pipeline_cache.h
index 143d8b6..75b9353 100644
--- a/iree/hal/vulkan/pipeline_cache.h
+++ b/iree/hal/vulkan/pipeline_cache.h
@@ -15,14 +15,13 @@
#ifndef IREE_HAL_VULKAN_PIPELINE_CACHE_H_
#define IREE_HAL_VULKAN_PIPELINE_CACHE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "absl/container/inlined_vector.h"
#include "iree/hal/executable.h"
#include "iree/hal/executable_cache.h"
#include "iree/hal/vulkan/handle_util.h"
#include "iree/hal/vulkan/pipeline_executable.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vulkan/pipeline_executable.cc b/iree/hal/vulkan/pipeline_executable.cc
index 17a3d9b..b954f24 100644
--- a/iree/hal/vulkan/pipeline_executable.cc
+++ b/iree/hal/vulkan/pipeline_executable.cc
@@ -20,45 +20,74 @@
#include "iree/base/tracing.h"
#include "iree/hal/vulkan/status_util.h"
+// flatcc schemas:
+#include "iree/base/flatcc.h"
+#include "iree/schemas/spirv_executable_def_reader.h"
+#include "iree/schemas/spirv_executable_def_verifier.h"
+
+// NOTE: starting to port this to C.
+
+// Verifies the structure of the flatbuffer so that we can avoid doing so during
+// runtime. There are still some conditions we must be aware of (such as omitted
+// names on functions with internal linkage), however we shouldn't need to
+// bounds check anything within the flatbuffer after this succeeds.
+static iree_status_t iree_hal_spirv_executable_flatbuffer_verify(
+ iree_const_byte_span_t flatbuffer_data) {
+ if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer data is not present or less than 16 bytes (%zu total)",
+ flatbuffer_data.data_length);
+ }
+
+ // Run flatcc generated verification. This ensures all pointers are in-bounds
+ // and that we can safely walk the file, but not that the actual contents of
+ // the flatbuffer meet our expectations.
+ int verify_ret = iree_SpirVExecutableDef_verify_as_root(
+ flatbuffer_data.data, flatbuffer_data.data_length);
+ if (verify_ret != flatcc_verify_ok) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "flatbuffer verification failed: %s",
+ flatcc_verify_error_string(verify_ret));
+ }
+
+ iree_SpirVExecutableDef_table_t executable_def =
+ iree_SpirVExecutableDef_as_root(flatbuffer_data.data);
+
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_SpirVExecutableDef_entry_points_get(executable_def);
+ size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
+ for (size_t i = 0; i < entry_point_count; ++i) {
+ if (!flatbuffers_string_len(
+ flatbuffers_string_vec_at(entry_points_vec, i))) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable entry point %zu has no name", i);
+ }
+ }
+
+ if (flatbuffers_uint32_vec_len(
+ iree_SpirVExecutableDef_code_get(executable_def)) < 0) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "executable SPIR-V code is missing/empty");
+ }
+
+ // TODO(benvanik): pull PopulateSpecializationInfo from history and update.
+ // For now the compiler isn't generating them, and we don't use them.
+ if (iree_SpirVExecutableDef_specialization_info_is_present(executable_def)) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "executable uses SPIR-V specialization constants; "
+ "they need to be revived");
+ }
+
+ return iree_ok_status();
+}
+
namespace iree {
namespace hal {
namespace vulkan {
namespace {
-// Generates the baked specialization constant data based on the flatbuffer.
-// We only support uint32_t right now so this is easy.
-// Note that the returned vectors are referenced by pointers in |out_info| and
-// must remain valid until the info is no longer in use.
-std::pair<std::vector<VkSpecializationMapEntry>, std::vector<uint8_t>>
-PopulateSpecializationInfo(const VkSpecializationInfoDef* info_def) {
- int entry_count =
- info_def && info_def->map_entries() ? info_def->map_entries()->size() : 0;
- if (!entry_count) {
- return {};
- }
-
- std::vector<VkSpecializationMapEntry> entries;
- entries.reserve(entry_count);
- std::vector<uint8_t> data;
- data.resize(entry_count * sizeof(uint32_t));
-
- uint32_t offset = 0;
- for (const auto* entry_def : *info_def->map_entries()) {
- if (!entry_def) continue;
- entries.push_back({});
- auto& entry = entries.back();
- entry.constantID = entry_def->constant_id();
- entry.offset = offset;
- entry.size = sizeof(uint32_t);
- uint32_t value = entry_def->uint32_value();
- std::memcpy(data.data() + offset, &value, sizeof(value));
- offset += entry.size;
- }
-
- return {std::move(entries), std::move(data)};
-}
-
class VkShaderModuleHandle : public RefObject<VkShaderModuleHandle> {
public:
explicit VkShaderModuleHandle(const ref_ptr<VkDeviceHandle>& logical_device)
@@ -99,47 +128,40 @@
StatusOr<ref_ptr<PipelineExecutable>> PipelineExecutable::Create(
ref_ptr<VkDeviceHandle> logical_device, VkPipelineCache pipeline_cache,
PipelineExecutableLayout* executable_layout,
- ExecutableCachingModeBitfield mode,
- const SpirVExecutableDef& spirv_executable_def) {
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) {
IREE_TRACE_SCOPE0("PipelineExecutable::Create");
const auto& syms = logical_device->syms();
- if (!spirv_executable_def.entry_points() ||
- spirv_executable_def.entry_points()->size() == 0) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
- }
- if (!spirv_executable_def.code()) {
- return InvalidArgumentErrorBuilder(IREE_LOC) << "No SPIR-V code present";
- }
- const auto& code = *spirv_executable_def.code();
+
+ // Verify and fetch the executable flatbuffer wrapper.
+ iree_const_byte_span_t executable_data = iree_make_const_byte_span(
+ spec.executable_data.data(), spec.executable_data.size());
+ IREE_RETURN_IF_ERROR(
+ iree_hal_spirv_executable_flatbuffer_verify(executable_data));
+ iree_SpirVExecutableDef_table_t executable_def =
+ iree_SpirVExecutableDef_as_root(executable_data.data);
// Create the shader module.
VkShaderModuleCreateInfo shader_module_create_info;
shader_module_create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
shader_module_create_info.pNext = nullptr;
shader_module_create_info.flags = 0;
- shader_module_create_info.codeSize = code.size() * sizeof(uint32_t);
- shader_module_create_info.pCode = code.data();
+ flatbuffers_uint32_vec_t code_vec =
+ iree_SpirVExecutableDef_code_get(executable_def);
+ shader_module_create_info.codeSize =
+ flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t);
+ shader_module_create_info.pCode = code_vec;
VkShaderModuleHandle shader_module(add_ref(logical_device));
VK_RETURN_IF_ERROR(syms->vkCreateShaderModule(
*logical_device, &shader_module_create_info, logical_device->allocator(),
shader_module.mutable_value()));
- // Specialization info is currently constant against all entry points.
- std::vector<VkSpecializationMapEntry> spec_entries;
- std::vector<uint8_t> spec_data;
- std::tie(spec_entries, spec_data) =
- PopulateSpecializationInfo(spirv_executable_def.specialization_info());
- VkSpecializationInfo specialization_info;
- specialization_info.mapEntryCount = spec_entries.size();
- specialization_info.pMapEntries = spec_entries.data();
- specialization_info.dataSize = spec_data.size();
- specialization_info.pData = spec_data.data();
-
// Create pipelines for each entry point.
- const auto& entry_points = *spirv_executable_def.entry_points();
+ flatbuffers_string_vec_t entry_points_vec =
+ iree_SpirVExecutableDef_entry_points_get(executable_def);
absl::InlinedVector<VkComputePipelineCreateInfo, 1> pipeline_create_infos;
- pipeline_create_infos.resize(entry_points.size());
- for (int entry_ordinal = 0; entry_ordinal < entry_points.size();
+ pipeline_create_infos.resize(flatbuffers_string_vec_len(entry_points_vec));
+ for (size_t entry_ordinal = 0;
+ entry_ordinal < flatbuffers_string_vec_len(entry_points_vec);
++entry_ordinal) {
auto& pipeline_create_info = pipeline_create_infos[entry_ordinal];
pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
@@ -163,26 +185,25 @@
stage_create_info.flags = 0;
stage_create_info.stage = VK_SHADER_STAGE_COMPUTE_BIT;
stage_create_info.module = shader_module;
- stage_create_info.pName = entry_points[entry_ordinal]->c_str();
- stage_create_info.pSpecializationInfo = &specialization_info;
+ stage_create_info.pName =
+ flatbuffers_string_vec_at(entry_points_vec, entry_ordinal);
+ stage_create_info.pSpecializationInfo = NULL;
}
absl::InlinedVector<VkPipeline, 1> pipelines;
- pipelines.resize(entry_points.size());
+ pipelines.resize(flatbuffers_string_vec_len(entry_points_vec));
// Some ICDs appear to leak in here, out of our control.
// Warning: leak checks remain disabled if an error is returned.
IREE_DISABLE_LEAK_CHECKS();
VK_RETURN_IF_ERROR(syms->vkCreateComputePipelines(
- *logical_device, pipeline_cache, pipeline_create_infos.size(),
+ *logical_device, pipeline_cache,
+ static_cast<uint32_t>(pipeline_create_infos.size()),
pipeline_create_infos.data(), logical_device->allocator(),
pipelines.data()));
IREE_ENABLE_LEAK_CHECKS();
- auto executable = make_ref<PipelineExecutable>(std::move(logical_device),
- std::move(pipelines));
- executable->tag_ =
- spirv_executable_def.tag() ? spirv_executable_def.tag()->str() : "";
- return executable;
+ return make_ref<PipelineExecutable>(std::move(logical_device),
+ std::move(pipelines));
}
PipelineExecutable::PipelineExecutable(
diff --git a/iree/hal/vulkan/pipeline_executable.h b/iree/hal/vulkan/pipeline_executable.h
index 839b3dd..25f5955 100644
--- a/iree/hal/vulkan/pipeline_executable.h
+++ b/iree/hal/vulkan/pipeline_executable.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <vector>
@@ -28,7 +28,6 @@
#include "iree/hal/vulkan/handle_util.h"
#include "iree/hal/vulkan/native_descriptor_set.h"
#include "iree/hal/vulkan/pipeline_executable_layout.h"
-#include "iree/schemas/spirv_executable_def_generated.h"
namespace iree {
namespace hal {
@@ -39,8 +38,7 @@
static StatusOr<ref_ptr<PipelineExecutable>> Create(
ref_ptr<VkDeviceHandle> logical_device, VkPipelineCache pipeline_cache,
PipelineExecutableLayout* executable_layout,
- ExecutableCachingModeBitfield mode,
- const SpirVExecutableDef& spirv_executable_def);
+ ExecutableCachingModeBitfield mode, const ExecutableSpec& spec);
PipelineExecutable(ref_ptr<VkDeviceHandle> logical_device,
absl::InlinedVector<VkPipeline, 1> pipelines);
@@ -56,7 +54,6 @@
private:
ref_ptr<VkDeviceHandle> logical_device_;
- std::string tag_;
// One pipeline per entry point.
absl::InlinedVector<VkPipeline, 1> pipelines_;
diff --git a/iree/hal/vulkan/pipeline_executable_layout.h b/iree/hal/vulkan/pipeline_executable_layout.h
index 00b7ec7..fd52f2b 100644
--- a/iree/hal/vulkan/pipeline_executable_layout.h
+++ b/iree/hal/vulkan/pipeline_executable_layout.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_
#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/span.h"
diff --git a/iree/hal/vulkan/registration/BUILD b/iree/hal/vulkan/registration/BUILD
new file mode 100644
index 0000000..c5b81f4
--- /dev/null
+++ b/iree/hal/vulkan/registration/BUILD
@@ -0,0 +1,54 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(${IREE_HAL_DRIVER_VULKAN})
+""",
+ inline = True,
+)
+
+cc_library(
+ name = "registration",
+ srcs = ["driver_module.cc"],
+ hdrs = ["driver_module.h"],
+ defines = [
+ "IREE_HAL_HAVE_VULKAN_DRIVER_MODULE=1",
+ ],
+ deps = [
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/hal:api",
+ "//iree/hal:driver_registry",
+ "//iree/hal/vulkan",
+ "//iree/hal/vulkan:utils",
+ "@com_google_absl//absl/flags:flag",
+ ],
+)
+
+iree_cmake_extra_content(
+ content = """
+endif()
+""",
+ inline = True,
+)
diff --git a/iree/hal/vulkan/registration/CMakeLists.txt b/iree/hal/vulkan/registration/CMakeLists.txt
new file mode 100644
index 0000000..4cc0fa1
--- /dev/null
+++ b/iree/hal/vulkan/registration/CMakeLists.txt
@@ -0,0 +1,40 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+if(${IREE_HAL_DRIVER_VULKAN})
+
+iree_cc_library(
+ NAME
+ registration
+ HDRS
+ "driver_module.h"
+ SRCS
+ "driver_module.cc"
+ DEPS
+ absl::flags
+ iree::base::flags
+ iree::base::status
+ iree::base::tracing
+ iree::hal::api
+ iree::hal::driver_registry
+ iree::hal::vulkan
+ iree::hal::vulkan::utils
+ DEFINES
+ "IREE_HAL_HAVE_VULKAN_DRIVER_MODULE=1"
+ PUBLIC
+)
+
+endif()
diff --git a/iree/hal/vulkan/vulkan_driver_module.cc b/iree/hal/vulkan/registration/driver_module.cc
similarity index 93%
rename from iree/hal/vulkan/vulkan_driver_module.cc
rename to iree/hal/vulkan/registration/driver_module.cc
index 9bd1cd0..9495f28 100644
--- a/iree/hal/vulkan/vulkan_driver_module.cc
+++ b/iree/hal/vulkan/registration/driver_module.cc
@@ -1,4 +1,4 @@
-// Copyright 2019 Google LLC
+// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <memory>
+#include "iree/hal/vulkan/registration/driver_module.h"
#include "absl/flags/flag.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/driver_registry.h"
@@ -132,8 +132,8 @@
} // namespace hal
} // namespace iree
-IREE_REGISTER_MODULE_INITIALIZER(iree_hal_vulkan_driver, {
- IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
- "vulkan", ::iree::hal::vulkan::CreateVulkanDriver));
-});
-IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_vulkan_driver);
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_vulkan_driver_module_register() {
+ return ::iree::hal::DriverRegistry::shared_registry()->Register(
+ "vulkan", ::iree::hal::vulkan::CreateVulkanDriver);
+}
diff --git a/iree/hal/vulkan/registration/driver_module.h b/iree/hal/vulkan/registration/driver_module.h
new file mode 100644
index 0000000..42f4d04
--- /dev/null
+++ b/iree/hal/vulkan/registration/driver_module.h
@@ -0,0 +1,31 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_REGISTRATION_DRIVER_MODULE_H_
+#define IREE_HAL_VULKAN_REGISTRATION_DRIVER_MODULE_H_
+
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_hal_vulkan_driver_module_register();
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_VULKAN_REGISTRATION_DRIVER_MODULE_H_
diff --git a/iree/hal/vulkan/renderdoc_capture_manager.cc b/iree/hal/vulkan/renderdoc_capture_manager.cc
index 039e9c0..b8b06ce 100644
--- a/iree/hal/vulkan/renderdoc_capture_manager.cc
+++ b/iree/hal/vulkan/renderdoc_capture_manager.cc
@@ -19,8 +19,7 @@
#include "iree/base/target_platform.h"
#include "iree/base/tracing.h"
-#if defined(IREE_PLATFORM_WINDOWS)
-#else
+#if !defined(IREE_PLATFORM_WINDOWS)
#include <dlfcn.h>
#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/hal/vulkan/serializing_command_queue.cc b/iree/hal/vulkan/serializing_command_queue.cc
index 1fc19fa..f4c3ba8 100644
--- a/iree/hal/vulkan/serializing_command_queue.cc
+++ b/iree/hal/vulkan/serializing_command_queue.cc
@@ -134,20 +134,20 @@
arena->AllocateSpan<VkSemaphore>(wait_semaphores.size());
auto wait_dst_stage_masks =
arena->AllocateSpan<VkPipelineStageFlags>(wait_semaphores.size());
- for (int i = 0, e = wait_semaphores.size(); i < e; ++i) {
+ for (size_t i = 0, e = wait_semaphores.size(); i < e; ++i) {
wait_semaphore_handles[i] = wait_semaphores[i];
wait_dst_stage_masks[i] = dst_stage_mask;
}
auto signal_semaphore_handles =
arena->AllocateSpan<VkSemaphore>(signal_semaphores.size());
- for (int i = 0, e = signal_semaphores.size(); i < e; ++i) {
+ for (size_t i = 0, e = signal_semaphores.size(); i < e; ++i) {
signal_semaphore_handles[i] = signal_semaphores[i];
}
auto command_buffer_handles =
arena->AllocateSpan<VkCommandBuffer>(command_buffers.size());
- for (int i = 0, e = command_buffers.size(); i < e; ++i) {
+ for (size_t i = 0, e = command_buffers.size(); i < e; ++i) {
const auto& command_buffer = command_buffers[i];
auto* direct_command_buffer =
static_cast<DirectCommandBuffer*>(command_buffer->impl());
@@ -156,12 +156,15 @@
submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
submit_info->pNext = nullptr;
- submit_info->waitSemaphoreCount = wait_semaphore_handles.size();
+ submit_info->waitSemaphoreCount =
+ static_cast<uint32_t>(wait_semaphore_handles.size());
submit_info->pWaitSemaphores = wait_semaphore_handles.data();
submit_info->pWaitDstStageMask = wait_dst_stage_masks.data();
- submit_info->commandBufferCount = command_buffer_handles.size();
+ submit_info->commandBufferCount =
+ static_cast<uint32_t>(command_buffer_handles.size());
submit_info->pCommandBuffers = command_buffer_handles.data();
- submit_info->signalSemaphoreCount = signal_semaphore_handles.size();
+ submit_info->signalSemaphoreCount =
+ static_cast<uint32_t>(signal_semaphore_handles.size());
submit_info->pSignalSemaphores = signal_semaphore_handles.data();
}
@@ -188,7 +191,7 @@
IREE_DVLOG(2) << "SerializingCommandQueue::Submit";
absl::MutexLock lock(&mutex_);
- for (int i = 0; i < batches.size(); ++i) {
+ for (size_t i = 0; i < batches.size(); ++i) {
// Grab a fence for this submission first. This will be used to check the
// progress of emulated timeline semaphores later.
IREE_ASSIGN_OR_RETURN(auto fence, fence_pool_->Acquire());
@@ -282,13 +285,13 @@
if (submit_infos.empty()) return false;
auto infos = arena.AllocateSpan<VkSubmitInfo>(submit_infos.size());
- for (int i = 0, e = submit_infos.size(); i < e; ++i) {
+ for (size_t i = 0, e = submit_infos.size(); i < e; ++i) {
infos[i] = submit_infos[i];
}
// Note: We might be able to batch the submission but it involves non-trivial
// fence handling. We can handle that if really needed.
- for (int i = 0, e = submit_infos.size(); i < e; ++i) {
+ for (size_t i = 0, e = submit_infos.size(); i < e; ++i) {
VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(
queue_, /*submitCount=*/1, &submit_infos[i], submit_fences[i]));
}
@@ -353,9 +356,9 @@
fences.reserve(pending_fences_.size());
for (const auto& fence : pending_fences_) fences.push_back(fence->value());
- VkResult result =
- syms()->vkWaitForFences(*logical_device_, fences.size(), fences.data(),
- /*waitAll=*/VK_TRUE, timeout_ns);
+ VkResult result = syms()->vkWaitForFences(
+ *logical_device_, static_cast<uint32_t>(fences.size()), fences.data(),
+ /*waitAll=*/VK_TRUE, timeout_ns);
switch (result) {
case VK_SUCCESS:
@@ -394,7 +397,8 @@
fences.reserve(pending_fences_.size());
for (const auto& fence : pending_fences_) fences.push_back(fence->value());
- syms()->vkWaitForFences(*logical_device_, fences.size(), fences.data(),
+ syms()->vkWaitForFences(*logical_device_,
+ static_cast<uint32_t>(fences.size()), fences.data(),
/*waitAll=*/VK_TRUE, /*timeout=*/UINT64_MAX);
// Clear the list. Fences will be automatically returned back to the queue
// after refcount reaches 0.
diff --git a/iree/hal/vulkan/serializing_command_queue.h b/iree/hal/vulkan/serializing_command_queue.h
index 3437eaf..a955a3c 100644
--- a/iree/hal/vulkan/serializing_command_queue.h
+++ b/iree/hal/vulkan/serializing_command_queue.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_SERIALIZING_COMMAND_QUEUE_H_
#define IREE_HAL_VULKAN_SERIALIZING_COMMAND_QUEUE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <memory>
#include <string>
diff --git a/iree/hal/vulkan/status_util.h b/iree/hal/vulkan/status_util.h
index 2a866b5..798111c 100644
--- a/iree/hal/vulkan/status_util.h
+++ b/iree/hal/vulkan/status_util.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_STATUS_UTIL_H_
#define IREE_HAL_VULKAN_STATUS_UTIL_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/base/status.h"
diff --git a/iree/hal/vulkan/timepoint_util.cc b/iree/hal/vulkan/timepoint_util.cc
index a096d4d..c14ea26 100644
--- a/iree/hal/vulkan/timepoint_util.cc
+++ b/iree/hal/vulkan/timepoint_util.cc
@@ -17,7 +17,6 @@
#include <memory>
#include "absl/synchronization/mutex.h"
-#include "absl/utility/utility.h"
#include "iree/base/time.h"
#include "iree/base/tracing.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
diff --git a/iree/hal/vulkan/timepoint_util.h b/iree/hal/vulkan/timepoint_util.h
index f162753..5fe35e4 100644
--- a/iree/hal/vulkan/timepoint_util.h
+++ b/iree/hal/vulkan/timepoint_util.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_TIMEPOINT_UTIL_H_
#define IREE_HAL_VULKAN_TIMEPOINT_UTIL_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <atomic>
#include <vector>
diff --git a/iree/hal/vulkan/vma_allocator.h b/iree/hal/vulkan/vma_allocator.h
index ccd677e..6acf935 100644
--- a/iree/hal/vulkan/vma_allocator.h
+++ b/iree/hal/vulkan/vma_allocator.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
#define IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <memory>
diff --git a/iree/hal/vulkan/vma_buffer.h b/iree/hal/vulkan/vma_buffer.h
index b768f71..dddb55b 100644
--- a/iree/hal/vulkan/vma_buffer.h
+++ b/iree/hal/vulkan/vma_buffer.h
@@ -15,10 +15,10 @@
#ifndef IREE_HAL_VULKAN_VMA_BUFFER_H_
#define IREE_HAL_VULKAN_VMA_BUFFER_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include "iree/hal/buffer.h"
-#include "vk_mem_alloc.h"
+#include "iree/hal/vulkan/internal_vk_mem_alloc.h"
namespace iree {
namespace hal {
diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc
index a975b3a..a5e2cd0 100644
--- a/iree/hal/vulkan/vulkan_device.cc
+++ b/iree/hal/vulkan/vulkan_device.cc
@@ -173,7 +173,7 @@
uint64_t compute_queue_count = CountOnes64(compute_queue_set.queue_indices);
for (uint32_t i = 0; i < compute_queue_count; ++i) {
- if (!(compute_queue_set.queue_indices & (1 << i))) continue;
+ if (!(compute_queue_set.queue_indices & (1ull << i))) continue;
VkQueue queue = VK_NULL_HANDLE;
syms->vkGetDeviceQueue(*logical_device,
@@ -195,7 +195,7 @@
uint64_t transfer_queue_count = CountOnes64(transfer_queue_set.queue_indices);
for (uint32_t i = 0; i < transfer_queue_count; ++i) {
- if (!(transfer_queue_set.queue_indices & (1 << i))) continue;
+ if (!(transfer_queue_set.queue_indices & (1ull << i))) continue;
VkQueue queue = VK_NULL_HANDLE;
syms->vkGetDeviceQueue(*logical_device,
@@ -367,7 +367,7 @@
QueueSet compute_queue_set = {};
compute_queue_set.queue_family_index = queue_family_info.dispatch_index;
for (uint32_t i = 0; i < queue_family_info.dispatch_queue_count; ++i) {
- compute_queue_set.queue_indices |= 1 << i;
+ compute_queue_set.queue_indices |= 1ull << i;
}
QueueSet transfer_queue_set = {};
transfer_queue_set.queue_family_index = queue_family_info.transfer_index;
@@ -377,7 +377,7 @@
base_queue_index = queue_family_info.dispatch_index;
}
for (uint32_t i = 0; i < queue_family_info.transfer_queue_count; ++i) {
- transfer_queue_set.queue_indices |= 1 << (i + base_queue_index);
+ transfer_queue_set.queue_indices |= 1ull << (i + base_queue_index);
}
// Emulate timeline semaphores if associated functions are not defined.
diff --git a/iree/hal/vulkan/vulkan_device.h b/iree/hal/vulkan/vulkan_device.h
index 456bc19..fbc10d3 100644
--- a/iree/hal/vulkan/vulkan_device.h
+++ b/iree/hal/vulkan/vulkan_device.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_VULKAN_DEVICE_H_
#define IREE_HAL_VULKAN_VULKAN_DEVICE_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <functional>
#include <memory>
diff --git a/iree/hal/vulkan/vulkan_driver.cc b/iree/hal/vulkan/vulkan_driver.cc
index 9f76455..0e1c7b5 100644
--- a/iree/hal/vulkan/vulkan_driver.cc
+++ b/iree/hal/vulkan/vulkan_driver.cc
@@ -114,9 +114,11 @@
create_info.pNext = nullptr;
create_info.flags = 0;
create_info.pApplicationInfo = &app_info;
- create_info.enabledLayerCount = enabled_layer_names.size();
+ create_info.enabledLayerCount =
+ static_cast<uint32_t>(enabled_layer_names.size());
create_info.ppEnabledLayerNames = enabled_layer_names.data();
- create_info.enabledExtensionCount = enabled_extension_names.size();
+ create_info.enabledExtensionCount =
+ static_cast<uint32_t>(enabled_extension_names.size());
create_info.ppEnabledExtensionNames = enabled_extension_names.data();
// If we have the debug_utils extension then we can chain a one-shot messenger
diff --git a/iree/hal/vulkan/vulkan_driver.h b/iree/hal/vulkan/vulkan_driver.h
index 9c8bcce..ed3e462 100644
--- a/iree/hal/vulkan/vulkan_driver.h
+++ b/iree/hal/vulkan/vulkan_driver.h
@@ -15,7 +15,7 @@
#ifndef IREE_HAL_VULKAN_VULKAN_DRIVER_H_
#define IREE_HAL_VULKAN_VULKAN_DRIVER_H_
-#include <vulkan/vulkan.h>
+#include "iree/hal/vulkan/vulkan_headers.h"
#include <memory>
#include <vector>
diff --git a/iree/hal/vulkan/vulkan_headers.h b/iree/hal/vulkan/vulkan_headers.h
new file mode 100644
index 0000000..9acafd1
--- /dev/null
+++ b/iree/hal/vulkan/vulkan_headers.h
@@ -0,0 +1,46 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_VULKAN_HEADERS_H_
+#define IREE_HAL_VULKAN_VULKAN_HEADERS_H_
+
+// We exclusively use Vulkan via queried function pointers. To ensure that there
+// are no accidental calls to the linker-loaded implicit functions we just
+// compile them all out.
+//
+// Code under iree/hal/vulkan/ *MUST NOT* directly include vulkan.h or any
+// header that includes it without this first being set. This means that this
+// iree/hal/vulkan/vulkan_headers.h file must usually be included first in all
+// files using it.
+//
+// From there, use iree/hal/vulkan/dynamic_symbols.h to plumb the dynamically
+// resolved symbols to any code that may need to make Vulkan calls. See that
+// header for more information: in general we try to keep our required set of
+// symbols minimal to avoid binary size/runtime memory/linker time so symbols
+// are only added as needed.
+//
+// Other non-core code can choose not to disable the prototypes if they want.
+// I don't suggest it though for anything beyond samples.
+//
+// There's a bunch of reasons to dynamically link against Vulkan like supporting
+// platforms without Vulkan or with differing Vulkan versions where all symbols
+// may not be available.
+//
+// See this article for more information:
+// https://djang86.blogspot.com/2019/01/what-is-vknoprototypes.html
+#define VK_NO_PROTOTYPES 1
+
+#include <vulkan/vulkan.h>
+
+#endif // IREE_HAL_VULKAN_VULKAN_HEADERS_H_
diff --git a/iree/modules/check/BUILD b/iree/modules/check/BUILD
index cfaedff..4686574 100644
--- a/iree/modules/check/BUILD
+++ b/iree/modules/check/BUILD
@@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load(
- "//iree:build_defs.oss.bzl",
- "IREE_DRIVER_MODULES",
- "PLATFORM_VULKAN_DEPS",
- iree_cc_binary = "cc_binary",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
@@ -34,38 +27,39 @@
"//iree/base:logging",
"//iree/base:status",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
-iree_cc_binary(
+cc_binary(
name = "iree-check-module",
testonly = True,
srcs = ["iree-check-module-main.cc"],
deps = [
":native_module",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
"//iree/base:api",
+ "//iree/base:core_headers",
"//iree/base:file_io",
- "//iree/base:init",
+ "//iree/base:flags",
"//iree/base:status",
- "//iree/base:target_platform",
"//iree/base:tracing",
+ "//iree/hal/drivers",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/tools/utils:vm_util",
"//iree/vm:bytecode_module",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ ],
)
cc_library(
@@ -80,7 +74,7 @@
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
],
diff --git a/iree/modules/check/CMakeLists.txt b/iree/modules/check/CMakeLists.txt
index 56dea0e..3d8dc84 100644
--- a/iree/modules/check/CMakeLists.txt
+++ b/iree/modules/check/CMakeLists.txt
@@ -29,13 +29,13 @@
iree::base::logging
iree::base::status
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
)
endif()
@@ -51,16 +51,16 @@
absl::flags
absl::strings
iree::base::api
+ iree::base::core_headers
iree::base::file_io
- iree::base::init
+ iree::base::flags
iree::base::status
- iree::base::target_platform
iree::base::tracing
+ iree::hal::drivers
iree::modules::hal
iree::testing::gtest
iree::tools::utils::vm_util
iree::vm::bytecode_module
- ${IREE_HAL_DRIVER_MODULES}
TESTONLY
)
@@ -80,7 +80,7 @@
iree::modules::hal
iree::testing::gtest
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
TESTONLY
PUBLIC
)
diff --git a/iree/modules/check/check_test.cc b/iree/modules/check/check_test.cc
index 4c30063..7086085 100644
--- a/iree/modules/check/check_test.cc
+++ b/iree/modules/check/check_test.cc
@@ -21,6 +21,7 @@
#include "iree/base/logging.h"
#include "iree/base/status.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/check/native_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
@@ -35,6 +36,7 @@
class CheckTest : public ::testing::Test {
protected:
static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
// TODO(benvanik): move to instance-based registration.
IREE_ASSERT_OK(iree_hal_module_register_types());
@@ -414,8 +416,8 @@
TEST_F(CheckTest, ExpectAlmostEqNearIdenticalBufferSuccess) {
vm::ref<iree_hal_buffer_view_t> lhs;
vm::ref<iree_hal_buffer_view_t> rhs;
- float lhs_contents[] = {1.0, 1.99999, 0.00001, 4};
- float rhs_contents[] = {1.00001, 2.0, 0, 4};
+ float lhs_contents[] = {1.0f, 1.99999f, 0.00001f, 4.0f};
+ float rhs_contents[] = {1.00001f, 2.0f, 0.0f, 4.0f};
int32_t shape[] = {4};
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(lhs_contents, shape, &lhs));
ASSERT_NO_FATAL_FAILURE(CreateFloat32BufferView(rhs_contents, shape, &rhs));
diff --git a/iree/modules/check/iree-check-module-main.cc b/iree/modules/check/iree-check-module-main.cc
index 7e3687f..35f57dc 100644
--- a/iree/modules/check/iree-check-module-main.cc
+++ b/iree/modules/check/iree-check-module-main.cc
@@ -19,10 +19,11 @@
#include "absl/strings/string_view.h"
#include "iree/base/api.h"
#include "iree/base/file_io.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/target_platform.h"
#include "iree/base/tracing.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/check/native_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
@@ -35,7 +36,7 @@
#if defined(IREE_PLATFORM_WINDOWS)
#include <fcntl.h>
#include <io.h>
-#define IREE_FORCE_BINARY_STDIN() setmode(_fileno(stdin), O_BINARY)
+#define IREE_FORCE_BINARY_STDIN() _setmode(_fileno(stdin), O_BINARY)
#else
#define IREE_FORCE_BINARY_STDIN()
#endif // IREE_PLATFORM_WINDOWS
@@ -111,8 +112,8 @@
std::array<iree_vm_module_t*, 3> modules = {hal_module, check_module,
input_module};
auto module_signature = iree_vm_module_signature(input_module);
- for (int ordinal = 0; ordinal < module_signature.export_function_count;
- ++ordinal) {
+ for (iree_host_size_t ordinal = 0;
+ ordinal < module_signature.export_function_count; ++ordinal) {
iree_vm_function_t function;
iree_string_view_t export_name_sv;
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
@@ -173,7 +174,8 @@
} // namespace
extern "C" int main(int argc, char** argv) {
- InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
::testing::InitGoogleTest(&argc, argv);
IREE_FORCE_BINARY_STDIN();
diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc
index f35968f..5686a1f 100644
--- a/iree/modules/check/native_module.cc
+++ b/iree/modules/check/native_module.cc
@@ -61,7 +61,7 @@
template <typename T>
Status ExpectAllTrue(iree_byte_span_t bytes) {
- EXPECT_THAT(AbslSpan<T>(bytes), Each(Not(0)));
+ EXPECT_THAT(AbslSpan<T>(bytes), Each(Not(T(0))));
return OkStatus();
}
diff --git a/iree/modules/hal/BUILD b/iree/modules/hal/BUILD
index bdea2da..6780425 100644
--- a/iree/modules/hal/BUILD
+++ b/iree/modules/hal/BUILD
@@ -25,10 +25,10 @@
deps = [
"//iree/base:api",
"//iree/base:tracing",
+ "//iree/hal",
"//iree/hal:api",
- "//iree/hal:device",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
diff --git a/iree/modules/hal/CMakeLists.txt b/iree/modules/hal/CMakeLists.txt
index 21fc736..636a96f 100644
--- a/iree/modules/hal/CMakeLists.txt
+++ b/iree/modules/hal/CMakeLists.txt
@@ -28,9 +28,9 @@
absl::span
iree::base::api
iree::base::tracing
+ iree::hal
iree::hal::api
- iree::hal::device
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
PUBLIC
)
diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc
index be8a730..9725e35 100644
--- a/iree/modules/hal/hal_module.cc
+++ b/iree/modules/hal/hal_module.cc
@@ -193,7 +193,7 @@
size_t buffer_length = source->data.data_length;
if (length == -1) {
- length = buffer_length;
+ length = static_cast<size_t>(buffer_length);
}
if (length < 0 || offset < 0 || offset > buffer_length ||
offset + length > buffer_length) {
diff --git a/iree/modules/strings/BUILD b/iree/modules/strings/BUILD
index 3a8dcf3..fcb1a40 100644
--- a/iree/modules/strings/BUILD
+++ b/iree/modules/strings/BUILD
@@ -23,10 +23,7 @@
"//iree/hal:api",
"//iree/modules/hal",
"//iree/vm",
- "//iree/vm:module",
- "//iree/vm:native_module_cc",
- "//iree/vm:ref",
- "//iree/vm:stack",
+ "//iree/vm:cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@@ -43,13 +40,13 @@
"//iree/base:api",
"//iree/base:logging",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_benchmark//:benchmark",
diff --git a/iree/modules/strings/CMakeLists.txt b/iree/modules/strings/CMakeLists.txt
index b65fa0b..5aac9ba 100644
--- a/iree/modules/strings/CMakeLists.txt
+++ b/iree/modules/strings/CMakeLists.txt
@@ -34,10 +34,7 @@
iree::hal::api
iree::modules::hal
iree::vm
- iree::vm::module
- iree::vm::native_module_cc
- iree::vm::ref
- iree::vm::stack
+ iree::vm::cc
PUBLIC
)
@@ -55,13 +52,13 @@
iree::base::api
iree::base::logging
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
)
iree_bytecode_module(
diff --git a/iree/modules/strings/api.cc b/iree/modules/strings/api.cc
index 34c3fb2..ee364ec 100644
--- a/iree/modules/strings/api.cc
+++ b/iree/modules/strings/api.cc
@@ -25,10 +25,8 @@
#include "iree/modules/strings/api.h"
#include "iree/modules/strings/api_detail.h"
#include "iree/modules/strings/strings_module.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
#include "iree/vm/native_module_cc.h"
-#include "iree/vm/ref.h"
-#include "iree/vm/stack.h"
extern "C" iree_status_t strings_string_create(iree_string_view_t value,
iree_allocator_t allocator,
@@ -141,7 +139,7 @@
if (!tensor || !rank) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
}
- *rank = tensor->rank;
+ *rank = static_cast<int32_t>(tensor->rank);
return iree_ok_status();
}
diff --git a/iree/modules/strings/strings_module.cc b/iree/modules/strings/strings_module.cc
index 2d43cef..c8ea45f 100644
--- a/iree/modules/strings/strings_module.cc
+++ b/iree/modules/strings/strings_module.cc
@@ -112,7 +112,7 @@
std::string str;
str.reserve(string_length);
StringTensorToStringHelper(str_tensor->values, str_tensor->shape,
- str_tensor->rank, &str);
+ static_cast<int32_t>(str_tensor->rank), &str);
IREE_RETURN_IF_ERROR(strings_string_create(
iree_make_cstring_view(str.c_str()), allocator_, &new_string));
diff --git a/iree/modules/strings/strings_module_test.cc b/iree/modules/strings/strings_module_test.cc
index 63cb269..68414ad 100644
--- a/iree/modules/strings/strings_module_test.cc
+++ b/iree/modules/strings/strings_module_test.cc
@@ -21,6 +21,7 @@
#include "iree/base/api.h"
#include "iree/base/logging.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/modules/strings/api.h"
#include "iree/modules/strings/api_detail.h"
@@ -40,6 +41,10 @@
class StringsModuleTest : public ::testing::Test {
protected:
+ static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
+ }
+
virtual void SetUp() {
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
@@ -301,7 +306,7 @@
std::vector<iree_string_view_t> out_strings(expected.size());
IREE_ASSERT_OK(strings_string_tensor_get_elements(
output_tensor, out_strings.data(), out_strings.size(), 0));
- for (int i = 0; i < expected.size(); i++) {
+ for (iree_host_size_t i = 0; i < expected.size(); i++) {
EXPECT_EQ(iree_string_view_compare(out_strings[i], expected[i]), 0)
<< "Expected: " << expected[i].data << " found "
<< out_strings[i].data;
diff --git a/iree/modules/tensorlist/BUILD b/iree/modules/tensorlist/BUILD
index 4cdfb5d..74deecd 100644
--- a/iree/modules/tensorlist/BUILD
+++ b/iree/modules/tensorlist/BUILD
@@ -36,13 +36,13 @@
"//iree/base:api",
"//iree/base:logging",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
@@ -61,7 +61,7 @@
"//iree/hal:api",
"//iree/modules/hal",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:span",
],
diff --git a/iree/modules/tensorlist/CMakeLists.txt b/iree/modules/tensorlist/CMakeLists.txt
index 2eb6691..1092601 100644
--- a/iree/modules/tensorlist/CMakeLists.txt
+++ b/iree/modules/tensorlist/CMakeLists.txt
@@ -41,13 +41,13 @@
iree::base::api
iree::base::logging
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
)
iree_cc_library(
@@ -66,6 +66,6 @@
iree::hal::api
iree::modules::hal
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
PUBLIC
)
diff --git a/iree/modules/tensorlist/tensorlist_test.cc b/iree/modules/tensorlist/tensorlist_test.cc
index 6eab915..3293d57 100644
--- a/iree/modules/tensorlist/tensorlist_test.cc
+++ b/iree/modules/tensorlist/tensorlist_test.cc
@@ -21,6 +21,7 @@
#include "iree/base/api.h"
#include "iree/base/logging.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/modules/tensorlist/native_module.h"
#include "iree/modules/tensorlist/tensorlist_test_module.h"
@@ -36,6 +37,10 @@
class TensorListModulesTest : public ::testing::Test {
protected:
+ static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
+ }
+
virtual void SetUp() {
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
diff --git a/iree/samples/custom_modules/BUILD b/iree/samples/custom_modules/BUILD
index b7c91e6..c980e20 100644
--- a/iree/samples/custom_modules/BUILD
+++ b/iree/samples/custom_modules/BUILD
@@ -46,13 +46,13 @@
"//iree/base:api",
"//iree/base:logging",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
@@ -67,6 +67,6 @@
"//iree/hal:api",
"//iree/modules/hal",
"//iree/vm",
- "//iree/vm:native_module_cc",
+ "//iree/vm:cc",
],
)
diff --git a/iree/samples/custom_modules/CMakeLists.txt b/iree/samples/custom_modules/CMakeLists.txt
index 4a28aec..f7942f7 100644
--- a/iree/samples/custom_modules/CMakeLists.txt
+++ b/iree/samples/custom_modules/CMakeLists.txt
@@ -45,13 +45,13 @@
iree::base::api
iree::base::logging
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
)
iree_cc_library(
@@ -66,6 +66,6 @@
iree::hal::api
iree::modules::hal
iree::vm
- iree::vm::native_module_cc
+ iree::vm::cc
PUBLIC
)
diff --git a/iree/samples/custom_modules/custom_modules_test.cc b/iree/samples/custom_modules/custom_modules_test.cc
index e975a01..18ebd30 100644
--- a/iree/samples/custom_modules/custom_modules_test.cc
+++ b/iree/samples/custom_modules/custom_modules_test.cc
@@ -19,6 +19,7 @@
#include "iree/base/api.h"
#include "iree/base/logging.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/samples/custom_modules/custom_modules_test_module.h"
#include "iree/samples/custom_modules/native_module.h"
@@ -32,6 +33,10 @@
class CustomModulesTest : public ::testing::Test {
protected:
+ static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
+ }
+
virtual void SetUp() {
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
diff --git a/iree/samples/custom_modules/dialect/BUILD b/iree/samples/custom_modules/dialect/BUILD
index c5a49ea..fa0bc78 100644
--- a/iree/samples/custom_modules/dialect/BUILD
+++ b/iree/samples/custom_modules/dialect/BUILD
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load("//iree:build_defs.oss.bzl", iree_cc_binary = "cc_binary")
load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
load("//build_tools/bazel:tblgen.bzl", "gentbl")
@@ -89,7 +88,7 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "custom-opt",
srcs = ["custom-opt-main.cc"],
deps = [
@@ -103,7 +102,7 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "custom-translate",
srcs = ["custom-translate-main.cc"],
deps = [
diff --git a/iree/samples/emitc_modules/CMakeLists.txt b/iree/samples/emitc_modules/CMakeLists.txt
index be48f3f..c699032 100644
--- a/iree/samples/emitc_modules/CMakeLists.txt
+++ b/iree/samples/emitc_modules/CMakeLists.txt
@@ -36,12 +36,7 @@
DEPS
::add_module_cc
iree::base::api
- iree::vm::c_funcs
- iree::vm::context
- iree::vm::instance
- iree::vm::native_module
- iree::vm::ref
- iree::vm::stack
+ iree::vm
PUBLIC
)
@@ -56,10 +51,7 @@
iree::base::status
iree::testing::gtest
iree::testing::gtest_main
- iree::vm::context
- iree::vm::instance
- iree::vm::invocation
- iree::vm::list
- iree::vm::ref_cc
+ iree::vm
+ iree::vm::cc
)
endif()
diff --git a/iree/samples/emitc_modules/add_module_test.cc b/iree/samples/emitc_modules/add_module_test.cc
index 619fefa..bb739ae 100644
--- a/iree/samples/emitc_modules/add_module_test.cc
+++ b/iree/samples/emitc_modules/add_module_test.cc
@@ -17,10 +17,7 @@
#include "iree/base/status.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/invocation.h"
-#include "iree/vm/list.h"
+#include "iree/vm/api.h"
#include "iree/vm/ref_cc.h"
namespace iree {
diff --git a/iree/samples/emitc_modules/add_module_test.h b/iree/samples/emitc_modules/add_module_test.h
index 07f61c6..c6223d2 100644
--- a/iree/samples/emitc_modules/add_module_test.h
+++ b/iree/samples/emitc_modules/add_module_test.h
@@ -12,11 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/native_module.h"
-#include "iree/vm/ref.h"
-#include "iree/vm/stack.h"
+#include "iree/vm/api.h"
// This would be generated together with the functions in the header
#include "iree/samples/emitc_modules/add_module.module"
diff --git a/iree/samples/simple_embedding/BUILD b/iree/samples/simple_embedding/BUILD
index ad2e027..db314ba 100644
--- a/iree/samples/simple_embedding/BUILD
+++ b/iree/samples/simple_embedding/BUILD
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load("//iree:build_defs.oss.bzl", "PLATFORM_VULKAN_TEST_DEPS", "iree_cmake_extra_content")
load("//iree/tools:compilation.bzl", "iree_bytecode_module")
package(
@@ -21,15 +20,6 @@
licenses = ["notice"], # Apache 2.0
)
-iree_cmake_extra_content(
- content = """
-if(NOT ${IREE_TARGET_BACKEND_VMLA} OR NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}
- OR NOT ${IREE_HAL_DRIVER_VMLA} OR NOT ${IREE_HAL_DRIVER_VULKAN})
- return()
-endif()
-""",
-)
-
iree_bytecode_module(
name = "simple_embedding_test_bytecode_module",
src = "simple_embedding_test.mlir",
@@ -48,25 +38,20 @@
# For AddressSanitizer when using Vulkan + a local Nvidia GPU
"//iree/tools:sanitizer_suppressions.txt",
],
- tags = ["driver=vulkan"],
deps = [
":simple_embedding_test_bytecode_module_cc",
"//iree/base:api",
"//iree/base:logging",
"//iree/hal:api",
+ "//iree/hal/drivers",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
-
- # These are the drivers we support running with and can produce
- # executables for from the source MLIR.
- "//iree/hal/vmla:vmla_driver_module", # build-cleaner: keep
- "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep
- ] + PLATFORM_VULKAN_TEST_DEPS,
+ ],
)
diff --git a/iree/samples/simple_embedding/CMakeLists.txt b/iree/samples/simple_embedding/CMakeLists.txt
index 1ceb8a0..fb0b717 100644
--- a/iree/samples/simple_embedding/CMakeLists.txt
+++ b/iree/samples/simple_embedding/CMakeLists.txt
@@ -12,11 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-if(NOT ${IREE_TARGET_BACKEND_VMLA} OR NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}
- OR NOT ${IREE_HAL_DRIVER_VMLA} OR NOT ${IREE_HAL_DRIVER_VULKAN})
- return()
-endif()
-
iree_add_all_subdirs()
iree_bytecode_module(
@@ -48,14 +43,11 @@
iree::base::api
iree::base::logging
iree::hal::api
- iree::hal::vmla::vmla_driver_module
- iree::hal::vulkan::vulkan_driver_module
+ iree::hal::drivers
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
- LABELS
- "driver=vulkan"
+ iree::vm::cc
)
diff --git a/iree/samples/simple_embedding/simple_embedding_test.cc b/iree/samples/simple_embedding/simple_embedding_test.cc
index aaea632..bc2c8cd 100644
--- a/iree/samples/simple_embedding/simple_embedding_test.cc
+++ b/iree/samples/simple_embedding/simple_embedding_test.cc
@@ -19,6 +19,7 @@
#include "iree/base/api.h"
#include "iree/base/logging.h"
#include "iree/hal/api.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
@@ -44,15 +45,25 @@
// Builds a list of tests to run based on the linked in driver modules.
std::vector<TestParams> GetAvailableDriverTestParams() {
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
+
std::vector<TestParams> all_test_params;
iree_string_view_t* driver_names = nullptr;
iree_host_size_t driver_count = 0;
IREE_CHECK_OK(iree_hal_driver_registry_query_available_drivers(
iree_allocator_system(), &driver_names, &driver_count));
- for (int i = 0; i < driver_count; ++i) {
+ for (iree_host_size_t i = 0; i < driver_count; ++i) {
TestParams test_params;
test_params.driver_name =
std::string(driver_names[i].data, driver_names[i].size);
+
+ // TODO(#3843): this whole file stopped being useful a long time ago as a
+ // "simple" embedded test. This is a hack to work around its bustedness.
+ if (test_params.driver_name == "dylib" ||
+ test_params.driver_name == "llvm") {
+ continue;
+ }
+
all_test_params.push_back(std::move(test_params));
}
iree_allocator_free(iree_allocator_system(), driver_names);
diff --git a/iree/samples/vulkan/BUILD b/iree/samples/vulkan/BUILD
index 89c927f..f005ede 100644
--- a/iree/samples/vulkan/BUILD
+++ b/iree/samples/vulkan/BUILD
@@ -48,11 +48,12 @@
deps = [
":simple_mul_bytecode_module_cc",
"//iree/base:main",
+ "//iree/hal/vulkan/registration",
"//iree/modules/hal",
"//iree/testing/vulkan:vulkan_gui_util",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/types:span",
],
diff --git a/iree/samples/vulkan/CMakeLists.txt b/iree/samples/vulkan/CMakeLists.txt
index c3de407..4c7f665 100644
--- a/iree/samples/vulkan/CMakeLists.txt
+++ b/iree/samples/vulkan/CMakeLists.txt
@@ -39,12 +39,6 @@
return()
endif()
-if(${CMAKE_HOST_SYSTEM_NAME} STREQUAL "Windows")
- set(_GUI_LINKOPTS "-SUBSYSTEM:WINDOWS")
-else()
- set(_GUI_LINKOPTS "")
-endif()
-
iree_cc_binary(
NAME
vulkan_inference_gui
@@ -53,12 +47,13 @@
DEPS
absl::base
iree::base::main
+ iree::hal::vulkan::registration
iree::modules::hal
iree::samples::vulkan::simple_mul_bytecode_module_cc
iree::testing::vulkan::vulkan_gui_util
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
LINKOPTS
- "${_GUI_LINKOPTS}"
+ "${IREE_TARGET_GUI_LINKOPTS}"
)
diff --git a/iree/samples/vulkan/vulkan_inference_gui.cc b/iree/samples/vulkan/vulkan_inference_gui.cc
index 6183580..c1fb686 100644
--- a/iree/samples/vulkan/vulkan_inference_gui.cc
+++ b/iree/samples/vulkan/vulkan_inference_gui.cc
@@ -23,6 +23,7 @@
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/vulkan/api.h"
+#include "iree/hal/vulkan/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
@@ -93,10 +94,11 @@
std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
std::vector<const char*> extensions =
GetInstanceExtensions(window, iree_vulkan_features);
- SetupVulkan(iree_vulkan_features, layers.data(), layers.size(),
- extensions.data(), extensions.size(), g_Allocator, &g_Instance,
- &g_QueueFamily, &g_PhysicalDevice, &g_Queue, &g_Device,
- &g_DescriptorPool);
+ SetupVulkan(iree_vulkan_features, layers.data(),
+ static_cast<uint32_t>(layers.size()), extensions.data(),
+ static_cast<uint32_t>(extensions.size()), g_Allocator,
+ &g_Instance, &g_QueueFamily, &g_PhysicalDevice, &g_Queue,
+ &g_Device, &g_DescriptorPool);
// Create Window Surface
VkSurfaceKHR surface;
@@ -174,9 +176,6 @@
// --------------------------------------------------------------------------
// Setup IREE.
- // This call to |iree_api_init| is not technically required, but it is
- // included for completeness.
- IREE_CHECK_OK(iree_api_init(&argc, &argv));
// Check API version.
iree_api_version_t actual_version;
@@ -188,7 +187,8 @@
IREE_LOG(FATAL) << "Unsupported runtime API version " << actual_version;
}
- // Register HAL module types.
+ // Register HAL drivers and VM module types.
+ IREE_CHECK_OK(iree_hal_vulkan_driver_module_register());
IREE_CHECK_OK(iree_hal_module_register_types());
// Create a runtime Instance.
diff --git a/iree/schemas/BUILD b/iree/schemas/BUILD
index abe8fa7..98fd357 100644
--- a/iree/schemas/BUILD
+++ b/iree/schemas/BUILD
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load("//iree:build_defs.oss.bzl", "FLATBUFFER_SUPPORTS_REFLECTIONS", "iree_build_test", "iree_flatbuffer_cc_library")
+load("//iree:build_defs.oss.bzl", "iree_build_test")
load("//build_tools/bazel:iree_flatcc.bzl", "iree_flatbuffer_c_library")
-load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
package(
default_visibility = ["//visibility:public"],
@@ -22,89 +21,57 @@
licenses = ["notice"], # Apache 2.0
)
-FLATC_ARGS = [
- # Preserve workspace-relative include paths in generated code.
- "--keep-prefix",
- # Use C++11 'enum class' for enums.
- "--scoped-enums",
- # Include reflection tables used for dumping debug representations.
- "--reflect-names",
- # Generate FooT types for unpack/pack support. Note that this should only
- # be used in tooling as the code size/runtime overhead is non-trivial.
- "--gen-object-api",
+FLATCC_ARGS = [
+ "--reader",
+ "--builder",
+ "--verifier",
+ "--json",
]
iree_flatbuffer_c_library(
name = "bytecode_module_def_c_fbs",
srcs = ["bytecode_module_def.fbs"],
- flatcc_args = [
- "--reader",
- "--builder",
- "--verifier",
- ],
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "bytecode_module_def_cc_fbs",
- srcs = ["bytecode_module_def.fbs"],
- flatc_args = FLATC_ARGS,
-)
-
-iree_flatbuffer_cc_library(
- name = "dylib_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "dylib_executable_def_c_fbs",
srcs = ["dylib_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "llvmir_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "llvmir_executable_def_c_fbs",
srcs = ["llvmir_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "metal_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "metal_executable_def_c_fbs",
srcs = ["metal_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "spirv_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "spirv_executable_def_c_fbs",
srcs = ["spirv_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_cc_library(
- name = "vmla_executable_def_cc_fbs",
+iree_flatbuffer_c_library(
+ name = "vmla_executable_def_c_fbs",
srcs = ["vmla_executable_def.fbs"],
- flatc_args = FLATC_ARGS,
+ flatcc_args = FLATCC_ARGS,
)
iree_build_test(
name = "schema_build_test",
targets = [
- ":bytecode_module_def_cc_fbs",
- ":dylib_executable_def_cc_fbs",
- ":llvmir_executable_def_cc_fbs",
- ":metal_executable_def_cc_fbs",
- ":spirv_executable_def_cc_fbs",
- ":vmla_executable_def_cc_fbs",
+ ":bytecode_module_def_c_fbs",
+ ":dylib_executable_def_c_fbs",
+ ":llvmir_executable_def_c_fbs",
+ ":metal_executable_def_c_fbs",
+ ":spirv_executable_def_c_fbs",
+ ":vmla_executable_def_c_fbs",
],
)
-
-REFLECTION_SRCS = [] if not FLATBUFFER_SUPPORTS_REFLECTIONS else [
- "bytecode_module_def.bfbs",
- "dylib_executable_def.bfbs",
- "llvmir_executable_def.bfbs",
- "metal_executable_def.bfbs",
- "spirv_executable_def.bfbs",
- "vmla_executable_def.bfbs",
-]
-
-cc_embed_data(
- name = "reflection_data",
- srcs = REFLECTION_SRCS,
- cc_file_output = "reflection_data.cc",
- cpp_namespace = "iree::schemas",
- h_file_output = "reflection_data.h",
-)
diff --git a/iree/schemas/CMakeLists.txt b/iree/schemas/CMakeLists.txt
index e62b2e2..68d435c 100644
--- a/iree/schemas/CMakeLists.txt
+++ b/iree/schemas/CMakeLists.txt
@@ -23,95 +23,71 @@
"--reader"
"--builder"
"--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- bytecode_module_def_cc_fbs
- SRCS
- "bytecode_module_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
- PUBLIC
-)
-
-flatbuffer_cc_library(
- NAME
- dylib_executable_def_cc_fbs
+ dylib_executable_def_c_fbs
SRCS
"dylib_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- llvmir_executable_def_cc_fbs
+ llvmir_executable_def_c_fbs
SRCS
"llvmir_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- metal_executable_def_cc_fbs
+ metal_executable_def_c_fbs
SRCS
"metal_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- spirv_executable_def_cc_fbs
+ spirv_executable_def_c_fbs
SRCS
"spirv_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
-flatbuffer_cc_library(
+flatbuffer_c_library(
NAME
- vmla_executable_def_cc_fbs
+ vmla_executable_def_c_fbs
SRCS
"vmla_executable_def.fbs"
- FLATC_ARGS
- "--keep-prefix"
- "--scoped-enums"
- "--reflect-names"
- "--gen-object-api"
- PUBLIC
-)
-
-iree_cc_embed_data(
- NAME
- reflection_data
- CC_FILE_OUTPUT
- "reflection_data.cc"
- H_FILE_OUTPUT
- "reflection_data.h"
- CPP_NAMESPACE
- "iree::schemas"
+ FLATCC_ARGS
+ "--reader"
+ "--builder"
+ "--verifier"
+ "--json"
PUBLIC
)
diff --git a/iree/schemas/dylib_executable_def.fbs b/iree/schemas/dylib_executable_def.fbs
index 5690096..8c30493 100644
--- a/iree/schemas/dylib_executable_def.fbs
+++ b/iree/schemas/dylib_executable_def.fbs
@@ -26,10 +26,10 @@
// An embedded (as opposed to external) dynamic library file.
// TODO(scotttodd): List of embedded files?
// TODO(scotttodd): Format of files, platform information (x86/arm/etc.)
- library_embedded:[byte];
+ library_embedded:[ubyte];
debug_database_filename:string;
- debug_database_embedded:[byte];
+ debug_database_embedded:[ubyte];
// TODO(scotttodd): Relative file path from this flatbuffer file
}
diff --git a/iree/schemas/llvmir_executable_def.fbs b/iree/schemas/llvmir_executable_def.fbs
index 77c0b61..5897900 100644
--- a/iree/schemas/llvmir_executable_def.fbs
+++ b/iree/schemas/llvmir_executable_def.fbs
@@ -25,7 +25,7 @@
// A map of entry points to string names with the same order as in the executable op.
entry_points:[string];
// A serialized llvm::Module object.
- llvmir_module:[byte];
+ bitcode_module:[ubyte];
}
root_type LLVMIRExecutableDef;
diff --git a/iree/schemas/metal_executable_def.fbs b/iree/schemas/metal_executable_def.fbs
index 7ad4b29..50f0eff 100644
--- a/iree/schemas/metal_executable_def.fbs
+++ b/iree/schemas/metal_executable_def.fbs
@@ -29,10 +29,6 @@
// This information is used to create MTLLibrary, MTLFunction and pipeline
// state objects.
table MetalExecutableDef {
- // Reserved implementation-specific value that can be passed from compiler
- // to runtime.
- tag:string;
-
// A map of entry point ordinals to string names as used in the shader
// library.
entry_points:[string];
diff --git a/iree/schemas/spirv_executable_def.fbs b/iree/schemas/spirv_executable_def.fbs
index 71cf330..250ea84 100644
--- a/iree/schemas/spirv_executable_def.fbs
+++ b/iree/schemas/spirv_executable_def.fbs
@@ -46,10 +46,6 @@
// This information is used to create the VkShaderModule, VkPipelineLayout, and
// any required VkDescriptorSetLayouts.
table SpirVExecutableDef {
- // Reserved implementation-specific value that can be passed from compiler to
- // runtime.
- tag:string;
-
// A map of entry point ordinals to string names as used in the shader module.
entry_points:[string];
diff --git a/iree/schemas/vmla_executable_def.fbs b/iree/schemas/vmla_executable_def.fbs
index 455b506..e85972c 100644
--- a/iree/schemas/vmla_executable_def.fbs
+++ b/iree/schemas/vmla_executable_def.fbs
@@ -27,7 +27,7 @@
// We embed the entire flatbuffer contents opaquely to allow for easier
// manipulation of the files (such that we can just slice out the bytes and
// dump them to a file, etc).
- bytecode_module:[byte];
+ bytecode_module:[ubyte];
}
root_type VMLAExecutableDef;
diff --git a/iree/test/e2e/hackability/BUILD b/iree/test/e2e/hackability/BUILD
new file mode 100644
index 0000000..5964b42
--- /dev/null
+++ b/iree/test/e2e/hackability/BUILD
@@ -0,0 +1,35 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Tests for end-to-end IREE support of specific features to prevent regression.
+# These should focus on support by IREE itself, not for issues with specific runner tools. Place
+# those tests in https://github.com/google/iree/tree/main/iree/tools/test/
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-run-mlir",
+ ],
+ tags = ["hostonly"],
+)
diff --git a/iree/test/e2e/hackability/CMakeLists.txt b/iree/test/e2e/hackability/CMakeLists.txt
new file mode 100644
index 0000000..e2b4157
--- /dev/null
+++ b/iree/test/e2e/hackability/CMakeLists.txt
@@ -0,0 +1,28 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-run-mlir
+ LABELS
+ "hostonly"
+)
diff --git a/iree/test/e2e/hackability/flow_partitioned.mlir b/iree/test/e2e/hackability/flow_partitioned.mlir
new file mode 100644
index 0000000..19ba210
--- /dev/null
+++ b/iree/test/e2e/hackability/flow_partitioned.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-run-mlir -export-all -iree-hal-target-backends=vmla %s | IreeFileCheck %s
+// RUN: [[ $IREE_LLVMJIT_DISABLE == 1 ]] || (iree-run-mlir -export-all -iree-hal-target-backends=llvm-ir %s | IreeFileCheck %s)
+
+flow.executable @ex0 {
+ flow.dispatch.entry @dispatch0 attributes {workload = 4 : index}
+ module {
+ func @dispatch0(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = mhlo.add %arg0, %arg0 : tensor<4xf32>
+ return %0 : tensor<4xf32>
+ }
+ }
+}
+
+// CHECK-LABEL: EXEC @staticShapedFn
+func @staticShapedFn() -> tensor<4xf32> {
+ %input = iree.unfoldable_constant dense<[-1.0, 2.0, -3.0, 4.0]> : tensor<4xf32>
+ %workload = constant 4 : index
+ %0 = flow.dispatch @ex0::@dispatch0[%workload : index](%input) : (tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+// CHECK: 4xf32=-2 4 -6 8
diff --git a/iree/testing/BUILD b/iree/testing/BUILD
index 0ad9324..f31a4d5 100644
--- a/iree/testing/BUILD
+++ b/iree/testing/BUILD
@@ -25,7 +25,7 @@
testonly = True,
srcs = ["benchmark_main.cc"],
deps = [
- "//iree/base:init",
+ "//iree/base:flags",
"@com_google_benchmark//:benchmark",
],
)
@@ -52,7 +52,7 @@
tags = ["keep_dep"],
deps = [
":gtest",
- "//iree/base:init",
+ "//iree/base:flags",
"@com_google_googletest//:gtest",
],
)
diff --git a/iree/testing/CMakeLists.txt b/iree/testing/CMakeLists.txt
index bdb2c6f..c790757 100644
--- a/iree/testing/CMakeLists.txt
+++ b/iree/testing/CMakeLists.txt
@@ -21,7 +21,7 @@
"benchmark_main.cc"
DEPS
benchmark
- iree::base::init
+ iree::base::flags
TESTONLY
PUBLIC
)
@@ -51,7 +51,7 @@
::gtest
gmock
gtest
- iree::base::init
+ iree::base::flags
TESTONLY
PUBLIC
)
diff --git a/iree/testing/benchmark_main.cc b/iree/testing/benchmark_main.cc
index 50df16d..690f456 100644
--- a/iree/testing/benchmark_main.cc
+++ b/iree/testing/benchmark_main.cc
@@ -13,13 +13,13 @@
// limitations under the License.
#include "benchmark/benchmark.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
namespace iree {
extern "C" int main(int argc, char** argv) {
::benchmark::Initialize(&argc, argv);
- iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
::benchmark::RunSpecifiedBenchmarks();
return 0;
}
diff --git a/iree/testing/gtest_main.cc b/iree/testing/gtest_main.cc
index 6fc7e78..d665aa8 100644
--- a/iree/testing/gtest_main.cc
+++ b/iree/testing/gtest_main.cc
@@ -13,10 +13,10 @@
// limitations under the License.
#include "gtest/gtest.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
extern "C" int main(int argc, char** argv) {
- ::iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
diff --git a/iree/testing/vulkan/BUILD b/iree/testing/vulkan/BUILD
index 4e235aa..3f40467 100644
--- a/iree/testing/vulkan/BUILD
+++ b/iree/testing/vulkan/BUILD
@@ -12,11 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-load(
- "//iree:build_defs.oss.bzl",
- "PLATFORM_VULKAN_DEPS",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
@@ -37,7 +32,7 @@
"//iree/hal/vulkan:api",
"@dear_imgui",
"@dear_imgui//:imgui_sdl_vulkan",
- "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
+ "@iree_vulkan_headers//:vulkan_headers",
"@sdl2//:SDL2",
"@vulkan_sdk//:sdk",
],
@@ -58,17 +53,17 @@
],
deps = [
":vulkan_gui_util",
- "//iree/base:init",
+ "//iree/base:flags",
"//iree/base:main",
"//iree/base:status",
"//iree/base:tracing",
+ "//iree/hal/vulkan/registration",
"//iree/modules/hal",
- "//iree/hal/vulkan:vulkan_driver_module",
"//iree/tools/utils:vm_util",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/flags:flag",
"@sdl2//:SDL2",
- ] + PLATFORM_VULKAN_DEPS,
+ ],
)
diff --git a/iree/testing/vulkan/CMakeLists.txt b/iree/testing/vulkan/CMakeLists.txt
index 6ceb5dc..a4e6a03 100644
--- a/iree/testing/vulkan/CMakeLists.txt
+++ b/iree/testing/vulkan/CMakeLists.txt
@@ -45,12 +45,6 @@
Vulkan::Vulkan
)
-if(${CMAKE_HOST_SYSTEM_NAME} STREQUAL "Windows")
- set(_GUI_LINKOPTS "-SUBSYSTEM:WINDOWS")
-else()
- set(_GUI_LINKOPTS "")
-endif()
-
iree_cc_binary(
NAME
iree-run-module-vulkan-gui
@@ -60,15 +54,15 @@
::vulkan_gui_util
absl::flags
iree::base::file_io
- iree::base::init
+ iree::base::flags
iree::base::main
iree::base::status
iree::base::tracing
- iree::hal::vulkan::vulkan_driver_module
+ iree::hal::vulkan::registration
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
iree::vm::bytecode_module
LINKOPTS
- "${_GUI_LINKOPTS}"
+ "${IREE_TARGET_GUI_LINKOPTS}"
)
diff --git a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
index 4ee342e..eec2dfe 100644
--- a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
+++ b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
@@ -20,9 +20,10 @@
// Other dependencies (helpers, etc.)
#include "absl/flags/flag.h"
#include "iree/base/file_io.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/main.h"
#include "iree/base/status.h"
+#include "iree/hal/vulkan/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
@@ -146,7 +147,8 @@
} // namespace iree
int iree::IreeMain(int argc, char** argv) {
- iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
+ IREE_CHECK_OK(iree_hal_vulkan_driver_module_register());
// --------------------------------------------------------------------------
// Create a window.
@@ -248,9 +250,6 @@
// --------------------------------------------------------------------------
// Setup IREE.
- // This call to |iree_api_init| is not technically required, but it is
- // included for completeness.
- IREE_CHECK_OK(iree_api_init(&argc, &argv));
// Check API version.
iree_api_version_t actual_version;
diff --git a/iree/testing/vulkan/vulkan_gui_util.cc b/iree/testing/vulkan/vulkan_gui_util.cc
index 5bb7c15..4c93bb3 100644
--- a/iree/testing/vulkan/vulkan_gui_util.cc
+++ b/iree/testing/vulkan/vulkan_gui_util.cc
@@ -194,12 +194,13 @@
VkQueueFamilyProperties* queues = (VkQueueFamilyProperties*)malloc(
sizeof(VkQueueFamilyProperties) * count);
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, queues);
- for (uint32_t i = 0; i < count; i++)
+ for (uint32_t i = 0; i < count; i++) {
if (queues[i].queueFlags &
(VK_QUEUE_GRAPHICS_BIT | VK_QUEUE_COMPUTE_BIT)) {
*queue_family_index = i;
break;
}
+ }
free(queues);
IM_ASSERT(*queue_family_index != (uint32_t)-1);
}
@@ -218,7 +219,8 @@
create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
create_info.queueCreateInfoCount = 1;
create_info.pQueueCreateInfos = &queue_info;
- create_info.enabledExtensionCount = device_extensions.size();
+ create_info.enabledExtensionCount =
+ static_cast<uint32_t>(device_extensions.size());
create_info.ppEnabledExtensionNames = device_extensions.data();
err = vkCreateDevice(*physical_device, &create_info, allocator, device);
check_vk_result(err);
diff --git a/iree/testing/vulkan/vulkan_gui_util.h b/iree/testing/vulkan/vulkan_gui_util.h
index ea097cd..c43cc5d 100644
--- a/iree/testing/vulkan/vulkan_gui_util.h
+++ b/iree/testing/vulkan/vulkan_gui_util.h
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// IREE's Vulkan HAL is built with VK_NO_PROTOTYPES so Vulkan can be loaded
-// dynamically. Using utilities defined in this header means to link against
-// the Vulkan SDK statically, so we want prototypes to be included.
-#undef VK_NO_PROTOTYPES
-
#include <SDL.h>
#include <SDL_vulkan.h>
#include <vulkan/vulkan.h>
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 27fc2fb..fb581b1 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -15,13 +15,6 @@
# Misc tools used to optimize, translate, and evaluate IREE.
# Compiler tooling, like the compiler, is not designed to run on device and is tagged as "hostonly".
-load(
- "//iree:build_defs.oss.bzl",
- "IREE_DRIVER_MODULES",
- "PLATFORM_VULKAN_DEPS",
- iree_cc_binary = "cc_binary",
-)
-
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
@@ -33,36 +26,34 @@
"sanitizer_suppressions.txt",
])
-iree_cc_binary(
+cc_binary(
name = "iree-benchmark-module",
testonly = True,
srcs = ["iree-benchmark-module-main.cc"],
deps = [
+ "//iree/base:file_io",
+ "//iree/base:flags",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/hal/drivers",
+ "//iree/modules/hal",
+ "//iree/tools/utils:vm_util",
+ "//iree/vm",
+ "//iree/vm:bytecode_module",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/flags:usage",
"@com_google_absl//absl/strings",
"@com_google_benchmark//:benchmark",
- "//iree/base:init",
- "//iree/base:file_io",
- "//iree/base:status",
- "//iree/base:tracing",
- "//iree/modules/hal",
- "//iree/tools/utils:vm_util",
- "//iree/vm",
- "//iree/vm:bytecode_module",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
-iree_cc_binary(
+cc_binary(
name = "iree-dump-module",
srcs = ["iree-dump-module-main.cc"],
deps = [
- "//iree/base:file_io", # build-cleaner: keep
- "//iree/base:flatbuffer_util",
- "//iree/base:init",
- "//iree/schemas:bytecode_module_def_cc_fbs",
- "@com_github_google_flatbuffers//:flatbuffers",
+ "//iree/base:file_mapping",
+ "//iree/schemas:bytecode_module_def_c_fbs",
],
)
@@ -211,7 +202,7 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "iree-opt",
tags = ["hostonly"],
deps = [
@@ -219,17 +210,16 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "iree-run-mlir",
srcs = ["iree-run-mlir-main.cc"],
tags = ["hostonly"],
deps = [
":init_passes_and_dialects",
":init_targets",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
"//iree/base:api",
+ "//iree/base:flags",
+ "//iree/base:status",
"//iree/base:tracing",
"//iree/compiler/Dialect/Flow/Transforms",
"//iree/compiler/Dialect/HAL/Transforms",
@@ -237,45 +227,52 @@
"//iree/compiler/Dialect/VM/Target:init_targets",
"//iree/compiler/Dialect/VM/Target/Bytecode",
"//iree/compiler/Dialect/VM/Transforms",
+ "//iree/compiler/Translation:IREEVM",
"//iree/hal:api",
+ "//iree/hal/drivers",
"//iree/modules/hal",
"//iree/tools/utils:vm_util",
"//iree/vm",
"//iree/vm:bytecode_module",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
- "//iree/base:init",
- "//iree/base:status",
- "//iree/compiler/Translation:IREEVM",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ ],
)
-iree_cc_binary(
+cc_binary(
name = "iree-run-module",
srcs = ["iree-run-module-main.cc"],
deps = [
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
"//iree/base:file_io",
- "//iree/base:init",
+ "//iree/base:flags",
"//iree/base:status",
"//iree/base:tracing",
+ "//iree/hal/drivers",
"//iree/modules/hal",
"//iree/tools/utils:vm_util",
"//iree/vm",
"//iree/vm:bytecode_module",
- ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ ],
)
-iree_cc_binary(
+cc_binary(
name = "iree-tblgen",
+ srcs = [
+ "//iree/compiler/Dialect/IREE/Tools:GenSrcs",
+ "//iree/compiler/Dialect/VM/Tools:GenSrcs",
+ ],
tags = ["hostonly"],
deps = [
- "//iree/compiler/Dialect/IREE/Tools",
- "//iree/compiler/Dialect/VM/Tools",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:TableGen",
"@llvm-project//mlir:MlirTableGenMain",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TableGen",
@@ -305,7 +302,7 @@
],
)
-iree_cc_binary(
+cc_binary(
name = "iree-translate",
tags = ["hostonly"],
deps = [
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 68f5d20..de1c9b9 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -68,15 +68,15 @@
absl::flags_usage
absl::strings
benchmark
- iree::base::init
+ iree::base::flags
iree::base::file_io
iree::base::status
iree::base::tracing
+ iree::hal::drivers
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
iree::vm::bytecode_module
- ${IREE_HAL_DRIVER_MODULES}
TESTONLY
)
@@ -88,11 +88,9 @@
SRCS
"iree-dump-module-main.cc"
DEPS
- flatbuffers
- iree::base::file_io
- iree::base::flatbuffer_util
- iree::base::init
- iree::schemas::bytecode_module_def_cc_fbs
+ flatcc::runtime
+ iree::base::file_mapping
+ iree::schemas::bytecode_module_def_c_fbs
)
iree_cc_binary(
@@ -106,14 +104,14 @@
absl::flags
absl::strings
iree::base::file_io
- iree::base::init
+ iree::base::flags
iree::base::status
iree::base::tracing
+ iree::hal::drivers
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
iree::vm::bytecode_module
- ${IREE_HAL_DRIVER_MODULES}
)
if(${IREE_ENABLE_MLIR})
@@ -122,13 +120,14 @@
iree-tblgen
SRCS
"${IREE_ROOT_DIR}/third_party/llvm-project/mlir/tools/mlir-tblgen/mlir-tblgen.cpp"
+ "${IREE_SOURCE_DIR}/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp"
+ "${IREE_SOURCE_DIR}/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp"
+ "${IREE_SOURCE_DIR}/iree/compiler/Dialect/VM/Tools/VMOpTableGen.cpp"
DEPS
+ LLVMSupport
+ LLVMTableGen
MLIRSupport
MLIRTableGen
- iree::compiler::Dialect::IREE::Tools
- iree::compiler::Dialect::VM::Tools
- LINKOPTS
- "-lpthread"
HOSTONLY
)
endif()
@@ -328,8 +327,7 @@
absl::span
absl::strings
iree::base::api
- iree::base::init
- iree::base::source_location
+ iree::base::flags
iree::base::status
iree::base::tracing
iree::compiler::Dialect::Flow::Transforms
@@ -340,11 +338,11 @@
iree::compiler::Dialect::VM::Transforms
iree::compiler::Translation::IREEVM
iree::hal::api
+ iree::hal::drivers
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
iree::vm::bytecode_module
- ${IREE_HAL_DRIVER_MODULES}
HOSTONLY
)
endif(${IREE_BUILD_COMPILER})
diff --git a/iree/tools/android/run_module_app/CMakeLists.txt b/iree/tools/android/run_module_app/CMakeLists.txt
index 42c2f23..2c90920 100644
--- a/iree/tools/android/run_module_app/CMakeLists.txt
+++ b/iree/tools/android/run_module_app/CMakeLists.txt
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-if (NOT ANDROID)
+if(NOT ANDROID)
return()
endif()
@@ -40,15 +40,14 @@
DEPS
::android_native_app_glue
absl::strings
- iree::base::initializer
iree::base::status
+ iree::hal::drivers
iree::modules::hal
iree::tools::utils::vm_util
iree::vm
- ${IREE_HAL_DRIVER_MODULES}
LINKOPTS
- android
- log
+ "-landroid"
+ "-llog"
SHARED
WHOLEARCHIVE
)
diff --git a/iree/tools/android/run_module_app/src/main.cc b/iree/tools/android/run_module_app/src/main.cc
index 7250ab3..b123626 100644
--- a/iree/tools/android/run_module_app/src/main.cc
+++ b/iree/tools/android/run_module_app/src/main.cc
@@ -20,8 +20,8 @@
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
-#include "iree/base/initializer.h"
#include "iree/base/status.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
@@ -166,7 +166,7 @@
// trigger the workload.
std::this_thread::sleep_for(std::chrono::seconds(2));
- IREE_RUN_MODULE_INITIALIZERS();
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
ModuleLoader loader(app);
StatusOr<IreeModuleInvocation> invocation = loader.LoadModuleInvocation();
diff --git a/iree/tools/init_xla_dialects.h b/iree/tools/init_xla_dialects.h
index 04c5a97..cae6b31 100644
--- a/iree/tools/init_xla_dialects.h
+++ b/iree/tools/init_xla_dialects.h
@@ -18,10 +18,10 @@
#ifndef IREE_TOOLS_INIT_XLA_DIALECTS_H_
#define IREE_TOOLS_INIT_XLA_DIALECTS_H_
+#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir/IR/Dialect.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index 235895e..e6afebd 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -18,9 +18,10 @@
#include "absl/strings/string_view.h"
#include "benchmark/benchmark.h"
#include "iree/base/file_io.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
@@ -223,7 +224,7 @@
iree_vm_function_t function;
iree_vm_module_signature_t signature =
input_module_->signature(input_module_->self);
- for (int i = 0; i < signature.export_function_count; ++i) {
+ for (iree_host_size_t i = 0; i < signature.export_function_count; ++i) {
iree_string_view_t name;
IREE_CHECK_OK(input_module_->get_function(input_module_->self,
IREE_VM_FUNCTION_LINKAGE_EXPORT,
@@ -295,7 +296,8 @@
absl::flags_internal::UsageFlagsAction::kHandleUsage,
absl::flags_internal::OnUndefinedFlag::kIgnoreUndefined);
::benchmark::Initialize(&argc, argv);
- iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
iree::IREEBenchmark iree_benchmark;
auto status = iree_benchmark.Register();
diff --git a/iree/tools/iree-dump-module-main.cc b/iree/tools/iree-dump-module-main.cc
index 2fd0242..5ed60ad 100644
--- a/iree/tools/iree-dump-module-main.cc
+++ b/iree/tools/iree-dump-module-main.cc
@@ -16,163 +16,36 @@
#include <string>
#include <utility>
-#include "flatbuffers/flatbuffers.h"
-#include "flatbuffers/minireflect.h"
-#include "flatbuffers/reflection.h"
-#include "flatbuffers/util.h"
-#include "iree/base/flatbuffer_util.h"
-#include "iree/base/init.h"
-#include "iree/schemas/bytecode_module_def_generated.h"
+#include "iree/base/file_mapping.h"
+#include "iree/schemas/bytecode_module_def_json_printer.h"
-namespace {
-
-using ::flatbuffers::ElementaryType;
-using ::flatbuffers::NumToString;
-using ::flatbuffers::String;
-using ::flatbuffers::TypeTable;
-
-struct TruncatingToStringVisitor : public ::flatbuffers::IterationVisitor {
- std::string s;
- std::string d;
-
- bool is_truncating_vector = false;
- int vector_depth = 0;
-
- explicit TruncatingToStringVisitor(std::string delimiter)
- : d(std::move(delimiter)) {}
-
- void StartSequence() override {
- if (is_truncating_vector) return;
- s += "{";
- s += d;
- }
- void EndSequence() override {
- if (is_truncating_vector) return;
- s += d;
- s += "}";
- }
- void Field(size_t field_idx, size_t set_idx, ElementaryType type,
- bool is_vector, const TypeTable* type_table, const char* name,
- const uint8_t* val) override {
- if (is_truncating_vector) return;
- if (!val) return;
- if (set_idx) {
- s += ",";
- s += d;
- }
- if (name) {
- s += name;
- s += ": ";
- }
- }
- template <typename T>
- void Named(T x, const char* name) {
- if (name) {
- s += name;
- } else {
- s += NumToString(x);
- }
- }
- void UType(uint8_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void Bool(bool x) override {
- if (is_truncating_vector) return;
- s += x ? "true" : "false";
- }
- void Char(int8_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void UChar(uint8_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void Short(int16_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void UShort(uint16_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void Int(int32_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void UInt(uint32_t x, const char* name) override {
- if (is_truncating_vector) return;
- Named(x, name);
- }
- void Long(int64_t x) override {
- if (is_truncating_vector) return;
- s += NumToString(x);
- }
- void ULong(uint64_t x) override {
- if (is_truncating_vector) return;
- s += NumToString(x);
- }
- void Float(float x) override {
- if (is_truncating_vector) return;
- s += NumToString(x);
- }
- void Double(double x) override {
- if (is_truncating_vector) return;
- s += NumToString(x);
- }
- void String(const struct String* str) override {
- if (is_truncating_vector) return;
- ::flatbuffers::EscapeString(str->c_str(), str->size(), &s, true, false);
- }
- void Unknown(const uint8_t*) override {
- if (is_truncating_vector) return;
- s += "(?)";
- }
- void StartVector() override {
- ++vector_depth;
- if (is_truncating_vector) return;
- s += "[ ";
- }
- void EndVector() override {
- --vector_depth;
- if (vector_depth == 0) {
- is_truncating_vector = false;
- }
- if (is_truncating_vector) return;
- s += " ]";
- }
- void Element(size_t i, ElementaryType type, const TypeTable* type_table,
- const uint8_t* val) override {
- if (is_truncating_vector) return;
- if (i > 1024) {
- if (!is_truncating_vector) {
- s += ", ...";
- is_truncating_vector = true;
- }
- } else if (i) {
- s += ", ";
- }
- }
-};
-
-} // namespace
-
+// Today we just print to JSON. We could do something more useful (size
+// analysis, etc), but JSON should be enough.
+//
+// We could also move all of this into iree-translate (mlir -> vmfb -> json),
+// though having a tiny little tool not reliant on LLVM is nice (can run this
+// on a device).
extern "C" int main(int argc, char** argv) {
- iree::InitializeEnvironment(&argc, &argv);
-
if (argc < 2) {
- std::cerr << "Syntax: iree-dump-module filename\n";
+ std::cerr << "Syntax: iree-dump-module module.vmfb > module.json\n";
return 1;
}
- std::string module_path = argv[1];
- auto module_fb = iree::FlatBufferFile<iree::vm::BytecodeModuleDef>::LoadFile(
- iree::vm::BytecodeModuleDefIdentifier(), module_path)
- .value();
- TruncatingToStringVisitor tos_visitor("\n");
- auto object = reinterpret_cast<const uint8_t*>(module_fb->root());
- flatbuffers::IterateObject(object, module_fb->root()->MiniReflectTypeTable(),
- &tos_visitor);
- std::cout << tos_visitor.s << std::endl;
+ auto module_mapping_or = iree::FileMapping::OpenRead(argv[1]);
+ if (!module_mapping_or.ok()) {
+ std::cerr << module_mapping_or.status();
+ return 1;
+ }
+ auto module_mapping = std::move(module_mapping_or.value());
+ auto module_buffer = module_mapping->data();
+
+ // Print direct to stdout.
+ flatcc_json_printer_t printer;
+ flatcc_json_printer_init(&printer, /*fp=*/nullptr);
+ flatcc_json_printer_set_skip_default(&printer, true);
+ bytecode_module_def_print_json(
+ &printer, reinterpret_cast<const char*>(module_buffer.data()),
+ module_buffer.size());
+ flatcc_json_printer_clear(&printer);
+
return 0;
}
diff --git a/iree/tools/iree-run-mlir-main.cc b/iree/tools/iree-run-mlir-main.cc
index b0fb141..2e7cc17 100644
--- a/iree/tools/iree-run-mlir-main.cc
+++ b/iree/tools/iree-run-mlir-main.cc
@@ -44,7 +44,7 @@
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "iree/base/api.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
@@ -56,6 +56,7 @@
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
#include "iree/compiler/Translation/IREEVM.h"
#include "iree/hal/api.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/init_dialects.h"
#include "iree/tools/init_targets.h"
@@ -165,7 +166,7 @@
iree_host_size_t driver_count = 0;
IREE_RETURN_IF_ERROR(iree_hal_driver_registry_query_available_drivers(
iree_allocator_system(), &driver_names, &driver_count));
- for (int i = 0; i < driver_count; ++i) {
+ for (iree_host_size_t i = 0; i < driver_count; ++i) {
target_backends.push_back(
std::string(driver_names[i].data, driver_names[i].size));
}
@@ -378,7 +379,8 @@
Status evaluate_status = OkStatus();
auto module_signature = iree_vm_module_signature(bytecode_module);
- for (int i = 0; i < module_signature.export_function_count; ++i) {
+ for (iree_host_size_t i = 0; i < module_signature.export_function_count;
+ ++i) {
evaluate_status = run_function(i);
if (!evaluate_status.ok()) {
break;
@@ -521,7 +523,8 @@
}
argc_absl += run_args_flag.size();
char** argv_absl_ptr = argv_absl.data();
- iree::InitializeEnvironment(&argc_absl, &argv_absl_ptr);
+ iree_flags_parse_checked(&argc_absl, &argv_absl_ptr);
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
auto status = RunFile(input_file_flag, registry);
if (!status.ok()) {
diff --git a/iree/tools/iree-run-module-main.cc b/iree/tools/iree-run-module-main.cc
index 66c2fcc..dec0f6a 100644
--- a/iree/tools/iree-run-module-main.cc
+++ b/iree/tools/iree-run-module-main.cc
@@ -17,9 +17,10 @@
#include "absl/flags/flag.h"
#include "absl/strings/string_view.h"
#include "iree/base/file_io.h"
-#include "iree/base/init.h"
+#include "iree/base/flags.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
+#include "iree/hal/drivers/init.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/api.h"
@@ -156,7 +157,8 @@
} // namespace
extern "C" int main(int argc, char** argv) {
- iree::InitializeEnvironment(&argc, &argv);
+ iree_flags_parse_checked(&argc, &argv);
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers());
IREE_CHECK_OK(Run());
return 0;
}
diff --git a/iree/tools/utils/BUILD b/iree/tools/utils/BUILD
index 4aa8c6b..6d630f7 100644
--- a/iree/tools/utils/BUILD
+++ b/iree/tools/utils/BUILD
@@ -31,7 +31,7 @@
"//iree/modules/hal",
"//iree/vm",
"//iree/vm:bytecode_module",
- "//iree/vm:ref_cc",
+ "//iree/vm:cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@@ -44,7 +44,7 @@
":vm_util",
"//iree/base:api",
"//iree/hal:api",
- "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vmla/registration",
"//iree/modules/hal",
"//iree/testing:gtest",
"//iree/testing:gtest_main",
diff --git a/iree/tools/utils/CMakeLists.txt b/iree/tools/utils/CMakeLists.txt
index b0a9f7b..c684006 100644
--- a/iree/tools/utils/CMakeLists.txt
+++ b/iree/tools/utils/CMakeLists.txt
@@ -31,7 +31,7 @@
iree::modules::hal
iree::vm
iree::vm::bytecode_module
- iree::vm::ref_cc
+ iree::vm::cc
PUBLIC
)
@@ -45,7 +45,7 @@
absl::strings
iree::base::api
iree::hal::api
- iree::hal::vmla::vmla_driver_module
+ iree::hal::vmla::registration
iree::modules::hal
iree::testing::gtest
iree::testing::gtest_main
diff --git a/iree/tools/utils/vm_util_test.cc b/iree/tools/utils/vm_util_test.cc
index 17f924b..5fd34d8 100644
--- a/iree/tools/utils/vm_util_test.cc
+++ b/iree/tools/utils/vm_util_test.cc
@@ -19,6 +19,7 @@
#include "absl/strings/str_cat.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
+#include "iree/hal/vmla/registration/driver_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
@@ -29,6 +30,10 @@
class VmUtilTest : public ::testing::Test {
protected:
+ static void SetUpTestSuite() {
+ IREE_CHECK_OK(iree_hal_vmla_driver_module_register());
+ }
+
virtual void SetUp() {
IREE_ASSERT_OK(iree_hal_module_register_types());
IREE_ASSERT_OK(CreateDevice("vmla", &device_));
diff --git a/iree/vm/BUILD b/iree/vm/BUILD
index b5a73ca..f69b517 100644
--- a/iree/vm/BUILD
+++ b/iree/vm/BUILD
@@ -1,4 +1,16 @@
-# Bytecode VM.
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
load("//iree/tools:compilation.bzl", "iree_bytecode_module")
load("//build_tools/bazel:tblgen.bzl", "gentbl")
@@ -9,36 +21,157 @@
licenses = ["notice"], # Apache 2.0
)
+#===------------------------------------------------------------------------===#
+# Public API
+#===------------------------------------------------------------------------===#
+
cc_library(
- name = "builtin_types",
- srcs = ["builtin_types.c"],
- hdrs = ["builtin_types.h"],
+ name = "vm",
+ hdrs = [
+ "api.h",
+ ],
deps = [
- ":list",
- ":ref",
+ ":impl",
+ "//iree/base:api",
+ ],
+)
+
+# TODO(benvanik): make these srcs and only expose an api_cc.h.
+cc_library(
+ name = "cc",
+ hdrs = [
+ "module_abi_packing.h",
+ "native_module_cc.h",
+ "ref_cc.h",
+ ],
+ deps = [
+ ":vm",
+ "//iree/base:api",
+ "//iree/base:ref_ptr",
+ "//iree/base:status",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+#===------------------------------------------------------------------------===#
+# Implementation
+#===------------------------------------------------------------------------===#
+
+cc_library(
+ name = "impl",
+ srcs = [
+ "builtin_types.c",
+ "context.c",
+ "instance.c",
+ "invocation.c",
+ "list.c",
+ "module.c",
+ "native_module.c",
+ "ref.c",
+ "stack.c",
+ ],
+ hdrs = [
+ "builtin_types.h",
+ "context.h",
+ "instance.h",
+ "invocation.h",
+ "list.h",
+ "module.h",
+ "native_module.h",
+ "ref.h",
+ "stack.h",
+ "type_def.h",
+ "value.h",
+ ],
+ deps = [
+ "//iree/base:api",
+ "//iree/base:core_headers",
+ "//iree/base:tracing",
+ ],
+)
+
+cc_test(
+ name = "list_test",
+ srcs = ["list_test.cc"],
+ deps = [
+ ":impl",
+ "//iree/base:api",
+ "//iree/base:ref_ptr",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "native_module_test",
+ srcs = ["native_module_test.cc"],
+ deps = [
+ ":cc",
+ ":impl",
+ ":native_module_test_hdrs",
+ "//iree/base:api",
+ "//iree/base:status",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "native_module_test_hdrs",
+ hdrs = [
+ "native_module_test.h",
+ ],
+ deps = [
+ ":impl",
"//iree/base:api",
],
)
cc_test(
- name = "bytecode_dispatch_test",
- srcs = ["bytecode_dispatch_test.cc"],
+ name = "native_module_benchmark",
+ srcs = ["native_module_benchmark.cc"],
deps = [
- ":builtin_types",
- ":bytecode_module",
- ":context",
- ":instance",
- ":invocation",
- ":module",
+ ":impl",
+ ":native_module_test_hdrs",
+ "//iree/base:api",
"//iree/base:logging",
- "//iree/base:status",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- "//iree/vm/test:all_bytecode_modules_cc",
- "@com_google_absl//absl/strings",
+ "//iree/testing:benchmark_main",
+ "@com_google_benchmark//:benchmark",
],
)
+cc_test(
+ name = "ref_test",
+ srcs = ["ref_test.cc"],
+ deps = [
+ ":impl",
+ "//iree/base:api",
+ "//iree/base:ref_ptr",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "stack_test",
+ srcs = ["stack_test.cc"],
+ deps = [
+ ":impl",
+ "//iree/base:api",
+ "//iree/base:ref_ptr",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ ],
+)
+
+#===------------------------------------------------------------------------===#
+# Bytecode interpreter module
+#===------------------------------------------------------------------------===#
+
cc_library(
name = "bytecode_module",
srcs = [
@@ -52,19 +185,50 @@
"bytecode_module.h",
],
deps = [
- ":builtin_types",
":bytecode_op_table_gen",
- ":list",
- ":module",
- ":ref",
- ":stack",
- ":type_def",
- ":value",
- "//iree/base:alignment",
+ ":vm",
"//iree/base:api",
+ "//iree/base:core_headers",
+ "//iree/base:flatcc",
"//iree/base:tracing",
"//iree/schemas:bytecode_module_def_c_fbs",
- "@com_github_dvidelabs_flatcc//:runtime",
+ ],
+)
+
+# TODO(benvanik): see if we can remove this; not good to have this dep.
+gentbl(
+ name = "bytecode_op_table_gen",
+ tbl_outs = [
+ ("-gen-iree-vm-op-table-defs", "bytecode_op_table.h"),
+ ],
+ tblgen = "//iree/tools:iree-tblgen",
+ td_file = "//iree/compiler/Dialect/VM/IR:VMOps.td",
+ td_srcs = [
+ "//iree/compiler/Dialect/IREE/IR:td_files",
+ "//iree/compiler/Dialect/VM/IR:td_files",
+ "@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
+ "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
+ "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
+ "@llvm-project//mlir:SideEffectTdFiles",
+ ],
+)
+
+cc_test(
+ name = "bytecode_module_test",
+ srcs = [
+ "bytecode_dispatch_test.cc",
+ "bytecode_module_test.cc",
+ ],
+ deps = [
+ ":bytecode_module",
+ ":vm",
+ "//iree/base:logging",
+ "//iree/base:status",
+ "//iree/testing:gtest",
+ "//iree/testing:gtest_main",
+ "//iree/vm/test:all_bytecode_modules_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -74,11 +238,7 @@
deps = [
":bytecode_module",
":bytecode_module_benchmark_module_cc",
- ":context",
- ":instance",
- ":module",
- ":native_module",
- ":stack",
+ ":vm",
"//iree/base:api",
"//iree/base:logging",
"//iree/testing:benchmark_main",
@@ -113,33 +273,9 @@
flags = ["-iree-vm-ir-to-bytecode-module"],
)
-cc_test(
- name = "bytecode_module_test",
- srcs = ["bytecode_module_test.cc"],
- deps = [
- ":bytecode_module",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-gentbl(
- name = "bytecode_op_table_gen",
- tbl_outs = [
- ("-gen-iree-vm-op-table-defs", "bytecode_op_table.h"),
- ],
- tblgen = "//iree/tools:iree-tblgen",
- td_file = "//iree/compiler/Dialect/VM/IR:VMOps.td",
- td_srcs = [
- "//iree/compiler/Dialect/IREE/IR:td_files",
- "//iree/compiler/Dialect/VM/IR:td_files",
- "@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
- "@llvm-project//mlir:include/mlir/Interfaces/ControlFlowInterfaces.td",
- "@llvm-project//mlir:SideEffectTdFiles",
- ],
-)
+#===------------------------------------------------------------------------===#
+# Emit-C modules
+#===------------------------------------------------------------------------===#
cc_library(
name = "c_funcs",
@@ -147,254 +283,3 @@
"c_funcs.h",
],
)
-
-cc_library(
- name = "context",
- srcs = ["context.c"],
- hdrs = ["context.h"],
- deps = [
- ":instance",
- ":module",
- ":stack",
- "//iree/base:api",
- "//iree/base:atomics",
- "//iree/base:tracing",
- ],
-)
-
-cc_library(
- name = "instance",
- srcs = ["instance.c"],
- hdrs = ["instance.h"],
- deps = [
- ":builtin_types",
- "//iree/base:api",
- "//iree/base:atomics",
- "//iree/base:tracing",
- ],
-)
-
-cc_library(
- name = "invocation",
- srcs = ["invocation.c"],
- hdrs = ["invocation.h"],
- deps = [
- ":context",
- ":list",
- ":module",
- "//iree/base:api",
- "//iree/base:tracing",
- ],
-)
-
-cc_library(
- name = "list",
- srcs = ["list.c"],
- hdrs = ["list.h"],
- deps = [
- ":ref",
- ":type_def",
- ":value",
- "//iree/base:alignment",
- "//iree/base:api",
- ],
-)
-
-cc_test(
- name = "list_test",
- srcs = ["list_test.cc"],
- deps = [
- ":builtin_types",
- ":list",
- "//iree/base:api",
- "//iree/base:ref_ptr",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-cc_library(
- name = "module",
- srcs = ["module.c"],
- hdrs = ["module.h"],
- deps = [
- ":ref",
- "//iree/base:alignment",
- "//iree/base:api",
- "//iree/base:atomics",
- "//iree/base:tracing",
- ],
-)
-
-cc_library(
- name = "native_module",
- srcs = ["native_module.c"],
- hdrs = ["native_module.h"],
- deps = [
- ":module",
- ":stack",
- "//iree/base:api",
- ],
-)
-
-cc_test(
- name = "native_module_benchmark",
- srcs = ["native_module_benchmark.cc"],
- deps = [
- ":module",
- ":native_module",
- ":native_module_test_hdrs",
- ":stack",
- "//iree/base:api",
- "//iree/base:logging",
- "//iree/testing:benchmark_main",
- "@com_google_benchmark//:benchmark",
- ],
-)
-
-cc_test(
- name = "native_module_test",
- srcs = ["native_module_test.cc"],
- deps = [
- ":context",
- ":instance",
- ":invocation",
- ":list",
- ":native_module_test_hdrs",
- ":ref_cc",
- "//iree/base:api",
- "//iree/base:status",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-cc_library(
- name = "native_module_test_hdrs",
- hdrs = [
- "native_module_test.h",
- ],
- deps = [
- ":context",
- ":instance",
- ":native_module",
- ":ref",
- ":stack",
- "//iree/base:api",
- ],
-)
-
-cc_library(
- name = "native_module_cc",
- hdrs = [
- "module_abi_packing.h",
- "native_module_cc.h",
- ],
- deps = [
- ":builtin_types",
- ":module",
- ":ref",
- ":ref_cc",
- ":stack",
- "//iree/base:api",
- "//iree/base:ref_ptr",
- "//iree/base:status",
- "@com_google_absl//absl/container:inlined_vector",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
- name = "ref",
- srcs = ["ref.c"],
- hdrs = ["ref.h"],
- deps = [
- "//iree/base:api",
- "//iree/base:atomics",
- ],
-)
-
-cc_test(
- name = "ref_test",
- srcs = ["ref_test.cc"],
- deps = [
- ":ref",
- "//iree/base:api",
- "//iree/base:ref_ptr",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-cc_library(
- name = "ref_cc",
- hdrs = ["ref_cc.h"],
- deps = [
- ":ref",
- "//iree/base:api",
- "@com_google_absl//absl/base:core_headers",
- ],
-)
-
-cc_library(
- name = "stack",
- srcs = ["stack.c"],
- hdrs = ["stack.h"],
- deps = [
- ":module",
- ":ref",
- "//iree/base:alignment",
- "//iree/base:api",
- "//iree/base:tracing",
- ],
-)
-
-cc_test(
- name = "stack_test",
- srcs = ["stack_test.cc"],
- deps = [
- ":ref",
- ":stack",
- "//iree/base:api",
- "//iree/base:ref_ptr",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- ],
-)
-
-cc_library(
- name = "type_def",
- hdrs = ["type_def.h"],
- deps = [
- ":ref",
- ":value",
- ],
-)
-
-cc_library(
- name = "value",
- hdrs = ["value.h"],
-)
-
-cc_library(
- name = "vm",
- hdrs = [
- "api.h",
- ],
- deps = [
- ":builtin_types",
- ":context",
- ":instance",
- ":invocation",
- ":list",
- ":module",
- ":native_module",
- ":ref",
- ":stack",
- ":type_def",
- ":value",
- "//iree/base:api",
- ],
-)
diff --git a/iree/vm/CMakeLists.txt b/iree/vm/CMakeLists.txt
index 48d6544..2546732 100644
--- a/iree/vm/CMakeLists.txt
+++ b/iree/vm/CMakeLists.txt
@@ -16,36 +16,144 @@
iree_cc_library(
NAME
- builtin_types
+ vm
+ HDRS
+ "api.h"
+ DEPS
+ ::impl
+ iree::base::api
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ cc
+ HDRS
+ "module_abi_packing.h"
+ "native_module_cc.h"
+ "ref_cc.h"
+ DEPS
+ ::vm
+ absl::core_headers
+ absl::inlined_vector
+ absl::optional
+ absl::span
+ absl::strings
+ iree::base::api
+ iree::base::ref_ptr
+ iree::base::status
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ impl
HDRS
"builtin_types.h"
+ "context.h"
+ "instance.h"
+ "invocation.h"
+ "list.h"
+ "module.h"
+ "native_module.h"
+ "ref.h"
+ "stack.h"
+ "type_def.h"
+ "value.h"
SRCS
"builtin_types.c"
+ "context.c"
+ "instance.c"
+ "invocation.c"
+ "list.c"
+ "module.c"
+ "native_module.c"
+ "ref.c"
+ "stack.c"
DEPS
- ::list
- ::ref
+ iree::base::api
+ iree::base::core_headers
+ iree::base::tracing
+ PUBLIC
+)
+
+iree_cc_test(
+ NAME
+ list_test
+ SRCS
+ "list_test.cc"
+ DEPS
+ ::impl
+ iree::base::api
+ iree::base::ref_ptr
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
+
+iree_cc_test(
+ NAME
+ native_module_test
+ SRCS
+ "native_module_test.cc"
+ DEPS
+ ::cc
+ ::impl
+ ::native_module_test_hdrs
+ iree::base::api
+ iree::base::status
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
+
+iree_cc_library(
+ NAME
+ native_module_test_hdrs
+ HDRS
+ "native_module_test.h"
+ DEPS
+ ::impl
iree::base::api
PUBLIC
)
iree_cc_test(
NAME
- bytecode_dispatch_test
+ native_module_benchmark
SRCS
- "bytecode_dispatch_test.cc"
+ "native_module_benchmark.cc"
DEPS
- ::builtin_types
- ::bytecode_module
- ::context
- ::instance
- ::invocation
- ::module
- absl::strings
+ ::impl
+ ::native_module_test_hdrs
+ benchmark
+ iree::base::api
iree::base::logging
- iree::base::status
+ iree::testing::benchmark_main
+)
+
+iree_cc_test(
+ NAME
+ ref_test
+ SRCS
+ "ref_test.cc"
+ DEPS
+ ::impl
+ iree::base::api
+ iree::base::ref_ptr
iree::testing::gtest
iree::testing::gtest_main
- iree::vm::test::all_bytecode_modules_cc
+)
+
+iree_cc_test(
+ NAME
+ stack_test
+ SRCS
+ "stack_test.cc"
+ DEPS
+ ::impl
+ iree::base::api
+ iree::base::ref_ptr
+ iree::testing::gtest
+ iree::testing::gtest_main
)
iree_cc_library(
@@ -60,21 +168,43 @@
"bytecode_module_impl.h"
"bytecode_op_table.h"
DEPS
- ::builtin_types
- ::list
- ::module
- ::ref
- ::stack
- ::type_def
- ::value
- flatcc::runtime
- iree::base::alignment
+ ::vm
iree::base::api
+ iree::base::core_headers
+ iree::base::flatcc
iree::base::tracing
iree::schemas::bytecode_module_def_c_fbs
PUBLIC
)
+iree_tablegen_library(
+ NAME
+ bytecode_op_table_gen
+ TD_FILE
+ "${IREE_ROOT_DIR}/iree/compiler/Dialect/VM/IR/VMOps.td"
+ OUTS
+ -gen-iree-vm-op-table-defs bytecode_op_table.h
+ TBLGEN
+ IREE
+)
+
+iree_cc_test(
+ NAME
+ bytecode_module_test
+ SRCS
+ "bytecode_dispatch_test.cc"
+ "bytecode_module_test.cc"
+ DEPS
+ ::bytecode_module
+ ::vm
+ absl::strings
+ iree::base::logging
+ iree::base::status
+ iree::testing::gtest
+ iree::testing::gtest_main
+ iree::vm::test::all_bytecode_modules_cc
+)
+
iree_cc_test(
NAME
bytecode_module_benchmark
@@ -83,11 +213,7 @@
DEPS
::bytecode_module
::bytecode_module_benchmark_module_cc
- ::context
- ::instance
- ::module
- ::native_module
- ::stack
+ ::vm
absl::inlined_vector
absl::strings
benchmark
@@ -132,28 +258,6 @@
PUBLIC
)
-iree_cc_test(
- NAME
- bytecode_module_test
- SRCS
- "bytecode_module_test.cc"
- DEPS
- ::bytecode_module
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_tablegen_library(
- NAME
- bytecode_op_table_gen
- TD_FILE
- "${IREE_ROOT_DIR}/iree/compiler/Dialect/VM/IR/VMOps.td"
- OUTS
- -gen-iree-vm-op-table-defs bytecode_op_table.h
- TBLGEN
- IREE
-)
-
iree_cc_library(
NAME
c_funcs
@@ -161,290 +265,3 @@
"c_funcs.h"
PUBLIC
)
-
-iree_cc_library(
- NAME
- context
- HDRS
- "context.h"
- SRCS
- "context.c"
- DEPS
- ::instance
- ::module
- ::stack
- iree::base::api
- iree::base::atomics
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- instance
- HDRS
- "instance.h"
- SRCS
- "instance.c"
- DEPS
- ::builtin_types
- iree::base::api
- iree::base::atomics
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- invocation
- HDRS
- "invocation.h"
- SRCS
- "invocation.c"
- DEPS
- ::context
- ::list
- ::module
- iree::base::api
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- list
- HDRS
- "list.h"
- SRCS
- "list.c"
- DEPS
- ::ref
- ::type_def
- ::value
- iree::base::alignment
- iree::base::api
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- list_test
- SRCS
- "list_test.cc"
- DEPS
- ::builtin_types
- ::list
- iree::base::api
- iree::base::ref_ptr
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- module
- HDRS
- "module.h"
- SRCS
- "module.c"
- DEPS
- ::ref
- iree::base::alignment
- iree::base::api
- iree::base::atomics
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_module
- HDRS
- "native_module.h"
- SRCS
- "native_module.c"
- DEPS
- ::module
- ::stack
- iree::base::api
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- native_module_benchmark
- SRCS
- "native_module_benchmark.cc"
- DEPS
- ::module
- ::native_module
- ::native_module_test_hdrs
- ::stack
- benchmark
- iree::base::api
- iree::base::logging
- iree::testing::benchmark_main
-)
-
-iree_cc_test(
- NAME
- native_module_test
- SRCS
- "native_module_test.cc"
- DEPS
- ::context
- ::instance
- ::invocation
- ::list
- ::native_module_test_hdrs
- ::ref_cc
- iree::base::api
- iree::base::status
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- native_module_test_hdrs
- HDRS
- "native_module_test.h"
- DEPS
- ::context
- ::instance
- ::native_module
- ::ref
- ::stack
- iree::base::api
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- native_module_cc
- HDRS
- "module_abi_packing.h"
- "native_module_cc.h"
- DEPS
- ::builtin_types
- ::module
- ::ref
- ::ref_cc
- ::stack
- absl::inlined_vector
- absl::optional
- absl::span
- absl::strings
- iree::base::api
- iree::base::ref_ptr
- iree::base::status
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- ref
- HDRS
- "ref.h"
- SRCS
- "ref.c"
- DEPS
- iree::base::api
- iree::base::atomics
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- ref_test
- SRCS
- "ref_test.cc"
- DEPS
- ::ref
- iree::base::api
- iree::base::ref_ptr
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- ref_cc
- HDRS
- "ref_cc.h"
- DEPS
- ::ref
- absl::core_headers
- iree::base::api
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- stack
- HDRS
- "stack.h"
- SRCS
- "stack.c"
- DEPS
- ::module
- ::ref
- iree::base::alignment
- iree::base::api
- iree::base::tracing
- PUBLIC
-)
-
-iree_cc_test(
- NAME
- stack_test
- SRCS
- "stack_test.cc"
- DEPS
- ::ref
- ::stack
- iree::base::api
- iree::base::ref_ptr
- iree::testing::gtest
- iree::testing::gtest_main
-)
-
-iree_cc_library(
- NAME
- type_def
- HDRS
- "type_def.h"
- DEPS
- ::ref
- ::value
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- value
- HDRS
- "value.h"
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- vm
- HDRS
- "api.h"
- DEPS
- ::builtin_types
- ::context
- ::instance
- ::invocation
- ::list
- ::module
- ::native_module
- ::ref
- ::stack
- ::type_def
- ::value
- iree::base::api
- PUBLIC
-)
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index bb5fbbf..bb5241f 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -15,8 +15,8 @@
#include <string.h>
#include "iree/base/tracing.h"
+#include "iree/vm/api.h"
#include "iree/vm/bytecode_dispatch_util.h"
-#include "iree/vm/list.h"
//===----------------------------------------------------------------------===//
// Math utilities, kept here to limit dependencies
diff --git a/iree/vm/bytecode_dispatch_test.cc b/iree/vm/bytecode_dispatch_test.cc
index caca978..b87dfc5 100644
--- a/iree/vm/bytecode_dispatch_test.cc
+++ b/iree/vm/bytecode_dispatch_test.cc
@@ -23,12 +23,8 @@
#include "iree/base/logging.h"
#include "iree/base/status.h"
#include "iree/testing/gtest.h"
-#include "iree/vm/builtin_types.h"
+#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/invocation.h"
-#include "iree/vm/module.h"
// Compiled module embedded here to avoid file IO:
#include "iree/vm/test/all_bytecode_modules.h"
diff --git a/iree/vm/bytecode_module.c b/iree/vm/bytecode_module.c
index caa89d3..52a5a36 100644
--- a/iree/vm/bytecode_module.c
+++ b/iree/vm/bytecode_module.c
@@ -17,9 +17,8 @@
#include "iree/base/alignment.h"
#include "iree/base/api.h"
#include "iree/base/tracing.h"
+#include "iree/vm/api.h"
#include "iree/vm/bytecode_module_impl.h"
-#include "iree/vm/ref.h"
-#include "iree/vm/stack.h"
// Perform an strcmp between a flatbuffers string and an IREE string view.
static bool iree_vm_flatbuffer_strcmp(flatbuffers_string_t lhs,
@@ -175,10 +174,6 @@
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"imports[%zu] missing full_name", i);
}
- if (!iree_vm_ImportFunctionDef_signature(import_def)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "imports[%zu] missing signature", i);
- }
}
for (size_t i = 0; i < iree_vm_ExportFunctionDef_vec_len(exported_functions);
@@ -195,10 +190,6 @@
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"exports[%zu] missing local_name", i);
}
- if (!iree_vm_ExportFunctionDef_signature(export_def)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "exports[%zu] missing signature", i);
- }
iree_host_size_t internal_ordinal =
iree_vm_ExportFunctionDef_internal_ordinal(export_def);
if (internal_ordinal >=
@@ -221,10 +212,6 @@
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"functions[%zu] missing body", i);
}
- if (!iree_vm_InternalFunctionDef_signature(function_def)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "functions[%zu] missing signature", i);
- }
iree_vm_FunctionDescriptor_struct_t function_descriptor =
iree_vm_FunctionDescriptor_vec_at(function_descriptors, i);
@@ -398,11 +385,17 @@
iree_vm_InternalFunctionDef_table_t function_def =
iree_vm_InternalFunctionDef_vec_at(internal_functions, ordinal);
- iree_vm_FunctionSignatureDef_table_t signature =
+ iree_vm_FunctionSignatureDef_table_t signature_def =
iree_vm_InternalFunctionDef_signature(function_def);
+ if (!signature_def) {
+ return iree_make_status(
+ IREE_STATUS_NOT_FOUND,
+ "reflection attribute at index %zu not found; no signature", index);
+ }
iree_vm_ReflectionAttrDef_vec_t reflection_attrs =
- iree_vm_FunctionSignatureDef_reflection_attrs(signature);
- if (index >= iree_vm_ReflectionAttrDef_vec_len(reflection_attrs)) {
+ iree_vm_FunctionSignatureDef_reflection_attrs(signature_def);
+ if (!reflection_attrs ||
+ index >= iree_vm_ReflectionAttrDef_vec_len(reflection_attrs)) {
return iree_make_status(IREE_STATUS_NOT_FOUND,
"reflection attribute at index %zu not found",
index);
@@ -502,14 +495,15 @@
static iree_host_size_t iree_vm_bytecode_module_layout_state(
iree_vm_BytecodeModuleDef_table_t module_def,
iree_vm_bytecode_module_state_t* state) {
- iree_vm_ModuleStateDef_table_t module_state =
+ iree_vm_ModuleStateDef_table_t module_state_def =
iree_vm_BytecodeModuleDef_module_state(module_def);
iree_host_size_t rwdata_storage_capacity = 0;
iree_host_size_t global_ref_count = 0;
- if (module_state) {
+ if (module_state_def) {
rwdata_storage_capacity =
- iree_vm_ModuleStateDef_global_bytes_capacity(module_state);
- global_ref_count = iree_vm_ModuleStateDef_global_ref_count(module_state);
+ iree_vm_ModuleStateDef_global_bytes_capacity(module_state_def);
+ global_ref_count =
+ iree_vm_ModuleStateDef_global_ref_count(module_state_def);
}
iree_host_size_t rodata_ref_count = iree_vm_RodataSegmentDef_vec_len(
iree_vm_BytecodeModuleDef_rodata_segments(module_def));
@@ -694,9 +688,12 @@
iree_vm_BytecodeModuleDef_internal_functions(module->def);
iree_vm_InternalFunctionDef_table_t function_def =
iree_vm_InternalFunctionDef_vec_at(internal_functions, function.ordinal);
+ iree_vm_FunctionSignatureDef_table_t signature_def =
+ iree_vm_InternalFunctionDef_signature(function_def);
flatbuffers_string_t calling_convention =
- iree_vm_FunctionSignatureDef_calling_convention(
- iree_vm_InternalFunctionDef_signature(function_def));
+ signature_def
+ ? iree_vm_FunctionSignatureDef_calling_convention(signature_def)
+ : 0;
iree_vm_function_signature_t signature;
memset(&signature, 0, sizeof(signature));
signature.calling_convention.data = calling_convention;
diff --git a/iree/vm/bytecode_module.h b/iree/vm/bytecode_module.h
index 2db39f5..8f694fc 100644
--- a/iree/vm/bytecode_module.h
+++ b/iree/vm/bytecode_module.h
@@ -18,7 +18,7 @@
#include <stdint.h>
#include "iree/base/api.h"
-#include "iree/vm/module.h"
+#include "iree/vm/api.h"
#ifdef __cplusplus
extern "C" {
diff --git a/iree/vm/bytecode_module_benchmark.cc b/iree/vm/bytecode_module_benchmark.cc
index 2b88ba1..529ff3a 100644
--- a/iree/vm/bytecode_module_benchmark.cc
+++ b/iree/vm/bytecode_module_benchmark.cc
@@ -19,13 +19,9 @@
#include "benchmark/benchmark.h"
#include "iree/base/api.h"
#include "iree/base/logging.h"
+#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "iree/vm/bytecode_module_benchmark_module.h"
-#include "iree/vm/context.h"
-#include "iree/vm/instance.h"
-#include "iree/vm/module.h"
-#include "iree/vm/native_module.h"
-#include "iree/vm/stack.h"
namespace {
@@ -82,7 +78,7 @@
static iree_status_t RunFunction(benchmark::State& state,
absl::string_view function_name,
absl::Span<const int32_t> i32_args,
- int result_count, int batch_size = 1) {
+ int result_count, int64_t batch_size = 1) {
iree_vm_instance_t* instance = NULL;
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
@@ -294,7 +290,7 @@
return i;
};
while (state.KeepRunningBatch(state.range(0))) {
- int ret = loop(state.range(0));
+ int ret = loop(static_cast<int>(state.range(0)));
benchmark::DoNotOptimize(ret);
benchmark::ClobberMemory();
}
diff --git a/iree/vm/bytecode_module_impl.h b/iree/vm/bytecode_module_impl.h
index 0a74957..e4a282d 100644
--- a/iree/vm/bytecode_module_impl.h
+++ b/iree/vm/bytecode_module_impl.h
@@ -24,15 +24,10 @@
#endif // _MSC_VER
#include "iree/base/api.h"
-#include "iree/vm/builtin_types.h"
-#include "iree/vm/module.h"
-#include "iree/vm/ref.h"
-#include "iree/vm/stack.h"
-#include "iree/vm/type_def.h"
-#include "iree/vm/value.h"
+#include "iree/vm/api.h"
// NOTE: include order matters:
-#include "flatcc/reflection/flatbuffers_common_reader.h"
+#include "iree/base/flatcc.h"
#include "iree/schemas/bytecode_module_def_reader.h"
#include "iree/schemas/bytecode_module_def_verifier.h"
diff --git a/iree/vm/list_test.cc b/iree/vm/list_test.cc
index 779bd04..6baaa88 100644
--- a/iree/vm/list_test.cc
+++ b/iree/vm/list_test.cc
@@ -101,7 +101,7 @@
EXPECT_EQ(5, iree_vm_list_size(list));
for (iree_host_size_t i = 0; i < 5; ++i) {
- iree_vm_value_t value = iree_vm_value_make_i32(i);
+ iree_vm_value_t value = iree_vm_value_make_i32((int32_t)i);
IREE_ASSERT_OK(iree_vm_list_set_value(list, i, &value));
}
@@ -138,7 +138,7 @@
EXPECT_EQ(5, iree_vm_list_size(list));
for (iree_host_size_t i = 0; i < 5; ++i) {
- iree_vm_ref_t ref_a = MakeRef<A>(i);
+ iree_vm_ref_t ref_a = MakeRef<A>((float)i);
IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, i, &ref_a));
}
@@ -173,11 +173,11 @@
EXPECT_EQ(10, iree_vm_list_size(list));
for (iree_host_size_t i = 0; i < 5; ++i) {
- iree_vm_value_t value = iree_vm_value_make_i32(i);
+ iree_vm_value_t value = iree_vm_value_make_i32((int32_t)i);
IREE_ASSERT_OK(iree_vm_list_set_value(list, i, &value));
}
for (iree_host_size_t i = 5; i < 10; ++i) {
- iree_vm_ref_t ref_a = MakeRef<A>(i);
+ iree_vm_ref_t ref_a = MakeRef<A>(static_cast<float>(i));
IREE_ASSERT_OK(iree_vm_list_set_ref_move(list, i, &ref_a));
}
diff --git a/iree/vm/native_module_cc.h b/iree/vm/native_module_cc.h
index fae6d64..32c0d56 100644
--- a/iree/vm/native_module_cc.h
+++ b/iree/vm/native_module_cc.h
@@ -148,7 +148,7 @@
if (out_function) {
out_function->module = module->interface();
out_function->linkage = IREE_VM_FUNCTION_LINKAGE_EXPORT;
- out_function->ordinal = ordinal;
+ out_function->ordinal = static_cast<uint16_t>(ordinal);
}
if (out_name) {
*out_name = dispatch_function.name;
diff --git a/repo_utils.bzl b/repo_utils.bzl
deleted file mode 100644
index 933c00b..0000000
--- a/repo_utils.bzl
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# TODO(laurenzo): This is available upstream as of 0.28. Remove when ready.
-# See: https://docs.bazel.build/versions/master/repo/utils.html#maybe
-def maybe(repo_rule, name, **kwargs):
- """Utility function for only adding a repository if it's not already present.
- This is to implement safe repositories.bzl macro documented in
- https://docs.bazel.build/versions/master/skylark/deploying.html#dependencies.
- Args:
- repo_rule: repository rule function.
- name: name of the repository to create.
- **kwargs: remaining arguments that are passed to the repo_rule function.
- Returns:
- Nothing, defines the repository when needed as a side-effect.
- """
- if not native.existing_rule(name):
- repo_rule(name = name, **kwargs)
diff --git a/scripts/get_e2e_artifacts.py b/scripts/get_e2e_artifacts.py
index e62226d..2440ab7 100755
--- a/scripts/get_e2e_artifacts.py
+++ b/scripts/get_e2e_artifacts.py
@@ -162,10 +162,7 @@
if FLAGS.run_test_suites:
# Use bazel test to execute all of the test suites in parallel.
- command = [
- 'bazel', 'test', *test_suites, '--color=yes',
- '--test_arg=--get_saved_model'
- ]
+ command = ['bazel', 'test', *test_suites, '--color=yes']
print(f'Running: `{" ".join(command)}`')
if not FLAGS.dry_run:
subprocess.check_call(command)
diff --git a/third_party/flatbuffers b/third_party/flatbuffers
deleted file mode 160000
index a5d9d0f..0000000
--- a/third_party/flatbuffers
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit a5d9d0f7d368054fd1691aedf1db4116efcc233e
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 37cfaab..218c3a2 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 37cfaab27f11635a240e93550c1c0faf36e07856
+Subproject commit 218c3a2712bc72d2239299a5eef4ff3c156004ed
diff --git a/third_party/tracy b/third_party/tracy
index d7059ec..d8cb536 160000
--- a/third_party/tracy
+++ b/third_party/tracy
@@ -1 +1 @@
-Subproject commit d7059eca6351546d1f51e248fc75e49dfeee709e
+Subproject commit d8cb536712e876ba956f27f23dbede1c2eccad28