Merge https://github.com/google/iree into build-with-cmake
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index ad9635a..d5d5bb0 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -59,9 +59,10 @@
run: |
git diff -U0 "${BASE_REF?}" | python3 third_party/format_diff/format_diff.py yapf -i
git diff --exit-code
- - name: Instructions for fixing these linting errors
+ - name: Instructions for fixing the above linting errors
+ if: ${{ failure() }}
run: |
- printf "If the lint above failed it can be fixed by running\n"
+ printf "You can fix the lint errors above by running\n"
printf " git diff -U0 "${BASE_REF?}" | python3 third_party/format_diff/format_diff.py yapf -i\n"
clang-format:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 12a72fe..b2c30b7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -96,8 +96,7 @@
set( IREE_HAL_DRIVERS_TO_BUILD ${IREE_ALL_HAL_DRIVERS} )
# For cross compilation towords Android, we don't want LLVM JIT HAL driver.
if(ANDROID)
- # TODO(ataei): Enable dylib/dylib-llvm-aot for android.
- list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM DyLib)
+ list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM)
endif()
# For Apple platforms we need to use Metal instead of Vulkan.
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 2aac894..e9e0de9 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -5,7 +5,7 @@
a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-51ff04567b2f8d06b2062bd3ed72eab2e93e4466 third_party/llvm-project
+f402e682d0ef5598eeffc9a21a691b03e602ff58 third_party/llvm-project
17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
d2cdb70e038370b5e28f353fe98ccd70af1cbc25 third_party/mlir-emitc
d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
@@ -14,7 +14,7 @@
685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
57eb48aed36160c4876bc8310d9ca84d42ee9e2a third_party/swiftshader
-74747f0c6b017df8fe87d1912f8eb4e2e287fbda third_party/tensorflow
+2d6bdab3adb0b8949763d5c63426338f938c9efe third_party/tensorflow
a9a09ab0940408898fccfdcfe2bb8dc19b50f13c third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
909f36b714c9239ee0b112a321220213a474ba53 third_party/vulkan_memory_allocator
diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc
index 46922a4..0f246e7 100644
--- a/build_tools/bazel/iree.bazelrc
+++ b/build_tools/bazel/iree.bazelrc
@@ -298,10 +298,10 @@
# Another TensorFlow flag from their config script.
build:windows --define with_default_optimizations=true
-# TensorFlow builds depend on these flags so we will as well
+# TensorFlow builds depend on this flag, but it doesn't appear to work with
+# gmock in some of our unit tests, so only enable it for TensorFlow files.
# MSVC (Windows): Standards-conformant preprocessor mode
# See https://docs.microsoft.com/en-us/cpp/preprocessor/preprocessor-experimental-overview
-build:windows --copt=/experimental:preprocessor
-build:windows --host_copt=/experimental:preprocessor
+build:windows --per_file_copt=tensorflow@/experimental:preprocessor
###############################################################################
diff --git a/build_tools/bazel/run_binary_test.bzl b/build_tools/bazel/run_binary_test.bzl
index 79536f3..aa165ea 100644
--- a/build_tools/bazel/run_binary_test.bzl
+++ b/build_tools/bazel/run_binary_test.bzl
@@ -32,12 +32,10 @@
"""
def _run_binary_test_impl(ctx):
- ctx.actions.run_shell(
- inputs = [ctx.file.test_binary],
- outputs = [ctx.outputs.executable],
- command = "cp $1 $2",
- arguments = [ctx.file.test_binary.path, ctx.outputs.executable.path],
- mnemonic = "CopyExecutable",
+ ctx.actions.symlink(
+ target_file = ctx.executable.test_binary,
+ output = ctx.outputs.executable,
+ is_executable = True,
)
data_runfiles = ctx.runfiles(files = ctx.files.data)
@@ -54,7 +52,8 @@
attrs = {
"test_binary": attr.label(
mandatory = True,
- allow_single_file = True,
+ executable = True,
+ cfg = "target",
),
"data": attr.label_list(allow_empty = True, allow_files = True),
},
diff --git a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
index 4aeb2b8..d129c47 100644
--- a/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
+++ b/build_tools/bazel/third_party_import/llvm-project/overlay/mlir/BUILD.bazel
@@ -458,6 +458,7 @@
srcs = [
"include/mlir/Dialect/StandardOps/IR/Ops.td",
"include/mlir/IR/OpAsmInterface.td",
+ "include/mlir/IR/SymbolInterfaces.td",
"include/mlir/Interfaces/CallInterfaces.td",
"include/mlir/Interfaces/ControlFlowInterfaces.td",
"include/mlir/Interfaces/SideEffectInterfaces.td",
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index 9d99d9b..dd3b3c0 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -38,10 +38,6 @@
${PROJECT_BINARY_DIR}
)
-if(${IREE_ENABLE_RUNTIME_TRACING})
- set (CMAKE_EXE_LINKER_FLAGS -ldl)
-endif()
-
iree_select_compiler_opts(IREE_DEFAULT_COPTS
CLANG
# LINT.IfChange(clang_diagnostics)
diff --git a/docs/developing_iree/e2e_benchmarking.md b/docs/developing_iree/e2e_benchmarking.md
index 3fe8604..96523c4 100644
--- a/docs/developing_iree/e2e_benchmarking.md
+++ b/docs/developing_iree/e2e_benchmarking.md
@@ -178,12 +178,19 @@
# Enter the TensorFlow Bazel workspace.
cd third_party/tensorflow/
-# Build the benchmark_model binary without RUY...
+# Build the benchmark_model binary.
bazel build --copt=-mavx2 -c opt \
//tensorflow/lite/tools/benchmark:benchmark_model
-# ...or build the benchmark_model binary with RUY. This will overwrite the
+# By default, TFLite/x86 uses various matrix multiplication libraries.
+# It is possible to force it to only use Ruy for all matrix multiplications.
+# That is the default on ARM but not on x86. This will overwrite the
# previous binary unless you move it.
+#
+# Note that Ruy takes care of -mavx2 and other AVX extensions internally,
+# so this passing this flag here isn't going to make a difference to
+# matrix multiplications. However, the rest of TFLite's kernels outside
+# of ruy will still benefit from -mavx2.
bazel build --copt=-mavx2 -c opt \
--define=tflite_with_ruy=true \
//tensorflow/lite/tools/benchmark:benchmark_model
@@ -274,6 +281,9 @@
```shell
# Build the benchmark_model binary without any add-ons.
+# Note that unlike TFLite/x86, TFLite/ARM uses Ruy by default for all
+# matrix multiplications (No need to pass tflite_with_ruy), except for some
+# matrix*vector products. Below we show how to force using ruy also for that.
bazel build -c opt \
--config=android_arm64 \
--cxxopt='--std=c++17' \
@@ -286,20 +296,21 @@
```
```shell
-# Build the benchmark_model binary with ruy.
-bazel build --copt=-mavx2 -c opt \
+# Build the benchmark_model binary using ruy even for matrix*vector
+# products. This is only worth trying in models that are heavy on matrix*vector
+# shapes, typically LSTMs and other RNNs.
+bazel build -c opt \
--config=android_arm64 \
--cxxopt='--std=c++17' \
- --define=tflite_with_ruy=true \
--copt=-DTFLITE_WITH_RUY_GEMV \
//tensorflow/lite/tools/benchmark:benchmark_model
# Rename the binary for comparison with the standard benchmark_model.
mv bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model \
- bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy
-adb push bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy \
+ bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy_gemv
+adb push bazel-bin/tensorflow/lite/tools/benchmark/benchmark_model_plus_ruy_gemv \
/data/local/tmp/
-adb shell chmod +x /data/local/tmp/benchmark_model_plus_ruy
+adb shell chmod +x /data/local/tmp/benchmark_model_plus_ruy_gemv
```
```shell
@@ -336,17 +347,15 @@
--warmup_runs=1 \
--num_threads=1 \
--num_runs=10 \
- --enable_op_profiling=true
```
```shell
-# Benchmark with TFLite + RUY.
-adb shell taskset f0 /data/local/tmp/benchmark_model_plus_ruy \
+# Benchmark with TFLite + RUY GEMV
+adb shell taskset f0 /data/local/tmp/benchmark_model_plus_ruy_gemv \
--graph=/data/local/tmp/MatrixOpsStaticModule/tflite/matmul_lhs_batch.tflite \
--warmup_runs=1 \
--num_threads=1 \
--num_runs=10 \
- --enable_op_profiling=true
```
```shell
@@ -356,7 +365,6 @@
--warmup_runs=1 \
--num_threads=1 \
--num_runs=10 \
- --enable_op_profiling=true
```
```shell
@@ -366,7 +374,6 @@
--warmup_runs=1 \
--num_threads=1 \
--num_runs=10 \
- --enable_op_profiling=true \
--use_gpu=true
```
@@ -384,3 +391,30 @@
name of the `.tflite` graph that you need to benchmark _may_ be different from
the name of the trace that you want to benchmark, but you can use `cat` on
the `graph_path` file to verify the correct `.tflite` filename if you're unsure.
+
+### Profile
+
+There are 2 profilers built into TFLite's `benchmark_model` program. Both of them impact latencies, so they should only be used to get a breakdown of the relative time spent in each operator type, they should not be enabled for the purpose of measuring a latency.
+
+The first is `enable_op_profiling`. It's based on timestamps before and after each op. It's a runtime commandline flag taken by `benchmark_model`. Example:
+
+```
+adb shell taskset f0 /data/local/tmp/benchmark_model \
+ --graph=/data/local/tmp/MatrixOpsStaticModule/tflite/matmul_lhs_batch.tflite \
+ --warmup_runs=1 \
+ --num_threads=1 \
+ --num_runs=10 \
+ --enable_op_profiling=true
+```
+
+The second is `ruy_profiler`. Despite its name, it's available regardless of whether `ruy` is used for the matrix multiplications. It's a sampling profiler, which allows it to provide some more detailed informations, particularly on matrix multiplications. It's a build-time switch:
+
+```
+blaze build \
+ --define=ruy_profiler=true \
+ -c opt \
+ --config=android_arm64 \
+ //tensorflow/lite/tools/benchmark:benchmark_model
+```
+
+The binary thus built can be run like above, no commandline flag needed.
diff --git a/docs/get_started/getting_started_android_cmake.md b/docs/get_started/getting_started_android_cmake.md
index ba78b21..e740b27 100644
--- a/docs/get_started/getting_started_android_cmake.md
+++ b/docs/get_started/getting_started_android_cmake.md
@@ -263,16 +263,16 @@
### Dylib LLVM AOT backend
To compile IREE module for the target Android device (assume Android 10 AArc64)
-we need install the corresponding standalone toolchain and setting AOT linker
-path environment variable:
+we need to use the corresponding standalone toolchain (which can be found in
+ANDROID_NDK) and setting AOT linker path environment variable:
```shell
-$ export ANDROID_ARM64_TOOLCHAIN=/path/to/install/the/toolchain
-$ "${ANDROID_NDK?}/build/tools/make-standalone-toolchain.sh" --arch=arm64 --platform=android-29 \
- --install-dir="${ANDROID_ARM64_TOOLCHAIN?}"
-$ export IREE_LLVMAOT_LINKER_PATH="${ANDROID_ARM64_TOOLCHAIN?}/aarch64-linux-android/bin/ld"
+$ export IREE_LLVMAOT_LINKER_PATH="${ANDROID_NDK?}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android29-clang++ -static-libstdc++ -O3"
```
+`-static-libstdc++` is needed because some dynamic libraries would not be able
+to open.
+
Translate a source MLIR into an IREE module:
```shell
diff --git a/iree/base/BUILD b/iree/base/BUILD
index 2d91662..a1b46ba 100644
--- a/iree/base/BUILD
+++ b/iree/base/BUILD
@@ -395,7 +395,12 @@
name = "synchronization",
srcs = ["synchronization.c"],
hdrs = ["synchronization.h"],
- linkopts = ["-lpthread"],
+ linkopts = select({
+ "//iree:iree_is_msvc": [],
+ "//conditions:default": [
+ "-lpthread",
+ ],
+ }),
deps = [
":api",
":atomics",
@@ -460,10 +465,13 @@
copts = [
"-D_GNU_SOURCE=1",
],
- linkopts = [
- "-ldl",
- "-lpthread",
- ],
+ linkopts = select({
+ "//iree:iree_is_msvc": [],
+ "//conditions:default": [
+ "-ldl",
+ "-lpthread",
+ ],
+ }),
deps = [
":api",
":atomics",
diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt
index 4e6b7ce..257618b 100644
--- a/iree/base/CMakeLists.txt
+++ b/iree/base/CMakeLists.txt
@@ -595,6 +595,11 @@
iree::testing::gtest_main
)
+iree_select_compiler_opts(IREE_LINKOPTS_TRACING
+ GCC_OR_CLANG
+ -ldl
+)
+
if(${IREE_ENABLE_RUNTIME_TRACING})
iree_cc_library(
NAME
@@ -605,6 +610,8 @@
"${IREE_ROOT_DIR}/third_party/tracy/TracyC.h"
SRCS
"tracing.cc"
+ LINKOPTS
+ ${IREE_LINKOPTS_TRACING}
DEPS
absl::core_headers
DEFINES
diff --git a/iree/build_defs.oss.bzl b/iree/build_defs.oss.bzl
index 5364f6b..4d359e7 100644
--- a/iree/build_defs.oss.bzl
+++ b/iree/build_defs.oss.bzl
@@ -85,11 +85,14 @@
def cc_binary(linkopts = [], **kwargs):
"""Wrapper around low-level cc_binary that adds flags."""
_cc_binary(
- linkopts = linkopts + [
- # Just include libraries that should be presumed in 2020.
- "-ldl",
- "-lpthread",
- ],
+ linkopts = linkopts + select({
+ "//iree:iree_is_msvc": [],
+ "//conditions:default": [
+ # Just include libraries that should be presumed in 2020.
+ "-ldl",
+ "-lpthread",
+ ],
+ }),
**kwargs
)
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/pad.mlir b/iree/compiler/Conversion/HLOToLinalg/test/pad.mlir
index a42a079..29c700b 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/pad.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/pad.mlir
@@ -76,14 +76,10 @@
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
}
}
-// TODO(hanchung): Make it just a copy op.
// CHECK_LABEL: @pad_no_op
-// CHECK-DAG: %[[CST:.+]] = constant 0.000000e+00 : f32
// CHECK-DAG: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<12x4xf32>
// CHECK-DAG: %[[IN:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<12x4xf32>
-// CHECK: linalg.fill(%[[OUT]], %[[CST]])
-// CHECK: %[[SUBVIEW:.+]] = subview %[[OUT]][0, 0] [12, 4] [1, 1]
-// CHECK: linalg.copy(%[[IN]], %[[SUBVIEW]])
+// CHECK: linalg.copy(%[[IN]], %[[OUT]])
// -----
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index 899a485..93b8809 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -204,7 +204,7 @@
if (vecType.getRank() != 2) return failure();
// TODO(thomasraoux): use coloumn major operand when TransfertRead +
// TransposeOp.
- if (!op.permutation_map().isIdentity()) return failure();
+ if (!op.permutation_map().isMinorIdentity()) return failure();
if (op.masked() &&
llvm::any_of(op.masked()->template cast<ArrayAttr>(),
[](mlir::Attribute maskedDim) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
index 52aeff7..ac83d70 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir
@@ -104,3 +104,41 @@
return
}
}
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+module attributes {gpu.container_module, spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader, CooperativeMatrixNV, Int8, Float16, StorageUniform16, StorageBuffer8BitAccess, Float16Buffer], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix, SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @kernel_matmul_vector_memref(%arg0: memref<4096x256xvector<4xi32>>, %arg1: memref<4096x256xvector<4xi32>>, %arg2: memref<4096x1024xvector<4xi32>>) attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+ %c32 = constant 32 : index
+ %c4096 = constant 4096 : index
+ %c0 = constant 0 : index
+ %cst = constant dense<0> : vector<4xi32>
+ // CHECK: %[[C:.+]] = spv.CooperativeMatrixLoadNV %{{.*}}, %{{.*}}, %{{.*}}
+ %4 = vector.transfer_read %arg2[%c0, %c0], %cst : memref<4096x1024xvector<4xi32>>, vector<16x16xi32>
+ // CHECK: %[[ACC:.+]] = spv.Variable : !spv.ptr<!spv.coopmatrix<16x16xi32, Subgroup>, Function>
+ // CHECK: spv.loop {
+ // CHECK: spv.Branch ^[[BB:.+]](%{{.*}}, %[[C]] : i32, !spv.coopmatrix<16x16xi32, Subgroup>)
+ // CHECK: ^[[BB]](%{{.*}}: i32, %[[C1:.+]]: !spv.coopmatrix<16x16xi32, Subgroup>)
+ %5 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %4) -> (vector<16x16xi32>) {
+ // CHECK: %[[A:.+]] = spv.CooperativeMatrixLoadNV %{{.*}}, %{{.*}}, %{{.*}}
+ %6 = vector.transfer_read %arg0[%c0, %arg3], %cst : memref<4096x256xvector<4xi32>>, vector<16x32xi8>
+ // CHECK: %[[B:.+]] = spv.CooperativeMatrixLoadNV %{{.*}}, %{{.*}}, %{{.*}}
+ %7 = vector.transfer_read %arg1[%arg3, %c0], %cst : memref<4096x256xvector<4xi32>>, vector<32x16xi8>
+ // CHECK: %[[R:.+]] = spv.CooperativeMatrixMulAddNV %[[A]], %[[B]], %[[C1]]
+ %8 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %6, %7, %arg4 : vector<16x32xi8>, vector<32x16xi8> into vector<16x16xi32>
+ // CHECK: spv.Store "Function" %[[ACC]], %[[R]] : !spv.coopmatrix<16x16xi32, Subgroup>
+ // CHECK: spv.Branch ^[[BB]](%{{.*}}, %[[R]] : i32, !spv.coopmatrix<16x16xi32, Subgroup>)
+ scf.yield %8 : vector<16x16xi32>
+ }
+ // CHECK: %[[ACCv:.+]] = spv.Load "Function" %[[ACC]] : !spv.coopmatrix<16x16xi32, Subgroup>
+ // CHECK: spv.CooperativeMatrixStoreNV %{{.*}}, %[[ACCv]], %{{.*}}, %{{.*}}
+ vector.transfer_write %5, %arg2[%c0, %c0] : vector<16x16xi32>, memref<4096x1024xvector<4xi32>>
+ return
+ }
+}
+
diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
index 2f5abb7..5d89aa3 100644
--- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
+++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.cpp
@@ -14,6 +14,9 @@
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/SPIRVToMSL.h"
+#include <vector>
+
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "spirv_cross/spirv_msl.hpp"
@@ -37,6 +40,66 @@
if (workgroupSize.constant != 0) return {0, 0, 0};
return {workgroupSize.x, workgroupSize.y, workgroupSize.z};
}
+
+ // A struct containing a resource descriptor's information.
+ struct Descriptor {
+ uint32_t set;
+ uint32_t binding;
+
+ Descriptor(uint32_t s, uint32_t b) : set(s), binding(b) {}
+
+ friend bool operator<(const Descriptor& l, const Descriptor& r) {
+ return std::tie(l.set, l.binding) < std::tie(r.set, r.binding);
+ }
+ };
+
+ // Returns all all resource buffer descriptors' set and binding number pairs
+ // in increasing order.
+ std::vector<Descriptor> getBufferSetBindingPairs() {
+ std::vector<Descriptor> descriptors;
+
+ // Iterate over all variables in the SPIR-V blob.
+ ir.for_each_typed_id<SPIRV_CROSS_NAMESPACE::SPIRVariable>(
+ [&](uint32_t id, SPIRV_CROSS_NAMESPACE::SPIRVariable& var) {
+ auto storage = var.storage;
+ switch (storage) {
+ // Non-interface variables. We don't care.
+ case spv::StorageClassFunction:
+ case spv::StorageClassPrivate:
+ case spv::StorageClassWorkgroup:
+ // Builtin variables. We don't care either.
+ case spv::StorageClassInput:
+ return;
+ default:
+ break;
+ }
+ if (storage == spv::StorageClassUniform ||
+ storage == spv::StorageClassStorageBuffer) {
+ uint32_t setNo = get_decoration(id, spv::DecorationDescriptorSet);
+ uint32_t bindingNo = get_decoration(id, spv::DecorationBinding);
+ descriptors.emplace_back(setNo, bindingNo);
+ return;
+ }
+ // TODO(antiagainst): push constant
+ llvm_unreachable("unspported storage class in SPIRVToMSLCompiler");
+ });
+
+ llvm::sort(descriptors);
+ return descriptors;
+ }
+
+ Options getCompilationOptions() {
+ // TODO(antiagainst): fill out the following according to the Metal GPU
+ // family.
+ SPIRVToMSLCompiler::Options spvCrossOptions;
+ spvCrossOptions.platform = SPIRVToMSLCompiler::Options::Platform::macOS;
+ spvCrossOptions.msl_version =
+ SPIRVToMSLCompiler::Options::make_msl_version(2, 0);
+ // Eanble using Metal argument buffers. It is more akin to Vulkan descriptor
+ // sets, which is how IREE HAL models resource bindings and mappings.
+ spvCrossOptions.argument_buffers = true;
+ return spvCrossOptions;
+ }
};
} // namespace
@@ -49,15 +112,25 @@
spvCrossCompiler.set_entry_point(
entryPoint, spv::ExecutionModel::ExecutionModelGLCompute);
- // TODO(antiagainst): fill out the following according to the Metal GPU
- // family.
- SPIRVToMSLCompiler::Options spvCrossOptions;
- spvCrossOptions.platform = SPIRVToMSLCompiler::Options::Platform::macOS;
- spvCrossOptions.msl_version =
- SPIRVToMSLCompiler::Options::make_msl_version(2, 0);
- // Eanble using Metal argument buffers. It is more akin to Vulkan descriptor
- // sets, which is how IREE HAL models resource bindings and mappings.
- spvCrossOptions.argument_buffers = true;
+ // Explicitly set the argument buffer index for each SPIR-V resource variable.
+ auto descriptors = spvCrossCompiler.getBufferSetBindingPairs();
+ for (const auto& descriptor : descriptors) {
+ if (descriptor.set != 0) {
+ llvm_unreachable(
+ "multiple descriptor set unimplemented in SPIRVToMSLCompiler");
+ }
+
+ SPIRV_CROSS_NAMESPACE::MSLResourceBinding binding = {};
+ binding.stage = spv::ExecutionModelGLCompute;
+ binding.desc_set = descriptor.set;
+ binding.binding = descriptor.binding;
+ // We only interact with buffers in IREE.
+ binding.msl_buffer = descriptor.binding;
+
+ spvCrossCompiler.add_msl_resource_binding(binding);
+ }
+
+ auto spvCrossOptions = spvCrossCompiler.getCompilationOptions();
spvCrossCompiler.set_msl_options(spvCrossOptions);
std::string mslSource = spvCrossCompiler.compile();
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index f60833f..a3c9060 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -37,24 +37,34 @@
"abs.mlir"
"add.mlir"
"broadcast.mlir"
+ "broadcast_add.mlir"
"broadcast_in_dim.mlir"
- "convert.mlir"
+ "clamp.mlir"
+ "compare.mlir"
"constant.mlir"
+ "convert.mlir"
"cosine.mlir"
+ "divide.mlir"
"exponential.mlir"
+ "gather.mlir"
"log.mlir"
"log_plus_one.mlir"
"maximum.mlir"
"minimum.mlir"
"multiply.mlir"
"negate.mlir"
+ "remainder.mlir"
"reshape.mlir"
"rsqrt.mlir"
+ "select.mlir"
"sine.mlir"
"slice.mlir"
"sqrt.mlir"
+ "subtract.mlir"
"tanh.mlir"
+ "torch_index_select.mlir"
"transpose.mlir"
+ "while.mlir"
TARGET_BACKEND
metal-spirv
DRIVER
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index 7d8a568..1f89648 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -19,24 +19,34 @@
"abs.mlir"
"add.mlir"
"broadcast.mlir"
+ "broadcast_add.mlir"
"broadcast_in_dim.mlir"
- "convert.mlir"
+ "clamp.mlir"
+ "compare.mlir"
"constant.mlir"
+ "convert.mlir"
"cosine.mlir"
+ "divide.mlir"
"exponential.mlir"
+ "gather.mlir"
"log.mlir"
"log_plus_one.mlir"
"maximum.mlir"
"minimum.mlir"
"multiply.mlir"
"negate.mlir"
+ "remainder.mlir"
"reshape.mlir"
"rsqrt.mlir"
+ "select.mlir"
"sine.mlir"
"slice.mlir"
"sqrt.mlir"
+ "subtract.mlir"
"tanh.mlir"
+ "torch_index_select.mlir"
"transpose.mlir"
+ "while.mlir"
TARGET_BACKEND
metal-spirv
DRIVER
diff --git a/scripts/prepare_doc_publication.py b/scripts/prepare_doc_publication.py
index 716e31a..352bcf0 100755
--- a/scripts/prepare_doc_publication.py
+++ b/scripts/prepare_doc_publication.py
@@ -86,7 +86,7 @@
# 'Getting Started' is 3.
# 'Developing IREE' is 4.
'design_roadmap.md': 5,
- 'milestones.md': 6,
+ 'objectives.md': 6,
'xla_op_coverage.md': 7,
'tf_e2e_coverage.md': 8,
'iree_community.md': 9,
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 51ff045..f402e68 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 51ff04567b2f8d06b2062bd3ed72eab2e93e4466
+Subproject commit f402e682d0ef5598eeffc9a21a691b03e602ff58
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 74747f0..2d6bdab 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 74747f0c6b017df8fe87d1912f8eb4e2e287fbda
+Subproject commit 2d6bdab3adb0b8949763d5c63426338f938c9efe