Merge pull request #2281 from ScottTodd:docs-iree-opt
PiperOrigin-RevId: 318352808
diff --git a/.bazelrc b/.bazelrc
index 3487c50..8d422f7 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -29,7 +29,7 @@
test --test_output=errors
###############################################################################
-# Options for "generic_clang" builds: these options should generally apply to
+# Options for "generic_clang" builds: these options should generally apply to
# either clang or gcc and are curated based on need.
###############################################################################
@@ -43,6 +43,10 @@
build:generic_clang --copt=-Wno-invalid-offsetof
build:generic_clang --copt=-Wno-unused-function
+# Enable warnings we do care about.
+build:generic_clang --copt=-Wimplicit-fallthrough
+build:generic_clang --copt=-Wthread-safety-analysis
+
# C++14 standard version is required.
build:generic_clang --cxxopt=-std=c++14 --host_cxxopt=-std=c++14
@@ -190,4 +194,3 @@
# The user.bazelrc file is not checked in but available for local mods.
# Always keep this at the end of the file so that user flags override.
try-import %workspace%/user.bazelrc
-
diff --git a/.github/workflows/synchronize_submodules.yml b/.github/workflows/synchronize_submodules.yml
index a27e7ed..f56af6b 100644
--- a/.github/workflows/synchronize_submodules.yml
+++ b/.github/workflows/synchronize_submodules.yml
@@ -41,7 +41,7 @@
- name: Committing updates
if: env.has_diff == 'true'
run: |
- git config --local user.email "noreply+action@github.com"
+ git config --local user.email "iree-github-actions-bot@google.com"
git config --local user.name "Submodule Synchronize Action"
git commit -am "Synchronize submodules"
- name: Pushing changes
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 501cd0a..004bed5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,6 +33,7 @@
# Project component configuration
#-------------------------------------------------------------------------------
+# LINT.IfChange(iree_options)
option(IREE_ENABLE_RUNTIME_TRACING "Enables instrumented runtime tracing." OFF)
option(IREE_ENABLE_MLIR "Enables MLIR/LLVM dependencies." ON)
option(IREE_ENABLE_EMITC "Enables MLIR EmitC dependencies." OFF)
@@ -50,6 +51,10 @@
CACHE STRING "Semicolon-separated list of HAL drivers to build, or \"all\"." FORCE)
set(IREE_TARGET_BACKENDS_TO_BUILD "all"
CACHE STRING "Semicolon-separated list of target backends to build, or \"all\"." FORCE)
+# LINT.ThenChange(
+# https://github.com/google/iree/tree/master/build_tools/cmake/iree_cross_compile.cmake:iree_cross_compile_options,
+# https://github.com/google/iree/tree/master/build_tools/cmake/iree_cross_compile.cmake:iree_cross_compile_invoke
+# )
if(${IREE_BUILD_SAMPLES} OR ${IREE_BUILD_EXPERIMENTAL})
set(IREE_BUILD_COMPILER ON CACHE BOOL "Build the IREE compiler for sample projects." FORCE)
@@ -136,6 +141,44 @@
)
#-------------------------------------------------------------------------------
+# Cross compiling configuration
+#-------------------------------------------------------------------------------
+
+if(CMAKE_CROSSCOMPILING)
+ if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
+ message(FATAL_ERROR "Cross compilation with Windows host system is not supported yet")
+ endif()
+
+ message(STATUS "Detected cross compilation mode; configuring IREE on host...")
+
+ # C/C++ compilers for host compilation.
+ # Note: we need to explicitly set this because IREE does not work well with
+ # GCC at the moment: https://github.com/google/iree/issues/1269
+ set(IREE_HOST_C_COMPILER "$ENV{IREE_HOST_C_COMPILER}" CACHE FILEPATH "C compiler for host compilation")
+ set(IREE_HOST_CXX_COMPILER "$ENV{IREE_HOST_CXX_COMPILER}" CACHE FILEPATH "C++ compiler for host compilation")
+
+ # Master configuration for the binary directory containing all artifacts
+ # compiled for host.
+ if(NOT IREE_HOST_BINARY_ROOT)
+ set(IREE_HOST_BINARY_ROOT "${CMAKE_CURRENT_BINARY_DIR}/host" CACHE FILEPATH "directory containing host artifacts")
+ endif()
+
+ set(IREE_HOST_BUILD_COMPILER ON) # For iree-translate
+ set(IREE_HOST_ENABLE_LLVM ON) # For iree-tblgen
+
+ # Set the host build directory for LLVM to our directory. Otherwise it will
+ # follow its own convention.
+ set(LLVM_NATIVE_BUILD "${IREE_HOST_BINARY_ROOT}/third_party/llvm-project/llvm")
+
+ include(iree_cross_compile)
+
+ # Use another CMake invocation to configure a build for host.
+ iree_create_configuration(HOST)
+
+ message(STATUS "Done configuring IREE on host in ${IREE_HOST_BINARY_ROOT}")
+endif()
+
+#-------------------------------------------------------------------------------
# IREE utility definitions
#-------------------------------------------------------------------------------
@@ -291,6 +334,24 @@
add_subdirectory(build_tools/third_party/renderdoc_api EXCLUDE_FROM_ALL)
add_subdirectory(build_tools/third_party/vulkan_extensionlayer EXCLUDE_FROM_ALL)
+if(CMAKE_CROSSCOMPILING)
+ # We need flatc to generate some source code. When cross-compiling, we need
+ # to make sure the flatc binary is configured under host environment.
+ iree_declare_host_excutable(flatc BUILDONLY)
+
+ # Set the FLATBUFFERS_FLATC_EXECUTABLE. It controls where to find the flatc
+ # binary in BuildFlatBuffers().
+ iree_get_executable_path(FLATBUFFERS_FLATC_EXECUTABLE flatc)
+
+ # Add a custom target to copy the flatc to the binary directory.
+ add_custom_target(iree_host_flatc
+ COMMAND "${CMAKE_COMMAND}" -E copy_if_different
+ "${IREE_HOST_BINARY_ROOT}/third_party/flatbuffers/flatc" "${IREE_HOST_BINARY_ROOT}/bin"
+ DEPENDS iree_host_build_flatc
+ COMMENT "Installing host flatc..."
+ )
+endif()
+
if(${IREE_BUILD_COMPILER})
add_subdirectory(build_tools/third_party/tensorflow/tensorflow/compiler/mlir/xla EXCLUDE_FROM_ALL)
endif()
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index ef46ddc..117179d 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -50,6 +50,11 @@
write these as you would a helpful commit message. Please also keep PRs small
(focused on a single issue) to streamline review and ease later culprit-finding.
+As part of a migration to make the project GitHub-first, our default branch is
+currently called `google` and all PRs should be directed there. This is an
+intermediate state. See
+https://groups.google.com/d/msg/iree-discuss/F07vsG9Ah4o/uAIusKO-BQAJ
+
Our documentation on
[repository management](https://github.com/google/iree/blob/master/docs/repository_management.md)
has more information on some of the oddities in our repository setup and
diff --git a/README.md b/README.md
index c6df57e..84eccc6 100644
--- a/README.md
+++ b/README.md
@@ -112,10 +112,10 @@
CI System | Build System | Platform | Component | Status
:-------: | :----------: | :------: | :-------------: | :----:
-Kokoro | Bazel | Linux | Core | [](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/core/build_result.html)
-Kokoro | Bazel | Linux | Bindings | [](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/bindings/build_result.html)
-Kokoro | Bazel | Linux | Integrations | [](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/integrations/build_result.html)
-Kokoro | CMake | Linux | Core + Bindings | [](https://storage.googleapis.com/iree-oss-build-badges/linux/cmake/build_result.html)
+Kokoro | Bazel | Linux | Core | [](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/core/google_result.html)
+Kokoro | Bazel | Linux | Bindings | [](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/bindings/google_result.html)
+Kokoro | Bazel | Linux | Integrations | [](https://storage.googleapis.com/iree-oss-build-badges/linux/bazel/integrations/google_result.html)
+Kokoro | CMake | Linux | Core + Bindings | [](https://storage.googleapis.com/iree-oss-build-badges/linux/cmake/google_result.html)
## License
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index bc5a8a1..19c9321 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -3,7 +3,7 @@
4c13807b7d43ff0946b7ffea0ae3aee9e611d778 third_party/dear_imgui
a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-7e825abd5704ce28b166f9463d4bd304348fd2a9 third_party/llvm-project
+9fb7e98db5aaef617878a127b663efa4d01aa834 third_party/llvm-project
17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
67f3ccebee84f3488b46a8d3ac005178c52ff264 third_party/mlir-emitc
80d452484c5409444b0ec19383faa84bb7a4d351 third_party/pybind11
@@ -11,7 +11,7 @@
b73f111094da3e380a1774b56b15f16c90ae8e23 third_party/sdl2
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
57eb48aed36160c4876bc8310d9ca84d42ee9e2a third_party/swiftshader
-b00a7808a7b29a78762b54e29aac87a77254b4b6 third_party/tensorflow
+f74654ac7b314a212b1df6687c2f99800084e97f third_party/tensorflow
864d86e8b6d21449474db5e9313dbff90aa9c24f third_party/tracy
8a457f8552d8d47ce3a96ed80a714ff6396f8ad8 third_party/vulkan_extensionlayer
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index c97b3c8..6b8c01d 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -164,6 +164,48 @@
}
}
+void PackScalar(const RawSignatureParser::Description& desc, py::handle py_arg,
+ VmVariantList& f_args) {
+ iree_vm_value value;
+ value.type = IREE_VM_VALUE_TYPE_I32;
+ switch (desc.scalar.type) {
+ case AbiConstants::ScalarType::kUint8:
+ case AbiConstants::ScalarType::kUint16:
+ case AbiConstants::ScalarType::kUint32: {
+ value.i32 = py_arg.cast<int32_t>();
+ break;
+ }
+ case AbiConstants::ScalarType::kSint8:
+ case AbiConstants::ScalarType::kSint16:
+ case AbiConstants::ScalarType::kSint32: {
+ value.i32 = py_arg.cast<int32_t>();
+ break;
+ }
+ default:
+ throw RaisePyError(PyExc_NotImplementedError, "Unsupported scalar type");
+ }
+ CheckApiStatus(iree_vm_variant_list_append_value(f_args.raw_ptr(), value),
+ "Could not pack scalar argument");
+}
+
+py::object UnpackScalar(const RawSignatureParser::Description& desc,
+ iree_vm_variant_t& f_result) {
+ switch (desc.scalar.type) {
+ case AbiConstants::ScalarType::kUint8:
+ case AbiConstants::ScalarType::kUint16:
+ case AbiConstants::ScalarType::kUint32: {
+ return py::int_(static_cast<uint32_t>(f_result.i32));
+ }
+ case AbiConstants::ScalarType::kSint8:
+ case AbiConstants::ScalarType::kSint16:
+ case AbiConstants::ScalarType::kSint32: {
+ return py::int_(f_result.i32);
+ }
+ default:
+ throw RaisePyError(PyExc_NotImplementedError, "Unsupported scalar type");
+ }
+}
+
} // namespace
//------------------------------------------------------------------------------
@@ -236,6 +278,9 @@
throw RaisePyError(PyExc_NotImplementedError,
"Ref objects not yet supported");
break;
+ case RawSignatureParser::Type::kScalar:
+ PackScalar(desc, py_args[i], f_args);
+ break;
default:
throw RaisePyError(PyExc_NotImplementedError,
"Unsupported argument type");
@@ -294,9 +339,12 @@
throw RaisePyError(PyExc_NotImplementedError,
"Ref objects not yet supported");
break;
+ case RawSignatureParser::Type::kScalar:
+ py_results[i] = UnpackScalar(desc, *f_result);
+ break;
default:
throw RaisePyError(PyExc_NotImplementedError,
- "Unsupported argument type");
+ "Unsupported result type");
}
}
}
@@ -358,9 +406,11 @@
throw RaisePyError(PyExc_NotImplementedError,
"Ref objects not yet supported");
break;
+ case RawSignatureParser::Type::kScalar:
+ break;
default:
throw RaisePyError(PyExc_NotImplementedError,
- "Unsupported argument type");
+ "Unsupported allocation argument type");
}
}
}
diff --git a/bindings/python/pyiree/rt/vm_test.py b/bindings/python/pyiree/rt/vm_test.py
index ed7e66f..6b633ce 100644
--- a/bindings/python/pyiree/rt/vm_test.py
+++ b/bindings/python/pyiree/rt/vm_test.py
@@ -21,6 +21,19 @@
from pyiree import rt
+def create_add_scalar_module():
+ ctx = compiler.Context()
+ input_module = ctx.parse_asm("""
+ func @add_scalar(%arg0: i32, %arg1: i32) -> i32 attributes { iree.module.export } {
+ %0 = addi %arg0, %arg1 : i32
+ return %0 : i32
+ }
+ """)
+ binary = input_module.compile()
+ m = rt.VmModule.from_flatbuffer(binary)
+ return m
+
+
def create_simple_static_mul_module():
ctx = compiler.Context()
input_module = ctx.parse_asm("""
@@ -103,6 +116,26 @@
context = rt.VmContext(instance, modules=[self.hal_module, m])
print(context)
+ def test_add_scalar(self):
+ m = create_add_scalar_module()
+ instance = rt.VmInstance()
+ context = rt.VmContext(instance, modules=[self.hal_module, m])
+ f = m.lookup_function("add_scalar")
+ abi = context.create_function_abi(self.device, self.htf, f)
+ print("INVOKING:", abi)
+ arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)
+ arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)
+ inputs = abi.raw_pack_inputs((5, 6))
+ print("INPUTS:", inputs)
+ allocated_results = abi.allocate_results(inputs, static_alloc=False)
+ print("ALLOCATED RESULTS:", allocated_results)
+ print("--- INVOKE:")
+ context.invoke(f, inputs, allocated_results)
+ print("--- DONE.")
+ results = abi.raw_unpack_results(allocated_results)
+ print("RESULTS:", results)
+ self.assertEqual(results[0], 11)
+
def test_synchronous_dynamic_shape_invoke_function(self):
m = create_simple_dynamic_abs_module()
instance = rt.VmInstance()
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/llvm/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/llvm/BUILD.bazel
index 50b71a3..75a5326 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/llvm/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/llvm/BUILD.bazel
@@ -561,6 +561,7 @@
name = "common_target_td_sources",
srcs = glob([
"include/llvm/CodeGen/*.td",
+ "include/llvm/Frontend/Directive/*.td",
"include/llvm/IR/Intrinsics*.td",
"include/llvm/TableGen/*.td",
"include/llvm/Target/*.td",
@@ -666,6 +667,17 @@
],
)
+gentbl(
+ name = "omp_gen",
+ tbl_outs = [("--gen-directive-decls", "include/llvm/Frontend/OpenMP/OMP.h.inc")],
+ tblgen = ":llvm-tblgen",
+ td_file = "include/llvm/Frontend/OpenMP/OMP.td",
+ td_srcs = glob([
+ "include/llvm/Frontend/OpenMP/*.td",
+ "include/llvm/Frontend/Directive/*.td",
+ ]),
+)
+
########################## Begin generated content ##########################
cc_library(
name = "AArch64AsmParser",
@@ -698,6 +710,7 @@
"lib/Target/AArch64/*.c",
"lib/Target/AArch64/*.cpp",
"lib/Target/AArch64/*.inc",
+ "lib/Target/AArch64/GISel/*.cpp",
]),
hdrs = glob([
"include/llvm/Target/AArch64/*.h",
@@ -1382,14 +1395,20 @@
cc_library(
name = "Analysis",
- srcs = glob([
- "lib/Analysis/*.c",
- "lib/Analysis/*.cpp",
- "lib/Analysis/*.inc",
- "include/llvm/Transforms/Utils/Local.h",
- "include/llvm/Transforms/Scalar.h",
- "lib/Analysis/*.h",
- ]),
+ srcs = glob(
+ [
+ "lib/Analysis/*.c",
+ "lib/Analysis/*.cpp",
+ "lib/Analysis/*.inc",
+ "include/llvm/Transforms/Utils/Local.h",
+ "include/llvm/Transforms/Scalar.h",
+ "lib/Analysis/*.h",
+ ],
+ exclude = [
+ "lib/Analysis/MLInlineAdvisor.cpp",
+ "lib/Analysis/ReleaseModeModelRunner.cpp",
+ ],
+ ),
hdrs = glob([
"include/llvm/Analysis/*.h",
"include/llvm/Analysis/*.def",
@@ -2052,6 +2071,7 @@
":Support",
":TransformUtils",
":config",
+ ":omp_gen",
],
)
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
index a0d1066..ba9b580 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
@@ -282,7 +282,7 @@
deps = [
":AVX512IncGen",
":IR",
- ":SideEffects",
+ ":SideEffectInterfaces",
":VectorOps",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
@@ -305,9 +305,9 @@
":IR",
":LLVMAVX512",
":LLVMDialect",
- ":LLVMTransforms",
":Pass",
":StandardOps",
+ ":StandardToLLVM",
":Support",
":Transforms",
":VectorOps",
@@ -489,7 +489,7 @@
":EDSC",
":IR",
":LoopLikeInterface",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
"@llvm-project//llvm:Support",
@@ -571,7 +571,7 @@
)
cc_library(
- name = "AffineToStandardTransforms",
+ name = "AffineToStandard",
srcs = glob([
"lib/Conversion/AffineToStandard/*.cpp",
"lib/Conversion/AffineToStandard/*.h",
@@ -591,6 +591,11 @@
],
)
+alias(
+ name = "AffineToStandardTransforms",
+ actual = "AffineToStandard",
+)
+
# SDBM dialect only contains attribute components that can be constructed given
# a dialect object, so whenever it is used it must also be registered. Therefore
# we don't split out the registration library for it.
@@ -631,7 +636,7 @@
":IR",
":LoopLikeInterface",
":SCFIncGen",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
"@llvm-project//llvm:Support",
@@ -719,7 +724,7 @@
":InferTypeOpInterface",
":MLIRShapeCanonicalizationIncGen",
":ShapeOpsIncGen",
- ":SideEffects",
+ ":SideEffectInterfaces",
":Support",
"@llvm-project//llvm:Support",
],
@@ -833,7 +838,7 @@
":ControlFlowInterfaces",
":EDSC",
":IR",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOpsIncGen",
":Support",
":ViewLikeInterface",
@@ -895,7 +900,7 @@
":DialectUtils",
":EDSC",
":IR",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
":VectorOpsIncGen",
@@ -1070,7 +1075,7 @@
":ControlFlowInterfaces",
":IR",
":LLVMOpsIncGen",
- ":SideEffects",
+ ":SideEffectInterfaces",
":Support",
"@llvm-project//llvm:AsmParser",
"@llvm-project//llvm:BitReader",
@@ -1193,7 +1198,7 @@
":GPUOpsIncGen",
":IR",
":LLVMDialect",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
],
@@ -1271,8 +1276,8 @@
":GPUDialect",
":IR",
":LLVMDialect",
- ":LLVMTransforms",
":StandardOps",
+ ":StandardToLLVM",
"@llvm-project//llvm:Support",
],
)
@@ -1311,9 +1316,9 @@
":GPUToNVVMGen",
":GPUTransforms",
":IR",
- ":LLVMTransforms",
":NVVMDialect",
":Pass",
+ ":StandardToLLVM",
":Transforms",
"@llvm-project//llvm:Support",
],
@@ -1333,10 +1338,10 @@
":ConversionPassIncGen",
":GPUDialect",
":LLVMDialect",
- ":LLVMTransforms",
":Pass",
":ROCDLDialect",
":StandardOps",
+ ":StandardToLLVM",
":Transforms",
":VectorOps",
],
@@ -1375,9 +1380,9 @@
":GPUDialect",
":GPUToROCDLTGen",
":GPUTransforms",
- ":LLVMTransforms",
":Pass",
":ROCDLDialect",
+ ":StandardToLLVM",
":Transforms",
":VectorOps",
":VectorToLLVM",
@@ -1475,7 +1480,7 @@
":SCFDialect",
":SPIRVDialect",
":SPIRVLowering",
- ":StandardToSPIRVConversions",
+ ":StandardToSPIRVTransforms",
":Support",
":Transforms",
],
@@ -1496,12 +1501,13 @@
":ConversionPassIncGen",
":IR",
":LLVMDialect",
- ":LLVMTransforms",
":Pass",
":SPIRVDialect",
":StandardOps",
+ ":StandardToLLVM",
":Support",
":Transforms",
+ "@llvm-project//llvm:Support",
],
)
@@ -1574,7 +1580,7 @@
":IR",
":LLVMDialect",
":NVVMOpsIncGen",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
"@llvm-project//llvm:AsmParser",
@@ -1646,7 +1652,7 @@
":IR",
":LLVMDialect",
":ROCDLOpsIncGen",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
"@llvm-project//llvm:AsmParser",
@@ -1894,7 +1900,7 @@
":SPIRVOpsIncGen",
":SPIRVSerializationGen",
":SPIRVTargetAndABIStructGen",
- ":SideEffects",
+ ":SideEffectInterfaces",
":Support",
":Transforms",
"@llvm-project//llvm:Support",
@@ -1947,7 +1953,7 @@
)
cc_library(
- name = "StandardToSPIRVConversions",
+ name = "StandardToSPIRVTransforms",
srcs = glob([
"lib/Conversion/StandardToSPIRV/*.cpp",
"lib/Conversion/StandardToSPIRV/*.h",
@@ -1968,10 +1974,16 @@
":StandardOps",
":Support",
":Transforms",
+ ":VectorOps",
"@llvm-project//llvm:Support",
],
)
+alias(
+ name = "StandardToSPIRVConversions",
+ actual = "StandardToSPIRVTransforms",
+)
+
cc_library(
name = "SPIRVSerialization",
srcs = glob(
@@ -2033,7 +2045,7 @@
":IR",
":LoopLikeInterface",
":SCFDialect",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
"@llvm-project//llvm:Support",
@@ -2152,7 +2164,7 @@
":LoopLikeInterface",
":Pass",
":SCFDialect",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
":TransformUtils",
@@ -2182,7 +2194,7 @@
includes = ["include"],
deps = [
":Affine",
- ":AffineToStandardTransforms",
+ ":AffineToStandard",
":ConversionPassIncGen",
":GPUDialect",
":GPUTransforms",
@@ -2222,7 +2234,7 @@
)
cc_library(
- name = "CFGTransforms",
+ name = "SCFToStandard",
srcs = [
"lib/Conversion/PassDetail.h",
"lib/Conversion/SCFToStandard/SCFToStandard.cpp",
@@ -2244,8 +2256,13 @@
],
)
+alias(
+ name = "CFGTransforms",
+ actual = "SCFToStandard",
+)
+
cc_library(
- name = "LLVMTransforms",
+ name = "StandardToLLVM",
srcs = [
"lib/Conversion/PassDetail.h",
"lib/Conversion/StandardToLLVM/StandardToLLVM.cpp",
@@ -2269,6 +2286,11 @@
],
)
+alias(
+ name = "LLVMTransforms",
+ actual = "StandardToLLVM",
+)
+
gentbl(
name = "CallOpInterfacesIncGen",
strip_include_prefix = "include",
@@ -2401,7 +2423,7 @@
)
cc_library(
- name = "SideEffects",
+ name = "SideEffectInterfaces",
srcs = [
"lib/Interfaces/SideEffectInterfaces.cpp",
],
@@ -2417,6 +2439,11 @@
],
)
+alias(
+ name = "SideEffects",
+ actual = "SideEffectInterfaces",
+)
+
cc_library(
name = "Analysis",
srcs = glob(
@@ -2627,7 +2654,6 @@
":GPUTransforms",
":IR",
":LLVMDialect",
- ":LLVMTransforms",
":LinalgToLLVM",
":LinalgToSPIRV",
":LinalgToStandard",
@@ -2639,7 +2665,8 @@
":ShapeToStandard",
":ShapeTransforms",
":StandardOpsTransforms",
- ":StandardToSPIRVConversions",
+ ":StandardToLLVM",
+ ":StandardToSPIRVTransforms",
":Support",
":Transforms",
":VectorToLLVM",
@@ -2699,7 +2726,6 @@
":Affine",
":AffinePassIncGen",
":AffineTransforms",
- ":CFGTransforms",
":ConversionPassIncGen",
":GPUDialect",
":GPUPassIncGen",
@@ -2714,7 +2740,6 @@
":LLVMDialect",
":LLVMIRTransforms",
":LLVMPassIncGen",
- ":LLVMTransforms",
":LinalgOps",
":LinalgPassIncGen",
":LinalgToLLVM",
@@ -2729,6 +2754,7 @@
":ROCDLDialect",
":SCFDialect",
":SCFToGPUPass",
+ ":SCFToStandard",
":SCFTransforms",
":SDBM",
":SPIRVDialect",
@@ -2743,7 +2769,8 @@
":StandardOps",
":StandardOpsTransforms",
":StandardOpsTransformsPassIncGen",
- ":StandardToSPIRVConversions",
+ ":StandardToLLVM",
+ ":StandardToSPIRVTransforms",
":Transforms",
":TransformsPassIncGen",
":VectorOps",
@@ -2809,13 +2836,13 @@
includes = ["include"],
deps = [
":AllPassesAndDialectsNoRegistration",
- ":CFGTransforms",
":ExecutionEngine",
":ExecutionEngineUtils",
":IR",
":LLVMDialect",
":Parser",
":Pass",
+ ":SCFToStandard",
":Support",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:OrcJIT",
@@ -2885,7 +2912,7 @@
":IR",
":Pass",
":SPIRVDialect",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
"@llvm-project//llvm:Support",
@@ -2918,10 +2945,10 @@
":GPUTransforms",
":IR",
":LLVMDialect",
- ":LLVMTransforms",
":MlirJitRunner",
":NVVMDialect",
":Pass",
+ ":StandardToLLVM",
":TargetNVVMIR",
":Transforms",
"//devtools/build/runtime:get_runfiles_dir",
@@ -2945,11 +2972,11 @@
":GPUToSPIRVTransforms",
":GPUToVulkanTransforms",
":GPUTransforms",
- ":LLVMTransforms",
":MlirJitRunner",
":Pass",
":SPIRVDialect",
- ":StandardToSPIRVConversions",
+ ":StandardToLLVM",
+ ":StandardToSPIRVTransforms",
"@llvm-project//llvm:Support",
],
)
@@ -3159,7 +3186,7 @@
":Pass",
":QuantOpsIncGen",
":QuantPassIncGen",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
"@llvm-project//llvm:Support",
],
@@ -3294,18 +3321,18 @@
]),
includes = ["include"],
deps = [
- ":AffineToStandardTransforms",
+ ":AffineToStandard",
":Analysis",
- ":CFGTransforms",
":ConversionPassIncGen",
":EDSC",
":IR",
":LLVMDialect",
- ":LLVMTransforms",
":LinalgOps",
":LinalgTransforms",
":Pass",
+ ":SCFToStandard",
":StandardOps",
+ ":StandardToLLVM",
":Support",
":Transforms",
":VectorToLLVM",
@@ -3385,7 +3412,7 @@
":LinalgOpsIncGen",
":LinalgStructuredOpsIncGen",
":Parser",
- ":SideEffects",
+ ":SideEffectInterfaces",
":StandardOps",
":Support",
":ViewLikeInterface",
@@ -3431,21 +3458,21 @@
includes = ["include"],
deps = [
":Affine",
- ":AffineToStandardTransforms",
+ ":AffineToStandard",
":Analysis",
- ":CFGTransforms",
":DialectUtils",
":EDSC",
":IR",
":LLVMDialect",
- ":LLVMTransforms",
":LinalgOps",
":LinalgPassIncGen",
":LinalgStructuredOpsIncGen",
":Pass",
":SCFDialect",
+ ":SCFToStandard",
":SCFTransforms",
":StandardOps",
+ ":StandardToLLVM",
":Support",
":TransformUtils",
":Transforms",
@@ -3537,9 +3564,9 @@
":EDSC",
":IR",
":LLVMDialect",
- ":LLVMTransforms",
":Pass",
":StandardOps",
+ ":StandardToLLVM",
":Support",
":Transforms",
":VectorOps",
@@ -3564,10 +3591,10 @@
":EDSC",
":IR",
":LLVMDialect",
- ":LLVMTransforms",
":Pass",
":SCFDialect",
":StandardOps",
+ ":StandardToLLVM",
":Support",
":Transforms",
":VectorOps",
diff --git a/build_tools/cmake/flatbuffer_cc_library.cmake b/build_tools/cmake/flatbuffer_cc_library.cmake
index 6ad2995..febf234 100644
--- a/build_tools/cmake/flatbuffer_cc_library.cmake
+++ b/build_tools/cmake/flatbuffer_cc_library.cmake
@@ -95,18 +95,24 @@
set(FLATBUFFERS_FLATC_SCHEMA_EXTRA_ARGS ${_RULE_FLATC_ARGS})
endif()
+ set(_GEN_TARGET "${_NAME}_gen")
+
build_flatbuffers(
"${_RULE_SRCS}"
"${IREE_ROOT_DIR}"
- "${_NAME}_gen" # custom_target_name
- "${_RULE_DEPS}" # additional_dependencies
+ "${_GEN_TARGET}" # custom_target_name
+ "${_RULE_DEPS}" # additional_dependencies
"${CMAKE_CURRENT_BINARY_DIR}" # generated_include_dir
"${CMAKE_CURRENT_BINARY_DIR}" # binary_schemas_dir
"" # copy_text_schemas_dir
)
+ # Add dependency on flatc explicitly. This is needed for cross-compiling
+ # where flatc comes from another CMake invocation for host.
+ iree_add_executable_dependencies(${_GEN_TARGET} flatc)
+
add_library(${_NAME} INTERFACE)
- add_dependencies(${_NAME} ${_NAME}_gen)
+ add_dependencies(${_NAME} ${_GEN_TARGET})
target_include_directories(${_NAME}
INTERFACE
"$<BUILD_INTERFACE:${IREE_COMMON_INCLUDE_DIRS}>"
diff --git a/build_tools/cmake/iree_bytecode_module.cmake b/build_tools/cmake/iree_bytecode_module.cmake
index f8002ec..64f8fd0 100644
--- a/build_tools/cmake/iree_bytecode_module.cmake
+++ b/build_tools/cmake/iree_bytecode_module.cmake
@@ -56,23 +56,24 @@
if(DEFINED _RULE_TRANSLATE_TOOL)
set(_TRANSLATE_TOOL ${_RULE_TRANSLATE_TOOL})
else()
- set(_TRANSLATE_TOOL "iree_tools_iree-translate")
+ set(_TRANSLATE_TOOL "iree-translate")
endif()
- # Resolve the executable binary path from the target name.
- set(_TRANSLATE_TOOL_EXECUTABLE $<TARGET_FILE:${_TRANSLATE_TOOL}>)
+ iree_get_executable_path(_TRANSLATE_TOOL_EXECUTABLE ${_TRANSLATE_TOOL})
set(_ARGS "${_FLAGS}")
list(APPEND _ARGS "${CMAKE_CURRENT_SOURCE_DIR}/${_RULE_SRC}")
list(APPEND _ARGS "-o")
list(APPEND _ARGS "${_RULE_NAME}.module")
+ # Depending on the binary instead of the target here given we might not have
+ # a target in this CMake invocation when cross-compiling.
add_custom_command(
OUTPUT "${_RULE_NAME}.module"
COMMAND ${_TRANSLATE_TOOL_EXECUTABLE} ${_ARGS}
# Changes to either the translation tool or the input source should
# trigger rebuilding.
- DEPENDS ${_TRANSLATE_TOOL} ${_RULE_SRC}
+ DEPENDS ${_TRANSLATE_TOOL_EXECUTABLE} ${_RULE_SRC}
)
if(_RULE_TESTONLY)
diff --git a/build_tools/cmake/iree_cc_binary.cmake b/build_tools/cmake/iree_cc_binary.cmake
index b4d6eff..6b3653a 100644
--- a/build_tools/cmake/iree_cc_binary.cmake
+++ b/build_tools/cmake/iree_cc_binary.cmake
@@ -30,6 +30,8 @@
# COPTS: List of private compile options
# DEFINES: List of public defines
# LINKOPTS: List of link options
+# TESTONLY: for testing; won't compile when tests are disabled
+# HOSTONLY: host only; compile using host toolchain when cross-compiling
#
# Note:
# By default, iree_cc_binary will always create a binary named iree_${NAME}.
@@ -58,7 +60,7 @@
function(iree_cc_binary)
cmake_parse_arguments(
_RULE
- "TESTONLY"
+ "HOSTONLY;TESTONLY"
"NAME;OUT"
"SRCS;COPTS;DEFINES;LINKOPTS;DATA;DEPS"
${ARGN}
@@ -68,6 +70,14 @@
return()
endif()
+ if(_RULE_HOSTONLY AND CMAKE_CROSSCOMPILING)
+ # The binary is marked as host only. We need to declare the rules for
+ # generating them under host configuration so cross-compiling towards
+ # target we can still have this binary.
+ iree_declare_host_excutable(${_RULE_NAME})
+ return()
+ endif()
+
# Prefix the library with the package name, so we get: iree_package_name
iree_package_name(_PACKAGE_NAME)
set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
@@ -126,6 +136,11 @@
# Track target and deps, use in iree_complete_binary_link_options() later.
set_property(GLOBAL APPEND PROPERTY _IREE_CC_BINARY_NAMES "${_NAME}")
set_property(TARGET ${_NAME} PROPERTY DIRECT_DEPS ${_RULE_DEPS})
+
+ install(TARGETS ${_NAME}
+ RENAME ${_RULE_NAME}
+ COMPONENT ${_RULE_NAME}
+ RUNTIME DESTINATION bin)
endfunction()
# Lists all transitive dependencies of DIRECT_DEPS in TRANSITIVE_DEPS.
diff --git a/build_tools/cmake/iree_cc_embed_data.cmake b/build_tools/cmake/iree_cc_embed_data.cmake
index d3644ed..7eeac23 100644
--- a/build_tools/cmake/iree_cc_embed_data.cmake
+++ b/build_tools/cmake/iree_cc_embed_data.cmake
@@ -79,10 +79,12 @@
list(APPEND _ARGS "${SRC}")
endforeach(SRC)
+ iree_get_executable_path(_EXE_PATH generate_cc_embed_data)
+
add_custom_command(
OUTPUT "${_RULE_H_FILE_OUTPUT}" "${_RULE_CC_FILE_OUTPUT}"
- COMMAND generate_cc_embed_data ${_ARGS}
- DEPENDS generate_cc_embed_data ${_RULE_SRCS} ${_RULE_GENERATED_SRCS}
+ COMMAND ${_EXE_PATH} ${_ARGS}
+ DEPENDS ${_EXE_PATH} ${_RULE_SRCS} ${_RULE_GENERATED_SRCS}
)
if(_RULE_TESTONLY)
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index cec163a..542536b 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -12,8 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+#-------------------------------------------------------------------------------
+# Abseil configuration
+#-------------------------------------------------------------------------------
+
include(AbseilConfigureCopts)
+# By default Abseil strips string literals on mobile platforms, which means
+# we cannot run IREE binaries via command-line with proper options. Turn off
+# the stripping.
+# TODO: we might still want to strip when compiling IREE into Android Java apps.
+if(CMAKE_CROSSCOMPILING AND "${CMAKE_SYSTEM_NAME}" MATCHES "Android")
+ add_definitions(-DABSL_FLAGS_STRIP_NAMES=0)
+endif()
+
#-------------------------------------------------------------------------------
# C++ used within IREE
#-------------------------------------------------------------------------------
@@ -42,6 +54,9 @@
"-Wno-gnu-label-as-value"
"-Wno-unused-local-typedef"
"-Wno-gnu-zero-variadic-macro-arguments"
+ # Enable some warnings
+ "-Wimplicit-fallthrough"
+ "-Wthread-safety-analysis"
CLANG_OR_GCC
"-Wno-unused-parameter"
"-Wno-undef"
@@ -89,13 +104,19 @@
#-------------------------------------------------------------------------------
set(FLATBUFFERS_BUILD_TESTS OFF CACHE BOOL "" FORCE)
-set(FLATBUFFERS_INSTALL OFF CACHE BOOL "" FORCE)
-set(FLATBUFFERS_BUILD_FLATC ON CACHE BOOL "" FORCE)
set(FLATBUFFERS_BUILD_FLATHASH OFF CACHE BOOL "" FORCE)
set(FLATBUFFERS_BUILD_GRPCTEST OFF CACHE BOOL "" FORCE)
+set(FLATBUFFERS_INSTALL OFF CACHE BOOL "" FORCE)
set(FLATBUFFERS_INCLUDE_DIRS
"${PROJECT_SOURCE_DIR}/third_party/flatbuffers/include/"
)
+
+if(CMAKE_CROSSCOMPILING)
+ set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "" FORCE)
+else()
+ set(FLATBUFFERS_BUILD_FLATC ON CACHE BOOL "" FORCE)
+endif()
+
iree_select_compiler_opts(FLATBUFFERS_COPTS
CLANG
# Flatbuffers has a bunch of incorrect documentation annotations.
@@ -148,7 +169,9 @@
endif()
set(MLIR_TABLEGEN_EXE mlir-tblgen)
-set(IREE_TABLEGEN_EXE iree-tblgen)
+# iree-tblgen is not defined using the add_tablegen mechanism as other TableGen
+# tools in LLVM.
+iree_get_executable_path(IREE_TABLEGEN_EXE iree-tblgen)
#-------------------------------------------------------------------------------
# Third party: tensorflow
diff --git a/build_tools/cmake/iree_cross_compile.cmake b/build_tools/cmake/iree_cross_compile.cmake
new file mode 100644
index 0000000..2568abd
--- /dev/null
+++ b/build_tools/cmake/iree_cross_compile.cmake
@@ -0,0 +1,234 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+include(iree_macros)
+
+# iree_create_configuration
+#
+# Creates custom commands and targets for an IREE configuration. An IREE
+# configuration means a new IREE CMake invocation with its own set of
+# parameters.
+#
+# This function defines a custom target, `iree_configure_${CONFIG_NAME}`,
+# to drive the generation of a new IREE configuration's `CMakeCache.txt`
+# file. Callers can then depend on either the `CMakeCache.txt` file or the
+# `iree_configure_${CONFIG_NAME}` target to make sure the configuration
+# is invoked as a dependency.
+#
+# This function is typically useful when cross-compiling towards another
+# architecture. For example, when cross-compiling towards Android, we need
+# to have certain tools first compiled on the host so that we can use them
+# to programmatically generate some source code to be compiled together
+# with other checked-in source code. Those host tools will be generated
+# by another CMake invocation configured by this function.
+#
+# Supported CMake options:
+# - IREE_<CONFIG_NAME>_BINARY_ROOT: the root directory for containing IREE build
+# artifacts for the given `CONFIG_NAME`. If not specified in caller, this is
+# set to a directory named as `CONFIG_NAME` under the current CMake binary
+# directory.
+# - IREE_<CONFIG_NAME>_C_COMPILER: C compiler for the given `CONFIG_NAME`.
+# This must be defined by the caller.
+# - IREE_<CONFIG_NAME>_CXX_COMPILER: C++ compiler for the given `CONFIG_NAME`.
+# This must be defined by the caller.
+# - IREE_<CONFIG_NAME>_<option>: switch for the given `option` specifically for
+# `CONFIG_NAME`. If missing, default to OFF for bool options; default to
+# IREE_<option> for non-bool variables.
+function(iree_create_configuration CONFIG_NAME)
+ # Set IREE_${CONFIG_NAME}_BINARY_ROOT if missing.
+ if(NOT DEFINED IREE_${CONFIG_NAME}_BINARY_ROOT)
+ set(IREE_${CONFIG_NAME}_BINARY_ROOT "${CMAKE_CURRENT_BINARY_DIR}/${CONFIG_NAME}")
+ set(IREE_${CONFIG_NAME}_BINARY_ROOT ${IREE_${CONFIG_NAME}_BINARY_ROOT} PARENT_SCOPE)
+ message(STATUS "Setting ${CONFIG_NAME} build directory to ${IREE_${CONFIG_NAME}_BINARY_ROOT}")
+ endif()
+
+ set(_CONFIG_BINARY_ROOT ${IREE_${CONFIG_NAME}_BINARY_ROOT})
+
+ set(_CONFIG_C_COMPILER ${IREE_${CONFIG_NAME}_C_COMPILER})
+ set(_CONFIG_CXX_COMPILER ${IREE_${CONFIG_NAME}_CXX_COMPILER})
+
+ # Check the compilers are specified in the caller.
+ if("${_CONFIG_C_COMPILER}" STREQUAL "")
+ message(FATAL_ERROR "Must define IREE_${CONFIG_NAME}_C_COMPILER for \"${CONFIG_NAME}\" configuration build")
+ endif()
+ if("${_CONFIG_CXX_COMPILER}" STREQUAL "")
+ message(FATAL_ERROR "Must define IREE_${CONFIG_NAME}_CXX_COMPILER for \"${CONFIG_NAME}\" configuration build")
+ endif()
+
+ add_custom_command(OUTPUT ${_CONFIG_BINARY_ROOT}
+ COMMAND ${CMAKE_COMMAND} -E make_directory ${_CONFIG_BINARY_ROOT}
+ COMMENT "Creating ${_CONFIG_BINARY_ROOT}...")
+
+ # Give it a custom target so we can drive the generation manually
+ # when useful.
+ add_custom_target(iree_prepare_${CONFIG_NAME}_dir DEPENDS ${_CONFIG_BINARY_ROOT})
+
+ # LINT.IfChange(iree_cross_compile_options)
+ iree_to_bool(_CONFIG_ENABLE_RUNTIME_TRACING "${IREE_${CONFIG_NAME}_ENABLE_RUNTIME_TRACING}")
+ iree_to_bool(_CONFIG_ENABLE_MLIR "${IREE_${CONFIG_NAME}_ENABLE_MLIR}")
+ iree_to_bool(_CONFIG_ENABLE_EMITC "${IREE_${CONFIG_NAME}_ENABLE_EMITC}")
+
+ iree_to_bool(_CONFIG_BUILD_COMPILER "${IREE_${CONFIG_NAME}_BUILD_COMPILER}")
+ iree_to_bool(_CONFIG_BUILD_TESTS "${IREE_${CONFIG_NAME}_BUILD_TESTS}")
+ iree_to_bool(_CONFIG_BUILD_DOCS "${IREE_${CONFIG_NAME}_BUILD_DOCS}")
+ iree_to_bool(_CONFIG_BUILD_SAMPLES "${IREE_${CONFIG_NAME}_BUILD_SAMPLES}")
+ iree_to_bool(_CONFIG_BUILD_DEBUGGER "${IREE_${CONFIG_NAME}_BUILD_DEBUGGER}")
+ iree_to_bool(_CONFIG_BUILD_PYTHON_BINDINGS "${IREE_${CONFIG_NAME}_BUILD_PYTHON_BINDINGS}")
+ iree_to_bool(_CONFIG_BUILD_EXPERIMENTAL "${IREE_${CONFIG_NAME}_BUILD_EXPERIMENTAL}")
+
+ # Escape semicolons in the targets list so that CMake doesn't expand them to
+ # spaces.
+ string(REPLACE ";" "$<SEMICOLON>" _CONFIG_HAL_DRIVERS_TO_BUILD "${IREE_HAL_DRIVERS_TO_BUILD}")
+ string(REPLACE ";" "$<SEMICOLON>" _CONFIG_TARGET_BACKENDS_TO_BUILD "${IREE_TARGET_BACKENDS_TO_BUILD}")
+ # LINT.ThenChange(
+ # https://github.com/google/iree/tree/master/CMakeLists.txt:iree_options,
+ # https://github.com/google/iree/tree/master/build_tools/cmake/iree_cross_compile.cmake:iree_cross_compile_invoke
+ # )
+
+ message(STATUS "C compiler for ${CONFIG_NAME} build: ${_CONFIG_C_COMPILER}")
+ message(STATUS "C++ compiler for ${CONFIG_NAME} build: ${_CONFIG_CXX_COMPILER}")
+
+ add_custom_command(OUTPUT ${IREE_${CONFIG_NAME}_BINARY_ROOT}/CMakeCache.txt
+ COMMAND "${CMAKE_COMMAND}" "${PROJECT_SOURCE_DIR}" -G "${CMAKE_GENERATOR}"
+ -DCMAKE_MAKE_PROGRAM="${CMAKE_MAKE_PROGRAM}"
+ -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}"
+ -DCMAKE_C_COMPILER="${_CONFIG_C_COMPILER}"
+ -DCMAKE_CXX_COMPILER="${_CONFIG_CXX_COMPILER}"
+ # LINT.IfChange(iree_cross_compile_invoke)
+ -DIREE_ENABLE_RUNTIME_TRACING=${_CONFIG_ENABLE_RUNTIME_TRACING}
+ -DIREE_ENABLE_MLIR=${_CONFIG_ENABLE_MLIR}
+ -DIREE_ENABLE_EMITC=${_CONFIG_ENABLE_EMITC}
+ -DIREE_BUILD_COMPILER=${_CONFIG_BUILD_COMPILER}
+ -DIREE_BUILD_TESTS=${_CONFIG_BUILD_TESTS}
+ -DIREE_BUILD_DOCS=${_CONFIG_BUILD_DOCS}
+ -DIREE_BUILD_SAMPLES=${_CONFIG_BUILD_SAMPLES}
+ -DIREE_BUILD_DEBUGGER=${_CONFIG_BUILD_DEBUGGER}
+ -DIREE_BUILD_PYTHON_BINDINGS=${_CONFIG_BUILD_PYTHON_BINDINGS}
+ -DIREE_BUILD_EXPERIMENTAL=${_CONFIG_BUILD_EXPERIMENTAL}
+ # LINT.ThenChange(
+ # https://github.com/google/iree/tree/master/CMakeLists.txt:iree_options,
+ # https://github.com/google/iree/tree/master/build_tools/cmake/iree_cross_compile.cmake:iree_cross_compile_options,
+ # )
+ -DIREE_HAL_DRIVERS_TO_BUILD="${_CONFIG_HAL_DRIVERS_TO_BUILD}"
+ -DIREE_TARGET_BACKENDS_TO_BUILD="${_CONFIG_TARGET_BACKENDS_TO_BUILD}"
+ WORKING_DIRECTORY ${_CONFIG_BINARY_ROOT}
+ DEPENDS iree_prepare_${CONFIG_NAME}_dir
+ COMMENT "Configuring IREE for ${CONFIG_NAME} build...")
+
+ add_custom_target(iree_configure_${CONFIG_NAME} DEPENDS ${_CONFIG_BINARY_ROOT}/CMakeCache.txt)
+endfunction()
+
+# iree_get_build_command
+#
+# Gets the CMake build command for the given `TARGET`.
+#
+# Parameters:
+# TARGET: the target to build.
+# BINDIR: root binary directory containing CMakeCache.txt.
+# CMDVAR: variable name for receiving the build command.
+function(iree_get_build_command TARGET)
+ cmake_parse_arguments(_RULE "" "BINDIR;CMDVAR;CONFIG" "" ${ARGN})
+ if(NOT _RULE_CONFIG)
+ set(_RULE_CONFIG "$<CONFIG>")
+ endif()
+ if (CMAKE_GENERATOR MATCHES "Make")
+ # Use special command for Makefiles to support parallelism.
+ set(${_RULE_CMDVAR}
+ "$(MAKE)" "-C" "${_RULE_BINDIR}" "${TARGET}" PARENT_SCOPE)
+ else()
+ set(${_RULE_CMDVAR}
+ "${CMAKE_COMMAND}" --build ${_RULE_BINDIR}
+ --target ${TARGET}
+ --config ${_RULE_CONFIG} PARENT_SCOPE)
+ endif()
+endfunction()
+
+# iree_host_install
+#
+# Defines custom commands and targets for installing the given `target` under
+# host configuration. The custom target for install will be named as
+# `iree_host_install_${TARGET}`.
+#
+# Precondition:
+# iree_create_configuration(HOST) is invoked previously.
+#
+# Parameters:
+# COMPONENT: installation component; used for filtering installation targets.
+# PREFIX: the root installation path prefix.
+# DEPENDS: addtional dependencies for the installation.
+function(iree_host_install TARGET)
+ cmake_parse_arguments(_RULE "" "TARGET;COMPONENT;PREFIX" "DEPENDS" ${ARGN})
+ if(_RULE_COMPONENT)
+ set(_COMPONENT_OPTION -DCMAKE_INSTALL_COMPONENT="${_RULE_COMPONENT}")
+ endif()
+ if(_RULE_PREFIX)
+ set(_PREFIX_OPTION -DCMAKE_INSTALL_PREFIX="${_RULE_PREFIX}")
+ endif()
+
+ iree_get_executable_path(_OUTPUT_PATH ${TARGET})
+
+ add_custom_command(
+ OUTPUT ${_OUTPUT_PATH}
+ DEPENDS ${_RULE_DEPENDS}
+ COMMAND "${CMAKE_COMMAND}" ${_COMPONENT_OPTION} ${_PREFIX_OPTION}
+ -P "${IREE_HOST_BINARY_ROOT}/cmake_install.cmake"
+ USES_TERMINAL)
+
+ # Give it a custom target so we can drive the generation manually
+ # when useful.
+ add_custom_target(iree_host_install_${TARGET} DEPENDS ${_OUTPUT_PATH})
+endfunction()
+
+# iree_declare_host_excutable
+#
+# Generates custom commands and targets for building and installing a tool on
+# host for cross-compilation.
+#
+# Precondition:
+# iree_create_configuration(HOST) is invoked previously.
+#
+# Parameters:
+# TARGET: the target to build on host.
+# BUILDONLY: only generates commands for building the target.
+# DEPENDS: any additional dependencies for the target.
+function(iree_declare_host_excutable TARGET)
+ cmake_parse_arguments(_RULE "BUILDONLY" "" "DEPENDS" ${ARGN})
+
+ iree_get_executable_path(_OUTPUT_PATH ${TARGET})
+
+ iree_get_build_command(${TARGET}
+ BINDIR ${IREE_HOST_BINARY_ROOT}
+ CMDVAR build_cmd)
+
+ add_custom_target(iree_host_build_${TARGET}
+ COMMAND ${build_cmd}
+ DEPENDS iree_configure_HOST ${_RULE_DEPENDS}
+ WORKING_DIRECTORY "${IREE_HOST_BINARY_ROOT}"
+ COMMENT "Building host ${TARGET}..."
+ USES_TERMINAL)
+
+ if(_RULE_BUILDONLY)
+ return()
+ endif()
+
+ iree_host_install(${TARGET}
+ COMPONENT ${TARGET}
+ PREFIX ${IREE_HOST_BINARY_ROOT}
+ DEPENDS iree_host_build_${TARGET})
+
+ # Note that this is not enabled when BUILDONLY so we can define
+ # iree_host_${TARGET} to point to another installation path to
+ # allow flexibility.
+ add_custom_target(iree_host_${TARGET} DEPENDS "${_OUTPUT_PATH}")
+endfunction()
diff --git a/build_tools/cmake/iree_macros.cmake b/build_tools/cmake/iree_macros.cmake
index 4929146..7b27392 100644
--- a/build_tools/cmake/iree_macros.cmake
+++ b/build_tools/cmake/iree_macros.cmake
@@ -25,6 +25,22 @@
endif()
#-------------------------------------------------------------------------------
+# General utilities
+#-------------------------------------------------------------------------------
+
+# iree_to_bool
+#
+# Sets `variable` to `ON` if `value` is true and `OFF` otherwise.
+function(iree_to_bool VARIABLE VALUE)
+ if(VALUE)
+ set(${VARIABLE} "ON" PARENT_SCOPE)
+ else()
+ set(${VARIABLE} "OFF" PARENT_SCOPE)
+ endif()
+endfunction()
+
+
+#-------------------------------------------------------------------------------
# Packages and Paths
#-------------------------------------------------------------------------------
@@ -72,6 +88,28 @@
set(${PACKAGE_DIR} ${_PACKAGE_DIR} PARENT_SCOPE)
endfunction()
+# iree_get_executable_path
+#
+# Gets the path to an executable in a cross-compilation-aware way. This
+# should be used when accessing binaries that are used as part of the build,
+# such as for generating files used for later build steps.
+#
+# Paramters:
+# - OUTPUT_PATH_VAR: variable name for receiving the path to the built target.
+# - TARGET: the target to build on host.
+function(iree_get_executable_path OUTPUT_PATH_VAR TARGET)
+ if(CMAKE_CROSSCOMPILING)
+ # The target is defined in the CMake invocation for host. We don't have
+ # access to the target; relying on the path here.
+ set(_OUTPUT_PATH "${IREE_HOST_BINARY_ROOT}/bin/${TARGET}")
+ set(${OUTPUT_PATH_VAR} "${_OUTPUT_PATH}" PARENT_SCOPE)
+ else()
+ # The target is defined in this CMake invocation. We can query the location
+ # directly from CMake.
+ set(${OUTPUT_PATH_VAR} "$<TARGET_FILE:${TARGET}>" PARENT_SCOPE)
+ endif()
+endfunction()
+
#-------------------------------------------------------------------------------
# select()-like Evaluation
#-------------------------------------------------------------------------------
@@ -169,3 +207,20 @@
endif()
endforeach()
endfunction()
+
+# iree_add_executable_dependencies
+#
+# Adds dependency on a target in a cross-compilation-aware way. This should
+# be used for depending on targets that are used as part of the build, such
+# as for generating files used for later build steps.
+#
+# Parameters:
+# TARGET: the target to take on dependencies
+# DEPENDENCY: additional dependencies to append to target
+function(iree_add_executable_dependencies TARGET DEPENDENCY)
+ if(CMAKE_CROSSCOMPILING)
+ add_dependencies(${TARGET} iree_host_${DEPENDENCY})
+ else()
+ add_dependencies(${TARGET} ${DEPENDENCY})
+ endif()
+endfunction()
diff --git a/build_tools/embed_data/CMakeLists.txt b/build_tools/embed_data/CMakeLists.txt
index ec07934..4efad40 100644
--- a/build_tools/embed_data/CMakeLists.txt
+++ b/build_tools/embed_data/CMakeLists.txt
@@ -12,13 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-add_executable(generate_cc_embed_data)
-target_sources(generate_cc_embed_data PRIVATE generate_cc_embed_data.cc)
-set_target_properties(generate_cc_embed_data PROPERTIES OUTPUT_NAME generate_cc_embed_data)
+if(CMAKE_CROSSCOMPILING)
+ iree_declare_host_excutable(generate_cc_embed_data)
+else()
+ add_executable(generate_cc_embed_data)
+ target_sources(generate_cc_embed_data PRIVATE generate_cc_embed_data.cc)
+ set_target_properties(generate_cc_embed_data PROPERTIES OUTPUT_NAME generate_cc_embed_data)
-target_link_libraries(generate_cc_embed_data
- absl::flags
- absl::flags_parse
- absl::strings
- absl::time
-)
+ target_link_libraries(generate_cc_embed_data
+ absl::flags
+ absl::flags_parse
+ absl::strings
+ absl::time
+ )
+ install(TARGETS generate_cc_embed_data
+ COMPONENT generate_cc_embed_data
+ RUNTIME DESTINATION bin)
+endif()
diff --git a/docs/repository_management.md b/docs/repository_management.md
index 3cd2618..1d6a342 100644
--- a/docs/repository_management.md
+++ b/docs/repository_management.md
@@ -7,6 +7,13 @@
transparency. If any of these things are particularly troublesome or painful for
your workflow, please reach out to us so we can prioritize a fix.
+NOTE: We are currently in the process of migrating our repository to be
+GitHub-first and hide the merging complexity in a separate `google` feature
+branch so that standard development workflows don't have to bear the cost for
+every contribution. During this part of the migration period, please direct PRs
+to the `google` branch (which will be marked as the default branch). See
+https://groups.google.com/d/msg/iree-discuss/F07vsG9Ah4o/uAIusKO-BQAJ.
+
## Dependencies
As a project which brings together compiler, runtime and graphics systems,
@@ -141,9 +148,6 @@
this prior to running just to make sure that their git view of the submodule
state is consistent.
-TODO(laurenzo): Add a GitHub hook to auto-commit submodule updates on
-`SUBMODULE_VERSIONS` file changes.
-
#### Updating TensorFlow and LLVM versions
WARNING: These scripts have not been updated to reflect the new tooling to
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index 39b6475..7f6d34d 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -165,6 +165,7 @@
"//experimental/ModelBuilder:ModelRunner",
"//experimental/ModelBuilder:VulkanLaunchWrapper",
"//iree/base:initializer",
+ "//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
diff --git a/experimental/ModelBuilder/test/TestVectorToGPU.cpp b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
index 547e825..72902ef 100644
--- a/experimental/ModelBuilder/test/TestVectorToGPU.cpp
+++ b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
@@ -47,7 +47,7 @@
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
using namespace mlir; // NOLINT
using namespace mlir::edsc; // NOLINT
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
index d23a095..1f980ae 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
@@ -108,6 +108,7 @@
"@llvm-project//mlir:IR",
"@org_tensorflow//tensorflow/cc/saved_model:loader_lite",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+ "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
"@org_tensorflow//tensorflow/core:core_cpu",
],
)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.cc b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.cc
index c9a4da4..1dcc0c2 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.cc
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/register_tensorflow.cc
@@ -63,7 +63,7 @@
absl::MakeSpan(mutable_exported_names));
if (!module_or.status().ok()) {
std::stringstream msg;
- msg << "Failed to convert saved model to MLIR'" << saved_model_dir
+ msg << "Failed to convert saved model to MLIR '" << saved_model_dir
<< "': " << module_or.status();
throw RaisePyError(PyExc_RuntimeError, msg.str().c_str());
}
@@ -93,7 +93,7 @@
context_bundle->mlir_context());
if (!module_or.status().ok()) {
std::stringstream msg;
- msg << "Failed to convert saved model to MLIR'" << saved_model_dir
+ msg << "Failed to convert saved model to MLIR '" << saved_model_dir
<< "': " << module_or.status();
throw RaisePyError(PyExc_RuntimeError, msg.str().c_str());
}
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
index a536513..c211cf0 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
@@ -16,6 +16,7 @@
"//bindings/python:build_defs.oss.bzl",
"INTREE_TENSORFLOW_PY_DEPS",
"iree_py_library",
+ "iree_py_test",
)
package(
@@ -35,3 +36,15 @@
"//bindings/python/pyiree/rt",
],
)
+
+iree_py_test(
+ name = "tf_test_utils_test",
+ srcs = [
+ "tf_test_utils.py",
+ "tf_test_utils_test.py",
+ ],
+ python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 37da253..963ee4b 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -21,6 +21,7 @@
import collections
import os
+import random
import re
import tempfile
@@ -29,7 +30,6 @@
import numpy as np
from pyiree import rt
from pyiree.tf import compiler
-import random
import tensorflow.compat.v2 as tf
flags.DEFINE_string(
@@ -83,11 +83,11 @@
pass_pipeline=())
# Save the input MLIR module.
- flattened_target_backends = re.sub("[^0-9a-zA-Z]+", "_",
+ flattened_target_backends = re.sub("[^0-9a-zA-Z_]+", "_",
"__".join(target_backends))
if global_debug_dir:
mlir_path = os.path.join(global_debug_dir,
- "raw_%s.mlir" % flattened_target_backends)
+ "raw__%s.mlir" % flattened_target_backends)
logging.info("Saving raw TF input MLIR to: %s", mlir_path)
with open(mlir_path, "w") as f:
f.write(compiler_module.to_asm())
@@ -97,7 +97,7 @@
if global_debug_dir:
mlir_path = os.path.join(global_debug_dir,
- "input_%s.mlir" % flattened_target_backends)
+ "input__%s.mlir" % flattened_target_backends)
logging.info("Saving IREE input MLIR to: %s", mlir_path)
with open(mlir_path, "w") as f:
f.write(compiler_module.to_asm())
@@ -105,7 +105,7 @@
compiled_module = compiler_module.compile(target_backends=target_backends)
if global_debug_dir:
compiled_path = os.path.join(
- global_debug_dir, "compiled_%s.vmfb" % flattened_target_backends)
+ global_debug_dir, "compiled__%s.vmfb" % flattened_target_backends)
logging.info("Saving compiled IREE module to: %s", compiled_path)
with open(compiled_path, "wb") as f:
f.write(compiled_module)
@@ -315,6 +315,68 @@
return _make_multi_result_class(results_tuple_class)(*all_results.values())
+def _recursive_check_same(result_ref, result_tgt, rtol=1e-6, atol=1e-6):
+ same = True
+ if not isinstance(result_tgt, type(result_ref)):
+ raise ValueError("Types of the outputs must be the same, but have '{}' and "
+ "'{}'".format(type(result_ref), type(result_tgt)))
+ if isinstance(result_ref, dict):
+ if result_ref.keys() != result_tgt.keys():
+ raise ValueError("Outputs must have the same structure, but have '{}' and"
+ " '{}'".format(result_ref.keys(), result_tgt.keys()))
+ for key in result_ref.keys():
+ same = same and _recursive_check_same(result_ref[key], result_tgt[key],
+ rtol, atol)
+ if not same:
+ return False # no need to go further they are different
+ elif isinstance(result_ref, list):
+ if len(result_ref) != len(result_tgt):
+ raise ValueError("Outputs must have the same structure, but have '{}' and"
+ " '{}'".format(result_ref, result_tgt))
+ for i in range(len(result_ref)):
+ same = same and _recursive_check_same(result_ref[i], result_tgt[i], rtol,
+ atol)
+ if not same:
+ return False # no need to go further they are different
+ elif isinstance(result_ref, np.ndarray):
+ if isinstance(result_ref.flat[0], np.floating):
+ return np.allclose(result_ref, result_tgt, rtol=rtol, atol=atol)
+ else:
+ return np.array_equal(result_ref, result_tgt)
+ else:
+ # this one need more checks
+ return result_ref == result_tgt
+ return same
+
+
+def _collect_disagreements_recursively(mr, rtol=1e-6, atol=1e-6):
+ """Compare result structs recursively and search for disagreements.
+
+ Args:
+ mr: A MultiResults namedtuple where each entry corresponds to a backend set
+ of results.
+ rtol: The relative tolerance parameter.
+ atol: The absolute tolerance parameter.
+
+ Returns:
+ An equivalent MultiResults where each entry is an array of result names
+ that disagree.
+ """
+ has_disagreement = False
+ disagreement_list = [list() for _ in mr]
+ for i in range(len(mr)):
+ result_ref = mr[i]
+ for j in range(len(mr)):
+ if i < j:
+ continue # Don't check self and reverse comparisons
+ result_tgt = mr[j]
+ if not _recursive_check_same(result_ref, result_tgt, rtol, atol):
+ has_disagreement = True
+ disagreement_list[i].append(mr._fields[j])
+ disagreements_tuple = collections.namedtuple("Disagreements", mr._fields)
+ return has_disagreement, disagreements_tuple(*disagreement_list)
+
+
def _collect_disagreements(mr, predicate):
"""Verifies that result structs.
@@ -363,10 +425,31 @@
(disagreements, self))
return self
+ def assert_all_close_and_equal(self, rtol=1e-6, atol=1e-6):
+ # it is a special case when output can be a nestet map of dict(), list()
+ # with different types: int, float, string
+ # in this case int and string must be equal and for float we use rtol,atol
+ has_disagreement, disagreements = _collect_disagreements_recursively(
+ self, rtol, atol)
+ assert not has_disagreement, ("Multiple backends disagree (%r):\n%r" %
+ (disagreements, self))
+ return self
+
def print(self):
print(self)
return self
+ def save(self):
+ if FLAGS.debug_dir:
+ for i in range(len(self)):
+ result = self[i] # output generated by a model
+ field = self._fields[i] # backend name
+ fname = os.path.join(FLAGS.debug_dir, "output_{}".format(field))
+ with open(fname, "w") as file:
+ # content of txt file can be converted to py objects by eval(txt)
+ file.write(str(result))
+ return self
+
return MultiResults
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
new file mode 100644
index 0000000..20ba522
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -0,0 +1,83 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for pyiree.tf.support.tf_test_utils."""
+
+from absl.testing import parameterized
+import numpy as np
+from pyiree.tf.support import tf_test_utils
+import tensorflow as tf
+
+
+class UtilsTests(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters([
+ {
+ 'testcase_name': 'all the same',
+ 'array_c': np.array([0, 1, 2]),
+ 'array_d': np.array(['0', '1', '2']),
+ 'array_e': np.array([0.0, 0.1, 0.2]),
+ 'tgt_same': True,
+ },
+ {
+ 'testcase_name': 'wrong int',
+ 'array_c': np.array([1, 1, 2]),
+ 'array_d': np.array(['0', '1', '2']),
+ 'array_e': np.array([0.0, 0.1, 0.2]),
+ 'tgt_same': False,
+ },
+ {
+ 'testcase_name': 'wrong string',
+ 'array_c': np.array([0, 1, 2]),
+ 'array_d': np.array(['a', '1', '2']),
+ 'array_e': np.array([0.0, 0.1, 0.2]),
+ 'tgt_same': False,
+ },
+ {
+ 'testcase_name': 'wrong float',
+ 'array_c': np.array([0, 1, 2]),
+ 'array_d': np.array(['0', '1', '2']),
+ 'array_e': np.array([1.0, 0.1, 0.2]),
+ 'tgt_same': False,
+ },
+ ])
+ def test_recursive_check_same(self, array_c, array_d, array_e, tgt_same):
+
+ ref = {
+ 'a':
+ 1,
+ 'b': [{
+ 'c': np.array([0, 1, 2])
+ }, {
+ 'd': np.array(['0', '1', '2'])
+ }, {
+ 'e': np.array([0.0, 0.1, 0.2])
+ }],
+ }
+ tgt = {
+ 'a': 1,
+ 'b': [{
+ 'c': array_c
+ }, {
+ 'd': array_d
+ }, {
+ 'e': array_e
+ }],
+ }
+ same = tf_test_utils._recursive_check_same(ref, tgt)
+ self.assertEqual(tgt_same, same)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index f25c6b1..1dbecf7 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -28,13 +28,29 @@
licenses = ["notice"], # Apache 2.0
)
+# Create binaries for all test srcs to allow them to be run manually.
+[
+ py_binary(
+ name = src.replace(".py", "_manual"),
+ srcs = [src],
+ main = src,
+ python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for src in glob(["*_test.py"])
+]
+
# Special cases to exclude from automatically expanding targets for all
# backends.
+# keep sorted
SPECIAL_CASES = [
"explicit_backend_test.py",
"linspace_test.py",
]
+# keep sorted
VMLA_FAILING = [
"fill_test.py",
"mandelbrot_test.py",
@@ -42,6 +58,7 @@
"strings_test.py",
]
+# keep sorted
LLVM_FAILING = [
"broadcasting_test.py",
"depth_conv_test.py",
@@ -56,6 +73,7 @@
"strings_test.py",
]
+# keep sorted
VULKAN_FAILING = [
"broadcasting_test.py",
"depth_conv_test.py",
@@ -92,7 +110,7 @@
)
iree_e2e_test_suite(
- name = "e2e",
+ name = "e2e_tests",
backends_to_srcs = {
"tf_also": TF_PASSING,
"iree_vmla": VMLA_PASSING,
@@ -106,7 +124,7 @@
)
iree_e2e_test_suite(
- name = "e2e_failing",
+ name = "e2e_tests_failing",
backends_to_srcs = {
"iree_vmla": VMLA_FAILING,
"iree_llvmjit": LLVM_FAILING,
@@ -125,10 +143,11 @@
# Special cases.
-# linspace_test passes internally, but fails in the OSS CI.
+# linspace_test passes internally, but fails in the OSS CI, so it needs
+# a "nokokoro" tag.
iree_e2e_test_suite(
# TODO(#2082): `linspace_test.py` fails in the `bazel-tensorflow` image.
- name = "linspace",
+ name = "linspace_tests",
backends_to_srcs = {
"tf_also": ["linspace_test.py"],
"iree_vmla": ["linspace_test.py"],
@@ -143,7 +162,7 @@
)
iree_e2e_test_suite(
- name = "linspace_failing",
+ name = "linspace_tests_failing",
backends_to_srcs = {
"iree_llvmjit": ["linspace_test.py"],
"iree_vulkan": ["linspace_test.py"],
@@ -159,7 +178,8 @@
],
)
-# This tests explicitly writing which backends to use in Python.
+# This tests explicitly writing which backends to use in Python,
+# so overriding the backends can cause it to break.
iree_py_test(
name = "explicit_backend_test",
srcs = ["explicit_backend_test.py"],
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 2f3d331..36e5e71 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -22,38 +22,67 @@
## Running tests
-NOTE: We are in the process of reworking how backend specification functions, so
-you have to specify the target name including the name of the test suite and
-using a specific backend pair even if you are overriding the backends. The
-override backends take precedence.
-
```shell
# For locally running tests and iterating on backend development,
# `bazel run` is preferred.
-bazel run :e2e_math_test_tf_tf_also -- --override_backends=iree_vulkan
+bazel run :math_test_manual -- --override_backends=iree_vmla
# Same as above, but add `tf` backend to cross-check numerical correctness.
-bazel run :e2e_math_test_tf_tf_also -- --override_backends=tf,iree_vulkan
+bazel run :math_test_manual -- --override_backends=tf,iree_vmla
# Run all tests with defaults and output on failure.
bazel test ... --test_output=errors
# Run an individual test interactively.
-bazel test simple_arithmetic_test --test_output=streamed
-
-# Run tests with an altered list of backends.
-bazel test ... --test_output=errors \
- --test_arg=--override_backends=tf,iree_vmla,iree_vulkan
+bazel run :math_test_manual -- --test_output=streamed
```
If you specify the same backend multiple times, for example
---override_backends=iree_vmla,iree_vmla. The same backends are grouped and in
-this example iree_vmla will run once. If you specify tf,iree_vmla as backends,
-then we will test both backends and compare them with each other. If you specify
-tf backend only, then we will also test tf vs tf to capture any model
-initialization/randomization issues (it is a special case for debug purpose).
-For reproducibility of the unit tests we set random seed of tf and numpy by
-calling tf_test_utils.set_random_seed() before model creation.
+`--override_backends=iree_vmla,iree_vmla`. The same backends are grouped and in
+this example `iree_vmla` will run once. If you specify `tf,iree_vmla` as
+backends, then we will test both backends and compare them with each other. If
+you specify `tf` backend only, then we will also test `tf` vs `tf` to capture
+any model initialization/randomization issues (it is a special case for debug
+purpose). For reproducibility of the unit tests we set random seed of `tf` and
+`numpy` by calling `tf_test_utils.set_random_seed()` before model creation.
+
+## Test Suites
+
+Test targets are automatically generated for each test file and for each backend
+to check numerical correctness against TensorFlow. Tests targets that pass are
+placed into the `e2e_tests` test suite. Tests that fail on particular backends
+are recorded in lists in the `BUILD` files. For example, if
+`experimental_new_test.py` fails on the `iree_llvmjit` and `iree_vulkan`
+backends then the following lines should be added to the `BUILD` file:
+
+```build
+LLVM_FAILING = [
+ ...
+ "experimental_new_test.py",
+ ...
+]
+
+VULKAN_FAILING = [
+ ...
+ "experimental_new_test.py",
+ ...
+]
+```
+
+Test targets for these backends are placed into the `e2e_tests_failing` test
+suite. Test targets in these test suites can be run as follows:
+
+```shell
+# Run all e2e tests that are expected to pass.
+bazel test :e2e_tests
+
+# Run all e2e tests that are expected to fail.
+bazel test :e2e_tests_failing
+
+# Run a specific failing e2e test target.
+# Note that generated test targets are prefixed with their test suite name.
+bazel test :e2e_tests_failing_broadcasting_test__tf__iree_vulkan
+```
## Debugging tests
@@ -74,15 +103,7 @@
### Limiting a test to only certain backends
The BUILD file specifies which targets work on which backends and controls which
-backends tests are run on by using the `--override_backends` flag. If you add a
-new test that does not work on some backends, list it as failing on those
-backends in the BUILD file.
-
-```build
-VULKAN_FAILING = [
- "my_experimental_new_test.py",
-]
-```
+backends tests are run on by using the `--override_backends` flag.
The `@tf_test_utils.compile_modules` decorator on tests also takes a `backends=`
keyword argument. Many tests still specify this, but it is ignored in the CI,
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index e9efa4c..7b19938 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -48,7 +48,7 @@
for backend, srcs in backends_to_srcs.items():
for src in srcs:
- test_name = "{}_{}_{}_{}".format(
+ test_name = "{}_{}__{}__{}".format(
name,
src[:-3],
reference_backend,
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 6bff6e4..84d617b 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -16,7 +16,6 @@
"//bindings/python:build_defs.oss.bzl",
"INTREE_TENSORFLOW_PY_DEPS",
"NUMPY_DEPS",
- "iree_py_test",
)
load(
"//integrations/tensorflow/e2e/keras:iree_vision_test_suite.bzl",
@@ -32,6 +31,47 @@
licenses = ["notice"], # Apache 2.0
)
+# @unused
+DOC = """
+vision_model_test_manual is for manual testing of all keras vision models.
+Test will run only manually with all parameters specified manually, for example:
+bazel run -c opt integrations/tensorflow/e2e/keras:vision_model_test_manual -- \
+--override_backends=tf,iree_vmla,iree_llvmjit \
+--data=imagenet \
+--include_top=1 \
+--url=https://storage.googleapis.com/iree_models/ \
+--model=ResNet50
+
+Command arguments description:
+--override_backends: can be combination of these: tf,iree_vmla,iree_llvmjit
+--data: can be 'imagenet' or 'cifar10'.
+ imagenet - input image size (1, 224, 224, 3)
+ cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
+ and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
+--include_top: can be 1 or 0. Include top layer 1, not include top layer 0
+--url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
+ imagenet pretrained weights url is specified by keras
+--model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
+ ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
+ InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
+ DenseNet201, NASNetMobile, NASNetLarge
+ All above models works with 'imagenet' data sets.
+ ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
+"""
+
+[
+ py_binary(
+ name = src.replace(".py", "_manual"),
+ srcs = [src],
+ main = src,
+ python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+ )
+ for src in glob(["*_test.py"])
+]
+
SPECIAL_CASES = [
"vision_model_test.py",
]
@@ -71,7 +111,7 @@
)
iree_e2e_test_suite(
- name = "keras",
+ name = "keras_tests",
backends_to_srcs = {
"tf_also": TF_PASSING,
"iree_vmla": VMLA_PASSING,
@@ -85,7 +125,7 @@
)
iree_e2e_test_suite(
- name = "keras_failing",
+ name = "keras_tests_failing",
backends_to_srcs = {
"iree_vmla": VMLA_FAILING,
"iree_llvmjit": LLVM_FAILING,
@@ -102,52 +142,8 @@
],
)
-# @unused
-DOC = """
-vision_models_test is for manual testing of all keras vision models.
-Test will run only manually with all parameters specified manually, for example:
-bazel run -c opt integrations/tensorflow/e2e/keras/vision_models_test -- \
---override_backends=tf,iree_vmla,iree_llvmjit \
---data=imagenet \
---include_top=1 \
---url=https://storage.googleapis.com/iree_models/ \
---model=ResNet50
-
-Command arguments description:
---override_backends: can be combination of these: tf,iree_vmla,iree_llvmjit
---data: can be 'imagenet' or 'cifar10'.
- imagenet - input image size (1, 224, 224, 3)
- cifar10 - input image size (1, 32, 32, 3) - it is used for quick tests
- and needs pretrained weights, we pretrained models: ResNet50, MobileNet, MobileNetV2
---include_top: can be 1 or 0. Include top layer 1, not include top layer 0
---url: we need it only for cifar10 models to load weights from https://storage.googleapis.com/iree_models/
- imagenet pretrained weights url is specified by keras
---model: supports ResNet50, MobileNet, MobileNetV2, ResNet101, ResNet152,
- ResNet50V2, ResNet101V2, ResNet152V2, VGG16, VGG19, Xception,
- InceptionV3, InceptionResNetV2, DenseNet121, DenseNet169,
- DenseNet201, NASNetMobile, NASNetLarge
- All above models works with 'imagenet' data sets.
- ResNet50, MobileNet, MobileNetV2 work with both 'imagenet' and 'cifar10' data sets.
-"""
-
-iree_py_test(
- name = "vision_models_test",
- srcs = ["vision_model_test.py"],
- main = "vision_model_test.py",
- python_version = "PY3",
- tags = [
- "external",
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
-
iree_vision_test_suite(
- name = "vision_models",
+ name = "vision_internal_tests",
datasets = ["cifar10"],
models_to_backends = {
"ResNet50": [
@@ -165,7 +161,7 @@
)
iree_vision_test_suite(
- name = "vision_models_external",
+ name = "vision_external_tests",
datasets = [
"cifar10",
"imagenet",
@@ -197,9 +193,8 @@
)
iree_vision_test_suite(
- # TODO: Combine this suite with keras_vision_models_external once these
- # tests pass.
- name = "vision_models_external_failing",
+ # TODO: Combine this suite with vision_external_tests once these tests pass.
+ name = "vision_external_tests_failing",
datasets = [
"cifar10",
"imagenet",
diff --git a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
index b2609fb..bd6aae2 100644
--- a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
+++ b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
@@ -67,11 +67,11 @@
for backend in backends:
for dataset in datasets:
test_backends = [reference_backend, backend]
- test_name = "{}_{}_{}_{}".format(
+ test_name = "{}_{}_{}__{}".format(
name,
model,
dataset,
- "_".join(test_backends),
+ "__".join(test_backends),
)
tests.append(test_name)
diff --git a/integrations/tensorflow/e2e/keras/train/BUILD b/integrations/tensorflow/e2e/keras/train/BUILD
index 534ac3e..1160ab2 100644
--- a/integrations/tensorflow/e2e/keras/train/BUILD
+++ b/integrations/tensorflow/e2e/keras/train/BUILD
@@ -27,8 +27,9 @@
licenses = ["notice"], # Apache 2.0
)
+# TODO(meadowlark): Refactor this rule to match iree_vision_test_suite.bzl
iree_train_test_suite(
- name = "train",
+ name = "train_tests",
configurations = [
# tuples of (optimizer, backends)
("sgd", "tf"),
@@ -45,7 +46,7 @@
)
iree_train_test_suite(
- name = "train_failing",
+ name = "train_tests_failing",
configurations = [
# tuples of (optimizer, backends)
# TODO: Combine this suite with keras_model_train once these tests pass.
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 0a1e812..9164d9b 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -136,7 +136,7 @@
np.float32)
input_data = input_data.reshape(input_shape)
self.modules.applications.all.predict(input_data).print().assert_all_close(
- atol=3e-5)
+ atol=1e-6)
if __name__ == '__main__':
diff --git a/iree/compiler/Conversion/CodegenUtils/BUILD b/iree/compiler/Conversion/CodegenUtils/BUILD
index 09daa26..ff2ef35 100644
--- a/iree/compiler/Conversion/CodegenUtils/BUILD
+++ b/iree/compiler/Conversion/CodegenUtils/BUILD
@@ -23,13 +23,16 @@
name = "CodegenUtils",
srcs = [
"FunctionUtils.cpp",
+ "MarkerUtils.cpp",
],
hdrs = [
"FunctionUtils.h",
+ "MarkerUtils.h",
],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Support",
],
)
diff --git a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
index 8f00fc1..716b94b 100644
--- a/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
+++ b/iree/compiler/Conversion/CodegenUtils/CMakeLists.txt
@@ -19,11 +19,14 @@
CodegenUtils
HDRS
"FunctionUtils.h"
+ "MarkerUtils.h"
SRCS
"FunctionUtils.cpp"
+ "MarkerUtils.cpp"
DEPS
LLVMSupport
MLIRIR
+ MLIRLinalgTransforms
MLIRSupport
PUBLIC
)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
similarity index 89%
rename from iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
rename to iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
index 251a686..cf641e1 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.cpp
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.cpp
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
@@ -34,6 +34,8 @@
return attr && (marker == "" || attr.getValue() == marker);
}
+StringRef getNoTileMarker() { return "no-tile"; }
+
StringRef getWorkGroupMarker() { return "workgroup"; }
StringRef getWorkItemMarker() { return "workitem"; }
@@ -44,6 +46,10 @@
return checkMarkerValue(op, marker);
}
+bool hasNoTileMarker(Operation *op) {
+ return checkMarkerValue(op, getNoTileMarker());
+}
+
bool hasWorkGroupMarker(Operation *op) {
return checkMarkerValue(op, getWorkGroupMarker());
}
@@ -63,6 +69,8 @@
StringAttr::get(marker, op->getContext()));
}
+void setNoTileMarker(Operation *op) { setMarker(op, getNoTileMarker()); }
+
void setCooperativeMatrixMarker(Operation *op) {
op->setAttr(VectorTransforms::kVectorTransformMarker,
StringAttr::get(getCooperativeMatrixMarker(), op->getContext()));
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
similarity index 70%
rename from iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
rename to iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
index 633bca0..fa14263 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/MarkerUtils.h
@@ -19,8 +19,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
-#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
+#ifndef IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
+#define IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/Support/LLVM.h"
@@ -30,34 +30,55 @@
class Operation;
namespace iree_compiler {
+/// Marker to denote that do not tile the linalg operation.
+StringRef getNoTileMarker();
+
+/// Marker to denote that a linalg operation is to be partitioned to workgroups.
+StringRef getWorkGroupMarker();
+
/// Marker to denote that a linalg operation is to be partitioned to workitems.
StringRef getWorkItemMarker();
+/// Returns true if an operation has the specified `marker`. When `marker` is
+/// empty, returns true if the operation has any marker.
+bool hasMarker(Operation *, StringRef marker = "");
+
+/// Returns true if an operation has marker to denote that it is not to be
+/// tiled.
+bool hasNoTileMarker(Operation *);
+
+/// Returns true if an operation has marker to denote that it is to be
+/// partitioned to workgroups.
+bool hasWorkGroupMarker(Operation *);
+
+/// Returns true if an operation has marker to denote that it is to be
+/// partitioned to workitems.
+bool hasWorkItemMarker(Operation *);
+
/// Returns true if an operation has a marker to denote that it will be mapped
/// to cooperative matrix operations. Markers need to be consistent as
/// cooperative matrices have their own type and load/store operations.
bool hasCooperativeMatrixMarker(Operation *);
-/// Returns true if an operation has the specified `marker`. When `marker` is
-/// empty, returns true if the operation has any marker.
-bool hasMarker(Operation *, StringRef marker = "");
-
-/// Returns true if an operation has marker to denote that it is to be
-/// partitioned to workitems.
-bool hasWorkItemMarker(Operation *);
-
-/// Sets marker to denote that a vector operation is to be execute on a
-/// cooperative matrix.
-void setCooperativeMatrixMarker(Operation *);
-
/// Sets a given marker on an operation.
void setMarker(Operation *, StringRef);
+/// Sets marker to prevent tiling of a linalg operation.
+void setNoTileMarker(Operation *);
+
+/// Sets marker to denote that a linalg operation is to be partitioned to
+/// workgroups.
+void setWorkGroupMarker(Operation *);
+
/// Sets marker to denote that a linalg operation is to be partitioned to
/// workitems.
void setWorkItemMarker(Operation *);
+/// Sets marker to denote that a vector operation is to be execute on a
+/// cooperative matrix.
+void setCooperativeMatrixMarker(Operation *);
+
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_MARKERUTILS_H_
+#endif // IREE_COMPILER_CONVERSION_CODEGENUTILS_MARKERUTILS_H_
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 01c0cba..6a54529 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -22,6 +22,7 @@
#include <cstddef>
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
@@ -611,6 +612,8 @@
cond ? rewriter.create<SelectOp>(loc, cond, inputVal, paddingVal)
: inputVal;
rewriter.create<linalg::YieldOp>(loc, result);
+
+ setNoTileMarker(linalgOp);
return success();
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index 70df4d6..d2308b4 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -23,18 +23,14 @@
"ConvertToGPUPass.cpp",
"ConvertToSPIRVPass.cpp",
"LinalgTileAndFusePass.cpp",
- "MarkerUtils.cpp",
"Passes.cpp",
"SplitDispatchFunctionPass.cpp",
- "Utils.cpp",
"VectorToGPUPass.cpp",
],
hdrs = [
"Attributes.h",
- "MarkerUtils.h",
"MemorySpace.h",
"Passes.h",
- "Utils.h",
],
deps = [
"//iree/compiler/Conversion/CodegenUtils",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index b8821e2..ccc694c 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -19,18 +19,14 @@
LinalgToSPIRV
HDRS
"Attributes.h"
- "MarkerUtils.h"
"MemorySpace.h"
"Passes.h"
- "Utils.h"
SRCS
"ConvertToGPUPass.cpp"
"ConvertToSPIRVPass.cpp"
"LinalgTileAndFusePass.cpp"
- "MarkerUtils.cpp"
"Passes.cpp"
"SplitDispatchFunctionPass.cpp"
- "Utils.cpp"
"VectorToGPUPass.cpp"
DEPS
LLVMSupport
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 9a4c203..8a85b6a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -17,9 +17,8 @@
// Partition computation within dispatch function to workgroups/workitems.
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
@@ -41,7 +40,7 @@
// Loop utilities
//===----------------------------------------------------------------------===//
-/// Builds an empty scf.for operation. The default builder adds an entry basic
+/// Builds an empty loop.for operation. The default builder adds an entry basic
/// block which needs to be avoided here.
static scf::ForOp buildEmptyForOp(Location loc, OpBuilder &builder, Value lb,
Value ub, Value step) {
@@ -51,15 +50,6 @@
return cast<scf::ForOp>(builder.createOperation(state));
}
-/// Builds an empty scf.if operation without the then and else blocks.
-static scf::IfOp buildEmptyIfOp(Location loc, OpBuilder &builder, Value cond) {
- OperationState state(loc, scf::IfOp::getOperationName());
- state.addOperands(cond);
- state.addRegion();
- state.addRegion();
- return cast<scf::IfOp>(builder.createOperation(state));
-}
-
namespace {
struct LoopBounds {
Value lb;
@@ -68,10 +58,10 @@
};
} // namespace
-/// Replaces a scf.parallelOp with an optional scf.parallel op and nested
-/// scf.for operations. To create the scf.parallel op as the outermost loop,
+/// Replaces a loop.parallelOp with an optional loop.parallel op and nested
+/// loop.for operations. To create the loop.parallel op as the outermost loop,
/// pass the lower bound, upper bound and steps in `newPLoopLbs`, `newPLoopUbs`,
-/// and `newPLoopStep` respectively. The bounds of the inner scf.for operations
+/// and `newPLoopStep` respectively. The bounds of the inner loop.for operations
/// to be created are passed in `forLbs`, `forUbs`, and `forStep`. The
/// `permutation` vector contains a mapping from the original loop order, to the
/// loop order to be generated.
@@ -80,21 +70,21 @@
ArrayRef<LoopBounds> newPLoopBounds,
ArrayRef<LoopBounds> forBounds,
ArrayRef<unsigned> permutation) {
- assert(!forBounds.empty() && "unhandled case of no scf.for created");
+ assert(!forBounds.empty() && "unhandled case of no loop.for created");
unsigned numLoops = pLoopOp.getNumLoops();
Location loc = pLoopOp.getLoc();
assert(forBounds.size() + newPLoopBounds.size() == numLoops &&
- "cannot drop loops when splitting scf.parallel operation");
+ "cannot drop loops when splitting loop.parallel operation");
assert(permutation.size() == numLoops);
OpBuilder::InsertionGuard guard(rewriter);
- // Need a signature conversion for the body of the scf.parallel operation,
+ // Need a signature conversion for the body of the loop.parallel operation,
// before can it can be used as the body of the innermost loop created here.
TypeConverter::SignatureConversion signatureConverter(numLoops);
Operation *outermostLoop = nullptr;
auto permuteIt = permutation.begin();
- // Create the scf.parallel operation as the outermost loop, if specified.
+ // Create the loop.parallel operation as the outermost loop, if specified.
if (!newPLoopBounds.empty()) {
auto lbs = llvm::to_vector<2>(llvm::map_range(
newPLoopBounds, [](LoopBounds bounds) -> Value { return bounds.lb; }));
@@ -111,7 +101,7 @@
outermostLoop = newPLoop.getOperation();
}
- // Generate the nested scf.for operations with the bounds passed.
+ // Generate the nested loop.for operations with the bounds passed.
for (auto it : enumerate(forBounds)) {
Value lb = it.value().lb, ub = it.value().ub, step = it.value().step;
if (it.index() != forBounds.size() - 1) {
@@ -120,7 +110,7 @@
signatureConverter.remapInput(*permuteIt, forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
} else {
- // For the last loop, move the body of the scf.parallel op as the body of
+ // For the last loop, move the body of the loop.parallel op as the body of
// the loop after signature conversion.
auto forOp = buildEmptyForOp(loc, rewriter, lb, ub, step);
if (!outermostLoop) outermostLoop = forOp.getOperation();
@@ -137,8 +127,8 @@
return outermostLoop;
}
-/// Serializes the dimensions of the scf.parallel specified in
-/// `serializedDimensions`, by creating an nested scf.for operation for each
+/// Serializes the dimensions of the loop.parallel specified in
+/// `serializedDimensions`, by creating an nested loop.for operation for each
/// dimension.
// TODO(ravishankarm): Move this into LoopUtils.h in MLIR.
static Operation *serializeDimensions(ConversionPatternRewriter &rewriter,
@@ -151,7 +141,7 @@
serializedDimSet.insert(serializedDimensions.begin(),
serializedDimensions.end());
assert(serializedDimSet.size() == serializedDimensions.size() &&
- "cannot repeat dimensions during serialization of scf.parallel");
+ "cannot repeat dimensions during serialization of loop.parallel");
SmallVector<LoopBounds, 2> newPLoopBounds, forBounds;
SmallVector<unsigned, 2> permutation;
auto lbs = pLoopOp.lowerBound();
@@ -184,85 +174,16 @@
return serializeDimensions(rewriter, pLoopOp, serializedDimensions);
}
-/// Collapses all loops in a scf.parallel into one scf.parallel operation. This
-/// is done by
-/// 1) Normalize the loop bounds to be [0, (ub - lb) / step)
-/// 2) Compute the total number of iterations.
-/// 3) From the induction variable of the modified loop, compute the values of
-/// the original induction variables by de-linearization.
-scf::ParallelOp collapseParallelLoops(ConversionPatternRewriter &rewriter,
- scf::ParallelOp pLoopOp) {
- if (pLoopOp.getNumReductions()) return nullptr;
-
- unsigned numLoops = pLoopOp.getNumLoops();
- if (numLoops == 1) return pLoopOp;
-
- // Compute the number of iterations of each loops starting from the innermost.
- Location loc = pLoopOp.getLoc();
- Value totalNumIterations = rewriter.create<ConstantIndexOp>(loc, 1);
-
- // Track the "stride" of each loop, i.e. product of the total number of
- // iterations of the inner loops.
- SmallVector<Value, 2> iterationStride;
- iterationStride.resize(pLoopOp.getNumLoops());
- auto lbs = pLoopOp.lowerBound();
- auto ubs = pLoopOp.upperBound();
- auto steps = pLoopOp.step();
- for (int i = numLoops - 1; i >= 0; --i) {
- Value lb = lbs[i], ub = ubs[i], step = steps[i];
- Value iterCount = rewriter.create<SignedDivIOp>(
- loc, rewriter.create<SubIOp>(loc, ub, lb), step);
- iterationStride[i] = totalNumIterations;
- totalNumIterations =
- rewriter.create<MulIOp>(loc, totalNumIterations, iterCount);
- }
-
- // Create the collapsed parallel loop op with lowerbound 0, step 1 and upper
- // bound being the totalNumIterations.
- Value newLb = rewriter.create<ConstantIndexOp>(loc, 0);
- Value newStep = rewriter.create<ConstantIndexOp>(loc, 1);
- scf::ParallelOp newPLoopOp =
- rewriter.create<scf::ParallelOp>(loc, newLb, totalNumIterations, newStep);
-
- // Build the body of the collapsed loop by cloning the original loop body. The
- // replacement value of the induction variables of the original loop body,
- // from the induction variable of the new loop, using
- // origLoopIv[i] = loopIv / iterationStride[i]
- // loopIv = loopIv % iterationStride[i]
- OpBuilder::InsertionGuard guard(rewriter);
- Block &pLoopBody = pLoopOp.getLoopBody().front();
- rewriter.setInsertionPointToStart(&newPLoopOp.getLoopBody().front());
- Value loopIv = *newPLoopOp.getInductionVars().begin();
- BlockAndValueMapping map;
- for (int i : llvm::seq<int>(0, numLoops)) {
- Value iterNum =
- rewriter.create<SignedDivIOp>(loc, loopIv, iterationStride[i]);
- Value newIv = rewriter.create<AddIOp>(
- loc, lbs[i], rewriter.create<MulIOp>(loc, iterNum, steps[i]));
- map.map(pLoopBody.getArgument(i), newIv);
- loopIv = rewriter.create<SignedRemIOp>(loc, loopIv, iterationStride[i]);
- }
- for (Operation &op : pLoopBody.without_terminator()) {
- rewriter.clone(op, map);
- }
- rewriter.eraseOp(pLoopOp);
- return newPLoopOp;
-}
-
//===----------------------------------------------------------------------===//
// GPU processor ID mapping utilities
//===----------------------------------------------------------------------===//
-/// Distributes scf.parallel to processors with the processors logically
+/// Distribute loop.parallel to processors with the processors logically
/// arranged with same dimensionality as the number of loops, i.e. a
-/// scf.parallel with 2 loops to a 2D grid of processors. `processorIDs` and
+/// loop.parallel with 2 loops to a 2D grid of processors. `processorIDs` and
/// `numProcessors` must be of same size as the number of loops and are the
/// values to use for process ID and number of processors along each dimension
/// in the distributed code.
-/// This method accounts for the case where the number of processors is not
-/// enough to execute the entire iteration space with one iteration mapped to
-/// each processor. So implements a block-cyclic distribution with each block
-/// size being equal to the number of processors.
static LogicalResult mapToProcessors(ConversionPatternRewriter &rewriter,
scf::ParallelOp pLoopOp,
ArrayRef<Value> processorIDs,
@@ -291,39 +212,6 @@
return success();
}
-/// Distributes scf.parallel to processors with the processors logically
-/// arranged with same dimensionality as the number of loops, i.e. a
-/// scf.parallel with 2 loops to a 2D grid of processors. `processorIDs` must be
-/// of same size as the number of loops and are the values to use for process ID
-/// and number of processors along each dimension in the distributed code. This
-/// method assumes that the number of processors is greater than or equal to the
-/// number of iterations. So just generates an if statement to mask of
-/// processors with no work.
-static LogicalResult mapToProcessorsAndGuard(
- ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp,
- ArrayRef<Value> processorIDs) {
- unsigned numLoops = pLoopOp.getNumLoops();
- Location loc = pLoopOp.getLoc();
- assert(numLoops == processorIDs.size() &&
- "expected as many ids as number of loops");
- Value cond = nullptr;
- TypeConverter::SignatureConversion signatureConverter(numLoops);
- auto ubs = pLoopOp.upperBound();
- for (unsigned i : llvm::seq<unsigned>(0, numLoops)) {
- Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt,
- processorIDs[i], ubs[i]);
- cond = (cond ? rewriter.create<AndOp>(loc, cond, cmp) : cmp);
- signatureConverter.remapInput(i, processorIDs[i]);
- }
- scf::IfOp ifOp = buildEmptyIfOp(loc, rewriter, cond);
- Region &pLoopOpRegion = pLoopOp.getLoopBody();
- rewriter.applySignatureConversion(&pLoopOpRegion, signatureConverter);
- Region &ifOpRegion = ifOp.getRegion(0);
- rewriter.inlineRegionBefore(pLoopOpRegion, ifOpRegion, ifOpRegion.begin());
- rewriter.eraseOp(pLoopOp);
- return success();
-}
-
namespace {
struct ProcessorIdAndCount {
Value id;
@@ -363,24 +251,7 @@
rewriter.create<MulIOp>(loc, blockDim, gridDim)};
}
-template <typename GPUIdOp, typename GPUCountOp>
-static void getGPUProcessorIdsAndCounts(Location loc,
- ConversionPatternRewriter &rewriter,
- unsigned numDims,
- MutableArrayRef<Value> id,
- MutableArrayRef<Value> count) {
- ArrayRef<StringRef> dims = {"x", "y", "z"};
- assert(id.size() == numDims);
- assert(count.size() == numDims);
- for (unsigned i = 0; i < numDims; ++i) {
- ProcessorIdAndCount idAndCount =
- getGPUProcessorIdAndCount<GPUIdOp, GPUCountOp>(loc, dims[i], rewriter);
- id[numDims - 1 - i] = idAndCount.id;
- count[numDims - 1 - i] = idAndCount.count;
- }
-}
-
-/// Distributes scf.parallel to processors where `IdOp` is used to get the
+/// Distribute loop.parallel to processors where `IdOp` is used to get the
/// processor ID and `DimOp` is used to get the number of processors along a
/// dimension.
template <typename GPUIdOp, typename GPUCountOp>
@@ -392,51 +263,38 @@
cast<scf::ParallelOp>(serializeDimensionsFrom(rewriter, pLoopOp, 3));
numLoops = 3;
}
- SmallVector<Value, 2> id(numLoops), count(numLoops);
- getGPUProcessorIdsAndCounts<GPUIdOp, GPUCountOp>(pLoopOp.getLoc(), rewriter,
- numLoops, id, count);
+ SmallVector<Value, 2> id, count;
+ id.reserve(numLoops);
+ count.reserve(numLoops);
+ ArrayRef<StringRef> dims = {"x", "y", "z"};
+ Location loc = pLoopOp.getLoc();
+ for (unsigned i = 0; i < numLoops; ++i) {
+ ProcessorIdAndCount idAndCount =
+ getGPUProcessorIdAndCount<GPUIdOp, GPUCountOp>(loc, dims[i], rewriter);
+ id.insert(id.begin(), idAndCount.id);
+ count.insert(count.begin(), idAndCount.count);
+ }
return mapToProcessors(rewriter, pLoopOp, id, count);
}
-/// Distributes scf.parallel to processors where `IdOp` is used to get the
-/// processor ID and `DimOp` is used to get the number of processors along a
-/// dimension. Assumes that the number of processors will be less than equal to
-/// the number of iterations of the pLoopOp along all dimensions.
-template <typename GPUIdOp, typename GPUCountOp>
-static LogicalResult mapToProcessorsAndGuard(
- ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp) {
- unsigned numLoops = pLoopOp.getNumLoops();
- if (numLoops > 3) {
- pLoopOp =
- cast<scf::ParallelOp>(serializeDimensionsFrom(rewriter, pLoopOp, 3));
- numLoops = 3;
- }
- SmallVector<Value, 2> id(numLoops), count(numLoops);
- getGPUProcessorIdsAndCounts<GPUIdOp, GPUCountOp>(pLoopOp.getLoc(), rewriter,
- numLoops, id, count);
- return mapToProcessorsAndGuard(rewriter, pLoopOp, id);
-}
-
-/// Distribute the scf.parallel to workgroups.
+/// Distribute the loop.parallel to workgroups.
static LogicalResult mapToWorkgroups(ConversionPatternRewriter &rewriter,
scf::ParallelOp pLoopOp) {
return mapToProcessor<gpu::BlockIdOp, gpu::GridDimOp>(rewriter, pLoopOp);
}
-/// Distributes scf.parallel to workitems using local invocation ID.
+/// Distribute loop.parallel to workitems using local invocation ID.
static LogicalResult mapToLocalInvocationId(ConversionPatternRewriter &rewriter,
scf::ParallelOp pLoopOp) {
- return mapToProcessorsAndGuard<gpu::ThreadIdOp, gpu::BlockDimOp>(rewriter,
- pLoopOp);
+ return mapToProcessor<gpu::ThreadIdOp, gpu::BlockDimOp>(rewriter, pLoopOp);
}
-/// Distributes scf.parallel to workitems using global invocation ID. The GPU
+/// Distribute loop.parallel to workitems using global invocation ID. The GPU
/// dialect doesn't have a direct operation to do this. This could be done using
/// id = blockIdx * blockDim + gridIdx. count = blockDim * gridDim.
static LogicalResult mapToGlobalInvocationId(
ConversionPatternRewriter &rewriter, scf::ParallelOp pLoopOp) {
- return mapToProcessorsAndGuard<GPUGlobalId, GPUGlobalCount>(rewriter,
- pLoopOp);
+ return mapToProcessor<GPUGlobalId, GPUGlobalCount>(rewriter, pLoopOp);
}
//===----------------------------------------------------------------------===//
@@ -449,7 +307,7 @@
void runOnFunction() override;
};
-/// Pattern to map scf.parallel to workgroups.
+/// Pattern to map loop.parallel to workgroups.
struct PartitionPLoopToWorkgroups
: public OpConversionPattern<scf::ParallelOp> {
using OpConversionPattern<scf::ParallelOp>::OpConversionPattern;
@@ -460,7 +318,7 @@
}
};
-/// Map tiled linalg op to workitems by lowering it to scf.parallel and
+/// Map tiled linalg op to workitems by lowering it to loop.parallel and
/// partitioning it to workitems.
template <typename LinalgOpTy>
struct MapLinalgOpToLocalInvocationId : public OpConversionPattern<LinalgOpTy> {
@@ -493,29 +351,19 @@
LogicalResult matchAndRewrite(
LinalgOpTy linalgOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- // If marker exists do nothing.
- if (hasMarker(linalgOp)) return failure();
+ // If marker exists and its not no-tile, do nothing.
+ if (hasMarker(linalgOp) && !hasNoTileMarker(linalgOp)) return failure();
Optional<linalg::LinalgLoops> loops =
linalg::linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
if (!loops) return failure();
-
- SmallVector<int64_t, 3> workgroupSize(3, 1);
if (!loops.getValue().empty()) {
scf::ParallelOp pLoopOp = dyn_cast<scf::ParallelOp>(loops.getValue()[0]);
// If there are parallel loops partition them to threads using global
// invocation ID.
- if (pLoopOp) {
- pLoopOp = collapseParallelLoops(rewriter, pLoopOp);
- if (!pLoopOp) return failure();
- if (failed(mapToGlobalInvocationId(rewriter, pLoopOp)))
- return rewriter.notifyMatchFailure(
- linalgOp, "mapping to GlobalInvocationID failed");
- workgroupSize = {32, 1, 1};
- }
+ if (pLoopOp && failed(mapToGlobalInvocationId(rewriter, pLoopOp)))
+ return failure();
}
rewriter.eraseOp(linalgOp);
- FuncOp funcOp = linalgOp.template getParentOfType<FuncOp>();
- if (funcOp) updateWorkGroupSize(funcOp, workgroupSize);
return success();
}
};
@@ -544,7 +392,7 @@
MLIRContext *context = &getContext();
ConversionTarget target(*context);
- // After this pass Linalg and scf.parallel ops should be gone.
+ // After this pass Linalg and loop.parallel ops should be gone.
target.addIllegalOp<scf::ParallelOp>();
target.addIllegalDialect<linalg::LinalgDialect>();
// Reshape ops are treated legal since they just change the way the underlying
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 399b3f5..aeb0996 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -21,7 +21,7 @@
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 7125f8b..48e1d57 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -17,10 +17,9 @@
// Implements a pass to tile and fuse linalg operations on buffers.
//
//===----------------------------------------------------------------------===//
-#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
+#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
-#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -72,6 +71,47 @@
.size();
}
+/// Updates the workgroup size used for the dispatch region.
+static LogicalResult updateWorkGroupSize(FuncOp funcOp,
+ ArrayRef<int64_t> workGroupSize) {
+ // Need to update both the surrounding FuncOp that has the spv.entry_point_abi
+ // attribute, and the hal.executable.
+ Region &body = funcOp.getBody();
+ if (!llvm::hasSingleElement(body))
+ return funcOp.emitError("unhandled dispatch function with multiple blocks");
+
+ SmallVector<int32_t, 3> workGroupSizeVec = llvm::to_vector<3>(llvm::map_range(
+ workGroupSize, [](int64_t v) { return static_cast<int32_t>(v); }));
+
+ // TODO(ravishankarm, antiagainst): We should have at most one scf.parallel
+ // op, but that is not the case till the splitting of kernels lands.
+ unsigned numParallelLoops = 0;
+ auto updateNumParallelLoops = [&numParallelLoops](unsigned nPar) {
+ numParallelLoops =
+ (!numParallelLoops ? nPar : std::min(numParallelLoops, nPar));
+ };
+ for (auto parallelLoop : body.front().getOps<scf::ParallelOp>()) {
+ updateNumParallelLoops(parallelLoop.getNumLoops());
+ }
+ // If there are no parallel loops, there might be linalg ops that arent
+ // tiled. Use that to get the number of parallel loops.
+ for (auto linalgOp : body.front().getOps<linalg::LinalgOp>()) {
+ updateNumParallelLoops(getNumOuterParallelLoops(linalgOp));
+ }
+ workGroupSizeVec.resize(numParallelLoops);
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";
+ llvm::dbgs() << "# workgroup sizes at end: [";
+ interleaveComma(workGroupSizeVec, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
+ MLIRContext *context = funcOp.getContext();
+ workGroupSizeVec.resize(3, 1);
+ funcOp.setAttr(spirv::getEntryPointABIAttrName(),
+ spirv::getEntryPointABIAttr(workGroupSizeVec, context));
+ return success();
+}
+
namespace {
/// Computes tile sizes (and workgroup size) to use based on operations within
@@ -79,13 +119,7 @@
class TileSizeCalculator {
public:
TileSizeCalculator(FuncOp funcOp)
- : resourceLimits(spirv::lookupTargetEnv(funcOp).getResourceLimits()) {
- if (DenseIntElementsAttr attr = spirv::lookupLocalWorkGroupSize(funcOp)) {
- for (auto val : attr.getValues<APInt>())
- workgroupSize.push_back(val.getSExtValue());
- }
- workgroupSize.resize(3, 1);
- }
+ : resourceLimits(spirv::lookupTargetEnv(funcOp).getResourceLimits()) {}
/// Compute the tile sizes based on workgroup size specified.
LogicalResult setTileSizesBasedOnWorkgroupSize(
@@ -105,10 +139,21 @@
/// Get the current tile size computed.
ArrayRef<int64_t> getTileSizes() const { return tileSizes; }
+ /// Linalg convention is to use 0 for no tiling. If any of the tile dimensions
+ /// is set to 1 make it 0.
+ SmallVector<int64_t, 3> getTileSizesForLinalg() const {
+ return llvm::to_vector<3>(llvm::map_range(
+ tileSizes, [](int64_t v) -> int64_t { return v == 1 ? 0 : v; }));
+ }
+
/// Returns the workgroup size to use based on the tile sizes.
ArrayRef<int64_t> getWorkGroupSize() const { return workgroupSize; }
private:
+ /// Get the default tile sizes based on just number of dimensions, i.e., "x",
+ /// "y", and "z".
+ void setTileSizesBasedOnDimensions(unsigned numDims);
+
/// Current tile size configuration.
SmallVector<int64_t, 4> tileSizes;
@@ -120,72 +165,67 @@
};
} // namespace
+void TileSizeCalculator::setTileSizesBasedOnDimensions(unsigned numDims) {
+ tileSizes.clear();
+ workgroupSize.clear();
+ tileSizes.reserve(3);
+ if (numDims == 0) {
+ // Scalar case.
+ workgroupSize = {1, 1, 1};
+ return;
+ }
+ unsigned maxWorkGroupSize =
+ resourceLimits.max_compute_workgroup_invocations().getInt();
+
+ // Make the tile size 32 along the x-dimension, and then split the remaining
+ // maxWorkGroupSize threads amongst the y-dimension or z-dimension.
+ unsigned tileSizeX = llvm::PowerOf2Floor(std::min(maxWorkGroupSize, 32u));
+ maxWorkGroupSize /= tileSizeX;
+ if (numDims == 1) {
+ tileSizes = {tileSizeX};
+ workgroupSize = {tileSizeX, 1, 1};
+ return;
+ }
+ if (numDims == 2) {
+ unsigned tileSizeY = llvm::PowerOf2Floor(maxWorkGroupSize);
+ tileSizes = {tileSizeY, tileSizeX};
+ workgroupSize = {tileSizeX, tileSizeY, 1};
+ return;
+ }
+ unsigned tileSizeYZ =
+ llvm::PowerOf2Floor(static_cast<unsigned>(std::sqrt(maxWorkGroupSize)));
+ tileSizes = {tileSizeYZ, tileSizeYZ, tileSizeX};
+ workgroupSize = {tileSizeX, tileSizeYZ, tileSizeYZ};
+}
+
LogicalResult TileSizeCalculator::setTileSizesBasedOnOps(
ArrayRef<linalg::LinalgOp> linalgOps) {
tileSizes.clear();
- if (linalgOps.empty()) {
- tileSizes = {1, 1, 1};
- workgroupSize = {1, 1, 1};
- return success();
- }
// The tile size will be driven by operations like matmul, conv, etc. within
// the list. So see what operation exists in the list to decide the tile size.
// If there are two such operations in the list, return error.
- enum OpInfo : uint32_t {
- None = 0x0,
- Convolution = 0x1,
- Matmul = 0x2,
- Pooling = 0x4,
- };
- uint32_t opInfo = OpInfo::None;
- for (linalg::LinalgOp linalgOp : linalgOps) {
- Operation *op = linalgOp.getOperation();
- if (isa<linalg::ConvOp>(op)) opInfo |= OpInfo::Convolution;
- if (isa<linalg::MatmulOp>(op)) opInfo |= OpInfo::Matmul;
- if (isa<linalg::PoolingMaxOp>(op)) opInfo |= OpInfo::Pooling;
- if (isa<linalg::PoolingMinOp>(op)) opInfo |= OpInfo::Pooling;
- if (isa<linalg::PoolingSumOp>(op)) opInfo |= OpInfo::Pooling;
+ bool hasMatmul = false;
+ unsigned numParallelLoops = kMaxWorkgroupRank;
+ for (linalg::LinalgOp op : linalgOps) {
+ // If there is no marker on this op (i.e. a marker to prevent tile), add an
+ // explicit marker to indicate that the op is to be tiled. Makes subsequent
+ // lowering simpler.
+ if (isa<linalg::MatmulOp>(op.getOperation())) {
+ if (hasMatmul)
+ return op.emitError(
+ "unhandled multiple matmuls within dispatch region");
+ hasMatmul = true;
+ }
+ numParallelLoops = std::min(numParallelLoops, getNumOuterParallelLoops(op));
}
- // If there are no tilable ops, there is nothing to do here.
- if (!opInfo) return success();
-
- Operation *linalgOp = *(linalgOps.begin());
- if (llvm::countPopulation(opInfo) != 1)
- return linalgOp->getParentOfType<FuncOp>().emitError(
- "unhandled fusion of ops in dispatch function");
-
- // TODO(ravishanarm, antiagainst): Only the maximum workgroup size is used
- // here for computing tile sizes. In reality we also need the maximum
- // workgroup memory size available (per workgroup) to compute the tile sizes
- // effectively.
- unsigned maxWorkgroupSize =
- resourceLimits.max_compute_workgroup_invocations().getInt();
- if (opInfo & OpInfo::Convolution) {
- // TODO(ravishankarm): This tiling is meant to enable promotion to workgroup
- // memory, but doesnt actually get us to a state where we can do this. The
- // promotion is possible only when the subviews created are constant
- // size. For now this doesnt really matter. Revisit this later.
- int64_t tileSizeX = 32;
- int64_t tileSizeY = maxWorkgroupSize / 32;
- tileSizes = {1, tileSizeY, tileSizeX};
- workgroupSize = {tileSizeX, tileSizeY, 1};
- return success();
- }
- if (opInfo & OpInfo::Matmul) {
+ if (hasMatmul) {
// TODO: For now just hard wire this, but we can do better.
tileSizes = {8, 8, 4};
workgroupSize = {8, 8, 1};
return success();
}
- if (opInfo & OpInfo::Pooling) {
- int64_t tileSizeX = 32;
- int64_t tileSizeY = maxWorkgroupSize / 32;
- tileSizes = {tileSizeY, tileSizeX};
- workgroupSize = {tileSizeX, tileSizeY, 1};
- return success();
- }
- return linalgOp->getParentOfType<FuncOp>().emitError(
- "unable to find tile size for ops in this dispatch function");
+ setTileSizesBasedOnDimensions(numParallelLoops);
+ return success();
}
//===----------------------------------------------------------------------===//
@@ -254,41 +294,22 @@
SmallVector<int64_t, 3> workGroupSize;
};
-/// Pattern for tiling operations. Updates the workgroup size in the surrounding
-/// function operation if tiling succeeds.
-template <typename OpTy>
-struct TilingPattern : public linalg::LinalgTilingPattern<OpTy> {
- using Base = linalg::LinalgTilingPattern<OpTy>;
- TilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
- ArrayRef<int64_t> workgroupSize,
- linalg::LinalgMarker marker = linalg::LinalgMarker(),
- PatternBenefit benefit = 1)
- : Base(context, options, marker, benefit),
- workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
-
- virtual LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
- // Find the parent FuncOp before tiling. If tiling succeeds, the op will be
- // erased.
- FuncOp funcOp = op->getParentOfType<FuncOp>();
- return failure(!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
- failed(updateWorkGroupSize(funcOp, workgroupSize)));
- }
-
- SmallVector<int64_t, 3> workgroupSize;
-};
-
-/// Pattern for tiling convolution and pooling operations. Currently is just a
-/// way to not tile when the operation has padding.
-template <typename OpTy>
-struct TileConvPoolPattern : public TilingPattern<OpTy> {
- using Base = TilingPattern<OpTy>;
- using Base::TilingPattern;
+/// Pattern to tile linalg operations if they have the workgroup marker.
+template <typename LinalgOp>
+struct TileLinalgOpPattern : public linalg::LinalgTilingPattern<LinalgOp> {
+ using linalg::LinalgTilingPattern<LinalgOp>::LinalgTilingPattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (cast<OpTy>(op).padding()) return failure();
- return Base::matchAndRewrite(op, rewriter);
+ if (!hasWorkGroupMarker(op)) return failure();
+ if (succeeded(linalg::LinalgTilingPattern<LinalgOp>::matchAndRewrite(
+ op, rewriter)))
+ return success();
+ // Update the marker to map to global invocation ID.
+ rewriter.startRootUpdate(op);
+ setNoTileMarker(op);
+ rewriter.finalizeRootUpdate(op);
+ return success();
}
};
@@ -327,7 +348,14 @@
auto linalgOps = block.getOps<linalg::LinalgOp>();
if (linalgOps.empty()) return;
+ // Go through all the Linalg ops and set the marker to trigger tiling./
+ // TODO(ravishankarm): Move this to HLOToLinalgOnBuffers so that it is added
+ // on op-creation.
+ for (auto op : linalgOps)
+ if (!hasMarker(op)) setWorkGroupMarker(op);
+
TileSizeCalculator tileSizeCalculator(funcOp);
+
if (workGroupSize.empty()) {
// Get the tile sizes to use for the lowering.
SmallVector<int64_t, 3> tileSizes;
@@ -348,17 +376,20 @@
});
OwningRewritePatternList tilingPatterns;
- tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>,
- TilingPattern<linalg::MatmulOp>,
- TileConvPoolPattern<linalg::PoolingMaxOp>,
- TileConvPoolPattern<linalg::PoolingMinOp>,
- TileConvPoolPattern<linalg::PoolingSumOp>>(
+ tilingPatterns.insert<TileLinalgOpPattern<linalg::ConvOp>,
+ TileLinalgOpPattern<linalg::CopyOp>,
+ TileLinalgOpPattern<linalg::FillOp>,
+ TileLinalgOpPattern<linalg::GenericOp>,
+ TileLinalgOpPattern<linalg::IndexedGenericOp>,
+ TileLinalgOpPattern<linalg::MatmulOp>,
+ TileLinalgOpPattern<linalg::PoolingMaxOp>,
+ TileLinalgOpPattern<linalg::PoolingMinOp>,
+ TileLinalgOpPattern<linalg::PoolingSumOp>>(
context,
linalg::LinalgTilingOptions()
- .setTileSizes(tileSizeCalculator.getTileSizes())
+ .setTileSizes(tileSizeCalculator.getTileSizesForLinalg())
.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
- tileSizeCalculator.getWorkGroupSize(),
- linalg::LinalgMarker(ArrayRef<Identifier>(),
+ linalg::LinalgMarker(Identifier::get(getWorkGroupMarker(), context),
Identifier::get(getWorkItemMarker(), context)));
applyPatternsAndFoldGreedily(getOperation(), tilingPatterns);
@@ -392,6 +423,15 @@
insertBarrierAfter(builder, linalgOp.getLoc(), linalgOp);
}
});
+
+ // Update the workgroup size to be consistent with the tile sizes used. Note
+ // the tile sizes are ordered from outer most to inner most loops. The
+ // heuristic is to map the inner loops to x, the next outer (if it exists) to
+ // y, and the next outer (if it exists) to z. So tile sizes are reversed to
+ // get the workgroup size.
+ if (failed(
+ updateWorkGroupSize(funcOp, tileSizeCalculator.getWorkGroupSize())))
+ return signalPassFailure();
}
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index c996161..4ca5f2f 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -15,11 +15,11 @@
//===- SplitDispathFunctionPass.cpp ---------------------------------------===//
//
// This file implements a pass to split computation workload to multiple
-// sequential dispatch functions. This pass operates on Linalg ops and
-// scf.parallel op and prepares for lowering to GPU, where we need to tile the
-// workload to workgroups and workitems. If the workload involves computation A
-// and B, where B is dependent on A and A needs all workgroups to complete, then
-// we need to split A and B into different kernels because there is no mechanism
+// sequential dispatch functions. This pass operates on Linalg ops and prepares
+// for lowering to GPU, where we need to tile the workload to workgroups and
+// workitems. If the workload involves computation A and B, where B is
+// dependent on A and A needs all workgroups to complete, then we need
+// to split A and B into different kernels because there is no mechanism
// to perform cross-workgroup synchronization within a single kernel.
//
//===----------------------------------------------------------------------===//
@@ -35,7 +35,6 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -52,20 +51,24 @@
namespace {
+/// Returns true if the given `block` contains 0 or 1 Linalg structured ops.
+bool hasZeroOrOneLinalgOp(Block &block) {
+ auto ops = block.getOps<linalg::LinalgOp>();
+ return std::distance(ops.begin(), ops.end()) <= 1;
+}
+
/// Returns true if the Linalg ops can be separated to multiple kernels.
-bool canSeparateOps(ArrayRef<Operation *> ops) {
- if (llvm::any_of(ops, [](Operation *op) {
- if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op))
- return !linalgOp.hasBufferSemantics();
- return false;
+bool canSeparateLinalgOps(MutableArrayRef<linalg::LinalgOp> linalgOps) {
+ if (llvm::any_of(linalgOps, [](linalg::LinalgOp op) {
+ return !op.hasBufferSemantics();
}))
return false;
// Require no other ops interleave with Linalg structured ops for now. This is
// the common case and it simplifies further analysis.
- for (auto currOp = ops.begin(), nextOp = std::next(ops.begin());
- nextOp != ops.end(); ++currOp, ++nextOp) {
- if ((*currOp)->getNextNode() != *nextOp) return false;
+ for (int i = 0, e = linalgOps.size() - 1; i < e; ++i) {
+ if (linalgOps[i].getOperation()->getNextNode() != linalgOps[i + 1])
+ return false;
}
return true;
@@ -141,20 +144,15 @@
return oldFn.emitError("expected only one block");
}
- // The dispatch function should have more than one separable ops. Otherwise
- // there is nothing to do.
- Block &fnBody = oldFn.getBlocks().front();
+ // The dispatch function should have more than one Linalg structured ops.
+ // Otherwise there is nothing to do.
+ if (hasZeroOrOneLinalgOp(oldFn.getBlocks().front())) return success();
- // Collect all Linalg and scf.parallel ops for distributing.
- SmallVector<Operation *, 4> separableOps;
- for (Operation &op : fnBody)
- if (isa<linalg::LinalgOp>(op) || isa<scf::ParallelOp>(op))
- separableOps.push_back(&op);
-
- if (separableOps.size() <= 1) return success();
- if (!canSeparateOps(separableOps)) {
- return oldFn.emitError(
- "cannot separate Linalg/Parallel ops into multiple kernels");
+ // Collect all Linalg ops for distributing.
+ SmallVector<linalg::LinalgOp, 4> linalgOps =
+ llvm::to_vector<4>(oldFn.getBlocks().front().getOps<linalg::LinalgOp>());
+ if (!canSeparateLinalgOps(linalgOps)) {
+ return oldFn.emitError("cannot separate Linalg ops into multiple kernels");
}
ModuleOp moduleOp = cast<ModuleOp>(oldFn.getParentOp());
@@ -162,13 +160,13 @@
Location loc = oldFn.getLoc();
SmallVector<std::string, 4> splitKernels;
- splitKernels.reserve(separableOps.size());
+ splitKernels.reserve(linalgOps.size());
llvm::SmallPtrSet<Operation *, 16> closure;
- for (const auto &separableOp : llvm::enumerate(separableOps)) {
- // Create a new function for hosting this op.
- splitKernels.emplace_back(llvm::formatv("{0}_dispatch_{1}", oldFn.getName(),
- separableOp.index()));
+ for (const auto &linalgOp : llvm::enumerate(linalgOps)) {
+ // Create a new function for hosting this Linalg op.
+ splitKernels.emplace_back(
+ llvm::formatv("{0}_dispatch_{1}", oldFn.getName(), linalgOp.index()));
StringRef newFnName = splitKernels.back();
builder.setInsertionPointToStart(moduleOp.getBody());
auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType(),
@@ -183,7 +181,7 @@
// Collect the closure for the current Linalg op.
closure.clear();
- collectAllReferencedOps(separableOp.value(), closure);
+ collectAllReferencedOps(linalgOp.value(), closure);
// Clone all ops in the closure to the new function.
Block *newFnBlock = newFn.addEntryBlock();
@@ -192,14 +190,14 @@
for (Operation &op : oldFnBlock) {
if (closure.count(&op) == 0) continue;
builder.insert(op.clone(remapper));
- if (&op == separableOp.value()) break;
+ if (&op == linalgOp.value()) break;
}
builder.insert(oldFnBlock.getTerminator()->clone(remapper));
}
// Add the entry point schedule to the module op.
SmallVector<Attribute, 4> entryPoints;
- entryPoints.reserve(separableOps.size());
+ entryPoints.reserve(linalgOps.size());
for (const std::string &kernel : splitKernels) {
entryPoints.emplace_back(builder.getStringAttr(kernel));
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
deleted file mode 100644
index 6fed42e..0000000
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.cpp
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- Utils.cpp - Utility functions used in Linalg to SPIR-V lowering ----===//
-//
-// Implementaiton of utility functions used while lowering from Linalg to SPIRV.
-//
-//===----------------------------------------------------------------------===//
-
-#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
-
-#include "mlir/Dialect/SPIRV/TargetAndABI.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Region.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-LogicalResult updateWorkGroupSize(FuncOp funcOp,
- ArrayRef<int64_t> workGroupSize) {
- // Need to update both the surrounding FuncOp that has the spv.entry_point_abi
- // attribute, and the hal.executable.
- Region &body = funcOp.getBody();
- if (!llvm::hasSingleElement(body))
- return funcOp.emitError("unhandled dispatch function with multiple blocks");
-
- if (workGroupSize.size() != 3)
- return funcOp.emitError("expected workgroup size to have three entries");
- SmallVector<int32_t, 3> workGroupSizeVec = llvm::to_vector<3>(llvm::map_range(
- workGroupSize, [](int64_t v) { return static_cast<int32_t>(v); }));
-
- funcOp.setAttr(
- spirv::getEntryPointABIAttrName(),
- spirv::getEntryPointABIAttr(workGroupSizeVec, funcOp.getContext()));
- return success();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h b/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
deleted file mode 100644
index bdea68e..0000000
--- a/iree/compiler/Conversion/LinalgToSPIRV/Utils.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//===- Utils.h - Utility functions used in Linalg to SPIR-V lowering ------===//
-//
-// Utility functions used while lowering from Linalg to SPIRV.
-//
-//===----------------------------------------------------------------------===//
-#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
-#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
-
-#include "mlir/Support/LLVM.h"
-
-namespace mlir {
-class FuncOp;
-struct LogicalResult;
-
-namespace iree_compiler {
-
-/// Updates the workgroup size used for the dispatch region.
-LogicalResult updateWorkGroupSize(FuncOp funcOp,
- ArrayRef<int64_t> workGroupSize);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_UTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index 09e4101..9d81a75 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -1,201 +1,205 @@
-// RUN: iree-opt -iree-codegen-convert-to-gpu -canonicalize -cse -split-input-file %s | IreeFileCheck %s
+// RUN: iree-opt -iree-codegen-convert-to-gpu -canonicalize -split-input-file %s | IreeFileCheck %s
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-module attributes {
- spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Shader], [SPV_KHR_storage_buffer_storage_class]>,
- {max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @parallel_4D(%arg0: memref<?x?x?x?xf32>,
- %arg1 : memref<?x?x?x?xf32>,
- %arg2 : memref<?x?x?x?xf32>)
- attributes {iree.dispatch_fn_name = "parallel_4D"} {
- linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- %arg0, %arg1, %arg2 {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
- %0 = addf %arg3, %arg4 : f32
- linalg.yield %0 : f32
- } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
+#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+#map2 = affine_map<(d0, d1) -> (d0, d1)>
+
+module {
+ func @pw_add(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>,
+ %arg2: memref<4x8xi32>)
+ attributes {iree.dispatch_fn_name = "pw_add"} {
+ %c32 = constant 32 : index
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %c8 = constant 8 : index
+ %c1 = constant 1 : index
+ scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c4, %c8) step (%c4, %c32) {
+ %0 = affine.min #map0(%c4, %c4, %arg3)
+ %1 = affine.min #map0(%c32, %c8, %arg4)
+ %2 = subview %arg0[%arg3, %arg4] [%0, %1] [%c1, %c1]
+ : memref<4x8xi32> to memref<?x?xi32, #map1>
+ %3 = affine.min #map0(%c4, %c4, %arg3)
+ %4 = affine.min #map0(%c32, %c8, %arg4)
+ %5 = subview %arg1[%arg3, %arg4] [%3, %4] [%c1, %c1]
+ : memref<4x8xi32> to memref<?x?xi32, #map1>
+ %6 = affine.min #map0(%c4, %c4, %arg3)
+ %7 = affine.min #map0(%c32, %c8, %arg4)
+ %8 = subview %arg2[%arg3, %arg4] [%6, %7] [%c1, %c1]
+ : memref<4x8xi32> to memref<?x?xi32, #map1>
+ linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map2, #map2, #map2],
+ iterator_types = ["parallel", "parallel"]}
+ {__internal_linalg_transform__ = "workitem"} %2, %5, %8 {
+ ^bb0(%arg5: i32, %arg6: i32, %arg7: i32): // no predecessors
+ %9 = addi %arg5, %arg6 : i32
+ linalg.yield %9 : i32
+ } : memref<?x?xi32, #map1>, memref<?x?xi32, #map1>, memref<?x?xi32, #map1>
+ scf.yield
+ }
return
}
}
-// CHECK-LABEL: func @parallel_4D
-// CHECK-SAME: local_size = dense<[32, 1, 1]>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[UB0:.+]] = dim %{{.*}}, %[[C0]]
-// CHECK-DAG: %[[UB1:.+]] = dim %{{.*}}, %[[C1]]
-// CHECK-DAG: %[[UB2:.+]] = dim %{{.*}}, %[[C2]]
-// CHECK-DAG: %[[UB3:.+]] = dim %{{.*}}, %[[C3]]
-// CHECK: %[[T4:.+]] = muli %[[UB3]], %[[UB2]]
-// CHECK: %[[T5:.+]] = muli %[[T4]], %[[UB1]]
-// CHECK: %[[UB:.+]] = muli %[[T5]], %[[UB0]]
-// CHECK-DAG: %[[BID:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK-DAG: %[[BDIM:.+]] = "gpu.block_dim"() {dimension = "x"}
-// CHECK-DAG: %[[TID:.+]] = "gpu.thread_id"() {dimension = "x"}
-// CHECK: %[[BOFFSET:.+]] = muli %[[BID]], %[[BDIM]]
-// CHECK: %[[IV:.+]] = addi %[[BOFFSET]], %[[TID]]
-// CHECK: %[[COND:.+]] = cmpi "slt", %[[IV]], %[[UB]]
-// CHECK: scf.if %[[COND]]
-// CHECK: %[[IV0:.+]] = divi_signed %[[IV]], %[[T5]]
-// CHECK: %[[T14:.+]] = remi_signed %[[IV]], %[[T5]]
-// CHECK: %[[IV1:.+]] = divi_signed %[[T14]], %[[T4]]
-// CHECK: %[[T16:.+]] = remi_signed %[[T14]], %[[T4]]
-// CHECK: %[[IV2:.+]] = divi_signed %[[T16]], %[[UB3]]
-// CHECK: %[[IV3:.+]] = remi_signed %[[T16]], %[[UB3]]
-// CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
-// CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
-// CHECK: store %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
-
-
-// -----
-
-#map0 = affine_map<() -> ()>
-#accesses = [#map0, #map0, #map0]
-#trait = {
- args_in = 2 : i64,
- args_out = 1 : i64,
- indexing_maps = #accesses,
- iterator_types = []
-}
-
-module attributes {
- spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Shader], [SPV_KHR_storage_buffer_storage_class]>,
- {max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @scalar_add(%arg0 : memref<f32>, %arg1 : memref<f32>,
- %arg2 : memref<f32>)
- {
- linalg.generic #trait %arg0, %arg1, %arg2 {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
- %0 = addf %arg3, %arg4 : f32
- linalg.yield %0 : f32
- } : memref<f32>, memref<f32>, memref<f32>
- return
- }
-}
-// CHECK-LABEL: func @scalar_add
-// CHECK-SAME: local_size = dense<1> : vector<3xi32>
-// CHECK-NEXT: load
-// CHECK-NEXT: load
-// CHECK-NEXT: addf
-// CHECK-NEXT: store
-// CHECK-NEXT: return
+// CHECK-DAG: %[[STEPY:.+]] = constant 4 : index
+// CHECK-DAG: %[[STEPX:.+]] = constant 32 : index
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
+// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
+// CHECK: %[[NEWLBY:.+]] = muli %[[BIDY]], %[[STEPY]]
+// CHECK: %[[NEWSTEPY:.+]] = muli %[[NBLOCKSY]], %[[STEPY]]
+// CHECK: %[[NEWLBX:.+]] = muli %[[BIDX]], %[[STEPX]]
+// CHECK: %[[NEWSTEPX:.+]] = muli %[[NBLOCKSX]], %[[STEPX]]
+// CHECK: scf.for %{{.+}} = %[[NEWLBY]] to %{{.+}} step %[[NEWSTEPY]]
+// CHECK: scf.for %{{.+}} = %[[NEWLBX]] to %{{.+}} step %[[NEWSTEPX]]
+// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
+// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+// CHECK-DAG: %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"}
+// CHECK: scf.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
+// CHECK: scf.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]
// -----
module {
- func @reduce_sum(%arg0: memref<?x?x?xf32>, %arg1: memref<f32>, %arg2: memref<?xf32>)
+ func @reduce_sum(%arg0: memref<4xf32>, %arg1: memref<f32>, %arg2: memref<f32>)
attributes {iree.dispatch_fn_name = "reduce_sum"} {
linalg.indexed_generic
{args_in = 2 : i64, args_out = 1 : i64,
- indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>,
- affine_map<(d0, d1, d2) -> (d0)>],
- iterator_types = ["parallel", "parallel", "reduction"]} %arg0, %arg1, %arg2 {
- ^bb0(%arg3: index, %arg4: index, %arg5: index,
- %arg6: f32, %arg7: f32, %arg8: f32): // no predecessors
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"]} %arg0, %arg1, %arg2 {
+ ^bb0(%arg3: index, %arg4: f32, %arg5: f32, %arg6: f32): // no predecessors
%c0 = constant 0 : index
%cst = constant true
- %0 = cmpi "eq", %arg5, %c0 : index
+ %0 = cmpi "eq", %arg3, %c0 : index
%1 = and %cst, %0 : i1
- %2 = select %1, %arg7, %arg8 : f32
- %3 = addf %arg6, %2 : f32
+ %2 = select %1, %arg5, %arg6 : f32
+ %3 = addf %arg4, %2 : f32
linalg.yield %3 : f32
- }: memref<?x?x?xf32>, memref<f32>, memref<?xf32>
+ }: memref<4xf32>, memref<f32>, memref<f32>
return
}
}
-
-// CHECK-LABEL: func @reduce_sum
-// CHECK-SAME: local_size = dense<[32, 1, 1]> : vector<3xi32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK: %[[UB0:.+]] = dim %{{.*}}, %[[C0]]
-// CHECK: %[[UB1:.+]] = dim %{{.*}}, %[[C1]]
-// CHECK: %[[UB2:.+]] = dim %{{.*}}, %[[C2]]
-// CHECK: %[[UB:.+]] = muli %[[UB1]], %[[UB0]]
-// CHECK: %[[COND:.+]] = cmpi "slt", %{{.*}}, %[[UB]]
-// CHECK: scf.if %[[COND]]
-// CHECK: %[[IV0:.+]] = divi_signed %{{.*}}, %[[UB1]]
-// CHECK: %[[IV1:.+]] = remi_signed %{{.*}}, %[[UB1]]
-// CHECK: scf.for %[[IV:.+]] = %{{.*}} to %[[UB2]]
-// CHECK: %[[ISZERO:.+]] = cmpi "eq", %[[IV]], %[[C0]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C4]] step %[[C1]]
+// CHECK-NOT: scf
// -----
-#map0 = affine_map<(d0)[s0] -> (8, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map2 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+#map1 = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+#map2 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
+#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
+module {
+ func @parallel_4D(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) attributes {iree.dispatch_fn_name = "parallel_4D", spv.entry_point_abi = {local_size = dense<[32, 2, 2]> : vector<3xi32>}} {
+ %c2 = constant 2 : index
+ %c32 = constant 32 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
- %c4 = constant 4 : index
- %c8 = constant 8 : index
- %0 = dim %arg0, %c0 : memref<?x?xf32>
- %1 = dim %arg0, %c1 : memref<?x?xf32>
- %2 = dim %arg1, %c1 : memref<?x?xf32>
- scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %2) step (%c8, %c8) {
- scf.for %arg5 = %c0 to %1 step %c4 {
- %3 = affine.min #map0(%arg3)[%0]
- %4 = affine.min #map1(%arg5)[%1]
- %5 = subview %arg0[%arg3, %arg5] [%3, %4] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- %6 = dim %arg1, %c0 : memref<?x?xf32>
- %7 = affine.min #map1(%arg5)[%6]
- %8 = affine.min #map0(%arg4)[%2]
- %9 = subview %arg1[%arg5, %arg4] [%7, %8] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- %10 = dim %arg2, %c0 : memref<?x?xf32>
- %11 = affine.min #map0(%arg3)[%10]
- %12 = dim %arg2, %c1 : memref<?x?xf32>
- %13 = affine.min #map0(%arg4)[%12]
- %14 = subview %arg2[%arg3, %arg4] [%11, %13] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map2>
- linalg.matmul %5, %9, %14 {__internal_linalg_transform__ = "workitem"} : (memref<?x?xf32, #map2>, memref<?x?xf32, #map2>, memref<?x?xf32, #map2>)
- }
+ %c3 = constant 3 : index
+ %0 = dim %arg0, %c0 : memref<?x?x?x?xf32>
+ %1 = dim %arg0, %c1 : memref<?x?x?x?xf32>
+ %2 = dim %arg0, %c2 : memref<?x?x?x?xf32>
+ %3 = dim %arg0, %c3 : memref<?x?x?x?xf32>
+ scf.parallel (%arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0) to (%0, %1, %2, %3) step (%c2, %c2, %c2, %c32) {
+ %12 = affine.min #map0(%arg3)[%0]
+ %13 = affine.min #map0(%arg4)[%1]
+ %14 = affine.min #map0(%arg5)[%2]
+ %15 = affine.min #map1(%arg6)[%3]
+ %16 = subview %arg0[%arg3, %arg4, %arg5, %c0] [%12, %13, %14, %15] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map2>
+ %17 = subview %arg1[%arg3, %arg4, %arg5, %c0] [%12, %13, %14, %15] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map2>
+ %18 = subview %arg2[%arg3, %arg4, %arg5, %c0] [%12, %13, %14, %15] [%c1, %c1, %c1, %c1] : memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map2>
+ linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map3, #map3, #map3],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ {__internal_linalg_transform__ = "workitem"}
+ %16, %17, %18
+ {
+ ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): // no predecessors
+ %19 = addf %arg7, %arg8 : f32
+ linalg.yield %19 : f32
+ } : memref<?x?x?x?xf32, #map2>, memref<?x?x?x?xf32, #map2>, memref<?x?x?x?xf32, #map2>
scf.yield
}
return
}
}
-// CHECK-LABEL: func @matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C8:.+]] = constant 8 : index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[UB0:.+]] = dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[UB1:.+]] = dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[UB2:.+]] = dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK-DAG: %[[GDIMX:.+]] = "gpu.grid_dim"() {dimension = "x"}
-// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-// CHECK-DAG: %[[GDIMY:.+]] = "gpu.grid_dim"() {dimension = "y"}
-// CHECK: %[[BOFFSETY:.+]] = muli %[[BIDY]], %[[C8]]
-// CHECK: %[[BSTEPY:.+]] = muli %[[GDIMY]], %[[C8]]
-// CHECK: %[[BOFFSETX:.+]] = muli %[[BIDX]], %[[C8]]
-// CHECK: %[[BSTEPX:.+]] = muli %[[GDIMX]], %[[C8]]
-// CHECK: scf.for %[[BIV0:.+]] = %[[BOFFSETY]] to %[[UB0]] step %[[BSTEPY]]
-// CHECK: scf.for %[[BIV1:.+]] = %[[BOFFSETX]] to %[[UB1]] step %[[BSTEPX]]
-// CHECK: scf.for %[[BIV2:.+]] = %[[C0]] to %[[UB2]] step %[[C4]]
-// CHECK-DAG: %[[VIEWUB0:.+]] = affine.min #{{.*}}(%[[BIV0]])[%[[UB0]]]
-// CHECK-DAG: %[[VIEWUB1:.+]] = affine.min #{{.*}}(%[[BIV1]])[%[[UB1]]]
-// CHECK-DAG: %[[VIEWUB2:.+]] = affine.min #{{.*}}(%[[BIV2]])[%[[UB2]]]
-// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
-// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
-// CHECK: %[[INBOUNDY:.+]] = cmpi "slt", %[[TIDY]], %[[VIEWUB0]]
-// CHECK: %[[INBOUNDX:.+]] = cmpi "slt", %[[TIDX]], %[[VIEWUB1]]
-// CHECK: %[[COND:.+]] = and %[[INBOUNDY]], %[[INBOUNDX]]
-// CHECK: scf.if %[[COND]]
-// CHECK: scf.for %{{.*}} = %[[C0]] to %[[VIEWUB2]] step %[[C1]]
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C32:.+]] = constant 32 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[SERIALDIMOUTER:.+]] = dim %{{.+}}, %[[C3]]
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"} : () -> index
+// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"} : () -> index
+// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"} : () -> index
+// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"} : () -> index
+// CHECK-DAG: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"} : () -> index
+// CHECK-DAG: %[[NBLOCKSZ:.+]] = "gpu.grid_dim"() {dimension = "z"} : () -> index
+// CHECK-DAG: %[[LB0:.+]] = muli %[[BIDZ]], %[[C2]]
+// CHECK-DAG: %[[STEP0:.+]] = muli %[[NBLOCKSZ]], %[[C2]]
+// CHECK-DAG: %[[LB1:.+]] = muli %[[BIDY]], %[[C2]]
+// CHECK-DAG: %[[STEP1:.+]] = muli %[[NBLOCKSY]], %[[C2]]
+// CHECK-DAG: %[[LB2:.+]] = muli %[[BIDX]], %[[C2]]
+// CHECK-DAG: %[[STEP2:.+]] = muli %[[NBLOCKSX]], %[[C2]]
+// CHECK: scf.for %{{.+}} = %[[LB0]] to %{{.+}} step %[[STEP0]]
+// CHECK: scf.for %{{.+}} = %[[LB1]] to %{{.+}} step %[[STEP1]]
+// CHECK: scf.for %{{.+}} = %[[LB2]] to %{{.+}} step %[[STEP2]]
+// CHECK: scf.for %{{.+}} = %[[C0]] to %[[SERIALDIMOUTER]] step %[[C32]]
+// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"} : () -> index
+// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"} : () -> index
+// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"} : () -> index
+// CHECK-DAG: %[[NTHREADSY:.+]] = "gpu.block_dim"() {dimension = "y"} : () -> index
+// CHECK-DAG: %[[TIDZ:.+]] = "gpu.thread_id"() {dimension = "z"} : () -> index
+// CHECK-DAG: %[[NTHREADSZ:.+]] = "gpu.block_dim"() {dimension = "z"} : () -> index
+// CHECK: scf.for %{{.+}} = %[[TIDZ]] to %{{.+}} step %[[NTHREADSZ]]
+// CHECK: scf.for %{{.+}} = %[[TIDY]] to %{{.+}} step %[[NTHREADSY]]
+// CHECK: scf.for %{{.+}} = %[[TIDX]] to %{{.+}} step %[[NTHREADSX]]
+// CHECK: scf.for %{{.+}} = %[[C0]] to %{{.+}} step %[[C1]]
+
+// -----
+
+module {
+ func @no_tile(%arg0: memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+ %arg2 : memref<?x?xf32>)
+ attributes {iree.dispatch_fn_name = "reduce_sum"} {
+ linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ {__internal_linalg_tranform__ = "no-tile"} %arg0, %arg1, %arg2 {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %0 = addf %arg3, %arg4 : f32
+ linalg.yield %0 : f32
+ }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+ return
+ }
+}
+
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[UBY:.+]] = dim %{{.*}}, %[[C0]]
+// CHECK-DAG: %[[UBX:.+]] = dim %{{.*}}, %[[C1]]
+// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
+// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[BLOCKSIZEX:.+]] = "gpu.block_dim"() {dimension = "x"}
+// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK: %[[T6:.+]] = muli %[[BIDX]], %[[BLOCKSIZEX]]
+// CHECK: %[[GIDX:.+]] = addi %[[T6]], %[[TIDX]]
+// CHECK: %[[NPROCSX:.+]] = muli %[[BLOCKSIZEX]], %[[NBLOCKSX]]
+// CHECK-DAG: %[[NBLOCKSY:.+]] = "gpu.grid_dim"() {dimension = "y"}
+// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[BLOCKSIZEY:.+]] = "gpu.block_dim"() {dimension = "y"}
+// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
+// CHECK: %[[T6:.+]] = muli %[[BIDY]], %[[BLOCKSIZEY]]
+// CHECK: %[[GIDY:.+]] = addi %[[T6]], %[[TIDY]]
+// CHECK: %[[NPROCSY:.+]] = muli %[[BLOCKSIZEY]], %[[NBLOCKSY]]
+// CHECK: scf.for %{{.+}} = %[[GIDY]] to %[[UBY]] step %[[NPROCSY]]
+// CHECK: scf.for %{{.+}} = %[[GIDX]] to %[[UBX]] step %[[NPROCSX]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 1b5ddda..70f3d17 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -1,14 +1,65 @@
// RUN: iree-opt -split-input-file -iree-codegen-linalg-tile-and-fuse %s | IreeFileCheck %s
-// Test to check that convolution with padding is not tiled.
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.3,
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ // CHECK-LABEL: func @tile_only
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<4x8xi32>
+ // CHECK-SAME: local_size = dense<[32, 4, 1]>
+ // CHECK: scf.parallel
+ // CHECK: %[[VIEW0:.+]] = subview %[[ARG0]]
+ // CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
+ // CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
+ // CHECK: linalg.generic
+ // CHECK-SAME: "workitem"
+ // CHECK-SAME: %[[VIEW0]]
+ // CHECK-SAME: %[[VIEW1]]
+ // CHECK-SAME: %[[VIEW2]]
+ func @tile_only(%arg0: memref<4x8xi32>, %arg1: memref<4x8xi32>,
+ %arg2: memref<4x8xi32>) {
+ linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]} %arg0, %arg1, %arg2 {
+ ^bb0(%arg3: i32, %arg4: i32, %arg5: i32):
+ %0 = addi %arg3, %arg4 : i32
+ linalg.yield %0 : i32
+ }: memref<4x8xi32>, memref<4x8xi32>, memref<4x8xi32>
+ return
+ }
+}
+
+// -----
+
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ // CHECK-LABEL: func @conv_padding
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+ // CHECK-SAME: local_size = dense<[32, 1, 1]>
+ // CHECK: scf.parallel (%{{.+}})
+ // CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
+ // CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
+ // CHECK: linalg.conv
+ // CHECK-SAME: %[[VIEW1]]
+ // CHECK-SAME: %[[VIEW2]]
+ // CHECK-SAME: "workitem"
func @conv_padding(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
- %arg2 : memref<?x?x?x?xf32>) {
+ %arg2 : memref<?x?x?x?xf32>)
+ attributes
+ {iree.dispatch_fn_name = "conv_padding"} {
linalg.conv(%arg0, %arg1, %arg2)
{dilations = [1, 1],
padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} :
@@ -16,14 +67,6 @@
return
}
}
-// CHECK-LABEL: func @conv_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK: linalg.conv
-// CHECK-SAME: %[[ARG0]]
-// CHECK-SAME: %[[ARG1]]
-// CHECK-SAME: %[[ARG2]]
// -----
@@ -33,24 +76,55 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ // CHECK-LABEL: func @conv_no_padding
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+ // CHECK-SAME: local_size = dense<[32, 2, 2]>
+ // CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
+ // CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
+ // CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
+ // CHECK: linalg.conv
+ // CHECK-SAME: %[[VIEW1]]
+ // CHECK-SAME: %[[VIEW2]]
+ // CHECK-SAME: "workitem"
func @conv_no_padding(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
- %arg2 : memref<?x?x?x?xf32>) {
+ %arg2 : memref<?x?x?x?xf32>)
+ attributes
+ {iree.dispatch_fn_name = "conv_no_padding"} {
linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} :
memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
}
-// CHECK-LABEL: func @conv_no_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: local_size = dense<[32, 4, 1]>
-// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
-// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
-// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
-// CHECK: linalg.conv
-// CHECK-SAME: %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
-// CHECK-SAME: "workitem"
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ // CHECK-LABEL: func @parallel_4D
+ // CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
+ func @parallel_4D(%arg0: memref<?x?x?x?xf32>,
+ %arg1 : memref<?x?x?x?xf32>,
+ %arg2 : memref<?x?x?x?xf32>)
+ attributes {iree.dispatch_fn_name = "parallel_4D"} {
+ linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ %arg0, %arg1, %arg2 {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %0 = addf %arg3, %arg4 : f32
+ linalg.yield %0 : f32
+ } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+ return
+ }
+}
// -----
@@ -60,52 +134,54 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul(%arg0: memref<?x?xf32>,
+ func @no_tile(%arg0: memref<?x?xf32>,
%arg1: memref<?x?xf32>,
%ret0: memref<?x?xf32>) {
- linalg.matmul %arg0, %arg1, %ret0 :
+ linalg.matmul %arg0, %arg1, %ret0 {__internal_linalg_transform__ = "no-tile"} :
(memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
return
}
}
-
-// CHECK-LABEL: func @matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-LABEL: func @no_tile
// CHECK-SAME: local_size = dense<[8, 8, 1]>
-// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}})
-// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]]
-// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
-// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
-// CHECK: linalg.matmul
-// CHECK-SAME: "workitem"
-// CHECK-SAME: %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
+// CHECK-NOT: scf
+// CHECK: linalg.matmul
+// CHECK-NOT: scf
+// CHECK: return
// -----
+#map0 = affine_map<() -> ()>
+#accesses = [#map0, #map0]
+#trait = {
+ args_in = 2 : i64,
+ args_out = 1 : i64,
+ indexing_maps = #accesses,
+ iterator_types = []
+}
+
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.3,
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @pooling_sum_no_padding(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
- %arg2 : memref<?x?xf32>) {
- linalg.pooling_max(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} :
- memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
- return
+ func @scalar_add(%arg0 : memref<f32>, %arg1 : memref<f32>,
+ %arg2 : memref<f32>)
+ {
+ linalg.generic #trait %arg0, %arg1, %arg2 {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %0 = addf %arg3, %arg4 : f32
+ linalg.yield %0 : f32
+ } : memref<f32>, memref<f32>, memref<f32>
+ return
}
}
-
-// CHECK-LABEL: func @pooling_sum_no_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: local_size = dense<[32, 4, 1]>
-// CHECK: scf.parallel (%{{.+}}, %{{.+}})
-// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]]
-// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]]
-// CHECK: linalg.pooling_max
-// CHECK-SAME: %[[VIEW0]], %[[ARG1]], %[[VIEW2]]
-// CHECK-SAME: "workitem"
+// CHECK-LABEL: func @scalar_add
+// CHECK-NOT: scf.parallel
+// CHECK-NOT: scf.for
+// CHECK: linalg.generic
+// CHECK-SAME: "no-tile"
+// CHECK-NOT: scf.parallel
+// CHECK-NOT: scf.for
+// CHECK: return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 637fd7b..81db628 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -46,65 +46,6 @@
// -----
-// CHECK: module attributes {vkspv.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1", "kernel_dispatch_2"]}
-module {
- // CHECK: func @kernel_dispatch_2()
- // CHECK: %[[DIM:.+]] = hal.interface.load.constant
- // CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
- // CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
- // CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
- // CHECK: %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
- // CHECK: %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
- // CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
- // CHECK: %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
- // CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
- // CHECK: return
-
- // CHECK: func @kernel_dispatch_1() {
- // CHECK: %[[C0:.+]] = constant 0 : index
- // CHECK: %[[C1:.+]] = constant 1 : index
- // CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[C1]]) step (%[[C1]])
- // CHECK: scf.yield
- // CHECK: return
-
- // CHECK: func @kernel_dispatch_0()
- // CHECK: %[[ZERO:.+]] = constant
- // CHECK: %[[DIM:.+]] = hal.interface.load.constant
- // CHECK: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
- // CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
- // CHECK: %[[TS:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE]]
- // CHECK: linalg.fill(%[[TS]], %[[ZERO]])
- // CHECK: return
-
- func @kernel() {
- %cst = constant 0.000000e+00 : f32
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %dim = hal.interface.load.constant offset = 0 : index
- %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
- %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
- %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
- %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
- %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
- %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
- linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
- scf.parallel (%iv) = (%c0) to (%c1) step (%c1) {
- scf.yield
- }
- linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
- return
- }
- hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
-}
-
-
-// -----
-
// Nothing to do if there is just one Linalg op.
// CHECK-NOT: vkspv.entry_point_schedule
@@ -130,7 +71,7 @@
// Do not split when Linalg and non-Linalg ops are interleaving each other.
module {
- // expected-error @+1 {{cannot separate Linalg/Parallel ops into multiple kernels}}
+ // expected-error @+1 {{cannot separate Linalg ops into multiple kernels}}
func @kernel() {
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index 76cfcb8..060dc5a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -5,7 +5,7 @@
%arg0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<96x96xf32>
%arg1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<96x96xf32>
%arg2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<96x96xf32>
- linalg.matmul %arg0, %arg1, %arg2 :
+ linalg.matmul %arg0, %arg1, %arg2 {__internal_linalg_transform__ = "workgroup"} :
(memref<96x96xf32>, memref<96x96xf32>, memref<96x96xf32>)
return
}
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index 093d52d..b7cb8b7 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -442,6 +442,6 @@
} // namespace iree
#include "iree/hal/vmla/op_kernels_generic.h" // IWYU pragma: export
-#include "iree/hal/vmla/op_kernels_ruy.h" // IWYU pragma: export
+#include "iree/hal/vmla/op_kernels_ruy.h" // IWYU pragma: export
#endif // IREE_HAL_VMLA_OP_KERNELS_H_
diff --git a/iree/hal/vulkan/BUILD b/iree/hal/vulkan/BUILD
index 8cf2112..d4c5753 100644
--- a/iree/hal/vulkan/BUILD
+++ b/iree/hal/vulkan/BUILD
@@ -37,6 +37,8 @@
},
)
+# TODO(antiagainst): expose configuration for emulated timeline semaphore
+
cc_library(
name = "api",
srcs = ["api.cc"],
@@ -91,6 +93,7 @@
":pipeline_executable",
":status_util",
":vma_allocator",
+ "//iree/base:alignment",
"//iree/base:arena",
"//iree/base:math",
"//iree/base:status",
@@ -187,6 +190,25 @@
)
cc_library(
+ name = "emulated_timeline_semaphore",
+ srcs = ["emulated_timeline_semaphore.cc"],
+ hdrs = ["emulated_timeline_semaphore.h"],
+ deps = [
+ ":handle_util",
+ ":status_util",
+ ":timepoint_util",
+ "//iree/base:intrusive_list",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/hal:semaphore",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
name = "extensibility_util",
srcs = ["extensibility_util.cc"],
hdrs = ["extensibility_util.h"],
@@ -326,6 +348,24 @@
)
cc_library(
+ name = "serializing_command_queue",
+ srcs = ["serializing_command_queue.cc"],
+ hdrs = ["serializing_command_queue.h"],
+ deps = [
+ ":direct_command_buffer",
+ ":emulated_timeline_semaphore",
+ ":handle_util",
+ ":status_util",
+ ":timepoint_util",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/hal:command_queue",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_library(
name = "status_util",
srcs = ["status_util.cc"],
hdrs = ["status_util.h"],
@@ -336,6 +376,21 @@
)
cc_library(
+ name = "timepoint_util",
+ srcs = ["timepoint_util.cc"],
+ hdrs = ["timepoint_util.h"],
+ deps = [
+ ":handle_util",
+ "//iree/base:intrusive_list",
+ "//iree/base:ref_ptr",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "@com_google_absl//absl/synchronization",
+ "@iree_vulkan_headers//:vulkan_headers_no_prototypes",
+ ],
+)
+
+cc_library(
name = "vma_allocator",
srcs = [
"internal_vk_mem_alloc.cc",
@@ -384,6 +439,7 @@
":direct_command_buffer",
":direct_command_queue",
":dynamic_symbols",
+ ":emulated_timeline_semaphore",
":extensibility_util",
":handle_util",
":native_descriptor_set",
@@ -391,6 +447,7 @@
":native_timeline_semaphore",
":pipeline_cache",
":pipeline_executable_layout",
+ ":serializing_command_queue",
":status_util",
":vma_allocator",
"//iree/base:math",
diff --git a/iree/hal/vulkan/CMakeLists.txt b/iree/hal/vulkan/CMakeLists.txt
index 1904a44..663437d 100644
--- a/iree/hal/vulkan/CMakeLists.txt
+++ b/iree/hal/vulkan/CMakeLists.txt
@@ -12,6 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# TODO(antiagainst): We should probably always compiling the emulation in and
+# probe at runtime to enable if the device does not support native timeline
+# semaphore.
+option(IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORE
+ "Emulates timeline semaphore with binary semaphores and fences" OFF)
+
+# Unconditionally turn on emulated timleine semaphore for Android.
+if(CMAKE_CROSSCOMPILING AND "${CMAKE_SYSTEM_NAME}" MATCHES "Android")
+ set(IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORE ON CACHE BOOL "" FORCE)
+endif()
+# Unless we are not compiling Vulkan HAL backend in.
+if(NOT IREE_HAL_DRIVER_VULKAN)
+ set(IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORE OFF CACHE BOOL "" FORCE)
+endif()
+
+if(IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORE)
+ set(IREE_EMULATE_TIMELINE_SEMAPHORE 1)
+else()
+ set(IREE_EMULATE_TIMELINE_SEMAPHORE 0)
+endif()
+
set(VMA_SRC_ROOT
"${IREE_ROOT_DIR}/third_party/vulkan_memory_allocator/src/"
)
@@ -86,6 +107,7 @@
COPTS
"-DVK_NO_PROTOTYPES"
DEPS
+ iree::base::alignment
iree::base::arena
iree::base::math
iree::base::status
@@ -198,6 +220,30 @@
iree_cc_library(
NAME
+ emulated_timeline_semaphore
+ HDRS
+ "emulated_timeline_semaphore.h"
+ SRCS
+ "emulated_timeline_semaphore.cc"
+ COPTS
+ "-DVK_NO_PROTOTYPES"
+ DEPS
+ ::handle_util
+ ::status_util
+ ::timepoint_util
+ absl::inlined_vector
+ absl::synchronization
+ absl::time
+ iree::base::intrusive_list
+ iree::base::status
+ iree::base::tracing
+ iree::hal::semaphore
+ Vulkan::Headers
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
extensibility_util
HDRS
"extensibility_util.h"
@@ -224,9 +270,11 @@
COPTS
"-DVK_NO_PROTOTYPES"
DEPS
+ absl::inlined_vector
absl::synchronization
absl::utility
iree::base::ref_ptr
+ iree::hal::command_queue
iree::hal::vulkan::dynamic_symbols
iree::hal::vulkan::extensibility_util
Vulkan::Headers
@@ -375,6 +423,30 @@
iree_cc_library(
NAME
+ serializing_command_queue
+ HDRS
+ "serializing_command_queue.h"
+ SRCS
+ "serializing_command_queue.cc"
+ COPTS
+ "-DVK_NO_PROTOTYPES"
+ DEPS
+ ::direct_command_buffer
+ ::emulated_timeline_semaphore
+ ::handle_util
+ ::status_util
+ ::timepoint_util
+ absl::inlined_vector
+ absl::synchronization
+ iree::base::status
+ iree::base::tracing
+ iree::hal::command_queue
+ Vulkan::Headers
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
status_util
HDRS
"status_util.h"
@@ -390,6 +462,26 @@
iree_cc_library(
NAME
+ timepoint_util
+ HDRS
+ "timepoint_util.h"
+ SRCS
+ "timepoint_util.cc"
+ COPTS
+ "-DVK_NO_PROTOTYPES"
+ DEPS
+ ::handle_util
+ absl::synchronization
+ iree::base::intrusive_list
+ iree::base::ref_ptr
+ iree::base::status
+ iree::base::tracing
+ Vulkan::Headers
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
vma_allocator
HDRS
"vma_allocator.h"
@@ -430,18 +522,21 @@
SRCS
"vulkan_device.cc"
COPTS
+ "-DIREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES=${IREE_EMULATE_TIMELINE_SEMAPHORE}"
"-DVK_NO_PROTOTYPES"
DEPS
::descriptor_pool_cache
::direct_command_buffer
::direct_command_queue
::dynamic_symbols
+ ::emulated_timeline_semaphore
::extensibility_util
::handle_util
::native_descriptor_set
::native_timeline_semaphore
::pipeline_cache
::pipeline_executable_layout
+ ::serializing_command_queue
::status_util
::vma_allocator
absl::inlined_vector
@@ -497,6 +592,7 @@
SRCS
"vulkan_driver_module.cc"
COPTS
+ "-DIREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES=${IREE_EMULATE_TIMELINE_SEMAPHORE}"
"-DVK_NO_PROTOTYPES"
DEPS
absl::flags
diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc
index c51a410..54d8902 100644
--- a/iree/hal/vulkan/descriptor_set_arena.cc
+++ b/iree/hal/vulkan/descriptor_set_arena.cc
@@ -14,6 +14,7 @@
#include "iree/hal/vulkan/descriptor_set_arena.h"
+#include "iree/base/alignment.h"
#include "iree/base/math.h"
#include "iree/base/tracing.h"
#include "iree/hal/vulkan/status_util.h"
@@ -47,7 +48,24 @@
buffer_info.buffer = buffer->handle();
// TODO(benvanik): properly subrange (add to BufferBinding).
buffer_info.offset = binding.buffer->byte_offset();
- buffer_info.range = binding.buffer->byte_length();
+ // Round up to a multiple of 32-bit. 32-bit is the most native bitwidth on
+ // GPUs; it has the best support compared to other bitwidths. We use VMA to
+ // manage GPU memory for us and VMA should already handled proper alignment
+ // when performing allocations; here we just need to provide the proper
+ // "view" to Vulkan drivers over the allocated memory.
+ //
+ // Note this is needed because we can see unusal buffers like tensor<3xi8>.
+ // Depending on GPU capabilities, this might not always be directly
+ // supported by the hardware. Under such circumstances, we need to emulate
+ // i8 support with i32. Shader CodeGen takes care of that: the shader will
+ // read the buffer as tensor<i32> and perform bit shifts to extract each
+ // byte and conduct computations. The extra additional byte is read but
+ // not really used by the shader. Here in application we need to match the
+ // ABI and provide the buffer as 32-bit aligned, otherwise the whole read by
+ // the shader is considered as out of bounds per the Vulkan spec.
+ // See https://github.com/google/iree/issues/2022#issuecomment-640617234
+ // for more details.
+ buffer_info.range = iree_align(binding.buffer->byte_length(), 4);
auto& write_info = write_infos[i];
write_info.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
diff --git a/iree/hal/vulkan/dynamic_symbols.cc b/iree/hal/vulkan/dynamic_symbols.cc
index 2083031..c8f6d23 100644
--- a/iree/hal/vulkan/dynamic_symbols.cc
+++ b/iree/hal/vulkan/dynamic_symbols.cc
@@ -76,7 +76,9 @@
DEV_PFN_FUNCTION_PTR)};
static const char* kVulkanLoaderSearchNames[] = {
-#if defined(IREE_PLATFORM_WINDOWS)
+#if defined(IREE_PLATFORM_ANDROID)
+ "libvulkan.so",
+#elif defined(IREE_PLATFORM_WINDOWS)
"vulkan-1.dll",
#else
"libvulkan.so.1",
diff --git a/iree/hal/vulkan/dynamic_symbols.h b/iree/hal/vulkan/dynamic_symbols.h
index 4fc56e9..8983b0a 100644
--- a/iree/hal/vulkan/dynamic_symbols.h
+++ b/iree/hal/vulkan/dynamic_symbols.h
@@ -106,8 +106,8 @@
// Each required and optional function in the loader tables will expand to
// the following member, such as for example 'vkSomeFunction':
// PFN_vkSomeFunction vkSomeFunction;
-#define REQUIRED_PFN(function_name) PFN_##function_name function_name
-#define OPTIONAL_PFN(function_name) PFN_##function_name function_name
+#define REQUIRED_PFN(function_name) PFN_##function_name function_name = nullptr
+#define OPTIONAL_PFN(function_name) PFN_##function_name function_name = nullptr
#define EXCLUDED_PFN(function_name)
#define PFN_MEMBER(requirement, function_name) requirement##_PFN(function_name);
REQUIRED_PFN(vkGetInstanceProcAddr);
diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.cc b/iree/hal/vulkan/emulated_timeline_semaphore.cc
new file mode 100644
index 0000000..9656310
--- /dev/null
+++ b/iree/hal/vulkan/emulated_timeline_semaphore.cc
@@ -0,0 +1,322 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/vulkan/emulated_timeline_semaphore.h"
+
+#include "absl/container/inlined_vector.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "absl/utility/utility.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/vulkan/dynamic_symbols.h"
+#include "iree/hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// static
+StatusOr<ref_ptr<Semaphore>> EmulatedTimelineSemaphore::Create(
+ ref_ptr<VkDeviceHandle> logical_device,
+ std::function<Status(Semaphore*)> on_signal,
+ std::function<void(Semaphore*)> on_failure,
+ ref_ptr<TimePointSemaphorePool> semaphore_pool, uint64_t initial_value) {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Create");
+ return make_ref<EmulatedTimelineSemaphore>(
+ std::move(logical_device), std::move(on_signal), std::move(on_failure),
+ std::move(semaphore_pool), initial_value);
+}
+
+EmulatedTimelineSemaphore::EmulatedTimelineSemaphore(
+ ref_ptr<VkDeviceHandle> logical_device,
+ std::function<Status(Semaphore*)> on_signal,
+ std::function<void(Semaphore*)> on_failure,
+ ref_ptr<TimePointSemaphorePool> semaphore_pool, uint64_t initial_value)
+ : signaled_value_(initial_value),
+ logical_device_(std::move(logical_device)),
+ on_signal_(std::move(on_signal)),
+ on_failure_(std::move(on_failure)),
+ semaphore_pool_(std::move(semaphore_pool)) {}
+
+EmulatedTimelineSemaphore::~EmulatedTimelineSemaphore() {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::dtor");
+ CHECK_OK(TryToAdvanceTimeline(UINT64_MAX).status());
+ absl::MutexLock lock(&mutex_);
+ CHECK(outstanding_semaphores_.empty())
+ << "Destroying an emulated timeline semaphore without first waiting on "
+ "outstanding signals";
+}
+
+StatusOr<uint64_t> EmulatedTimelineSemaphore::Query() {
+ RETURN_IF_ERROR(TryToAdvanceTimeline(UINT64_MAX).status());
+ uint64_t value = signaled_value_.load();
+ if (value == UINT64_MAX) {
+ absl::MutexLock lock(&mutex_);
+ return status_;
+ }
+ return value;
+}
+
+Status EmulatedTimelineSemaphore::Signal(uint64_t value) {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Signal");
+ auto signaled_value = signaled_value_.exchange(value);
+ // Make sure the previous signaled value is smaller than the new value.
+ CHECK(signaled_value < value)
+ << "Attempting to signal a timeline value out of order; trying " << value
+ << " but " << signaled_value << " already signaled";
+
+ // Inform the device to make progress given we have a new value signaled now.
+ RETURN_IF_ERROR(on_signal_(this));
+
+ return OkStatus();
+}
+
+Status EmulatedTimelineSemaphore::Wait(uint64_t value, absl::Time deadline) {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait");
+
+ VkFence fence = VK_NULL_HANDLE;
+ do {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait#loop");
+ // First try to advance the timeline without blocking to see whether we've
+ // already reached the desired value.
+ ASSIGN_OR_RETURN(bool reached_desired_value, TryToAdvanceTimeline(value));
+ if (reached_desired_value) return OkStatus();
+
+ // We must wait now. Find the first emulated time point that has a value >=
+ // the desired value so we can wait on its associated signal fence to make
+ // sure the timeline is advanced to the desired value.
+ absl::MutexLock lock(&mutex_);
+ auto semaphore = outstanding_semaphores_.begin();
+ for (; semaphore != outstanding_semaphores_.end(); ++semaphore) {
+ if ((*semaphore)->value >= value) break;
+ }
+ if (semaphore != outstanding_semaphores_.end()) {
+ if (!(*semaphore)->signal_fence) {
+ return InternalErrorBuilder(IREE_LOC)
+ << "Timeline should have a signal fence for the first time "
+ "point beyond the signaled value";
+ }
+ fence = (*semaphore)->signal_fence->value();
+ // Found; we can break the loop and proceed to waiting now.
+ break;
+ }
+ // TODO(antiagainst): figure out a better way instead of the busy loop here.
+ } while (absl::Now() < deadline);
+
+ if (fence == VK_NULL_HANDLE) {
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline reached when waiting timeline semaphore";
+ }
+
+ uint64_t timeout_nanos;
+ if (deadline == absl::InfiniteFuture()) {
+ timeout_nanos = UINT64_MAX;
+ } else if (deadline == absl::InfinitePast()) {
+ timeout_nanos = 0;
+ } else {
+ auto relative_nanos = absl::ToInt64Nanoseconds(deadline - absl::Now());
+ timeout_nanos = relative_nanos < 0 ? 0 : relative_nanos;
+ }
+
+ VK_RETURN_IF_ERROR(logical_device_->syms()->vkWaitForFences(
+ *logical_device_, /*fenceCount=*/1, &fence, /*waitAll=*/true,
+ timeout_nanos));
+
+ RETURN_IF_ERROR(TryToAdvanceTimeline(value).status());
+ return OkStatus();
+}
+
+void EmulatedTimelineSemaphore::Fail(Status status) {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Fail");
+ absl::MutexLock lock(&mutex_);
+ status_ = std::move(status);
+ signaled_value_.store(UINT64_MAX);
+}
+
+VkSemaphore EmulatedTimelineSemaphore::GetWaitSemaphore(
+ uint64_t value, const ref_ptr<TimePointFence>& wait_fence) {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetWaitSemaphore");
+ absl::MutexLock lock(&mutex_);
+
+ VkSemaphore semaphore = VK_NULL_HANDLE;
+ for (TimePointSemaphore* point : outstanding_semaphores_) {
+ if (point->value > value && point->wait_fence) {
+ point->wait_fence = add_ref(wait_fence);
+ semaphore = point->semaphore;
+ break;
+ }
+ }
+
+ return semaphore;
+}
+
+Status EmulatedTimelineSemaphore::CancelWaitSemaphore(VkSemaphore semaphore) {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::CancelWaitSemaphore");
+ absl::MutexLock lock(&mutex_);
+ for (TimePointSemaphore* point : outstanding_semaphores_) {
+ if (point->semaphore != semaphore) continue;
+
+ if (!point->wait_fence) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Time point wasn't waited before";
+ }
+ point->wait_fence = nullptr;
+ return OkStatus();
+ }
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "No time point for the given semaphore";
+}
+
+StatusOr<VkSemaphore> EmulatedTimelineSemaphore::GetSignalSemaphore(
+ uint64_t value, const ref_ptr<TimePointFence>& signal_fence) {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetSignalSemaphore");
+
+ if (signaled_value_.load() >= value) {
+ return FailedPreconditionErrorBuilder(IREE_LOC)
+ << "Timeline semaphore already signaled past " << value;
+ }
+
+ absl::MutexLock lock(&mutex_);
+
+ auto insertion_point = outstanding_semaphores_.begin();
+ while (insertion_point != outstanding_semaphores_.end()) {
+ if ((*insertion_point)->value > value) break;
+ }
+
+ ASSIGN_OR_RETURN(TimePointSemaphore * semaphore, semaphore_pool_->Acquire());
+ semaphore->value = value;
+ semaphore->signal_fence = add_ref(signal_fence);
+ if (semaphore->wait_fence) {
+ return InternalErrorBuilder(IREE_LOC)
+ << "Newly acquired time point semaphore should not have waiters";
+ }
+ outstanding_semaphores_.insert(insertion_point, semaphore);
+
+ return semaphore->semaphore;
+}
+
+StatusOr<bool> EmulatedTimelineSemaphore::TryToAdvanceTimeline(
+ uint64_t to_upper_value) {
+ IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::TryToAdvanceTimeline");
+
+ // We hold the lock during the entire resolve process so that we can resolve
+ // to the furthest possible value.
+ absl::MutexLock lock(&mutex_);
+
+ uint64_t past_value = signaled_value_.load();
+
+ // Fast path for when already signaled past the desired value.
+ if (past_value >= to_upper_value) return true;
+
+ // The timeline has not signaled past the desired value and there is no
+ // binary semaphore pending on GPU yet: certainly the timeline cannot
+ // advance to the desired value.
+ if (outstanding_semaphores_.empty()) return false;
+
+ IntrusiveList<TimePointSemaphore> resolved_semaphores;
+
+ bool keep_resolving = true;
+ bool reached_desired_value = false;
+ while (keep_resolving && !outstanding_semaphores_.empty()) {
+ auto* semaphore = outstanding_semaphores_.front();
+
+ // If the current semaphore is for a value beyond our upper limit, then
+ // early exit so that we don't spend time dealing with signals we don't yet
+ // care about. This can prevent live lock where one thread is signaling
+ // fences as fast/faster than another thread can consume them.
+ if (semaphore->value > to_upper_value) {
+ keep_resolving = false;
+ reached_desired_value = true;
+ break;
+ }
+
+ // If the current semaphore is for a value not greater than the past
+ // signaled value, then we know it was signaled previously. But there might
+ // be a waiter on it on GPU.
+ if (semaphore->value <= past_value) {
+ if (semaphore->signal_fence) {
+ return InternalErrorBuilder(IREE_LOC)
+ << "Timeline should already signaled past this time point and "
+ "cleared the signal fence";
+ }
+
+ // If ther is no waiters, we can recycle this semaphore now. If there
+ // exists one waiter, then query its status and recycle on success. We
+ // only handle success status here. Others will be handled when the fence
+ // is checked for other semaphores' signaling status for the same queue
+ // submission.
+ if (!semaphore->wait_fence ||
+ semaphore->wait_fence->GetStatus() == VK_SUCCESS) {
+ semaphore->signal_fence = nullptr;
+ semaphore->wait_fence = nullptr;
+ outstanding_semaphores_.erase(semaphore);
+ resolved_semaphores.push_back(semaphore);
+ }
+
+ continue;
+ }
+
+ // This semaphore represents a value gerater than the known previously
+ // signaled value. We don't know its status so we need to really query now.
+
+ if (!semaphore->signal_fence) {
+ return InternalErrorBuilder(IREE_LOC)
+ << "The status of this time point in the timeline should still be "
+ "pending with a singal fence";
+ }
+ VkResult signal_status = semaphore->signal_fence->GetStatus();
+
+ switch (signal_status) {
+ case VK_SUCCESS:
+ signaled_value_.store(semaphore->value);
+ semaphore->signal_fence = nullptr;
+ // If no waiters, we can recycle this semaphore now.
+ if (!semaphore->wait_fence) {
+ semaphore->signal_fence = nullptr;
+ semaphore->wait_fence = nullptr;
+ outstanding_semaphores_.erase(semaphore);
+ resolved_semaphores.push_back(semaphore);
+ }
+ break;
+ case VK_NOT_READY:
+ // The fence has not been signaled yet so this is the furthest time
+ // point we can go in this timeline.
+ keep_resolving = false;
+ break;
+ default:
+ // Fence indicates an error (device lost, out of memory, etc).
+ // Propagate this back to our status (and thus any waiters).
+ // Since we only take the first error we find we skip all remaining
+ // fences.
+ keep_resolving = false;
+ semaphore->signal_fence = nullptr;
+ status_ = VkResultToStatus(signal_status);
+ signaled_value_.store(UINT64_MAX);
+ break;
+ }
+ }
+
+ semaphore_pool_->ReleaseResolved(&resolved_semaphores);
+ if (!status_.ok()) {
+ on_failure_(this);
+ semaphore_pool_->ReleaseUnresolved(&outstanding_semaphores_);
+ return status_;
+ }
+
+ return reached_desired_value;
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.h b/iree/hal/vulkan/emulated_timeline_semaphore.h
new file mode 100644
index 0000000..cc13a09
--- /dev/null
+++ b/iree/hal/vulkan/emulated_timeline_semaphore.h
@@ -0,0 +1,223 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_
+#define IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_
+
+#include <vulkan/vulkan.h>
+
+#include <atomic>
+#include <vector>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "iree/base/intrusive_list.h"
+#include "iree/base/ref_ptr.h"
+#include "iree/base/status.h"
+#include "iree/hal/semaphore.h"
+#include "iree/hal/vulkan/handle_util.h"
+#include "iree/hal/vulkan/timepoint_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// A timeline semaphore emulated via `VkFence`s and binary `VkSemaphore`s.
+//
+// Vulkan provides several explicit synchronization primitives: fences,
+// (binary/timeline) semaphores, events, pipeline barriers, and render passes.
+// See "6. Synchronization and Cache Control" of the Vulkan specification
+// for the details.
+//
+// Render passes are for graphics pipelines so IREE does not care about them.
+// Pipeline barriers synchronize control within a command buffer at a single
+// point. Fences, (binary/timeline) semaphores, and events are synchronization
+// primitives that have separate signal and wait operations. Events are more
+// fine-grained compared to fences and semaphores given that they can be
+// signaled or waited within a command buffer while fences and semaphores are
+// at queue submissions. Each of them have its usage requirements:
+//
+// * Fences must be signaled on GPU and waited on CPU. Fences must be reset
+// before reuse.
+// * Binary semaphores must be signaled on GPU and waited on GPU. They do not
+// support wait-before-signal submission order. More importantly, binary
+// semaphore wait also unsignals the semaphore. So binary semaphore signals
+// and waits should occur in discrete 1:1 pairs.
+// * Timeline semaphores can be signaled on CPU or GPU and waited on CPU or GPU.
+// They support wait-before-signal submission order. Timeline semaphores do
+// not need to be reset.
+//
+// It's clear that timeline semaphore is more flexible than fences and binary
+// semaphores: it unifies GPU and CPU synchronization with a single primitive.
+// But it's not always available: it requires the VK_KHR_timeline_semaphore
+// or Vulkan 1.2. When it's not available, it can be emulated via `VkFence`s
+// and binary `VkSemaphore`s. The emulation need to provide the functionality of
+// timeline semaphores and also not violate the usage requirements of `VkFence`s
+// and binary `VkSemaphore`s.
+//
+// The basic idea is to create a timeline object with time points to emulate the
+// timeline semaphore, which consists of a monotonically increasing 64-bit
+// integer value. Each time point represents a specific signaled/waited integer
+// value of the timeline semaphore; each time point can associate with binary
+// `VkSemaphore`s and/or `VkFence`s for emulating the synchronization.
+//
+// Concretely, for each of the possible signal -> wait scenarios timeline
+// semaphore supports:
+//
+// ### GPU -> GPU (via `vkQueueSubmit`)
+//
+// Each `vkQueueSubmit` can attach a `VkTimelineSemaphoreSubmitInfo` to describe
+// the timeline semaphore values signaled and waited. Each of the signaled value
+// will be a time point and emulated by a binary `VkSemaphore`. We submit the
+// binary `VkSemahpore`s to the GPU under the hood. For the waited values, the
+// situation is more complicated because of the differences between binary and
+// timeline semaphores:
+//
+// * Binary semaphore signal-wait relationship is strictly 1:1, unlike timeline
+// semaphore where we can have 1:N cases. This means for a specific binary
+// `VkSemaphore` used to emulate a signaled time point, we can have at most
+// one subsequent `vkQueueSubmit` waits on it. We need other mechanisms for
+// additional waits. A simple way is to involve the CPU and don't sumbit
+// the additional work to queue until the desired value is already signaled
+// past. This requires `VkFence`s for letting the CPU know the status of
+// GPU progress, but `VkFence` is needed anyway because of GPU -> CPU
+// synchronization.
+// * Binary semaphores does not support wait-before-signal submission order.
+// This means we need to put the submission into a self-managed queue if the
+// binary semaphores used to emulate the time points waited by the submission
+// are not submitted to GPU yet.
+//
+// ### GPU -> CPU (via `vkWaitSemaphores`)
+//
+// Without timeline semaphore, we need to use fences to let CPU wait on GPU
+// progress. So this direction can be emulated by `vkWaitFences`. It means we
+// need to associate a `VkFence` with the given waited timeline semaphores.
+// Because we don't know whether a particular `vkQueueSubmit` with timeline
+// semaphores will be later waited on by CPU beforehand, we need to bundle each
+// of them with a `VkFence` just in case they will be waited on later.
+//
+// ### CPU -> GPU (via `vkSignalSemaphore`)
+//
+// This direction can be handled by bumping the signaled timeline value and
+// scan the self-managed queue to submit more work to GPU if possible.
+//
+// ### CPU -> CPU (via `vkWaitSemaphores`)
+//
+// This is similar to CPU -> GPU direction; we just need to enable other threads
+// on CPU side and let them progress.
+//
+// The implementation is inspired by the Vulkan-ExtensionLayer project:
+// https://github.com/KhronosGroup/Vulkan-ExtensionLayer. We don't handle all
+// the aspects of the full spec though given that IREE only uses a subset of
+// synchronization primitives. So this should not be treated as a full
+// emulation of the Vulkan spec and thus does not substitute
+// Vulkan-ExtensionLayer.
+class EmulatedTimelineSemaphore final : public Semaphore {
+ public:
+ // Creates a timeline semaphore with the given |initial_value|.
+ static StatusOr<ref_ptr<Semaphore>> Create(
+ ref_ptr<VkDeviceHandle> logical_device,
+ std::function<Status(Semaphore*)> on_signal,
+ std::function<void(Semaphore*)> on_failure,
+ ref_ptr<TimePointSemaphorePool> semaphore_pool, uint64_t initial_value);
+
+ EmulatedTimelineSemaphore(ref_ptr<VkDeviceHandle> logical_device,
+ std::function<Status(Semaphore*)> on_signal,
+ std::function<void(Semaphore*)> on_failure,
+ ref_ptr<TimePointSemaphorePool> semaphore_pool,
+ uint64_t initialValue);
+
+ ~EmulatedTimelineSemaphore() override;
+
+ StatusOr<uint64_t> Query() override;
+
+ Status Signal(uint64_t value) override;
+
+ Status Wait(uint64_t value, absl::Time deadline) override;
+
+ void Fail(Status status) override;
+
+ // Gets a binary semaphore for waiting on the timeline to advance to the given
+ // |value|. The semaphore returned won't be waited by anyone else. Returns
+ // VK_NULL_HANDLE if no available semaphores for the given |value|.
+ // |wait_fence| is the fence associated with the queue submission that waiting
+ // on this semaphore.
+ VkSemaphore GetWaitSemaphore(uint64_t value,
+ const ref_ptr<TimePointFence>& wait_fence);
+
+ // Cancels the waiting attempt on the given binary |semaphore|. This allows
+ // the |semaphore| to be waited by others.
+ Status CancelWaitSemaphore(VkSemaphore semaphore);
+
+ // Gets a binary semaphore for signaling the timeline to the given |value|.
+ // |value| must be smaller than the current timeline value. |signal_fence| is
+ // the fence associated with the queue submission that signals this semaphore.
+ StatusOr<VkSemaphore> GetSignalSemaphore(
+ uint64_t value, const ref_ptr<TimePointFence>& signal_fence);
+
+ private:
+ // Tries to advance the timeline to the given |to_upper_value| without
+ // blocking and returns whether the |to_upper_value| is reached.
+ StatusOr<bool> TryToAdvanceTimeline(uint64_t to_upper_value)
+ ABSL_LOCKS_EXCLUDED(mutex_);
+
+ std::atomic<uint64_t> signaled_value_;
+
+ ref_ptr<VkDeviceHandle> logical_device_;
+
+ // Callback to inform that this timeline semaphore has signaled a new value.
+ std::function<Status(Semaphore*)> on_signal_;
+
+ // Callback to inform that this timeline semaphore has encountered a failure.
+ std::function<void(Semaphore*)> on_failure_;
+
+ ref_ptr<TimePointSemaphorePool> semaphore_pool_;
+
+ mutable absl::Mutex mutex_;
+
+ // A list of outstanding semaphores used to emulate time points.
+ //
+ // The life time of each semaphore is in one of the following state:
+ //
+ // * Unused state: value = UINT64_MAX, signal/wait fence = nullptr. This is
+ // the state of the semaphore when it's initially acquired from the pool and
+ // not put in the queue for emulating a time point yet.
+ // * Pending state: signaled value < value < UINT64_MAX, signal fence =
+ // <some-fence>, wait fence == nullptr. This is the state of the semaphore
+ // when it's put into the GPU queue for emulating a time point.
+ // * Pending and waiting state: signaled value < value < UINT64_MAX, signal
+ // fence = <some-fence>, wait fence == <some-fence>. This is the state of
+ // the semaphore when it's put into the GPU queue for emulating a time
+ // point and there is another queue submission waiting on it in GPU.
+ // * Signaled and not ever waited state: value <= signaled value, singal/wait
+ // fence = nullptr. This is the state of the semaphore when we know it's
+ // already signaled on GPU and there is no waiters for it.
+ // * Signaled and waiting state: value <= signaled value, signal fence =
+ // nullptr, wait fence = <some-fence>. This is the state of the semaphore
+ // when we know it's already signaled on GPU and there is still one queue
+ // submission on GPU is waiting for it.
+ IntrusiveList<TimePointSemaphore> outstanding_semaphores_
+ ABSL_GUARDED_BY(mutex_);
+
+ // NOTE: We only need to access this status (and thus take the lock) when we
+ // want to either signal failure or query the status in the case of the
+ // semaphore being set to UINT64_MAX.
+ Status status_ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_
diff --git a/iree/hal/vulkan/serializing_command_queue.cc b/iree/hal/vulkan/serializing_command_queue.cc
new file mode 100644
index 0000000..9d6d24c
--- /dev/null
+++ b/iree/hal/vulkan/serializing_command_queue.cc
@@ -0,0 +1,355 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/vulkan/serializing_command_queue.h"
+
+#include <memory>
+
+#include "absl/time/clock.h"
+#include "absl/types/span.h"
+#include "iree/base/memory.h"
+#include "iree/base/source_location.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/command_buffer.h"
+#include "iree/hal/command_queue.h"
+#include "iree/hal/semaphore.h"
+#include "iree/hal/vulkan/direct_command_buffer.h"
+#include "iree/hal/vulkan/emulated_timeline_semaphore.h"
+#include "iree/hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+namespace {
+
+// Tries to prepare all necessary binary `VKSemaphore`s for emulating the time
+// points as specified in the given submission |batch| and returns true if
+// possible so that the |batch| is ready to be submitted to GPU.
+// |wait_semaphores| and |signal_semaphores| will be filled with the binary
+// `VkSemaphores` on success. |fence| is the fence associated with the
+// submission |batch|.
+StatusOr<bool> TryToPrepareSemaphores(
+ const SubmissionBatch& batch, const ref_ptr<TimePointFence>& fence,
+ absl::InlinedVector<VkSemaphore, 4>* wait_semaphores,
+ absl::InlinedVector<VkSemaphore, 4>* signal_semaphores) {
+ IREE_TRACE_SCOPE0("TryToPrepareSemaphores");
+
+ wait_semaphores->clear();
+ for (const auto& timeline_semaphore : batch.wait_semaphores) {
+ // Query first to progress this timeline semaphore to the furthest.
+ ASSIGN_OR_RETURN(auto signaled_value,
+ timeline_semaphore.semaphore->Query());
+
+ // If it's already signaled to a value greater than we require here,
+ // we can just ignore this semaphore now.
+ if (signaled_value >= timeline_semaphore.value) continue;
+
+ // SerializingCommandQueue only works with EmulatedTimelineSemaphore.
+ auto* emulated_semaphore =
+ static_cast<EmulatedTimelineSemaphore*>(timeline_semaphore.semaphore);
+
+ // Otherwise try to get a binary semaphore for this time point so that
+ // we can wait on.
+ VkSemaphore binary_semaphore =
+ emulated_semaphore->GetWaitSemaphore(timeline_semaphore.value, fence);
+
+ if (binary_semaphore == VK_NULL_HANDLE) {
+ // We cannot wait on this time point yet: there are no previous semaphores
+ // submitted to the GPU that can signal a value greater than what's
+ // desired here.
+
+ // Cancel the wait so others may make progress.
+ for (VkSemaphore semaphore : *wait_semaphores) {
+ RETURN_IF_ERROR(emulated_semaphore->CancelWaitSemaphore(semaphore));
+ }
+
+ // This batch cannot be submitted to GPU yet.
+ return false;
+ }
+
+ wait_semaphores->push_back(binary_semaphore);
+ }
+
+ // We've collected all necessary binary semaphores for each timeline we need
+ // to wait on. Now prepare binary semaphores for signaling.
+ signal_semaphores->clear();
+ for (const auto& timeline_semaphore : batch.signal_semaphores) {
+ // SerializingCommandQueue only works with EmulatedTimelineSemaphore.
+ auto* emulated_semaphore =
+ static_cast<EmulatedTimelineSemaphore*>(timeline_semaphore.semaphore);
+
+ ASSIGN_OR_RETURN(auto binary_semaphore,
+ emulated_semaphore->GetSignalSemaphore(
+ timeline_semaphore.value, fence));
+ signal_semaphores->push_back(binary_semaphore);
+ }
+
+ // Good to submit!
+ return true;
+}
+
+// Prepares `VkSubmitInfo` to submit the given list of |command_buffers| that
+// waiting on |wait_semaphores| and signalling |signal_semaphores|. Necessary
+// structures are allocated from |arena| and the result `VkSubmitInfo` is
+// written to |submit_info|.
+void PrepareSubmitInfo(
+ const absl::InlinedVector<VkSemaphore, 4>& wait_semaphores,
+ absl::Span<CommandBuffer* const> command_buffers,
+ const absl::InlinedVector<VkSemaphore, 4>& signal_semaphores,
+ VkSubmitInfo* submit_info, Arena* arena) {
+ IREE_TRACE_SCOPE0("PrepareSubmitInfo");
+
+ // TODO(benvanik): see if we can go to finer-grained stages.
+ // For example, if this was just queue ownership transfers then we can use
+ // the pseudo-stage of VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT.
+ VkPipelineStageFlags dst_stage_mask =
+ VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;
+
+ auto wait_semaphore_handles =
+ arena->AllocateSpan<VkSemaphore>(wait_semaphores.size());
+ auto wait_dst_stage_masks =
+ arena->AllocateSpan<VkPipelineStageFlags>(wait_semaphores.size());
+ for (int i = 0, e = wait_semaphores.size(); i < e; ++i) {
+ wait_semaphore_handles[i] = wait_semaphores[i];
+ wait_dst_stage_masks[i] = dst_stage_mask;
+ }
+
+ auto signal_semaphore_handles =
+ arena->AllocateSpan<VkSemaphore>(signal_semaphores.size());
+ for (int i = 0, e = signal_semaphores.size(); i < e; ++i) {
+ signal_semaphore_handles[i] = signal_semaphores[i];
+ }
+
+ auto command_buffer_handles =
+ arena->AllocateSpan<VkCommandBuffer>(command_buffers.size());
+ for (int i = 0, e = command_buffers.size(); i < e; ++i) {
+ const auto& command_buffer = command_buffers[i];
+ auto* direct_command_buffer =
+ static_cast<DirectCommandBuffer*>(command_buffer->impl());
+ command_buffer_handles[i] = direct_command_buffer->handle();
+ }
+
+ submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
+ submit_info->pNext = nullptr;
+ submit_info->waitSemaphoreCount = wait_semaphore_handles.size();
+ submit_info->pWaitSemaphores = wait_semaphore_handles.data();
+ submit_info->pWaitDstStageMask = wait_dst_stage_masks.data();
+ submit_info->commandBufferCount = command_buffer_handles.size();
+ submit_info->pCommandBuffers = command_buffer_handles.data();
+ submit_info->signalSemaphoreCount = signal_semaphore_handles.size();
+ submit_info->pSignalSemaphores = signal_semaphore_handles.data();
+}
+
+} // namespace
+
+SerializingCommandQueue::SerializingCommandQueue(
+ std::string name, CommandCategoryBitfield supported_categories,
+ const ref_ptr<VkDeviceHandle>& logical_device,
+ const ref_ptr<TimePointFencePool>& fence_pool, VkQueue queue)
+ : CommandQueue(std::move(name), supported_categories),
+ logical_device_(add_ref(logical_device)),
+ fence_pool_(add_ref(fence_pool)),
+ queue_(queue) {}
+
+SerializingCommandQueue::~SerializingCommandQueue() {
+ IREE_TRACE_SCOPE0("SerializingCommandQueue::dtor");
+ absl::MutexLock lock(&mutex_);
+ syms()->vkQueueWaitIdle(queue_);
+}
+
+Status SerializingCommandQueue::Submit(
+ absl::Span<const SubmissionBatch> batches) {
+ IREE_TRACE_SCOPE0("SerializingCommandQueue::Submit");
+
+ absl::MutexLock lock(&mutex_);
+ for (const auto& batch : batches) {
+ // Grab a fence for this submission first. This will be used to check the
+ // progress of emulated timeline semaphores later.
+ ASSIGN_OR_RETURN(auto fence, fence_pool_->Acquire());
+ deferred_submissions_.push_back(
+ std::make_unique<FencedSubmission>(batch, std::move(fence)));
+ }
+
+ return ProcessDeferredSubmissions().status();
+}
+
+StatusOr<bool> SerializingCommandQueue::ProcessDeferredSubmissions() {
+ IREE_TRACE_SCOPE0("SerializingCommandQueue::ProcessDeferredSubmissions");
+
+ // Prepare `VkSubmitInfo`s for all submissions we are able to submit.
+
+ // Note that we must keep all arrays referenced alive until submission
+ // completes and since there are a bunch of them we use an arena.
+ Arena arena(4 * 1024);
+
+ absl::InlinedVector<VkSubmitInfo, 4> submit_infos;
+ absl::InlinedVector<VkFence, 4> submit_fences;
+
+ absl::InlinedVector<VkSemaphore, 4> wait_semaphores;
+ absl::InlinedVector<VkSemaphore, 4> signal_semaphores;
+
+ // A list of submissions that still needs to be deferred.
+ IntrusiveList<std::unique_ptr<FencedSubmission>> remaining_submissions;
+
+ while (!deferred_submissions_.empty()) {
+ wait_semaphores.clear();
+ signal_semaphores.clear();
+
+ auto submission = deferred_submissions_.take(deferred_submissions_.front());
+ const SubmissionBatch& batch = submission->batch;
+ ref_ptr<TimePointFence> fence(std::move(submission->fence));
+
+ ASSIGN_OR_RETURN(bool ready_to_submit,
+ TryToPrepareSemaphores(batch, fence, &wait_semaphores,
+ &signal_semaphores));
+
+ if (ready_to_submit) {
+ submit_infos.emplace_back();
+ PrepareSubmitInfo(wait_semaphores, batch.command_buffers,
+ signal_semaphores, &submit_infos.back(), &arena);
+ submit_fences.push_back(fence->value());
+ pending_fences_.emplace_back(std::move(fence));
+ } else {
+ // We need to defer the submission until later.
+ remaining_submissions.push_back(std::move(submission));
+ }
+ }
+
+ if (submit_infos.empty()) return false;
+
+ auto infos = arena.AllocateSpan<VkSubmitInfo>(submit_infos.size());
+ for (int i = 0, e = submit_infos.size(); i < e; ++i) {
+ infos[i] = submit_infos[i];
+ }
+
+ // Note: We might be able to batch the submission but it involves non-trivial
+ // fence handling. We can handle that if really needed.
+ for (int i = 0, e = submit_infos.size(); i < e; ++i) {
+ VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(
+ queue_, /*submitCount=*/1, &submit_infos[i], submit_fences[i]));
+ }
+
+ while (!remaining_submissions.empty()) {
+ deferred_submissions_.push_back(
+ remaining_submissions.take(remaining_submissions.front()));
+ }
+
+ return true;
+}
+
+Status SerializingCommandQueue::WaitIdle(absl::Time deadline) {
+ absl::MutexLock lock(&mutex_);
+
+ if (deadline == absl::InfiniteFuture()) {
+ IREE_TRACE_SCOPE0("SerializingCommandQueue::WaitIdle#vkQueueWaitIdle");
+ // Fast path for using vkQueueWaitIdle, which is usually cheaper (as it
+ // requires fewer calls into the driver).
+
+ // Complete all pending work on the queue.
+ VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_));
+ pending_fences_.clear();
+
+ // Submit and complete all deferred work.
+ while (!deferred_submissions_.empty()) {
+ ASSIGN_OR_RETURN(bool work_submitted, ProcessDeferredSubmissions());
+ if (work_submitted) {
+ VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_));
+ pending_fences_.clear();
+ }
+ }
+
+ return OkStatus();
+ }
+
+ IREE_TRACE_SCOPE0("SerializingCommandQueue::WaitIdle#Fence");
+
+ // Keep trying to submit more workload to the GPU until reaching the deadline.
+ do {
+ RETURN_IF_ERROR(ProcessDeferredSubmissions().status());
+
+ uint64_t timeout_nanos;
+ if (deadline == absl::InfinitePast()) {
+ // Do not wait.
+ timeout_nanos = 0;
+ } else {
+ // Convert to relative time in nanoseconds.
+ // The implementation may not wait with this granularity (like, by
+ // 10000x).
+ absl::Time now = absl::Now();
+ if (deadline < now) {
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline exceeded waiting for idle";
+ }
+ timeout_nanos =
+ static_cast<uint64_t>(absl::ToInt64Nanoseconds(deadline - now));
+ }
+
+ if (pending_fences_.empty()) continue;
+
+ std::vector<VkFence> fences;
+ fences.reserve(pending_fences_.size());
+ for (const auto& fence : pending_fences_) fences.push_back(fence->value());
+
+ VkResult result =
+ syms()->vkWaitForFences(*logical_device_, fences.size(), fences.data(),
+ /*waitAll=*/VK_TRUE, timeout_nanos);
+
+ switch (result) {
+ case VK_SUCCESS:
+ pending_fences_.clear();
+ break;
+ case VK_TIMEOUT:
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline exceeded waiting for idle";
+ default:
+ return VkResultToStatus(result);
+ }
+ // As long as there is submitted or deferred work still pending.
+ } while (!pending_fences_.empty() || !deferred_submissions_.empty());
+
+ return OkStatus();
+}
+
+Status SerializingCommandQueue::AdvanceQueueSubmission() {
+ absl::MutexLock lock(&mutex_);
+ // The returned value just indicates whether there were newly ready
+ // submissions gotten submitted to the GPU. Other callers might be
+ // interested in that information but for this API we just want to advance
+ // queue submisison if possible. So we ignore it here.
+ ASSIGN_OR_RETURN(std::ignore, ProcessDeferredSubmissions());
+ return OkStatus();
+}
+
+void SerializingCommandQueue::AbortQueueSubmission() {
+ absl::MutexLock lock(&mutex_);
+
+ // We have fences in deferred_submissions_ but they are not submitted to GPU
+ // yet so we don't need to reset.
+ deferred_submissions_.clear();
+
+ std::vector<VkFence> fences;
+ fences.reserve(pending_fences_.size());
+ for (const auto& fence : pending_fences_) fences.push_back(fence->value());
+
+ syms()->vkWaitForFences(*logical_device_, fences.size(), fences.data(),
+ /*waitAll=*/VK_TRUE, /*timeout=*/UINT64_MAX);
+ // Clear the list. Fences will be automatically returned back to the queue
+ // after refcount reaches 0.
+ pending_fences_.clear();
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/vulkan/serializing_command_queue.h b/iree/hal/vulkan/serializing_command_queue.h
new file mode 100644
index 0000000..e38643b
--- /dev/null
+++ b/iree/hal/vulkan/serializing_command_queue.h
@@ -0,0 +1,111 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_SERIALIZING_COMMAND_QUEUE_H_
+#define IREE_HAL_VULKAN_SERIALIZING_COMMAND_QUEUE_H_
+
+#include <vulkan/vulkan.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/container/inlined_vector.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "iree/base/intrusive_list.h"
+#include "iree/base/ref_ptr.h"
+#include "iree/base/status.h"
+#include "iree/hal/command_buffer.h"
+#include "iree/hal/command_queue.h"
+#include "iree/hal/vulkan/dynamic_symbols.h"
+#include "iree/hal/vulkan/handle_util.h"
+#include "iree/hal/vulkan/timepoint_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// A command queue that potentially defers and serializes command buffer
+// submission to the GPU.
+//
+// This command queue is designed to be used together with emulated timeline
+// semaphores. Timeline semaphores can follow wait-before-signal submission
+// order but binary `VkSemaphore` cannot. So when emulating timeline semaphores
+// with binary `VkSemaphore`s and `VkFence`s, we need to make sure no
+// wait-before-signal submission order occur for binary `VkSemaphore`s. The way
+// to enforce that is to defer the submission until we can be certain that the
+// `VkSemaphore`s emulating time points in the timeline are all *submitted* to
+// the GPU.
+class SerializingCommandQueue final : public CommandQueue {
+ public:
+ SerializingCommandQueue(std::string name,
+ CommandCategoryBitfield supported_categories,
+ const ref_ptr<VkDeviceHandle>& logical_device,
+ const ref_ptr<TimePointFencePool>& fence_pool,
+ VkQueue queue);
+ ~SerializingCommandQueue() override;
+
+ const ref_ptr<DynamicSymbols>& syms() const {
+ return logical_device_->syms();
+ }
+
+ Status Submit(absl::Span<const SubmissionBatch> batches) override;
+
+ Status WaitIdle(absl::Time deadline) override;
+
+ // Releases all deferred submissions ready to submit to the GPU.
+ Status AdvanceQueueSubmission();
+
+ // Aborts all deferred submissions and waits for submitted work to complete.
+ void AbortQueueSubmission();
+
+ private:
+ // A submission batch together with the fence to singal its status.
+ struct FencedSubmission : IntrusiveLinkBase<void> {
+ SubmissionBatch batch;
+ ref_ptr<TimePointFence> fence;
+
+ FencedSubmission(const SubmissionBatch& batch,
+ ref_ptr<TimePointFence> fence)
+ : batch(batch), fence(std::move(fence)) {}
+ };
+
+ // Processes deferred submissions in this queue and returns whether there are
+ // new workload submitted to the GPU if no errors happen.
+ StatusOr<bool> ProcessDeferredSubmissions()
+ ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ ref_ptr<VkDeviceHandle> logical_device_;
+
+ ref_ptr<TimePointFencePool> fence_pool_;
+
+ mutable absl::Mutex mutex_;
+
+ // A list of fences that are submitted to GPU.
+ absl::InlinedVector<ref_ptr<TimePointFence>, 4> pending_fences_
+ ABSL_GUARDED_BY(mutex_);
+ // A list of deferred submissions that haven't been submitted to GPU.
+ IntrusiveList<std::unique_ptr<FencedSubmission>> deferred_submissions_
+ ABSL_GUARDED_BY(mutex_);
+
+ // VkQueue needs to be externally synchronized.
+ VkQueue queue_ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_SERIALIZING_COMMAND_QUEUE_H_
diff --git a/iree/hal/vulkan/timepoint_util.cc b/iree/hal/vulkan/timepoint_util.cc
new file mode 100644
index 0000000..c212856
--- /dev/null
+++ b/iree/hal/vulkan/timepoint_util.cc
@@ -0,0 +1,226 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/vulkan/timepoint_util.h"
+
+#include <memory>
+
+#include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
+#include "absl/utility/utility.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/vulkan/dynamic_symbols.h"
+#include "iree/hal/vulkan/status_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+// static
+void TimePointFence::Delete(TimePointFence* ptr) {
+ ptr->pool()->ReleaseResolved(ptr);
+}
+
+VkResult TimePointFence::GetStatus() {
+ absl::MutexLock lock(&status_mutex_);
+ if (status_ == VK_NOT_READY) {
+ const auto& device = pool()->logical_device();
+ status_ = device->syms()->vkGetFenceStatus(*device, fence_);
+ }
+ return status_;
+}
+
+// static
+StatusOr<ref_ptr<TimePointFencePool>> TimePointFencePool::Create(
+ ref_ptr<VkDeviceHandle> logical_device) {
+ IREE_TRACE_SCOPE0("TimePointFencePool::Create");
+ ref_ptr<TimePointFencePool> pool(
+ new TimePointFencePool(std::move(logical_device)));
+ RETURN_IF_ERROR(pool->PreallocateFences());
+ return pool;
+}
+
+TimePointFencePool::~TimePointFencePool() {
+ IREE_TRACE_SCOPE0("TimePointFencePool::dtor");
+
+ absl::MutexLock lock(&mutex_);
+ int free_count = 0;
+ for (auto* fence : free_fences_) {
+ syms()->vkDestroyFence(*logical_device_, fence->value(),
+ logical_device_->allocator());
+ ++free_count;
+ }
+ DCHECK_EQ(free_count, kMaxInFlightFenceCount);
+ free_fences_.clear();
+}
+
+StatusOr<ref_ptr<TimePointFence>> TimePointFencePool::Acquire() {
+ IREE_TRACE_SCOPE0("TimePointFencePool::Acquire");
+
+ absl::MutexLock lock(&mutex_);
+ if (free_fences_.empty()) {
+ return ResourceExhaustedErrorBuilder(IREE_LOC)
+ << "Fence pool out of free fences";
+ }
+
+ auto* fence = free_fences_.front();
+ free_fences_.pop_front();
+ return add_ref(fence);
+}
+
+void TimePointFencePool::ReleaseResolved(TimePointFence* fence) {
+ IREE_TRACE_SCOPE0("TimePointFencePool::ReleaseResolved");
+ VkFence f = fence->value();
+ syms()->vkResetFences(*logical_device_, 1, &f);
+ absl::MutexLock lock(&mutex_);
+ free_fences_.push_back(fence);
+}
+
+TimePointFencePool::TimePointFencePool(ref_ptr<VkDeviceHandle> logical_device)
+ : logical_device_(std::move(logical_device)) {}
+
+const ref_ptr<DynamicSymbols>& TimePointFencePool::syms() const {
+ return logical_device_->syms();
+}
+
+Status TimePointFencePool::PreallocateFences() {
+ IREE_TRACE_SCOPE0("TimePointFencePool::PreallocateFences");
+
+ VkFenceCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+
+ std::array<std::unique_ptr<TimePointFence>, kMaxInFlightFenceCount> fences;
+ {
+ absl::MutexLock lock(&mutex_);
+ for (int i = 0; i < fences.size(); ++i) {
+ VkFence fence = VK_NULL_HANDLE;
+ VK_RETURN_IF_ERROR(syms()->vkCreateFence(*logical_device_, &create_info,
+ logical_device_->allocator(),
+ &fence));
+ fences[i].reset(new TimePointFence(this, fence));
+ }
+ }
+
+ for (int i = 0; i < fences.size(); ++i) {
+ // The `TimePointFence`s was created with an initial ref-count of one.
+ // Decrease explicitly to zero so that later we can rely on the ref-count
+ // reaching zero to auto-release the `TimePointFence` back to the free
+ // list. As a nice side effect, this will also initialize the free list
+ // with all newly created fences.
+ // TODO: Might want to avoid acquiring and releasing the mutex for each
+ // fence.
+ fences[i].release()->ReleaseReference();
+ }
+
+ return OkStatus();
+}
+
+// static
+StatusOr<ref_ptr<TimePointSemaphorePool>> TimePointSemaphorePool::Create(
+ ref_ptr<VkDeviceHandle> logical_device) {
+ IREE_TRACE_SCOPE0("TimePointSemaphorePool::Create");
+ ref_ptr<TimePointSemaphorePool> pool(
+ new TimePointSemaphorePool(std::move(logical_device)));
+ RETURN_IF_ERROR(pool->PreallocateSemaphores());
+ return pool;
+}
+
+TimePointSemaphorePool::~TimePointSemaphorePool() {
+ IREE_TRACE_SCOPE0("TimePointSemaphorePool::dtor");
+
+ absl::MutexLock lock(&mutex_);
+
+ DCHECK_EQ(free_semaphores_.size(), kMaxInFlightSemaphoreCount);
+ free_semaphores_.clear();
+
+ for (auto& semaphore : storage_) {
+ syms()->vkDestroySemaphore(*logical_device_, semaphore.semaphore,
+ logical_device_->allocator());
+ }
+}
+
+StatusOr<TimePointSemaphore*> TimePointSemaphorePool::Acquire() {
+ IREE_TRACE_SCOPE0("TimePointSemaphorePool::Acquire");
+
+ absl::MutexLock lock(&mutex_);
+ if (free_semaphores_.empty()) {
+ return ResourceExhaustedErrorBuilder(IREE_LOC)
+ << "Semaphore pool out of free semaphores";
+ }
+
+ auto* semaphore = free_semaphores_.front();
+ free_semaphores_.pop_front();
+ return semaphore;
+}
+
+void TimePointSemaphorePool::ReleaseResolved(
+ IntrusiveList<TimePointSemaphore>* semaphores) {
+ IREE_TRACE_SCOPE0("TimePointSemaphorePool::ReleaseResolved");
+
+ for (auto* semaphore : *semaphores) {
+ DCHECK(!semaphore->signal_fence && !semaphore->wait_fence);
+ semaphore->value = UINT64_MAX;
+ }
+
+ absl::MutexLock lock(&mutex_);
+ free_semaphores_.merge_from(semaphores);
+}
+
+void TimePointSemaphorePool::ReleaseUnresolved(
+ IntrusiveList<TimePointSemaphore>* semaphores) {
+ IREE_TRACE_SCOPE0("TimePointSemaphorePool::ReleaseUnresolved");
+
+ for (auto* semaphore : *semaphores) {
+ semaphore->signal_fence = nullptr;
+ semaphore->wait_fence = nullptr;
+ semaphore->value = UINT64_MAX;
+ }
+
+ absl::MutexLock lock(&mutex_);
+ free_semaphores_.merge_from(semaphores);
+}
+
+TimePointSemaphorePool::TimePointSemaphorePool(
+ ref_ptr<VkDeviceHandle> logical_device)
+ : logical_device_(std::move(logical_device)) {}
+
+const ref_ptr<DynamicSymbols>& TimePointSemaphorePool::syms() const {
+ return logical_device_->syms();
+}
+
+Status TimePointSemaphorePool::PreallocateSemaphores() {
+ IREE_TRACE_SCOPE0("TimePointSemaphorePool::PreallocateSemaphores");
+
+ VkSemaphoreCreateInfo create_info;
+ create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO;
+ create_info.pNext = nullptr;
+ create_info.flags = 0;
+
+ absl::MutexLock lock(&mutex_);
+ for (int i = 0; i < kMaxInFlightSemaphoreCount; ++i) {
+ auto* semaphore = &storage_[i];
+ VK_RETURN_IF_ERROR(syms()->vkCreateSemaphore(*logical_device_, &create_info,
+ logical_device_->allocator(),
+ &semaphore->semaphore));
+ free_semaphores_.push_back(semaphore);
+ }
+
+ return OkStatus();
+}
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/vulkan/timepoint_util.h b/iree/hal/vulkan/timepoint_util.h
new file mode 100644
index 0000000..e2cb7df
--- /dev/null
+++ b/iree/hal/vulkan/timepoint_util.h
@@ -0,0 +1,210 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_VULKAN_TIMEPOINT_UTIL_H_
+#define IREE_HAL_VULKAN_TIMEPOINT_UTIL_H_
+
+#include <vulkan/vulkan.h>
+
+#include <atomic>
+#include <vector>
+
+#include "absl/base/thread_annotations.h"
+#include "absl/synchronization/mutex.h"
+#include "iree/base/intrusive_list.h"
+#include "iree/base/ref_ptr.h"
+#include "iree/base/status.h"
+#include "iree/hal/vulkan/handle_util.h"
+
+namespace iree {
+namespace hal {
+namespace vulkan {
+
+class TimePointFencePool;
+class TimePointSemaphorePool;
+
+// A fence used for tracking progress of timeline semaphores.
+//
+// Each queue submission gets a new `VkFence` associated with it so that we can
+// later query the `VkFence` on CPU to know what time points were signaled for
+// timeline semaphores.
+//
+// Ref-counting allows the fence to be associated with multiple time points from
+// different timelines without worrying about ownership complexity.
+//
+// This is expected to used together with `TimePointFencePool` and must be
+// externally synchronized via `TimePointFencePool`'s mutex.
+class TimePointFence final : public RefObject<TimePointFence>,
+ public IntrusiveLinkBase<void> {
+ public:
+ TimePointFence(TimePointFencePool* pool, VkFence fence)
+ : pool_(pool), fence_(fence) {}
+
+ TimePointFence(TimePointFence&& that) = delete;
+ TimePointFence& operator=(TimePointFence&&) = delete;
+
+ TimePointFence(const TimePointFence&) = delete;
+ TimePointFence& operator=(const TimePointFence&) = delete;
+
+ // Returns this fence to the pool on destruction.
+ static void Delete(TimePointFence* ptr);
+
+ VkFence value() const noexcept { return fence_; }
+ operator VkFence() const noexcept { return fence_; }
+
+ // Gets the status of this fence object. This might issue an Vulkan API call
+ // under the hood.
+ VkResult GetStatus();
+
+ // Returns the pool from which this fence comes.
+ TimePointFencePool* pool() const { return pool_; }
+
+ private:
+ // The pool from which this fence comes.
+ TimePointFencePool* pool_;
+
+ // Allocated fence that associated with a bunch of time point(s) of
+ // timeline(s). This is passed to queue submission so that we can track the
+ // timeline(s) progress on CPU and schedule work.
+ VkFence fence_;
+
+ // The fence's status.
+ absl::Mutex status_mutex_;
+ VkResult status_ ABSL_GUARDED_BY(status_mutex_) = VK_NOT_READY;
+};
+
+// A semaphore used for emulating a specific time point of timeline semaphores.
+//
+// Each signaled time point in a timeline semaphore is emulated with a new
+// binary `VkSemaphore` associated with queue submission. These time point
+// semaphores are stored in `EmulatedTimelineSemaphore` to quickly scan and
+// process signaled values.
+//
+// This is expected to used together with `TimePointSemaphorePool` and
+// `EmulatedTimelineSemaphore` and must be externally synchronized via their
+// mutexes.
+struct TimePointSemaphore final : public IntrusiveLinkBase<void> {
+ // Allocated binary semaphore that represents a time point in the timeline.
+ // This is passed to queue submission.
+ VkSemaphore semaphore = VK_NULL_HANDLE;
+
+ // Value of the timeline should be at when the binary semaphore is signaled.
+ uint64_t value = UINT64_MAX;
+
+ // The fence associated with the queue submission signaling this semaphore.
+ // nullptr means this binary semaphore has not been submitted to GPU.
+ ref_ptr<TimePointFence> signal_fence = nullptr;
+
+ // The fence associated with the queue submission waiting this semaphore.
+ // nullptr means this binary semaphore has not been waited by any queue
+ // submission.
+ ref_ptr<TimePointFence> wait_fence = nullptr;
+};
+
+// A pool of `VkFence`s that can be used by `EmulatedTimelineSemaphore` to track
+// timeline progress on CPU. Each `VkFence` can be used to query the status of
+// all the semaphores in the same submission to a `VkQueue`.
+class TimePointFencePool final : public RefObject<TimePointFencePool> {
+ public:
+ static constexpr int kMaxInFlightFenceCount = 32;
+
+ // Creates a new pool and pre-allocates `kMaxInFlightFenceCount` fences.
+ static StatusOr<ref_ptr<TimePointFencePool>> Create(
+ ref_ptr<VkDeviceHandle> logical_device);
+
+ ~TimePointFencePool();
+
+ // Acquires a fence from the pool for use by the caller. The fence is
+ // guaranteed to be in unsignaled state and not in-flight on GPU.
+ //
+ // Returns RESOURCE_EXHAUSTED if the pool has no more available fences.
+ // Callers are expected to handle this by waiting on previous fences or for
+ // complete device idle. Yes, that's as bad as it sounds, and if we start
+ // seeing that we should bump up the max count.
+ StatusOr<ref_ptr<TimePointFence>> Acquire();
+
+ // Releases one fence back to the pool. The fence must either be signaled or
+ // not be in flight on GPU.
+ void ReleaseResolved(TimePointFence* fence);
+
+ const ref_ptr<VkDeviceHandle>& logical_device() const {
+ return logical_device_;
+ }
+
+ private:
+ explicit TimePointFencePool(ref_ptr<VkDeviceHandle> logical_device);
+
+ const ref_ptr<DynamicSymbols>& syms() const;
+
+ Status PreallocateFences() ABSL_LOCKS_EXCLUDED(mutex_);
+
+ ref_ptr<VkDeviceHandle> logical_device_;
+
+ absl::Mutex mutex_;
+
+ IntrusiveList<TimePointFence> free_fences_ ABSL_GUARDED_BY(mutex_);
+};
+
+// A pool of `VkSemaphore`s that can be used by `EmulatedTimelineSemaphore` to
+// simulate individual timeline value signaling.
+class TimePointSemaphorePool final : public RefObject<TimePointSemaphorePool> {
+ public:
+ static constexpr int kMaxInFlightSemaphoreCount = 64;
+
+ // Creates a new pool and pre-allocates `kMaxInFlightSemaphoreCount` binary
+ // semaphores.
+ static StatusOr<ref_ptr<TimePointSemaphorePool>> Create(
+ ref_ptr<VkDeviceHandle> logical_device);
+
+ ~TimePointSemaphorePool();
+
+ // Acquires a binary semaphore from the pool for use by the caller. The
+ // semaphore is guaranteed to be in unsignaled state and not in-flight on GPU.
+ //
+ // Returns RESOURCE_EXHAUSTED if the pool has no more available semaphores.
+ // Callers are expected to handle this by waiting on previous fences or for
+ // complete device idle. Yes, that's as bad as it sounds, and if we start
+ // seeing that we should bump up the max count.
+ StatusOr<TimePointSemaphore*> Acquire();
+
+ // Releases one or more semaphores back to the pool. The binary semaphore must
+ // be unsignaled and not in flight on GPU.
+ void ReleaseResolved(IntrusiveList<TimePointSemaphore>* semaphores);
+
+ // Releases one or more semaphores back to the pool. These may be in any state
+ // and will be assumed as untouchable; the pool will unconditionally recycle
+ // them.
+ void ReleaseUnresolved(IntrusiveList<TimePointSemaphore>* semaphores);
+
+ private:
+ explicit TimePointSemaphorePool(ref_ptr<VkDeviceHandle> logical_device);
+
+ const ref_ptr<DynamicSymbols>& syms() const;
+
+ Status PreallocateSemaphores() ABSL_LOCKS_EXCLUDED(mutex_);
+
+ ref_ptr<VkDeviceHandle> logical_device_;
+
+ absl::Mutex mutex_;
+
+ std::array<TimePointSemaphore, kMaxInFlightSemaphoreCount> storage_
+ ABSL_GUARDED_BY(mutex_);
+ IntrusiveList<TimePointSemaphore> free_semaphores_ ABSL_GUARDED_BY(mutex_);
+};
+
+} // namespace vulkan
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_VULKAN_TIMEPOINT_UTIL_H_
diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc
index 3c7e37b..2f2d16e 100644
--- a/iree/hal/vulkan/vulkan_device.cc
+++ b/iree/hal/vulkan/vulkan_device.cc
@@ -30,12 +30,14 @@
#include "iree/hal/vulkan/direct_command_buffer.h"
#include "iree/hal/vulkan/direct_command_queue.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
+#include "iree/hal/vulkan/emulated_timeline_semaphore.h"
#include "iree/hal/vulkan/extensibility_util.h"
#include "iree/hal/vulkan/native_descriptor_set.h"
#include "iree/hal/vulkan/native_event.h"
#include "iree/hal/vulkan/native_timeline_semaphore.h"
#include "iree/hal/vulkan/pipeline_cache.h"
#include "iree/hal/vulkan/pipeline_executable_layout.h"
+#include "iree/hal/vulkan/serializing_command_queue.h"
#include "iree/hal/vulkan/status_util.h"
#include "iree/hal/vulkan/vma_allocator.h"
@@ -164,6 +166,7 @@
const DeviceInfo& device_info,
const ref_ptr<VkDeviceHandle>& logical_device,
const QueueSet& compute_queue_set, const QueueSet& transfer_queue_set,
+ const ref_ptr<TimePointFencePool>& fence_pool,
const ref_ptr<DynamicSymbols>& syms) {
absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues;
@@ -175,10 +178,17 @@
syms->vkGetDeviceQueue(*logical_device,
compute_queue_set.queue_family_index, i, &queue);
std::string queue_name = absl::StrCat(device_info.name(), ":d", i);
+#if IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
+ command_queues.push_back(absl::make_unique<SerializingCommandQueue>(
+ std::move(queue_name),
+ CommandCategory::kDispatch | CommandCategory::kTransfer, logical_device,
+ fence_pool, queue));
+#else
command_queues.push_back(absl::make_unique<DirectCommandQueue>(
std::move(queue_name),
CommandCategory::kDispatch | CommandCategory::kTransfer, logical_device,
queue));
+#endif // IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
}
uint64_t transfer_queue_count = CountOnes64(transfer_queue_set.queue_indices);
@@ -189,9 +199,15 @@
syms->vkGetDeviceQueue(*logical_device,
transfer_queue_set.queue_family_index, i, &queue);
std::string queue_name = absl::StrCat(device_info.name(), ":t", i);
+#if IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
+ command_queues.push_back(absl::make_unique<SerializingCommandQueue>(
+ std::move(queue_name), CommandCategory::kTransfer, logical_device,
+ fence_pool, queue));
+#else
command_queues.push_back(absl::make_unique<DirectCommandQueue>(
std::move(queue_name), CommandCategory::kTransfer, logical_device,
queue));
+#endif // IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
}
return command_queues;
@@ -354,14 +370,27 @@
for (uint32_t i = 0; i < queue_family_info.transfer_queue_count; ++i) {
transfer_queue_set.queue_indices |= 1 << (i + base_queue_index);
}
- auto command_queues = CreateCommandQueues(
- device_info, logical_device, compute_queue_set, transfer_queue_set, syms);
+
+#if IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
+ ASSIGN_OR_RETURN(auto semaphore_pool,
+ TimePointSemaphorePool::Create(add_ref(logical_device)));
+ ASSIGN_OR_RETURN(auto fence_pool,
+ TimePointFencePool::Create(add_ref(logical_device)));
+#else
+ ref_ptr<TimePointSemaphorePool> semaphore_pool = nullptr;
+ ref_ptr<TimePointFencePool> fence_pool = nullptr;
+#endif // IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
+
+ auto command_queues =
+ CreateCommandQueues(device_info, logical_device, compute_queue_set,
+ transfer_queue_set, fence_pool, syms);
return assign_ref(new VulkanDevice(
std::move(driver), device_info, physical_device,
std::move(logical_device), std::move(allocator),
std::move(command_queues), std::move(dispatch_command_pool),
- std::move(transfer_command_pool), debug_capture_manager));
+ std::move(transfer_command_pool), std::move(semaphore_pool),
+ std::move(fence_pool), debug_capture_manager));
}
// static
@@ -421,13 +450,25 @@
device_handle, transfer_queue_set.queue_family_index));
}
- auto command_queues = CreateCommandQueues(
- device_info, device_handle, compute_queue_set, transfer_queue_set, syms);
+#if IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
+ ASSIGN_OR_RETURN(auto semaphore_pool,
+ TimePointSemaphorePool::Create(add_ref(device_handle)));
+ ASSIGN_OR_RETURN(auto fence_pool,
+ TimePointFencePool::Create(add_ref(device_handle)));
+#else
+ ref_ptr<TimePointSemaphorePool> semaphore_pool = nullptr;
+ ref_ptr<TimePointFencePool> fence_pool = nullptr;
+#endif // IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
+
+ auto command_queues =
+ CreateCommandQueues(device_info, device_handle, compute_queue_set,
+ transfer_queue_set, fence_pool, syms);
return assign_ref(new VulkanDevice(
std::move(driver), device_info, physical_device, std::move(device_handle),
std::move(allocator), std::move(command_queues),
std::move(dispatch_command_pool), std::move(transfer_command_pool),
+ std::move(semaphore_pool), std::move(fence_pool),
/*debug_capture_manager=*/nullptr));
}
@@ -438,6 +479,8 @@
absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues,
ref_ptr<VkCommandPoolHandle> dispatch_command_pool,
ref_ptr<VkCommandPoolHandle> transfer_command_pool,
+ ref_ptr<TimePointSemaphorePool> semaphore_pool,
+ ref_ptr<TimePointFencePool> fence_pool,
DebugCaptureManager* debug_capture_manager)
: Device(device_info),
driver_(std::move(driver)),
@@ -449,6 +492,8 @@
make_ref<DescriptorPoolCache>(add_ref(logical_device_))),
dispatch_command_pool_(std::move(dispatch_command_pool)),
transfer_command_pool_(std::move(transfer_command_pool)),
+ semaphore_pool_(std::move(semaphore_pool)),
+ fence_pool_(std::move(fence_pool)),
debug_capture_manager_(debug_capture_manager) {
// Populate the queue lists based on queue capabilities.
for (auto& command_queue : command_queues_) {
@@ -650,8 +695,36 @@
StatusOr<ref_ptr<Semaphore>> VulkanDevice::CreateSemaphore(
uint64_t initial_value) {
IREE_TRACE_SCOPE0("VulkanDevice::CreateSemaphore");
+#if IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
+ return EmulatedTimelineSemaphore::Create(
+ add_ref(logical_device_),
+ // Triggers necessary processing on all queues due to new values gotten
+ // signaled for the given timeline |semaphore|.
+ // Different clang-format versions disagree about the asterisk placement.
+ // clang-format off
+ [this](Semaphore* /*semaphore*/) -> Status {
+ // clang-format on
+ IREE_TRACE_SCOPE0("<lambda>::OnSemaphoreSignal");
+ for (const auto& queue : command_queues_) {
+ RETURN_IF_ERROR(static_cast<SerializingCommandQueue*>(queue.get())
+ ->AdvanceQueueSubmission());
+ }
+ return OkStatus();
+ },
+ // Triggers necessary processing on all queues due to failures for the
+ // given timeline |semaphore|.
+ [this](Semaphore* /*semaphore*/) {
+ IREE_TRACE_SCOPE0("<lambda>::OnSemaphoreFailure");
+ for (const auto& queue : command_queues_) {
+ static_cast<SerializingCommandQueue*>(queue.get())
+ ->AbortQueueSubmission();
+ }
+ },
+ add_ref(semaphore_pool_), initial_value);
+#else
return NativeTimelineSemaphore::Create(add_ref(logical_device_),
initial_value);
+#endif // IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
}
Status VulkanDevice::WaitAllSemaphores(
@@ -672,6 +745,23 @@
VkSemaphoreWaitFlags wait_flags) {
IREE_TRACE_SCOPE0("VulkanDevice::WaitSemaphores");
+#if IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
+
+ // TODO(antiagainst): We actually should get the fences associated with the
+ // emulated timeline semaphores so that we can wait them in a bunch. This
+ // implementation is problematic if we wait to wait any and we have the first
+ // semaphore taking extra long time but the following ones signal quickly.
+ for (int i = 0; i < semaphores.size(); ++i) {
+ auto* semaphore =
+ static_cast<EmulatedTimelineSemaphore*>(semaphores[i].semaphore);
+ RETURN_IF_ERROR(semaphore->Wait(semaphores[i].value, deadline));
+ if (wait_flags & VK_SEMAPHORE_WAIT_ANY_BIT) return OkStatus();
+ }
+
+ return OkStatus();
+
+#else
+
absl::InlinedVector<VkSemaphore, 4> semaphore_handles(semaphores.size());
absl::InlinedVector<uint64_t, 4> semaphore_values(semaphores.size());
for (int i = 0; i < semaphores.size(); ++i) {
@@ -714,6 +804,8 @@
// semaphores we waited on (including those already expired above).
return OkStatus();
+
+#endif // IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES
}
Status VulkanDevice::WaitIdle(absl::Time deadline) {
diff --git a/iree/hal/vulkan/vulkan_device.h b/iree/hal/vulkan/vulkan_device.h
index cfceb4f..ce7c9d7 100644
--- a/iree/hal/vulkan/vulkan_device.h
+++ b/iree/hal/vulkan/vulkan_device.h
@@ -30,6 +30,7 @@
#include "iree/hal/semaphore.h"
#include "iree/hal/vulkan/descriptor_pool_cache.h"
#include "iree/hal/vulkan/dynamic_symbols.h"
+#include "iree/hal/vulkan/emulated_timeline_semaphore.h"
#include "iree/hal/vulkan/extensibility_util.h"
#include "iree/hal/vulkan/handle_util.h"
@@ -119,6 +120,8 @@
absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues,
ref_ptr<VkCommandPoolHandle> dispatch_command_pool,
ref_ptr<VkCommandPoolHandle> transfer_command_pool,
+ ref_ptr<TimePointSemaphorePool> semaphore_pool,
+ ref_ptr<TimePointFencePool> fence_pool,
DebugCaptureManager* debug_capture_manager);
Status WaitSemaphores(absl::Span<const SemaphoreValue> semaphores,
@@ -139,6 +142,10 @@
ref_ptr<VkCommandPoolHandle> dispatch_command_pool_;
ref_ptr<VkCommandPoolHandle> transfer_command_pool_;
+ // Fields used for emulated timeline semaphores.
+ ref_ptr<TimePointSemaphorePool> semaphore_pool_;
+ ref_ptr<TimePointFencePool> fence_pool_;
+
DebugCaptureManager* debug_capture_manager_ = nullptr;
};
diff --git a/iree/hal/vulkan/vulkan_driver_module.cc b/iree/hal/vulkan/vulkan_driver_module.cc
index 4b98f58..f034127 100644
--- a/iree/hal/vulkan/vulkan_driver_module.cc
+++ b/iree/hal/vulkan/vulkan_driver_module.cc
@@ -67,9 +67,11 @@
// promoted to core, so we list it as optional even though we require it.
options.instance_extensibility.optional_extensions.push_back(
VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
+#if IREE_HAL_VULKAN_EMULATE_TIMELINE_SEMAPHORES == 0
// Timeline semaphore support is required.
options.device_extensibility.required_extensions.push_back(
VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME);
+#endif
if (absl::GetFlag(FLAGS_vulkan_validation_layers)) {
options.instance_extensibility.optional_layers.push_back(
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index c8ea2fb..6f95110 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -48,11 +48,6 @@
"convert.mlir",
"concatenate.mlir",
"constant.mlir",
-
- # TODO(#1687): Enable after casting from fp to int is handled
- # on structured ops path in vulkan
- # "convert.mlir",
- #
"cosine.mlir",
"divide.mlir",
"dot.mlir",
@@ -76,7 +71,8 @@
"reduce_window.mlir",
"remainder.mlir",
"reshape.mlir",
- "reverse.mlir",
+ # TODO(#1699): Enable after xla_hlo.reverse can be lowered to linalg.
+ # "reverse.mlir",
"rsqrt.mlir",
"select.mlir",
"sine.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index 7f65b06..b2eec93 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -56,7 +56,6 @@
"reduce_window.mlir"
"remainder.mlir"
"reshape.mlir"
- "reverse.mlir"
"rsqrt.mlir"
"select.mlir"
"sine.mlir"
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index fce935b..5267dec 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -101,6 +101,7 @@
iree::compiler::Dialect::VM::Tools
LINKOPTS
"-lpthread"
+ HOSTONLY
)
endif()
@@ -253,6 +254,31 @@
PUBLIC
)
+ iree_cc_library(
+ NAME
+ iree_translate_main
+ SRCS
+ "translate_main.cc"
+ DEPS
+ ::init_compiler_modules
+ ::init_iree_passes_and_dialects
+ ::init_mlir_passes_and_dialects
+ ::init_targets
+ ::init_translations
+ ::init_xla_dialects
+ LLVMSupport
+ MLIRIR
+ MLIRSCFTransforms
+ MLIRPass
+ MLIRSupport
+ MLIRTranslation
+ iree::compiler::Conversion::init_conversions
+ iree::compiler::Dialect::VM::Target::Bytecode
+ iree::compiler::Dialect::VM::Target::init_targets
+ iree::compiler::Translation::IREEVM
+ PUBLIC
+ )
+
iree_cc_binary(
NAME
iree-opt
@@ -260,6 +286,7 @@
iree-opt
DEPS
::iree_opt_main
+ HOSTONLY
)
iree_cc_binary(
@@ -303,33 +330,14 @@
iree::vm::bytecode_module
iree::vm::value
${IREE_HAL_DRIVER_MODULES}
+ HOSTONLY
)
+endif(${IREE_BUILD_COMPILER})
- iree_cc_library(
- NAME
- iree_translate_main
- SRCS
- "translate_main.cc"
- DEPS
- ::init_compiler_modules
- ::init_iree_passes_and_dialects
- ::init_mlir_passes_and_dialects
- ::init_targets
- ::init_translations
- ::init_xla_dialects
- LLVMSupport
- MLIRIR
- MLIRSCFTransforms
- MLIRPass
- MLIRSupport
- MLIRTranslation
- iree::compiler::Conversion::init_conversions
- iree::compiler::Dialect::VM::Target::Bytecode
- iree::compiler::Dialect::VM::Target::init_targets
- iree::compiler::Translation::IREEVM
- PUBLIC
- )
-
+# If cross-compiling, we need to declare iree-translate under host configuration
+# unconditionally because we need to run it on host to generate VM modules
+# for tests.
+if(${IREE_BUILD_COMPILER} OR CMAKE_CROSSCOMPILING)
iree_cc_binary(
NAME
iree-translate
@@ -337,6 +345,7 @@
iree-translate
DEPS
::iree_translate_main
+ HOSTONLY
)
endif()
diff --git a/iree/vm/test/BUILD b/iree/vm/test/BUILD
index 28312bc..fb9848d 100644
--- a/iree/vm/test/BUILD
+++ b/iree/vm/test/BUILD
@@ -35,13 +35,11 @@
iree_bytecode_module(
name = "arithmetic_ops",
src = "arithmetic_ops.mlir",
- cc_namespace = "iree::vm::test",
flags = ["-iree-vm-ir-to-bytecode-module"],
)
iree_bytecode_module(
name = "control_flow_ops",
src = "control_flow_ops.mlir",
- cc_namespace = "iree::vm::test",
flags = ["-iree-vm-ir-to-bytecode-module"],
)
diff --git a/iree/vm/test/CMakeLists.txt b/iree/vm/test/CMakeLists.txt
index 7f9a4ae..6c189b6 100644
--- a/iree/vm/test/CMakeLists.txt
+++ b/iree/vm/test/CMakeLists.txt
@@ -35,8 +35,6 @@
arithmetic_ops
SRC
"arithmetic_ops.mlir"
- CC_NAMESPACE
- "iree::vm::test"
FLAGS
"-iree-vm-ir-to-bytecode-module"
PUBLIC
@@ -47,8 +45,6 @@
control_flow_ops
SRC
"control_flow_ops.mlir"
- CC_NAMESPACE
- "iree::vm::test"
FLAGS
"-iree-vm-ir-to-bytecode-module"
PUBLIC
diff --git a/kokoro/gcp_ubuntu/bazel/bindings/continuous.cfg b/kokoro/gcp_ubuntu/bazel/bindings/continuous.cfg
index af84216..50a7eed 100644
--- a/kokoro/gcp_ubuntu/bazel/bindings/continuous.cfg
+++ b/kokoro/gcp_ubuntu/bazel/bindings/continuous.cfg
@@ -15,5 +15,5 @@
# limitations under the License.
# Deliberately blank as everything necessary is configured in common files, but
-# file must still exist to match corresponding (upstream only) job
+# file must still exist to match corresponding (Google internal) job
# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/bindings/google.cfg b/kokoro/gcp_ubuntu/bazel/bindings/google.cfg
new file mode 100644
index 0000000..50a7eed
--- /dev/null
+++ b/kokoro/gcp_ubuntu/bazel/bindings/google.cfg
@@ -0,0 +1,19 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Deliberately blank as everything necessary is configured in common files, but
+# file must still exist to match corresponding (Google internal) job
+# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/bindings/main.cfg b/kokoro/gcp_ubuntu/bazel/bindings/main.cfg
new file mode 100644
index 0000000..50a7eed
--- /dev/null
+++ b/kokoro/gcp_ubuntu/bazel/bindings/main.cfg
@@ -0,0 +1,19 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Deliberately blank as everything necessary is configured in common files, but
+# file must still exist to match corresponding (Google internal) job
+# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/bindings/presubmit.cfg b/kokoro/gcp_ubuntu/bazel/bindings/presubmit.cfg
index af84216..50a7eed 100644
--- a/kokoro/gcp_ubuntu/bazel/bindings/presubmit.cfg
+++ b/kokoro/gcp_ubuntu/bazel/bindings/presubmit.cfg
@@ -15,5 +15,5 @@
# limitations under the License.
# Deliberately blank as everything necessary is configured in common files, but
-# file must still exist to match corresponding (upstream only) job
+# file must still exist to match corresponding (Google internal) job
# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/core/continuous.cfg b/kokoro/gcp_ubuntu/bazel/core/continuous.cfg
index af84216..50a7eed 100755
--- a/kokoro/gcp_ubuntu/bazel/core/continuous.cfg
+++ b/kokoro/gcp_ubuntu/bazel/core/continuous.cfg
@@ -15,5 +15,5 @@
# limitations under the License.
# Deliberately blank as everything necessary is configured in common files, but
-# file must still exist to match corresponding (upstream only) job
+# file must still exist to match corresponding (Google internal) job
# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/core/google.cfg b/kokoro/gcp_ubuntu/bazel/core/google.cfg
new file mode 100755
index 0000000..50a7eed
--- /dev/null
+++ b/kokoro/gcp_ubuntu/bazel/core/google.cfg
@@ -0,0 +1,19 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Deliberately blank as everything necessary is configured in common files, but
+# file must still exist to match corresponding (Google internal) job
+# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/core/main.cfg b/kokoro/gcp_ubuntu/bazel/core/main.cfg
new file mode 100755
index 0000000..50a7eed
--- /dev/null
+++ b/kokoro/gcp_ubuntu/bazel/core/main.cfg
@@ -0,0 +1,19 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Deliberately blank as everything necessary is configured in common files, but
+# file must still exist to match corresponding (Google internal) job
+# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/core/presubmit.cfg b/kokoro/gcp_ubuntu/bazel/core/presubmit.cfg
index af84216..50a7eed 100755
--- a/kokoro/gcp_ubuntu/bazel/core/presubmit.cfg
+++ b/kokoro/gcp_ubuntu/bazel/core/presubmit.cfg
@@ -15,5 +15,5 @@
# limitations under the License.
# Deliberately blank as everything necessary is configured in common files, but
-# file must still exist to match corresponding (upstream only) job
+# file must still exist to match corresponding (Google internal) job
# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/integrations/continuous.cfg b/kokoro/gcp_ubuntu/bazel/integrations/continuous.cfg
index af84216..50a7eed 100644
--- a/kokoro/gcp_ubuntu/bazel/integrations/continuous.cfg
+++ b/kokoro/gcp_ubuntu/bazel/integrations/continuous.cfg
@@ -15,5 +15,5 @@
# limitations under the License.
# Deliberately blank as everything necessary is configured in common files, but
-# file must still exist to match corresponding (upstream only) job
+# file must still exist to match corresponding (Google internal) job
# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/integrations/google.cfg b/kokoro/gcp_ubuntu/bazel/integrations/google.cfg
new file mode 100644
index 0000000..50a7eed
--- /dev/null
+++ b/kokoro/gcp_ubuntu/bazel/integrations/google.cfg
@@ -0,0 +1,19 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Deliberately blank as everything necessary is configured in common files, but
+# file must still exist to match corresponding (Google internal) job
+# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/integrations/main.cfg b/kokoro/gcp_ubuntu/bazel/integrations/main.cfg
new file mode 100644
index 0000000..50a7eed
--- /dev/null
+++ b/kokoro/gcp_ubuntu/bazel/integrations/main.cfg
@@ -0,0 +1,19 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Deliberately blank as everything necessary is configured in common files, but
+# file must still exist to match corresponding (Google internal) job
+# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/bazel/integrations/presubmit.cfg b/kokoro/gcp_ubuntu/bazel/integrations/presubmit.cfg
index af84216..50a7eed 100644
--- a/kokoro/gcp_ubuntu/bazel/integrations/presubmit.cfg
+++ b/kokoro/gcp_ubuntu/bazel/integrations/presubmit.cfg
@@ -15,5 +15,5 @@
# limitations under the License.
# Deliberately blank as everything necessary is configured in common files, but
-# file must still exist to match corresponding (upstream only) job
+# file must still exist to match corresponding (Google internal) job
# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/cmake/continuous.cfg b/kokoro/gcp_ubuntu/cmake/continuous.cfg
index d825b90..e4cc270 100644
--- a/kokoro/gcp_ubuntu/cmake/continuous.cfg
+++ b/kokoro/gcp_ubuntu/cmake/continuous.cfg
@@ -15,5 +15,5 @@
# limitations under the License.
# Deliberately blank as everything necessary is configured in common files, but
-# file must still exist to match corresponding (upstream only) job
+# file must still exist to match corresponding (Google internal) job
# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/cmake/google.cfg b/kokoro/gcp_ubuntu/cmake/google.cfg
new file mode 100644
index 0000000..e4cc270
--- /dev/null
+++ b/kokoro/gcp_ubuntu/cmake/google.cfg
@@ -0,0 +1,19 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Deliberately blank as everything necessary is configured in common files, but
+# file must still exist to match corresponding (Google internal) job
+# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/cmake/main.cfg b/kokoro/gcp_ubuntu/cmake/main.cfg
new file mode 100644
index 0000000..e4cc270
--- /dev/null
+++ b/kokoro/gcp_ubuntu/cmake/main.cfg
@@ -0,0 +1,19 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Deliberately blank as everything necessary is configured in common files, but
+# file must still exist to match corresponding (Google internal) job
+# configurations that trigger the builds.
diff --git a/kokoro/gcp_ubuntu/cmake/presubmit.cfg b/kokoro/gcp_ubuntu/cmake/presubmit.cfg
index d825b90..e4cc270 100644
--- a/kokoro/gcp_ubuntu/cmake/presubmit.cfg
+++ b/kokoro/gcp_ubuntu/cmake/presubmit.cfg
@@ -15,5 +15,5 @@
# limitations under the License.
# Deliberately blank as everything necessary is configured in common files, but
-# file must still exist to match corresponding (upstream only) job
+# file must still exist to match corresponding (Google internal) job
# configurations that trigger the builds.
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 7e825ab..9fb7e98 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 7e825abd5704ce28b166f9463d4bd304348fd2a9
+Subproject commit 9fb7e98db5aaef617878a127b663efa4d01aa834
diff --git a/third_party/tensorflow b/third_party/tensorflow
index b00a780..f74654a 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit b00a7808a7b29a78762b54e29aac87a77254b4b6
+Subproject commit f74654ac7b314a212b1df6687c2f99800084e97f